gtests/ssl_gtest/ssl_record_unittest.cc
author Martin Thomson <martin.thomson@gmail.com>
Wed, 03 Jan 2018 15:36:18 +1100
changeset 14254 74e4cd34e49fa72806ed7405926591d90f4a2d2a
parent 13485 19a655263f2c349d655a48d879aa0af89f821976
child 14255 54a2412cbcc3c1016fa3cd0cf1fbcb52cb116ebc
permissions -rw-r--r--
Bug 1427675 - Add TlsAgent argument to TlsRecordFilter, r=ekr This is a fairly disruptive change, but mostly just mechanical. There are a few extra changes: - I have renamed the TlsInspector* filters for consistency. This was purely mechanical. - I renamed the SetPacketFilter function to just SetFilter. Also mechanical. - TlsRecordFilter maintains a weak pointer reference to the TlsAgent now rather than using a bare pointer. This meant that I had to change TlsAgentTestBase to use shared_ptr rather than unique_ptr to support of use of filters with those tests. - I removed the helper function that enables decryption. Enabling decryption is now more explicit. - I ran a newer clang-format version and it fixed a few extra things, like the comments on the end of namespace {} blocks, some of which were wrong. - I discovered a bug in some of the drop tests: in the 0-RTT tests, the filters were being installed on the client and server right at the start, which meant that they were capturing the first handshake and not the second one. This was clearly against intent, but the tests were mostly right still, it was only the expected ACKs that were wrong. We were expecting just one record to be ACKed by a server (Finished), but the record with EndOfEarlyData should have been acknowledged as well. - In TlsSkipTest and Tls13SkipTest, I had to override SetUp() so that client_ and server_ are initialized prior to constructing filters. In doing so, I noticed that we weren't being consistent about overriding SetUp properly, so I fixed the small number of instances of that by adding an override label to each and marking the base method virtual. - The stateless HRR test for TLS 1.3 compat mode was replacing the server, but expecting to retain the same filters. That wasn't a problem in that case, but I didn't want to have any places where the filter was set on a different agent from the one that was passed to it.

/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
/* vim: set ts=2 et sw=2 tw=80: */
/* This Source Code Form is subject to the terms of the Mozilla Public
 * License, v. 2.0. If a copy of the MPL was not distributed with this file,
 * You can obtain one at http://mozilla.org/MPL/2.0/. */

#include "nss.h"
#include "ssl.h"
#include "sslimpl.h"

#include "databuffer.h"
#include "gtest_utils.h"
#include "tls_connect.h"
#include "tls_filter.h"

namespace nss_test {

const static size_t kMacSize = 20;

class TlsPaddingTest
    : public ::testing::Test,
      public ::testing::WithParamInterface<std::tuple<size_t, bool>> {
 public:
  TlsPaddingTest() : plaintext_len_(std::get<0>(GetParam())) {
    size_t extra =
        (plaintext_len_ + 1) % 16;  // Bytes past a block (1 == pad len)
    // Minimal padding.
    pad_len_ = extra ? 16 - extra : 0;
    if (std::get<1>(GetParam())) {
      // Maximal padding.
      pad_len_ += 240;
    }
    MakePaddedPlaintext();
  }

  // Makes a plaintext record with correct padding.
  void MakePaddedPlaintext() {
    EXPECT_EQ(0UL, (plaintext_len_ + pad_len_ + 1) % 16);
    size_t i = 0;
    plaintext_.Allocate(plaintext_len_ + pad_len_ + 1);
    for (; i < plaintext_len_; ++i) {
      plaintext_.Write(i, 'A', 1);
    }

    for (; i < plaintext_len_ + pad_len_ + 1; ++i) {
      plaintext_.Write(i, pad_len_, 1);
    }
  }

  void Unpad(bool expect_success) {
    std::cerr << "Content length=" << plaintext_len_
              << " padding length=" << pad_len_
              << " total length=" << plaintext_.len() << std::endl;
    std::cerr << "Plaintext: " << plaintext_ << std::endl;
    sslBuffer s;
    s.buf = const_cast<unsigned char*>(
        static_cast<const unsigned char*>(plaintext_.data()));
    s.len = plaintext_.len();
    SECStatus rv = ssl_RemoveTLSCBCPadding(&s, kMacSize);
    if (expect_success) {
      EXPECT_EQ(SECSuccess, rv);
      EXPECT_EQ(plaintext_len_, static_cast<size_t>(s.len));
    } else {
      EXPECT_EQ(SECFailure, rv);
    }
  }

 protected:
  size_t plaintext_len_;
  size_t pad_len_;
  DataBuffer plaintext_;
};

TEST_P(TlsPaddingTest, Correct) {
  if (plaintext_len_ >= kMacSize) {
    Unpad(true);
  } else {
    Unpad(false);
  }
}

TEST_P(TlsPaddingTest, PadTooLong) {
  if (plaintext_.len() < 255) {
    plaintext_.Write(plaintext_.len() - 1, plaintext_.len(), 1);
    Unpad(false);
  }
}

TEST_P(TlsPaddingTest, FirstByteOfPadWrong) {
  if (pad_len_) {
    plaintext_.Write(plaintext_len_, plaintext_.data()[plaintext_len_] + 1, 1);
    Unpad(false);
  }
}

TEST_P(TlsPaddingTest, LastByteOfPadWrong) {
  if (pad_len_) {
    plaintext_.Write(plaintext_.len() - 2,
                     plaintext_.data()[plaintext_.len() - 1] + 1, 1);
    Unpad(false);
  }
}

class RecordReplacer : public TlsRecordFilter {
 public:
  RecordReplacer(const std::shared_ptr<TlsAgent>& agent, size_t size)
      : TlsRecordFilter(agent), enabled_(false), size_(size) {}

  PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
                                    const DataBuffer& data,
                                    DataBuffer* changed) override {
    if (!enabled_) {
      return KEEP;
    }

    EXPECT_EQ(kTlsApplicationDataType, header.content_type());
    changed->Allocate(size_);

    for (size_t i = 0; i < size_; ++i) {
      changed->data()[i] = i & 0xff;
    }

    enabled_ = false;
    return CHANGE;
  }

  void Enable() { enabled_ = true; }

 private:
  bool enabled_;
  size_t size_;
};

TEST_F(TlsConnectStreamTls13, LargeRecord) {
  EnsureTlsSetup();

  const size_t record_limit = 16384;
  auto replacer = std::make_shared<RecordReplacer>(client_, record_limit);
  replacer->EnableDecryption();
  client_->SetFilter(replacer);
  Connect();

  replacer->Enable();
  client_->SendData(10);
  WAIT_(server_->received_bytes() == record_limit, 2000);
  ASSERT_EQ(record_limit, server_->received_bytes());
}

TEST_F(TlsConnectStreamTls13, TooLargeRecord) {
  EnsureTlsSetup();

  const size_t record_limit = 16384;
  auto replacer = std::make_shared<RecordReplacer>(client_, record_limit + 1);
  replacer->EnableDecryption();
  client_->SetFilter(replacer);
  Connect();

  replacer->Enable();
  ExpectAlert(server_, kTlsAlertRecordOverflow);
  client_->SendData(10);  // This is expanded.

  uint8_t buf[record_limit + 2];
  PRInt32 rv = PR_Read(server_->ssl_fd(), buf, sizeof(buf));
  EXPECT_GT(0, rv);
  EXPECT_EQ(SSL_ERROR_RX_RECORD_TOO_LONG, PORT_GetError());

  // Read the server alert.
  rv = PR_Read(client_->ssl_fd(), buf, sizeof(buf));
  EXPECT_GT(0, rv);
  EXPECT_EQ(SSL_ERROR_RECORD_OVERFLOW_ALERT, PORT_GetError());
}

const static size_t kContentSizesArr[] = {
    1, kMacSize - 1, kMacSize, 30, 31, 32, 36, 256, 257, 287, 288};

auto kContentSizes = ::testing::ValuesIn(kContentSizesArr);
const static bool kTrueFalseArr[] = {true, false};
auto kTrueFalse = ::testing::ValuesIn(kTrueFalseArr);

INSTANTIATE_TEST_CASE_P(TlsPadding, TlsPaddingTest,
                        ::testing::Combine(kContentSizes, kTrueFalse));
}  // namespace nss_test