Bug 1304604 - Send client-side handshake alerts with handshake keys, r=mt NSS_TLS13_DRAFT19_BRANCH
authorEKR <ekr@rtfm.com>
Thu, 10 Aug 2017 08:50:24 +1000
branchNSS_TLS13_DRAFT19_BRANCH
changeset 13519 1572d2ca75b70d3234253263e4ad1db37767e9f9
parent 13518 481dd1cb873c4abe0cea3dd12cd1572a46ddaeee
child 13520 7032f2210e59373cf24b1578fe12f11c9c7f7bb6
push id2314
push usermartin.thomson@gmail.com
push dateWed, 09 Aug 2017 23:01:37 +0000
reviewersmt
bugs1304604
Bug 1304604 - Send client-side handshake alerts with handshake keys, r=mt Differential Revision: https://nss-review.dev.mozaws.net/D397
gtests/ssl_gtest/ssl_0rtt_unittest.cc
gtests/ssl_gtest/ssl_damage_unittest.cc
gtests/ssl_gtest/ssl_extension_unittest.cc
gtests/ssl_gtest/ssl_skip_unittest.cc
gtests/ssl_gtest/tls_agent.cc
lib/ssl/ssl3con.c
lib/ssl/tls13con.c
lib/ssl/tls13con.h
--- a/gtests/ssl_gtest/ssl_0rtt_unittest.cc
+++ b/gtests/ssl_gtest/ssl_0rtt_unittest.cc
@@ -259,16 +259,24 @@ TEST_P(TlsConnectTls13, TestTls13ZeroRtt
     return true;
   });
   Handshake();
   CheckConnected();
   SendReceive();
   CheckAlpn("a");
 }
 
