Bug 1348720 - Move TlsAlertRecorder to some specific alert tests, r=ttaubert
authorMartin Thomson <martin.thomson@gmail.com>
Thu, 23 Mar 2017 14:08:36 +1100
changeset 13243 d8efc3d7a72f7080c9a86bc678a096f559aea2a4
parent 13242 85e168b055dbd4130dfacc218374b808f063c4fa
child 13244 73eac4c4656c4850d8ce0bed21fea48b2f61b0fa
push id2111
push usermartin.thomson@gmail.com
push dateThu, 23 Mar 2017 11:00:50 +0000
reviewersttaubert
bugs1348720
Bug 1348720 - Move TlsAlertRecorder to some specific alert tests, r=ttaubert
gtests/ssl_gtest/ssl_gather_unittest.cc
gtests/ssl_gtest/ssl_loopback_unittest.cc
gtests/ssl_gtest/ssl_skip_unittest.cc
gtests/ssl_gtest/ssl_staticrsa_unittest.cc
gtests/ssl_gtest/tls_filter.cc
gtests/ssl_gtest/tls_filter.h
--- a/gtests/ssl_gtest/ssl_gather_unittest.cc
+++ b/gtests/ssl_gtest/ssl_gather_unittest.cc
@@ -11,27 +11,21 @@ namespace nss_test {
 
 class GatherV2ClientHelloTest : public TlsConnectTestBase {
  public:
   GatherV2ClientHelloTest() : TlsConnectTestBase(STREAM, 0) {}
 
   void ConnectExpectMalformedClientHello(const DataBuffer &data) {
     EnsureTlsSetup();
     server_->ExpectSendAlert(kTlsAlertIllegalParameter);
-    auto alert_recorder = std::make_shared<TlsAlertRecorder>();
-    server_->SetPacketFilter(alert_recorder);
-
     client_->SendDirect(data);
     server_->StartConnect();
     server_->Handshake();
     ASSERT_TRUE_WAIT(
         (server_->error_code() == SSL_ERROR_RX_MALFORMED_CLIENT_HELLO), 2000);
-
-    EXPECT_EQ(kTlsAlertFatal, alert_recorder->level());
-    EXPECT_EQ(kTlsAlertIllegalParameter, alert_recorder->description());
   }
 };
 
 // Gather a 5-byte v3 record, with a zero fragment length. The empty handshake
 // message should be ignored, and the connection will succeed afterwards.
 TEST_F(TlsConnectTest, GatherEmptyV3Record) {
   DataBuffer buffer;
 
@@ -51,26 +45,21 @@ TEST_F(TlsConnectTest, GatherExcessiveV3
 
   size_t idx = 0;
   idx = buffer.Write(idx, 0x16, 1);                            // handshake
   idx = buffer.Write(idx, 0x0301, 2);                          // record_version
   (void)buffer.Write(idx, MAX_FRAGMENT_LENGTH + 2048 + 1, 2);  // length=max+1
 
   EnsureTlsSetup();
   server_->ExpectSendAlert(kTlsAlertRecordOverflow);
-  auto alert_recorder = std::make_shared<TlsAlertRecorder>();
-  server_->SetPacketFilter(alert_recorder);
   client_->SendDirect(buffer);
   server_->StartConnect();
   server_->Handshake();
   ASSERT_TRUE_WAIT((server_->error_code() == SSL_ERROR_RX_RECORD_TOO_LONG),
                    2000);
-
-  EXPECT_EQ(kTlsAlertFatal, alert_recorder->level());
-  EXPECT_EQ(kTlsAlertRecordOverflow, alert_recorder->description());
 }
 
 // Gather a 3-byte v2 header, with a fragment length of 2.
 TEST_F(GatherV2ClientHelloTest, GatherV2RecordLongHeader) {
   DataBuffer buffer;
 
   size_t idx = 0;
   idx = buffer.Write(idx, 0x0002, 2);  // length=2 (long header)
--- a/gtests/ssl_gtest/ssl_loopback_unittest.cc
+++ b/gtests/ssl_gtest/ssl_loopback_unittest.cc
@@ -34,30 +34,116 @@ TEST_P(TlsConnectGeneric, Connect) {
 
 TEST_P(TlsConnectGeneric, ConnectEcdsa) {
   SetExpectedVersion(std::get<1>(GetParam()));
   Reset(TlsAgent::kServerEcdsa256);
   Connect();
   CheckKeys(ssl_kea_ecdh, ssl_auth_ecdsa);
 }
 
-TEST_P(TlsConnectGenericPre13, CipherSuiteMismatch) {
+TEST_P(TlsConnectGeneric, CipherSuiteMismatch) {
   EnsureTlsSetup();
   if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) {
     client_->EnableSingleCipher(TLS_AES_128_GCM_SHA256);
     server_->EnableSingleCipher(TLS_AES_256_GCM_SHA384);
   } else {
     client_->EnableSingleCipher(TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA);
     server_->EnableSingleCipher(TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA);
   }
   ConnectExpectAlert(server_, kTlsAlertHandshakeFailure);
   client_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP);
   server_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP);
 }
 
+class TlsAlertRecorder : public TlsRecordFilter {
+ public:
+  TlsAlertRecorder() : level_(255), description_(255) {}
+
+  PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
+                                    const DataBuffer& input,
+                                    DataBuffer* output) override {
+    if (level_ != 255) {  // Already captured.
+      return KEEP;
+    }
+    if (header.content_type() != kTlsAlertType) {
+      return KEEP;
+    }
+
+    std::cerr << "Alert: " << input << std::endl;
+
+    TlsParser parser(input);
+    EXPECT_TRUE(parser.Read(&level_));
+    EXPECT_TRUE(parser.Read(&description_));
+    return KEEP;
+  }
+
+  uint8_t level() const { return level_; }
+  uint8_t description() const { return description_; }
+
+ private:
+  uint8_t level_;
+  uint8_t description_;
+};
+
+class HelloTruncator : public TlsHandshakeFilter {
+  PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
+                                       const DataBuffer& input,
+                                       DataBuffer* output) override {
+    if (header.handshake_type() != kTlsHandshakeClientHello &&
+        header.handshake_type() != kTlsHandshakeServerHello) {
+      return KEEP;
+    }
+    output->Assign(input.data(), input.len() - 1);
+    return CHANGE;
+  }
+};
+
+// Verify that when NSS reports that an alert is sent, it is actually sent.
+TEST_P(TlsConnectGeneric, CaptureAlertServer) {
+  client_->SetPacketFilter(std::make_shared<HelloTruncator>());
+  auto alert_recorder = std::make_shared<TlsAlertRecorder>();
+  server_->SetPacketFilter(alert_recorder);
+
+  ConnectExpectAlert(server_, kTlsAlertIllegalParameter);
+  EXPECT_EQ(kTlsAlertFatal, alert_recorder->level());
+  EXPECT_EQ(kTlsAlertIllegalParameter, alert_recorder->description());
+}
+
+TEST_P(TlsConnectGenericPre13, CaptureAlertClient) {
+  server_->SetPacketFilter(std::make_shared<HelloTruncator>());
+  auto alert_recorder = std::make_shared<TlsAlertRecorder>();
+  client_->SetPacketFilter(alert_recorder);
+
+  ConnectExpectAlert(client_, kTlsAlertDecodeError);
+  EXPECT_EQ(kTlsAlertFatal, alert_recorder->level());
+  EXPECT_EQ(kTlsAlertDecodeError, alert_recorder->description());
+}
+
+// In TLS 1.3, the server can't read the client alert.
+TEST_P(TlsConnectTls13, CaptureAlertClient) {
+  server_->SetPacketFilter(std::make_shared<HelloTruncator>());
+  auto alert_recorder = std::make_shared<TlsAlertRecorder>();
+  client_->SetPacketFilter(alert_recorder);
+
+  server_->StartConnect();
+  client_->StartConnect();
+
+  client_->Handshake();
+  client_->ExpectSendAlert(kTlsAlertDecodeError);
+  server_->Handshake();
+  client_->Handshake();
+  if (mode_ == STREAM) {
+    // DTLS just drops the alert it can't decrypt.
+    server_->ExpectSendAlert(kTlsAlertBadRecordMac);
+  }
+  server_->Handshake();
+  EXPECT_EQ(kTlsAlertFatal, alert_recorder->level());
+  EXPECT_EQ(kTlsAlertDecodeError, alert_recorder->description());
+}
+
 TEST_P(TlsConnectGenericPre13, ConnectFalseStart) {
   client_->EnableFalseStart();
   Connect();
   SendReceive();
 }
 
 TEST_P(TlsConnectGeneric, ConnectAlpn) {
   EnableAlpn();
--- a/gtests/ssl_gtest/ssl_skip_unittest.cc
+++ b/gtests/ssl_gtest/ssl_skip_unittest.cc
@@ -82,22 +82,18 @@ class TlsSkipTest
     : public TlsConnectTestBase,
       public ::testing::WithParamInterface<std::tuple<std::string, uint16_t>> {
  protected:
   TlsSkipTest()
       : TlsConnectTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) {}
 
   void ServerSkipTest(std::shared_ptr<PacketFilter> filter,
                       uint8_t alert = kTlsAlertUnexpectedMessage) {
-    auto alert_recorder = std::make_shared<TlsAlertRecorder>();
-    client_->SetPacketFilter(alert_recorder);
     server_->SetPacketFilter(filter);
     ConnectExpectAlert(client_, alert);
-    EXPECT_EQ(kTlsAlertFatal, alert_recorder->level());
-    EXPECT_EQ(alert, alert_recorder->description());
   }
 };
 
 class Tls13SkipTest : public TlsConnectTestBase,
                       public ::testing::WithParamInterface<std::string> {
  protected:
   Tls13SkipTest()
       : TlsConnectTestBase(GetParam(), SSL_LIBRARY_VERSION_TLS_1_3) {}
@@ -125,16 +121,18 @@ class Tls13SkipTest : public TlsConnectT
     EnsureTlsSetup();
     client_->SetTlsRecordFilter(filter);
     filter->EnableDecryption();
     server_->ExpectSendAlert(kTlsAlertUnexpectedMessage);
     ConnectExpectFailOneSide(TlsAgent::SERVER);
 
     server_->CheckErrorCode(error);
     ASSERT_EQ(TlsAgent::STATE_CONNECTED, client_->state());
+
+    client_->Handshake();  // Make sure to consume the alert the server sends.
   }
 };
 
 TEST_P(TlsSkipTest, SkipCertificateRsa) {
   EnableOnlyStaticRsaCiphers();
   ServerSkipTest(
       std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeCertificate));
   client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_HELLO_DONE);
