Bug 1381257 - Reinject TLS 1.3 HelloRetryRequest as a hash. r=mt NSS_TLS13_DRAFT19_BRANCH
authorEKR <ekr@rtfm.com>
Sat, 15 Jul 2017 09:05:13 -0700
branchNSS_TLS13_DRAFT19_BRANCH
changeset 13464 6dec11c45d4fee11c0f37f0635cc43dd39c5956b
parent 13463 7abf299e1e6fc9ec0303ea7291b0160d2917059b
child 13465 1bca0a132021760b7df57dc6953ad02d7c4a2c88
push id2273
push userekr@mozilla.com
push dateSat, 15 Jul 2017 18:23:08 +0000
reviewersmt
bugs1381257
Bug 1381257 - Reinject TLS 1.3 HelloRetryRequest as a hash. r=mt Summary: Reviewers: mt Differential Revision: https://nss-review.dev.mozaws.net/D372 Review comments
lib/ssl/ssl3con.c
lib/ssl/sslimpl.h
lib/ssl/sslt.h
lib/ssl/tls13con.c
--- a/lib/ssl/ssl3con.c
+++ b/lib/ssl/ssl3con.c
@@ -6558,17 +6558,17 @@ ssl_ClientConsumeCipherSuite(sslSocket *
     /* Don't let the server change its mind. */
     if (ss->ssl3.hs.helloRetry && temp != ss->ssl3.hs.cipher_suite) {
         (void)SSL3_SendAlert(ss, alert_fatal, illegal_parameter);
         PORT_SetError(SSL_ERROR_RX_MALFORMED_SERVER_HELLO);
         return SECFailure;
     }
 
     ss->ssl3.hs.cipher_suite = (ssl3CipherSuite)temp;
-    return SECSuccess;
+    return ssl3_SetupCipherSuite(ss, PR_FALSE);
 }
 
 /* Called from ssl3_HandleHandshakeMessage() when it has deciphered a complete
  * ssl3 ServerHello message.
  * Caller must hold Handshake and RecvBuf locks.
  */
 static SECStatus
 ssl3_HandleServerHello(sslSocket *ss, PRUint8 *b, PRUint32 length)