+// NOTE: In this test and those below, the client always sends
+// post-ServerHello alerts with the handshake keys, even if the server
+// has accepted 0-RTT.  In some cases, as with errors in
+// EncryptedExtensions, the client can't know the server's behavior,
+// and in others it's just simpler.  What the server is expecting
+// depends on whether it accepted 0-RTT or not. Eventually, we may
+// make the server trial decrypt.
+//
 // Have the server negotiate a different ALPN value, and therefore
 // reject 0-RTT.
 TEST_P(TlsConnectTls13, TestTls13ZeroRttAlpnChangeServer) {
   EnableAlpn();
   SetupForZeroRtt();
   static const uint8_t client_alpn[] = {0x01, 0x61, 0x01, 0x62};  // "a", "b"
   static const uint8_t server_alpn[] = {0x01, 0x62};              // "b"
   client_->EnableAlpn(client_alpn, sizeof(client_alpn));
@@ -297,42 +305,52 @@ TEST_P(TlsConnectTls13, TestTls13ZeroRtt
   server_->Set0RttEnabled(true);
   EnableAlpn();
   ExpectResumption(RESUME_TICKET);
   ZeroRttSendReceive(true, true, [this]() {
     PRUint8 b[] = {'b'};
     client_->CheckAlpn(SSL_NEXT_PROTO_EARLY_VALUE, "a");
     EXPECT_EQ(SECSuccess, SSLInt_Set0RttAlpn(client_->ssl_fd(), b, sizeof(b)));
     client_->CheckAlpn(SSL_NEXT_PROTO_EARLY_VALUE, "b");
-    ExpectAlert(client_, kTlsAlertIllegalParameter);
+    client_->ExpectSendAlert(kTlsAlertIllegalParameter);
     return true;
   });
-  Handshake();
+  if (variant_ == ssl_variant_stream) {
+    server_->ExpectSendAlert(kTlsAlertBadRecordMac);
+    Handshake();
+    server_->CheckErrorCode(SSL_ERROR_BAD_MAC_READ);
+  } else {
+    client_->Handshake();
+  }
   client_->CheckErrorCode(SSL_ERROR_NEXT_PROTOCOL_DATA_INVALID);
-  server_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
 }
 
 // Set up with no ALPN and then set the client so it thinks it has ALPN.
 // The server responds without the extension and the client returns an
 // error.
 TEST_P(TlsConnectTls13, TestTls13ZeroRttNoAlpnClient) {
   SetupForZeroRtt();
   client_->Set0RttEnabled(true);
   server_->Set0RttEnabled(true);
   ExpectResumption(RESUME_TICKET);
   ZeroRttSendReceive(true, true, [this]() {
     PRUint8 b[] = {'b'};
     EXPECT_EQ(SECSuccess, SSLInt_Set0RttAlpn(client_->ssl_fd(), b, 1));
     client_->CheckAlpn(SSL_NEXT_PROTO_EARLY_VALUE, "b");
-    ExpectAlert(client_, kTlsAlertIllegalParameter);
+    client_->ExpectSendAlert(kTlsAlertIllegalParameter);
     return true;
   });
-  Handshake();
+  if (variant_ == ssl_variant_stream) {
+    server_->ExpectSendAlert(kTlsAlertBadRecordMac);
+    Handshake();
+    server_->CheckErrorCode(SSL_ERROR_BAD_MAC_READ);
+  } else {
+    client_->Handshake();
+  }
   client_->CheckErrorCode(SSL_ERROR_NEXT_PROTOCOL_DATA_INVALID);
-  server_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
 }
 
 // Remove the old ALPN value and so the client will not offer early data.
 TEST_P(TlsConnectTls13, TestTls13ZeroRttAlpnChangeBoth) {
   EnableAlpn();
   SetupForZeroRtt();
   static const uint8_t alpn[] = {0x01, 0x62};  // "b"
   EnableAlpn(alpn, sizeof(alpn));
--- a/gtests/ssl_gtest/ssl_damage_unittest.cc
+++ b/gtests/ssl_gtest/ssl_damage_unittest.cc
@@ -46,26 +46,22 @@ TEST_F(TlsConnectTest, DamageSecretHandl
   client_->CheckErrorCode(SSL_ERROR_DECRYPT_ERROR_ALERT);
 }
 
 TEST_F(TlsConnectTest, DamageSecretHandleServerFinished) {
   client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1,
                            SSL_LIBRARY_VERSION_TLS_1_3);
   server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1,
                            SSL_LIBRARY_VERSION_TLS_1_3);
-  client_->ExpectSendAlert(kTlsAlertDecryptError);
-  // The server can't read the client's alert, so it also sends an alert.
-  server_->ExpectSendAlert(kTlsAlertBadRecordMac);
   server_->SetPacketFilter(std::make_shared<AfterRecordN>(
       server_, client_,
       0,  // ServerHello.
       [this]() { SSLInt_DamageServerHsTrafficSecret(client_->ssl_fd()); }));
-  ConnectExpectFail();
+  ConnectExpectAlert(client_, kTlsAlertDecryptError);
   client_->CheckErrorCode(SSL_ERROR_BAD_HANDSHAKE_HASH_VALUE);
-  server_->CheckErrorCode(SSL_ERROR_BAD_MAC_READ);
 }
 
 TEST_P(TlsConnectGenericPre13, DamageServerSignature) {
   EnsureTlsSetup();
   auto filter =
       std::make_shared<TlsLastByteDamager>(kTlsHandshakeServerKeyExchange);
   server_->SetTlsRecordFilter(filter);
   ExpectAlert(client_, kTlsAlertDecryptError);
@@ -74,25 +70,17 @@ TEST_P(TlsConnectGenericPre13, DamageSer
   server_->CheckErrorCode(SSL_ERROR_DECRYPT_ERROR_ALERT);
 }
 
 TEST_P(TlsConnectTls13, DamageServerSignature) {
   EnsureTlsSetup();
   auto filter =
       std::make_shared<TlsLastByteDamager>(kTlsHandshakeCertificateVerify);
   server_->SetTlsRecordFilter(filter);
-  client_->ExpectSendAlert(kTlsAlertDecryptError);
-  // The server can't read the client's alert, so it also sends an alert.
-  if (variant_ == ssl_variant_stream) {
-    server_->ExpectSendAlert(kTlsAlertBadRecordMac);
-    ConnectExpectFail();
-    server_->CheckErrorCode(SSL_ERROR_BAD_MAC_READ);
-  } else {
-    ConnectExpectFailOneSide(TlsAgent::CLIENT);
-  }
+  ConnectExpectAlert(client_, kTlsAlertDecryptError);
   client_->CheckErrorCode(SEC_ERROR_BAD_SIGNATURE);
 }
 
 TEST_P(TlsConnectGeneric, DamageClientSignature) {
   EnsureTlsSetup();
   client_->SetupClientAuth();
   server_->RequestClientAuth(true);
   auto filter =
--- a/gtests/ssl_gtest/ssl_extension_unittest.cc
+++ b/gtests/ssl_gtest/ssl_extension_unittest.cc
@@ -1026,17 +1026,17 @@ class TlsBogusExtensionTestPre13 : publi
   void ConnectAndFail(uint8_t) override {
     ConnectExpectAlert(client_, kTlsAlertUnsupportedExtension);
   }
 };
 
 class TlsBogusExtensionTest13 : public TlsBogusExtensionTest {
  protected:
   void ConnectAndFail(uint8_t message) override {
-    if (message == kTlsHandshakeHelloRetryRequest) {
+    if (message != kTlsHandshakeServerHello) {
       ConnectExpectAlert(client_, kTlsAlertUnsupportedExtension);
       return;
     }
 
     FailWithAlert(kTlsAlertUnsupportedExtension);
   }
 
   void FailWithAlert(uint8_t alert) {
@@ -1069,17 +1069,17 @@ TEST_P(TlsBogusExtensionTest13, AddBogus
 TEST_P(TlsBogusExtensionTest13, AddBogusExtensionCertificate) {
   Run(kTlsHandshakeCertificate);
 }
 
 // It's perfectly valid to set unknown extensions in CertificateRequest.
 TEST_P(TlsBogusExtensionTest13, AddBogusExtensionCertificateRequest) {
   server_->RequestClientAuth(false);
   AddFilter(kTlsHandshakeCertificateRequest, 0xff);
-  FailWithAlert(kTlsAlertDecryptError);
+  ConnectExpectAlert(client_, kTlsAlertDecryptError);
   client_->CheckErrorCode(SEC_ERROR_BAD_SIGNATURE);
 }
 
 TEST_P(TlsBogusExtensionTest13, AddBogusExtensionHelloRetryRequest) {
   static const std::vector<SSLNamedGroup> groups = {ssl_grp_ec_secp384r1};
   server_->ConfigNamedGroups(groups);
 
   Run(kTlsHandshakeHelloRetryRequest);
--- a/gtests/ssl_gtest/ssl_skip_unittest.cc
+++ b/gtests/ssl_gtest/ssl_skip_unittest.cc
@@ -96,29 +96,20 @@ class Tls13SkipTest : public TlsConnectT
                       public ::testing::WithParamInterface<SSLProtocolVariant> {
  protected:
   Tls13SkipTest()
       : TlsConnectTestBase(GetParam(), SSL_LIBRARY_VERSION_TLS_1_3) {}
 
   void ServerSkipTest(std::shared_ptr<TlsRecordFilter> filter, int32_t error) {
     EnsureTlsSetup();
     server_->SetTlsRecordFilter(filter);
-    client_->ExpectSendAlert(kTlsAlertUnexpectedMessage);
-    if (variant_ == ssl_variant_stream) {
-      server_->ExpectSendAlert(kTlsAlertBadRecordMac);
-      ConnectExpectFail();
-    } else {
-      ConnectExpectFailOneSide(TlsAgent::CLIENT);
-    }
+    ExpectAlert(client_, kTlsAlertUnexpectedMessage);
+    ConnectExpectFail();
     client_->CheckErrorCode(error);
-    if (variant_ == ssl_variant_stream) {
-      server_->CheckErrorCode(SSL_ERROR_BAD_MAC_READ);
-    } else {
-      ASSERT_EQ(TlsAgent::STATE_CONNECTING, server_->state());
-    }
+    server_->CheckErrorCode(SSL_ERROR_HANDSHAKE_UNEXPECTED_ALERT);
   }
 
   void ClientSkipTest(std::shared_ptr<TlsRecordFilter> filter, int32_t error) {
     EnsureTlsSetup();
     client_->SetTlsRecordFilter(filter);
     server_->ExpectSendAlert(kTlsAlertUnexpectedMessage);
     ConnectExpectFailOneSide(TlsAgent::SERVER);
 
--- a/gtests/ssl_gtest/tls_agent.cc
+++ b/gtests/ssl_gtest/tls_agent.cc
@@ -88,21 +88,21 @@ TlsAgent::~TlsAgent() {
 
   if (adapter_) {
     Poller::Instance()->Cancel(READABLE_EVENT, adapter_);
   }
 
   // Add failures manually, if any, so we don't throw in a destructor.
   if (expected_received_alert_ != kTlsAlertCloseNotify ||
       expected_received_alert_level_ != kTlsAlertWarning) {
-    ADD_FAILURE() << "Wrong expected_received_alert status";
+    ADD_FAILURE() << "Wrong expected_received_alert status: " << role_str();
   }
   if (expected_sent_alert_ != kTlsAlertCloseNotify ||
       expected_sent_alert_level_ != kTlsAlertWarning) {
-    ADD_FAILURE() << "Wrong expected_sent_alert status";
+    ADD_FAILURE() << "Wrong expected_sent_alert status: " << role_str();
   }
 }
 
 void TlsAgent::SetState(State state) {
   if (state_ == state) return;
 
   LOG("Changing state from " << state_ << " to " << state);
   state_ = state;
--- a/lib/ssl/ssl3con.c
+++ b/lib/ssl/ssl3con.c
@@ -3104,16 +3104,25 @@ SSL3_SendAlert(sslSocket *ss, SSL3AlertL
     if (needHsLock) {
         ssl_GetSSL3HandshakeLock(ss);
     }
     if (level == alert_fatal) {
         if (!ss->opt.noCache && ss->sec.ci.sid) {
             ss->sec.uncache(ss->sec.ci.sid);
         }
     }
+
+    rv = tls13_SetAlertCipherSpec(ss);
+    if (rv != SECSuccess) {
+        if (needHsLock) {
+            ssl_ReleaseSSL3HandshakeLock(ss);
+        }
+        return rv;
+    }
+
     ssl_GetXmitBufLock(ss);
     rv = ssl3_FlushHandshake(ss, ssl_SEND_FLAG_FORCE_INTO_BUFFER);
     if (rv == SECSuccess) {
         PRInt32 sent;
         sent = ssl3_SendRecord(ss, NULL, content_alert, bytes, 2,
                                (desc == no_certificate) ? ssl_SEND_FLAG_FORCE_INTO_BUFFER : 0);
         rv = (sent >= 0) ? SECSuccess : (SECStatus)sent;
     }
--- a/lib/ssl/tls13con.c
+++ b/lib/ssl/tls13con.c
@@ -2929,16 +2929,50 @@ tls13_SetupPendingCipherSpec(sslSocket *
     SSL_TRC(3, ("%d: TLS13[%d]: Set Pending Cipher Suite to 0x%04x",
                 SSL_GETPID(), ss->fd, suite));
     pSpec->cipher_def = bulk;
 
     ssl_ReleaseSpecWriteLock(ss); /*******************************/
     return SECSuccess;
 }
 
+/*
+ * Called before sending alerts to set up the right key on the client.
+ * We might encounter errors during the handshake where the current
+ * key is ClearText or EarlyApplicationData. This
+ * function switches to the Handshake key if possible.
+ */
+SECStatus
+tls13_SetAlertCipherSpec(sslSocket *ss)
+{
+    SECStatus rv;
+
+    if (ss->sec.isServer) {
+        return SECSuccess;
+    }
+    if (ss->version < SSL_LIBRARY_VERSION_TLS_1_3) {
+        return SECSuccess;
+    }
+    if (TLS13_IN_HS_STATE(ss, wait_server_hello)) {
+        return SECSuccess;
+    }
+    if ((ss->ssl3.cwSpec->epoch != TrafficKeyClearText) &&
+        (ss->ssl3.cwSpec->epoch != TrafficKeyEarlyApplicationData)) {
+        return SECSuccess;
+    }
+
+    rv = tls13_SetCipherSpec(ss, TrafficKeyHandshake,
+                             CipherSpecWrite, PR_FALSE);
+    if (rv != SECSuccess) {
+        PORT_SetError(SEC_ERROR_LIBRARY_FAILURE);
+        return SECFailure;
+    }
+    return SECSuccess;
+}
+
 /* Install a new cipher spec for this direction. */
 static SECStatus
 tls13_SetCipherSpec(sslSocket *ss, TrafficKeyType type,
                     CipherSpecDirection direction, PRBool deleteSecret)
 {
     SECStatus rv;
     ssl3CipherSpec *spec = NULL;
     ssl3CipherSpec **specp = (direction == CipherSpecRead) ? &ss->ssl3.crSpec : &ss->ssl3.cwSpec;
--- a/lib/ssl/tls13con.h
+++ b/lib/ssl/tls13con.h
@@ -82,16 +82,17 @@ SECStatus tls13_HandleHelloRetryRequest(
                                         PRUint32 length);
 void tls13_DestroyKeyShareEntry(TLS13KeyShareEntry *entry);
 void tls13_DestroyKeyShares(PRCList *list);
 SECStatus tls13_CreateKeyShare(sslSocket *ss, const sslNamedGroupDef *groupDef);
 void tls13_DestroyEarlyData(PRCList *list);
 void tls13_CipherSpecAddRef(ssl3CipherSpec *spec);
 void tls13_CipherSpecRelease(ssl3CipherSpec *spec);
 void tls13_DestroyCipherSpecs(PRCList *list);
+SECStatus tls13_SetAlertCipherSpec(sslSocket *ss);
 tls13ExtensionStatus tls13_ExtensionStatus(PRUint16 extension,
                                            SSLHandshakeType message);
 SECStatus tls13_ProtectRecord(sslSocket *ss,
                               ssl3CipherSpec *cwSpec,
                               SSL3ContentType type,
                               const PRUint8 *pIn,
                               PRUint32 contentLen,
                               sslBuffer *wrBuf);