@@ -213,27 +211,25 @@ TEST_P(Tls13SkipTest, SkipServerCertific
 
 TEST_P(Tls13SkipTest, SkipClientCertificate) {
   client_->SetupClientAuth();
   server_->RequestClientAuth(true);
   client_->ExpectReceiveAlert(kTlsAlertUnexpectedMessage);
   ClientSkipTest(
       std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeCertificate),
       SSL_ERROR_RX_UNEXPECTED_CERT_VERIFY);
-  client_->Handshake();  // Make sure to consume the alert.
 }
 
 TEST_P(Tls13SkipTest, SkipClientCertificateVerify) {
   client_->SetupClientAuth();
   server_->RequestClientAuth(true);
   client_->ExpectReceiveAlert(kTlsAlertUnexpectedMessage);
   ClientSkipTest(
       std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeCertificateVerify),
       SSL_ERROR_RX_UNEXPECTED_FINISHED);
-  client_->Handshake();  // Make sure to consume the alert.
 }
 
 INSTANTIATE_TEST_CASE_P(SkipTls10, TlsSkipTest,
                         ::testing::Combine(TlsConnectTestBase::kTlsModesStream,
                                            TlsConnectTestBase::kTlsV10));
 INSTANTIATE_TEST_CASE_P(SkipVariants, TlsSkipTest,
                         ::testing::Combine(TlsConnectTestBase::kTlsModesAll,
                                            TlsConnectTestBase::kTlsV11V12));