@@ -9724,17 +9724,18 @@ ssl3_HandleCertificateVerify(sslSocket *
 
     signed_hash.data = NULL;
 
     if (length != 0) {
         desc = isTLS ? decode_error : illegal_parameter;
         goto alert_loser; /* malformed */
     }
 
-    rv = ssl_HashHandshakeMessage(ss, savedMsg, savedLen);
+    rv = ssl_HashHandshakeMessage(ss, ssl_hs_certificate_verify,
+                                  savedMsg, savedLen);
     if (rv != SECSuccess) {
         PORT_SetError(SEC_ERROR_LIBRARY_FAILURE);
         return rv;
     }
 
     ss->ssl3.hs.ws = wait_change_cipher;
     return SECSuccess;
 
@@ -11329,17 +11330,17 @@ ssl3_HandleFinished(sslSocket *ss, PRUin
 
     rv = ssl3_ComputeHandshakeHashes(ss, ss->ssl3.crSpec, &hashes,
                                      isServer ? sender_client : sender_server);
     if (rv != SECSuccess) {
         PORT_SetError(SEC_ERROR_LIBRARY_FAILURE);
         return SECFailure;
     }
 
-    rv = ssl_HashHandshakeMessage(ss, b, length);
+    rv = ssl_HashHandshakeMessage(ss, ssl_hs_finished, b, length);
     if (rv != SECSuccess) {
         PORT_SetError(SEC_ERROR_LIBRARY_FAILURE);
         return rv;
     }
 
     isTLS = (PRBool)(ss->ssl3.crSpec->version > SSL_LIBRARY_VERSION_3_0);
     if (isTLS) {
         TLSFinished tlsFinished;
@@ -11566,23 +11567,24 @@ ssl3_FinishHandshake(sslSocket *ss)
     ss->ssl3.hs.ws = idle_handshake;
 
     ssl_FinishHandshake(ss);
 
     return SECSuccess;
 }
 
 SECStatus
-ssl_HashHandshakeMessage(sslSocket *ss, const PRUint8 *b, PRUint32 length)
+ssl_HashHandshakeMessage(sslSocket *ss, SSLHandshakeType type,
+                         const PRUint8 *b, PRUint32 length)
 {
     PRUint8 hdr[4];
     PRUint8 dtlsData[8];
     SECStatus rv;
 
-    hdr[0] = (PRUint8)ss->ssl3.hs.msg_type;
+    hdr[0] = (PRUint8)type;
     hdr[1] = (PRUint8)(length >> 16);
     hdr[2] = (PRUint8)(length >> 8);
     hdr[3] = (PRUint8)(length);
 
     rv = ssl3_UpdateHandshakeHashes(ss, (unsigned char *)hdr, 4);
     if (rv != SECSuccess)
         return rv; /* err code already set. */
 
@@ -11642,24 +11644,25 @@ ssl3_HandleHandshakeMessage(sslSocket *s
     }
     switch (ss->ssl3.hs.msg_type) {
         case ssl_hs_hello_request:
         case ssl_hs_hello_verify_request:
             /* We don't include hello_request and hello_verify_request messages
              * in the handshake hashes */
             break;
 
+        case ssl_hs_hello_retry_request:
         case ssl_hs_certificate_verify:
         case ssl_hs_finished:
             /* Defer hashing of these messages until the message handlers
              * we need to finalize the hashes there. */
             break;
 
         default:
-            rv = ssl_HashHandshakeMessage(ss, b, length);
+            rv = ssl_HashHandshakeMessage(ss, ss->ssl3.hs.msg_type, b, length);
             if (rv != SECSuccess) {
                 return SECFailure;
             }
     }
 
     PORT_SetError(0); /* each message starts with no error. */
 
     if (ss->ssl3.hs.ws == wait_certificate_status &&
--- a/lib/ssl/sslimpl.h
+++ b/lib/ssl/sslimpl.h
@@ -1340,18 +1340,18 @@ extern SECStatus ssl_CipherPrefSetDefaul
 
 extern SECStatus ssl3_ConstrainRangeByPolicy(void);
 
 extern void ssl3_InitState(sslSocket *ss);
 extern void ssl3_RestartHandshakeHashes(sslSocket *ss);
 extern SECStatus ssl3_UpdateHandshakeHashes(sslSocket *ss,
                                             const unsigned char *b,
                                             unsigned int l);
-SECStatus ssl_HashHandshakeMessage(sslSocket *ss, const PRUint8 *b,
-                                   PRUint32 length);
+SECStatus ssl_HashHandshakeMessage(sslSocket *ss, SSLHandshakeType type,
+                                   const PRUint8 *b, PRUint32 length);
 
 
 /* Returns PR_TRUE if we are still waiting for the server to complete its
  * response to our client second round. Once we've received the Finished from
  * the server then there is no need to check false start.
  */
 extern PRBool ssl3_WaitingForServerSecondRound(sslSocket *ss);
 
--- a/lib/ssl/sslt.h
+++ b/lib/ssl/sslt.h
@@ -25,17 +25,18 @@ typedef enum {
     ssl_hs_certificate = 11,
     ssl_hs_server_key_exchange = 12,
     ssl_hs_certificate_request = 13,
     ssl_hs_server_hello_done = 14,
     ssl_hs_certificate_verify = 15,
     ssl_hs_client_key_exchange = 16,
     ssl_hs_finished = 20,
     ssl_hs_certificate_status = 22,
-    ssl_hs_next_proto = 67
+    ssl_hs_next_proto = 67,
+    ssl_hs_message_hash = 254, /* Not a real message. */
 } SSLHandshakeType;
 
 typedef struct SSL3StatisticsStr {
     /* statistics from ssl3_SendClientHello (sch) */
     long sch_sid_cache_hits;
     long sch_sid_cache_misses;
     long sch_sid_cache_not_ok;
 
--- a/lib/ssl/tls13con.c
+++ b/lib/ssl/tls13con.c
@@ -58,16 +58,17 @@ static SECStatus tls13_SendHelloRetryReq
                                              const sslNamedGroupDef *selectedGroup);
 
 static SECStatus tls13_HandleServerKeyShare(sslSocket *ss);
 static SECStatus tls13_HandleEncryptedExtensions(sslSocket *ss, PRUint8 *b,
                                                  PRUint32 length);
 static SECStatus tls13_SendCertificate(sslSocket *ss);
 static SECStatus tls13_HandleCertificate(
     sslSocket *ss, PRUint8 *b, PRUint32 length);
+static SECStatus tls13_ReinjectHandshakeTranscript(sslSocket *ss);
 static SECStatus tls13_HandleCertificateRequest(sslSocket *ss, PRUint8 *b,
                                                 PRUint32 length);
 static SECStatus
 tls13_SendCertificateVerify(sslSocket *ss, SECKEYPrivateKey *privKey);
 static SECStatus tls13_HandleCertificateVerify(
                                                sslSocket *ss, PRUint8 *b, PRUint32 length);
 static SECStatus tls13_RecoverWrappedSharedSecret(sslSocket *ss,
                                                   sslSessionID *sid);
@@ -1489,16 +1490,23 @@ tls13_SendHelloRetryRequest(sslSocket *s
 
     /* We asked already, but made no progress. */
     if (ss->ssl3.hs.helloRetry) {
         FATAL_ERROR(ss, SSL_ERROR_BAD_2ND_CLIENT_HELLO, illegal_parameter);
         return SECFailure;
     }
 
     ssl_GetXmitBufLock(ss);
+    /* Reset the handshake hash. */
+    rv = tls13_ReinjectHandshakeTranscript(ss);
+    if (rv != SECSuccess) {
+        FATAL_ERROR(ss, SEC_ERROR_LIBRARY_FAILURE, internal_error);
+        goto loser;
+    }
+
     rv = ssl3_AppendHandshakeHeader(ss, ssl_hs_hello_retry_request,
                                     2 +     /* version */
                                         2 + /* cipher suite */
                                         2 + /* extension length */
                                         2 + /* group extension id */
                                         2 + /* group extension length */
                                         2 /* group */);
     if (rv != SECSuccess) {
@@ -1665,22 +1673,59 @@ tls13_SendCertificateRequest(sslSocket *
     sslBuffer_Clear(&extensionBuf);
     return SECSuccess;
 
 loser:
     sslBuffer_Clear(&extensionBuf);
     return SECFailure;
 }
 
+/* [draft-ietf-tls-tls13; S 4.4.1] says:
+ *
+ *     Transcript-Hash(ClientHello1, HelloRetryRequest, ... MN) =
+ *      Hash(message_hash ||        // Handshake type
+ *           00 00 Hash.length ||   // Handshake message length
+ *           Hash(ClientHello1) ||  // Hash of ClientHello1
+ *           HelloRetryRequest ... MN)
+ */
+static SECStatus
+tls13_ReinjectHandshakeTranscript(sslSocket *ss)
+{
+    SSL3Hashes hashes;
+    SECStatus rv;
+
+    // First compute the hash.
+    rv = tls13_ComputeHash(ss, &hashes,
+                           ss->ssl3.hs.messages.buf,
+                           ss->ssl3.hs.messages.len);
+    if (rv != SECSuccess) {
+        return SECFailure;
+    }
+
+    // Now re-init the handshake.
+    ssl3_RestartHandshakeHashes(ss);
+
+    // And reinject the message.
+    rv = ssl_HashHandshakeMessage(ss, ssl_hs_message_hash,
+                                  hashes.u.raw, hashes.len);
+    if (rv != SECSuccess) {
+        return SECFailure;
+    }
+
+    return SECSuccess;
+}
+
 SECStatus
 tls13_HandleHelloRetryRequest(sslSocket *ss, PRUint8 *b, PRUint32 length)
 {
     SECStatus rv;
     PRUint32 tmp;
     SSL3ProtocolVersion version;
+    const PRUint8 *savedMsg = b;
+    const PRUint32 savedLength = length;
 
     SSL_TRC(3, ("%d: TLS13[%d]: handle hello retry request",
                 SSL_GETPID(), ss->fd));
 
     PORT_Assert(ss->opt.noLocks || ssl_HaveRecvBufLock(ss));
     PORT_Assert(ss->opt.noLocks || ssl_HaveSSL3HandshakeLock(ss));
 
     if (ss->vrange.max < SSL_LIBRARY_VERSION_TLS_1_3) {
@@ -1748,18 +1793,29 @@ tls13_HandleHelloRetryRequest(sslSocket 
     }
 
     rv = ssl3_HandleExtensions(ss, &b, &length, ssl_hs_hello_retry_request);
     if (rv != SECSuccess) {
         return SECFailure; /* Error code set below */
     }
 
     ss->ssl3.hs.helloRetry = PR_TRUE;
+    rv = tls13_ReinjectHandshakeTranscript(ss);
+    if (rv != SECSuccess) {
+        return rv;
+    }
+
+    rv = ssl_HashHandshakeMessage(ss, ssl_hs_hello_retry_request,
+                                   savedMsg, savedLength);
+    if (rv != SECSuccess) {
+        return rv;
+    }
 
     ssl_GetXmitBufLock(ss);
+
     rv = ssl3_SendClientHello(ss, client_hello_retry);
     ssl_ReleaseXmitBufLock(ss);
     if (rv != SECSuccess) {
         return SECFailure;
     }
 
     return SECSuccess;
 }
@@ -3265,17 +3321,17 @@ tls13_HandleCertificateVerify(sslSocket 
         return SECFailure;
     }
 
     rv = tls13_ComputeHandshakeHashes(ss, &hashes);
     if (rv != SECSuccess) {
         return SECFailure;
     }
 
-    rv = ssl_HashHandshakeMessage(ss, b, length);
+    rv = ssl_HashHandshakeMessage(ss, ssl_hs_certificate_verify, b, length);
     if (rv != SECSuccess) {
         PORT_SetError(SEC_ERROR_LIBRARY_FAILURE);
         return SECFailure;
     }
 
     rv = ssl_ConsumeSignatureScheme(ss, &b, &length, &sigScheme);
     if (rv != SECSuccess) {
         PORT_SetError(SSL_ERROR_RX_MALFORMED_CERT_VERIFY);
@@ -3587,17 +3643,17 @@ tls13_ClientHandleFinished(sslSocket *ss
     }
 
     rv = tls13_ComputeHandshakeHashes(ss, &hashes);
     if (rv != SECSuccess) {
         LOG_ERROR(ss, SEC_ERROR_LIBRARY_FAILURE);
         return SECFailure;
     }
 
-    rv = ssl_HashHandshakeMessage(ss, b, length);
+    rv = ssl_HashHandshakeMessage(ss, ssl_hs_finished, b, length);
     if (rv != SECSuccess) {
         PORT_SetError(SEC_ERROR_LIBRARY_FAILURE);
         return SECFailure;
     }
 
     rv = tls13_VerifyFinished(ss, ssl_hs_finished,
                               ss->ssl3.hs.serverHsTrafficSecret,
                               b, length, &hashes);
@@ -3632,17 +3688,17 @@ tls13_ServerHandleFinished(sslSocket *ss
     }
 
     rv = tls13_ComputeHandshakeHashes(ss, &hashes);
     if (rv != SECSuccess) {
         LOG_ERROR(ss, SEC_ERROR_LIBRARY_FAILURE);
         return SECFailure;
     }
 
-    rv = ssl_HashHandshakeMessage(ss, b, length);
+    rv = ssl_HashHandshakeMessage(ss, ssl_hs_finished, b, length);
     if (rv != SECSuccess) {
         PORT_SetError(SEC_ERROR_LIBRARY_FAILURE);
         return SECFailure;
     }
 
     rv = tls13_VerifyFinished(ss, ssl_hs_finished, secret, b, length, &hashes);
     if (rv != SECSuccess)
         return SECFailure;