Bug 1348720 - Move TlsAlertRecorder to some specific alert tests, r=ttaubert
--- a/gtests/ssl_gtest/ssl_gather_unittest.cc
+++ b/gtests/ssl_gtest/ssl_gather_unittest.cc
@@ -11,27 +11,21 @@ namespace nss_test {
class GatherV2ClientHelloTest : public TlsConnectTestBase {
public:
GatherV2ClientHelloTest() : TlsConnectTestBase(STREAM, 0) {}
void ConnectExpectMalformedClientHello(const DataBuffer &data) {
EnsureTlsSetup();
server_->ExpectSendAlert(kTlsAlertIllegalParameter);
- auto alert_recorder = std::make_shared<TlsAlertRecorder>();
- server_->SetPacketFilter(alert_recorder);
-
client_->SendDirect(data);
server_->StartConnect();
server_->Handshake();
ASSERT_TRUE_WAIT(
(server_->error_code() == SSL_ERROR_RX_MALFORMED_CLIENT_HELLO), 2000);
-
- EXPECT_EQ(kTlsAlertFatal, alert_recorder->level());
- EXPECT_EQ(kTlsAlertIllegalParameter, alert_recorder->description());
}
};
// Gather a 5-byte v3 record, with a zero fragment length. The empty handshake
// message should be ignored, and the connection will succeed afterwards.
TEST_F(TlsConnectTest, GatherEmptyV3Record) {
DataBuffer buffer;
@@ -51,26 +45,21 @@ TEST_F(TlsConnectTest, GatherExcessiveV3
size_t idx = 0;
idx = buffer.Write(idx, 0x16, 1); // handshake
idx = buffer.Write(idx, 0x0301, 2); // record_version
(void)buffer.Write(idx, MAX_FRAGMENT_LENGTH + 2048 + 1, 2); // length=max+1
EnsureTlsSetup();
server_->ExpectSendAlert(kTlsAlertRecordOverflow);
- auto alert_recorder = std::make_shared<TlsAlertRecorder>();
- server_->SetPacketFilter(alert_recorder);
client_->SendDirect(buffer);
server_->StartConnect();
server_->Handshake();
ASSERT_TRUE_WAIT((server_->error_code() == SSL_ERROR_RX_RECORD_TOO_LONG),
2000);
-
- EXPECT_EQ(kTlsAlertFatal, alert_recorder->level());
- EXPECT_EQ(kTlsAlertRecordOverflow, alert_recorder->description());
}
// Gather a 3-byte v2 header, with a fragment length of 2.
TEST_F(GatherV2ClientHelloTest, GatherV2RecordLongHeader) {
DataBuffer buffer;
size_t idx = 0;
idx = buffer.Write(idx, 0x0002, 2); // length=2 (long header)
--- a/gtests/ssl_gtest/ssl_loopback_unittest.cc
+++ b/gtests/ssl_gtest/ssl_loopback_unittest.cc
@@ -34,30 +34,116 @@ TEST_P(TlsConnectGeneric, Connect) {
TEST_P(TlsConnectGeneric, ConnectEcdsa) {
SetExpectedVersion(std::get<1>(GetParam()));
Reset(TlsAgent::kServerEcdsa256);
Connect();
CheckKeys(ssl_kea_ecdh, ssl_auth_ecdsa);
}
-TEST_P(TlsConnectGenericPre13, CipherSuiteMismatch) {
+TEST_P(TlsConnectGeneric, CipherSuiteMismatch) {
EnsureTlsSetup();
if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) {
client_->EnableSingleCipher(TLS_AES_128_GCM_SHA256);
server_->EnableSingleCipher(TLS_AES_256_GCM_SHA384);
} else {
client_->EnableSingleCipher(TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA);
server_->EnableSingleCipher(TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA);
}
ConnectExpectAlert(server_, kTlsAlertHandshakeFailure);
client_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP);
server_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP);
}
+class TlsAlertRecorder : public TlsRecordFilter {
+ public:
+ TlsAlertRecorder() : level_(255), description_(255) {}
+
+ PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
+ const DataBuffer& input,
+ DataBuffer* output) override {
+ if (level_ != 255) { // Already captured.
+ return KEEP;
+ }
+ if (header.content_type() != kTlsAlertType) {
+ return KEEP;
+ }
+
+ std::cerr << "Alert: " << input << std::endl;
+
+ TlsParser parser(input);
+ EXPECT_TRUE(parser.Read(&level_));
+ EXPECT_TRUE(parser.Read(&description_));
+ return KEEP;
+ }
+
+ uint8_t level() const { return level_; }
+ uint8_t description() const { return description_; }
+
+ private:
+ uint8_t level_;
+ uint8_t description_;
+};
+
+class HelloTruncator : public TlsHandshakeFilter {
+ PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
+ const DataBuffer& input,
+ DataBuffer* output) override {
+ if (header.handshake_type() != kTlsHandshakeClientHello &&
+ header.handshake_type() != kTlsHandshakeServerHello) {
+ return KEEP;
+ }
+ output->Assign(input.data(), input.len() - 1);
+ return CHANGE;
+ }
+};
+
+// Verify that when NSS reports that an alert is sent, it is actually sent.
+TEST_P(TlsConnectGeneric, CaptureAlertServer) {
+ client_->SetPacketFilter(std::make_shared<HelloTruncator>());
+ auto alert_recorder = std::make_shared<TlsAlertRecorder>();
+ server_->SetPacketFilter(alert_recorder);
+
+ ConnectExpectAlert(server_, kTlsAlertIllegalParameter);
+ EXPECT_EQ(kTlsAlertFatal, alert_recorder->level());
+ EXPECT_EQ(kTlsAlertIllegalParameter, alert_recorder->description());
+}
+
+TEST_P(TlsConnectGenericPre13, CaptureAlertClient) {
+ server_->SetPacketFilter(std::make_shared<HelloTruncator>());
+ auto alert_recorder = std::make_shared<TlsAlertRecorder>();
+ client_->SetPacketFilter(alert_recorder);
+
+ ConnectExpectAlert(client_, kTlsAlertDecodeError);
+ EXPECT_EQ(kTlsAlertFatal, alert_recorder->level());
+ EXPECT_EQ(kTlsAlertDecodeError, alert_recorder->description());
+}
+
+// In TLS 1.3, the server can't read the client alert.
+TEST_P(TlsConnectTls13, CaptureAlertClient) {
+ server_->SetPacketFilter(std::make_shared<HelloTruncator>());
+ auto alert_recorder = std::make_shared<TlsAlertRecorder>();
+ client_->SetPacketFilter(alert_recorder);
+
+ server_->StartConnect();
+ client_->StartConnect();
+
+ client_->Handshake();
+ client_->ExpectSendAlert(kTlsAlertDecodeError);
+ server_->Handshake();
+ client_->Handshake();
+ if (mode_ == STREAM) {
+ // DTLS just drops the alert it can't decrypt.
+ server_->ExpectSendAlert(kTlsAlertBadRecordMac);
+ }
+ server_->Handshake();
+ EXPECT_EQ(kTlsAlertFatal, alert_recorder->level());
+ EXPECT_EQ(kTlsAlertDecodeError, alert_recorder->description());
+}
+
TEST_P(TlsConnectGenericPre13, ConnectFalseStart) {
client_->EnableFalseStart();
Connect();
SendReceive();
}
TEST_P(TlsConnectGeneric, ConnectAlpn) {
EnableAlpn();
--- a/gtests/ssl_gtest/ssl_skip_unittest.cc
+++ b/gtests/ssl_gtest/ssl_skip_unittest.cc
@@ -82,22 +82,18 @@ class TlsSkipTest
: public TlsConnectTestBase,
public ::testing::WithParamInterface<std::tuple<std::string, uint16_t>> {
protected:
TlsSkipTest()
: TlsConnectTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) {}
void ServerSkipTest(std::shared_ptr<PacketFilter> filter,
uint8_t alert = kTlsAlertUnexpectedMessage) {
- auto alert_recorder = std::make_shared<TlsAlertRecorder>();
- client_->SetPacketFilter(alert_recorder);
server_->SetPacketFilter(filter);
ConnectExpectAlert(client_, alert);
- EXPECT_EQ(kTlsAlertFatal, alert_recorder->level());
- EXPECT_EQ(alert, alert_recorder->description());
}
};
class Tls13SkipTest : public TlsConnectTestBase,
public ::testing::WithParamInterface<std::string> {
protected:
Tls13SkipTest()
: TlsConnectTestBase(GetParam(), SSL_LIBRARY_VERSION_TLS_1_3) {}
@@ -125,16 +121,18 @@ class Tls13SkipTest : public TlsConnectT
EnsureTlsSetup();
client_->SetTlsRecordFilter(filter);
filter->EnableDecryption();
server_->ExpectSendAlert(kTlsAlertUnexpectedMessage);
ConnectExpectFailOneSide(TlsAgent::SERVER);
server_->CheckErrorCode(error);
ASSERT_EQ(TlsAgent::STATE_CONNECTED, client_->state());
+
+ client_->Handshake(); // Make sure to consume the alert the server sends.
}
};
TEST_P(TlsSkipTest, SkipCertificateRsa) {
EnableOnlyStaticRsaCiphers();
ServerSkipTest(
std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeCertificate));
client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_HELLO_DONE);
@@ -213,27 +211,25 @@ TEST_P(Tls13SkipTest, SkipServerCertific
TEST_P(Tls13SkipTest, SkipClientCertificate) {
client_->SetupClientAuth();
server_->RequestClientAuth(true);
client_->ExpectReceiveAlert(kTlsAlertUnexpectedMessage);
ClientSkipTest(
std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeCertificate),
SSL_ERROR_RX_UNEXPECTED_CERT_VERIFY);
- client_->Handshake(); // Make sure to consume the alert.
}
TEST_P(Tls13SkipTest, SkipClientCertificateVerify) {
client_->SetupClientAuth();
server_->RequestClientAuth(true);
client_->ExpectReceiveAlert(kTlsAlertUnexpectedMessage);
ClientSkipTest(
std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeCertificateVerify),
SSL_ERROR_RX_UNEXPECTED_FINISHED);
- client_->Handshake(); // Make sure to consume the alert.
}
INSTANTIATE_TEST_CASE_P(SkipTls10, TlsSkipTest,
::testing::Combine(TlsConnectTestBase::kTlsModesStream,
TlsConnectTestBase::kTlsV10));
INSTANTIATE_TEST_CASE_P(SkipVariants, TlsSkipTest,
::testing::Combine(TlsConnectTestBase::kTlsModesAll,
TlsConnectTestBase::kTlsV11V12));
--- a/gtests/ssl_gtest/ssl_staticrsa_unittest.cc
+++ b/gtests/ssl_gtest/ssl_staticrsa_unittest.cc
@@ -47,34 +47,26 @@ TEST_P(TlsConnectGenericPre13, ConnectSt
// Test that a totally bogus EPMS is handled correctly.
// This test is stream so we can catch the bad_record_mac alert.
TEST_P(TlsConnectStreamPre13, ConnectStaticRSABogusCKE) {
EnableOnlyStaticRsaCiphers();
auto i1 = std::make_shared<TlsInspectorReplaceHandshakeMessage>(
kTlsHandshakeClientKeyExchange,
DataBuffer(kBogusClientKeyExchange, sizeof(kBogusClientKeyExchange)));
client_->SetPacketFilter(i1);
- auto alert_recorder = std::make_shared<TlsAlertRecorder>();
- server_->SetPacketFilter(alert_recorder);
ConnectExpectAlert(server_, kTlsAlertBadRecordMac);
- EXPECT_EQ(kTlsAlertFatal, alert_recorder->level());
- EXPECT_EQ(kTlsAlertBadRecordMac, alert_recorder->description());
}
// Test that a PMS with a bogus version number is handled correctly.
// This test is stream so we can catch the bad_record_mac alert.
TEST_P(TlsConnectStreamPre13, ConnectStaticRSABogusPMSVersionDetect) {
EnableOnlyStaticRsaCiphers();
client_->SetPacketFilter(
std::make_shared<TlsInspectorClientHelloVersionChanger>(server_));
- auto alert_recorder = std::make_shared<TlsAlertRecorder>();
- server_->SetPacketFilter(alert_recorder);
ConnectExpectAlert(server_, kTlsAlertBadRecordMac);
- EXPECT_EQ(kTlsAlertFatal, alert_recorder->level());
- EXPECT_EQ(kTlsAlertBadRecordMac, alert_recorder->description());
}
// Test that a PMS with a bogus version number is ignored when
// rollback detection is disabled. This is a positive control for
// ConnectStaticRSABogusPMSVersionDetect.
TEST_P(TlsConnectGenericPre13, ConnectStaticRSABogusPMSVersionIgnore) {
EnableOnlyStaticRsaCiphers();
client_->SetPacketFilter(
@@ -86,35 +78,27 @@ TEST_P(TlsConnectGenericPre13, ConnectSt
// This test is stream so we can catch the bad_record_mac alert.
TEST_P(TlsConnectStreamPre13, ConnectExtendedMasterSecretStaticRSABogusCKE) {
EnableOnlyStaticRsaCiphers();
EnableExtendedMasterSecret();
auto inspect = std::make_shared<TlsInspectorReplaceHandshakeMessage>(
kTlsHandshakeClientKeyExchange,
DataBuffer(kBogusClientKeyExchange, sizeof(kBogusClientKeyExchange)));
client_->SetPacketFilter(inspect);
- auto alert_recorder = std::make_shared<TlsAlertRecorder>();
- server_->SetPacketFilter(alert_recorder);
ConnectExpectAlert(server_, kTlsAlertBadRecordMac);
- EXPECT_EQ(kTlsAlertFatal, alert_recorder->level());
- EXPECT_EQ(kTlsAlertBadRecordMac, alert_recorder->description());
}
// This test is stream so we can catch the bad_record_mac alert.
TEST_P(TlsConnectStreamPre13,
ConnectExtendedMasterSecretStaticRSABogusPMSVersionDetect) {
EnableOnlyStaticRsaCiphers();
EnableExtendedMasterSecret();
client_->SetPacketFilter(
std::make_shared<TlsInspectorClientHelloVersionChanger>(server_));
- auto alert_recorder = std::make_shared<TlsAlertRecorder>();
- server_->SetPacketFilter(alert_recorder);
ConnectExpectAlert(server_, kTlsAlertBadRecordMac);
- EXPECT_EQ(kTlsAlertFatal, alert_recorder->level());
- EXPECT_EQ(kTlsAlertBadRecordMac, alert_recorder->description());
}
TEST_P(TlsConnectStreamPre13,
ConnectExtendedMasterSecretStaticRSABogusPMSVersionIgnore) {
EnableOnlyStaticRsaCiphers();
EnableExtendedMasterSecret();
client_->SetPacketFilter(
std::make_shared<TlsInspectorClientHelloVersionChanger>(server_));
--- a/gtests/ssl_gtest/tls_filter.cc
+++ b/gtests/ssl_gtest/tls_filter.cc
@@ -364,41 +364,16 @@ PacketFilter::Action TlsInspectorReplace
PacketFilter::Action TlsConversationRecorder::FilterRecord(
const TlsRecordHeader& header, const DataBuffer& input,
DataBuffer* output) {
buffer_.Append(input);
return KEEP;
}
-PacketFilter::Action TlsAlertRecorder::FilterRecord(
- const TlsRecordHeader& header, const DataBuffer& input,
- DataBuffer* output) {
- if (level_ == kTlsAlertFatal) { // already fatal
- return KEEP;
- }
- if (header.content_type() != kTlsAlertType) {
- return KEEP;
- }
-
- std::cerr << "Alert: " << input << std::endl;
-
- TlsParser parser(input);
- uint8_t lvl;
- if (!parser.Read(&lvl)) {
- return KEEP;
- }
- if (lvl == kTlsAlertWarning) { // not strong enough
- return KEEP;
- }
- level_ = lvl;
- (void)parser.Read(&description_);
- return KEEP;
-}
-
PacketFilter::Action ChainedPacketFilter::Filter(const DataBuffer& input,
DataBuffer* output) {
DataBuffer in(input);
bool changed = false;
for (auto it = filters_.begin(); it != filters_.end(); ++it) {
PacketFilter::Action action = (*it)->Filter(in, output);
if (action == DROP) {
return DROP;
--- a/gtests/ssl_gtest/tls_filter.h
+++ b/gtests/ssl_gtest/tls_filter.h
@@ -227,34 +227,16 @@ class TlsConversationRecorder : public T
virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
const DataBuffer& input,
DataBuffer* output);
private:
DataBuffer& buffer_;
};
-// Records an alert. If an alert has already been recorded, it won't save the
-// new alert unless the old alert is a warning and the new one is fatal.
-class TlsAlertRecorder : public TlsRecordFilter {
- public:
- TlsAlertRecorder() : level_(255), description_(255) {}
-
- virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
- const DataBuffer& input,
- DataBuffer* output);
-
- uint8_t level() const { return level_; }
- uint8_t description() const { return description_; }
-
- private:
- uint8_t level_;
- uint8_t description_;
-};
-
// Runs multiple packet filters in series.
class ChainedPacketFilter : public PacketFilter {
public:
ChainedPacketFilter() {}
ChainedPacketFilter(const std::vector<std::shared_ptr<PacketFilter>> filters)
: filters_(filters.begin(), filters.end()) {}
virtual ~ChainedPacketFilter() {}