--- a/gtests/ssl_gtest/ssl_staticrsa_unittest.cc
+++ b/gtests/ssl_gtest/ssl_staticrsa_unittest.cc
@@ -47,34 +47,26 @@ TEST_P(TlsConnectGenericPre13, ConnectSt
 // Test that a totally bogus EPMS is handled correctly.
 // This test is stream so we can catch the bad_record_mac alert.
 TEST_P(TlsConnectStreamPre13, ConnectStaticRSABogusCKE) {
   EnableOnlyStaticRsaCiphers();
   auto i1 = std::make_shared<TlsInspectorReplaceHandshakeMessage>(
       kTlsHandshakeClientKeyExchange,
       DataBuffer(kBogusClientKeyExchange, sizeof(kBogusClientKeyExchange)));
   client_->SetPacketFilter(i1);
-  auto alert_recorder = std::make_shared<TlsAlertRecorder>();
-  server_->SetPacketFilter(alert_recorder);
   ConnectExpectAlert(server_, kTlsAlertBadRecordMac);
-  EXPECT_EQ(kTlsAlertFatal, alert_recorder->level());
-  EXPECT_EQ(kTlsAlertBadRecordMac, alert_recorder->description());
 }
 
 // Test that a PMS with a bogus version number is handled correctly.
 // This test is stream so we can catch the bad_record_mac alert.
 TEST_P(TlsConnectStreamPre13, ConnectStaticRSABogusPMSVersionDetect) {
   EnableOnlyStaticRsaCiphers();
   client_->SetPacketFilter(
       std::make_shared<TlsInspectorClientHelloVersionChanger>(server_));
-  auto alert_recorder = std::make_shared<TlsAlertRecorder>();
-  server_->SetPacketFilter(alert_recorder);
   ConnectExpectAlert(server_, kTlsAlertBadRecordMac);
-  EXPECT_EQ(kTlsAlertFatal, alert_recorder->level());
-  EXPECT_EQ(kTlsAlertBadRecordMac, alert_recorder->description());
 }
 
 // Test that a PMS with a bogus version number is ignored when
 // rollback detection is disabled. This is a positive control for
 // ConnectStaticRSABogusPMSVersionDetect.
 TEST_P(TlsConnectGenericPre13, ConnectStaticRSABogusPMSVersionIgnore) {
   EnableOnlyStaticRsaCiphers();
   client_->SetPacketFilter(
@@ -86,35 +78,27 @@ TEST_P(TlsConnectGenericPre13, ConnectSt
 // This test is stream so we can catch the bad_record_mac alert.
 TEST_P(TlsConnectStreamPre13, ConnectExtendedMasterSecretStaticRSABogusCKE) {
   EnableOnlyStaticRsaCiphers();
   EnableExtendedMasterSecret();
   auto inspect = std::make_shared<TlsInspectorReplaceHandshakeMessage>(
       kTlsHandshakeClientKeyExchange,
       DataBuffer(kBogusClientKeyExchange, sizeof(kBogusClientKeyExchange)));
   client_->SetPacketFilter(inspect);
-  auto alert_recorder = std::make_shared<TlsAlertRecorder>();
-  server_->SetPacketFilter(alert_recorder);
   ConnectExpectAlert(server_, kTlsAlertBadRecordMac);
-  EXPECT_EQ(kTlsAlertFatal, alert_recorder->level());
-  EXPECT_EQ(kTlsAlertBadRecordMac, alert_recorder->description());
 }
 
 // This test is stream so we can catch the bad_record_mac alert.
 TEST_P(TlsConnectStreamPre13,
        ConnectExtendedMasterSecretStaticRSABogusPMSVersionDetect) {
   EnableOnlyStaticRsaCiphers();
   EnableExtendedMasterSecret();
   client_->SetPacketFilter(
       std::make_shared<TlsInspectorClientHelloVersionChanger>(server_));
-  auto alert_recorder = std::make_shared<TlsAlertRecorder>();
-  server_->SetPacketFilter(alert_recorder);
   ConnectExpectAlert(server_, kTlsAlertBadRecordMac);
-  EXPECT_EQ(kTlsAlertFatal, alert_recorder->level());
-  EXPECT_EQ(kTlsAlertBadRecordMac, alert_recorder->description());
 }
 
 TEST_P(TlsConnectStreamPre13,
        ConnectExtendedMasterSecretStaticRSABogusPMSVersionIgnore) {
   EnableOnlyStaticRsaCiphers();
   EnableExtendedMasterSecret();
   client_->SetPacketFilter(
       std::make_shared<TlsInspectorClientHelloVersionChanger>(server_));
--- a/gtests/ssl_gtest/tls_filter.cc
+++ b/gtests/ssl_gtest/tls_filter.cc
@@ -364,41 +364,16 @@ PacketFilter::Action TlsInspectorReplace
 
 PacketFilter::Action TlsConversationRecorder::FilterRecord(
     const TlsRecordHeader& header, const DataBuffer& input,
     DataBuffer* output) {
   buffer_.Append(input);
   return KEEP;
 }
 
-PacketFilter::Action TlsAlertRecorder::FilterRecord(
-    const TlsRecordHeader& header, const DataBuffer& input,
-    DataBuffer* output) {
-  if (level_ == kTlsAlertFatal) {  // already fatal
-    return KEEP;
-  }
-  if (header.content_type() != kTlsAlertType) {
-    return KEEP;
-  }
-
-  std::cerr << "Alert: " << input << std::endl;
-
-  TlsParser parser(input);
-  uint8_t lvl;
-  if (!parser.Read(&lvl)) {
-    return KEEP;
-  }
-  if (lvl == kTlsAlertWarning) {  // not strong enough
-    return KEEP;
-  }
-  level_ = lvl;
-  (void)parser.Read(&description_);
-  return KEEP;
-}
-
 PacketFilter::Action ChainedPacketFilter::Filter(const DataBuffer& input,
                                                  DataBuffer* output) {
   DataBuffer in(input);
   bool changed = false;
   for (auto it = filters_.begin(); it != filters_.end(); ++it) {
     PacketFilter::Action action = (*it)->Filter(in, output);
     if (action == DROP) {
       return DROP;
--- a/gtests/ssl_gtest/tls_filter.h
+++ b/gtests/ssl_gtest/tls_filter.h
@@ -227,34 +227,16 @@ class TlsConversationRecorder : public T
   virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
                                             const DataBuffer& input,
                                             DataBuffer* output);
 
  private:
   DataBuffer& buffer_;
 };
 
-// Records an alert.  If an alert has already been recorded, it won't save the
-// new alert unless the old alert is a warning and the new one is fatal.
-class TlsAlertRecorder : public TlsRecordFilter {
- public:
-  TlsAlertRecorder() : level_(255), description_(255) {}
-
-  virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
-                                            const DataBuffer& input,
-                                            DataBuffer* output);
-
-  uint8_t level() const { return level_; }
-  uint8_t description() const { return description_; }
-
- private:
-  uint8_t level_;
-  uint8_t description_;
-};
-
 // Runs multiple packet filters in series.
 class ChainedPacketFilter : public PacketFilter {
  public:
   ChainedPacketFilter() {}
   ChainedPacketFilter(const std::vector<std::shared_ptr<PacketFilter>> filters)
       : filters_(filters.begin(), filters.end()) {}
   virtual ~ChainedPacketFilter() {}