Bug 1369606 - Add cipher suite to HelloRetryRequest, r=ekr NSS_TLS13_DRAFT19_BRANCH
authorMartin Thomson <martin.thomson@gmail.com>
Fri, 02 Jun 2017 14:11:07 +1000
branchNSS_TLS13_DRAFT19_BRANCH
changeset 13426 de2d9bcedc24191c5e9a9eddb999fc75f84c7368
parent 13409 47deda0e74a2fcd763457753c4ec7a2df24588b7
child 13427 0c2f460814688174180a1d4f863e4a33a47b45ec
push id2241
push usermartin.thomson@gmail.com
push dateSun, 11 Jun 2017 14:21:58 +0000
reviewersekr
bugs1369606
Bug 1369606 - Add cipher suite to HelloRetryRequest, r=ekr
gtests/ssl_gtest/ssl_hrr_unittest.cc
gtests/ssl_gtest/ssl_resumption_unittest.cc
gtests/ssl_gtest/tls_filter.cc
gtests/ssl_gtest/tls_filter.h
lib/ssl/ssl3con.c
lib/ssl/sslimpl.h
lib/ssl/tls13con.c
--- a/gtests/ssl_gtest/ssl_hrr_unittest.cc
+++ b/gtests/ssl_gtest/ssl_hrr_unittest.cc
@@ -182,16 +182,33 @@ TEST_P(TlsConnectTls13, RetryWithSameKey
   static const std::vector<SSLNamedGroup> groups = {ssl_grp_ec_secp384r1,
                                                     ssl_grp_ec_secp521r1};
   server_->ConfigNamedGroups(groups);
   ConnectExpectAlert(server_, kTlsAlertIllegalParameter);
   EXPECT_EQ(SSL_ERROR_BAD_2ND_CLIENT_HELLO, server_->error_code());
   EXPECT_EQ(SSL_ERROR_ILLEGAL_PARAMETER_ALERT, client_->error_code());
 }
 
+// Stream because the server doesn't consume the alert and terminate.
+TEST_F(TlsConnectStreamTls13, RetryWithDifferentCipherSuite) {
+  EnsureTlsSetup();
+  // Force a HelloRetryRequest.
+  static const std::vector<SSLNamedGroup> groups = {ssl_grp_ec_secp384r1};
+  server_->ConfigNamedGroups(groups);
+  // Then switch out the default suite (TLS_AES_128_GCM_SHA256).
+  server_->SetPacketFilter(std::make_shared<SelectedCipherSuiteReplacer>(
+      TLS_CHACHA20_POLY1305_SHA256));
+
+  client_->ExpectSendAlert(kTlsAlertIllegalParameter);
+  server_->ExpectSendAlert(kTlsAlertBadRecordMac);
+  ConnectExpectFail();
+  EXPECT_EQ(SSL_ERROR_RX_MALFORMED_SERVER_HELLO, client_->error_code());
+  EXPECT_EQ(SSL_ERROR_BAD_MAC_READ, server_->error_code());
+}
+
 // This tests that the second attempt at sending a ClientHello (after receiving
 // a HelloRetryRequest) is correctly retransmitted.
 TEST_F(TlsConnectDatagram13, DropClientSecondFlightWithHelloRetry) {
   static const std::vector<SSLNamedGroup> groups = {ssl_grp_ec_secp384r1,
                                                     ssl_grp_ec_secp521r1};
   server_->ConfigNamedGroups(groups);
   server_->SetPacketFilter(std::make_shared<SelectiveDropFilter>(0x2));
   Connect();
@@ -271,19 +288,20 @@ class HelloRetryRequestAgentTest : publi
     TlsAgentTestClient::SetUp();
     EnsureInit();
     agent_->StartConnect();
   }
 
   void MakeCannedHrr(const uint8_t* body, size_t len, DataBuffer* hrr_record,
                      uint32_t seq_num = 0) const {
     DataBuffer hrr_data;
-    hrr_data.Allocate(len + 4);
+    hrr_data.Allocate(len + 6);
     size_t i = 0;
     i = hrr_data.Write(i, 0x7f00 | TLS_1_3_DRAFT_VERSION, 2);
+    i = hrr_data.Write(i, TLS_AES_128_GCM_SHA256, 2);
     i = hrr_data.Write(i, static_cast<uint32_t>(len), 2);
     if (len) {
       hrr_data.Write(i, body, len);
     }
     DataBuffer hrr;
     MakeHandshakeMessage(kTlsHandshakeHelloRetryRequest, hrr_data.data(),
                          hrr_data.len(), &hrr, seq_num);
     MakeRecord(kTlsHandshakeType, SSL_LIBRARY_VERSION_TLS_1_3, hrr.data(),
--- a/gtests/ssl_gtest/ssl_resumption_unittest.cc
+++ b/gtests/ssl_gtest/ssl_resumption_unittest.cc
@@ -432,46 +432,16 @@ TEST_P(TlsConnectGeneric, TestResumeServ
   Reset();
   ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
   ExpectResumption(RESUME_NONE);
   server_->EnableSingleCipher(ChooseAnotherCipher(version_));
   Connect();
   CheckKeys();
 }
 
-class SelectedCipherSuiteReplacer : public TlsHandshakeFilter {
- public:
-  SelectedCipherSuiteReplacer(uint16_t suite) : cipher_suite_(suite) {}
-
- protected:
-  PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
-                                       const DataBuffer& input,
-                                       DataBuffer* output) override {
-    if (header.handshake_type() != kTlsHandshakeServerHello) {
-      return KEEP;
-    }
-
-    *output = input;
-    uint32_t temp = 0;
-    EXPECT_TRUE(input.Read(0, 2, &temp));
-    // Cipher suite is after version(2) and random(32).
-    size_t pos = 34;
-    if (temp < SSL_LIBRARY_VERSION_TLS_1_3) {
-      // In old versions, we have to skip a session_id too.
-      EXPECT_TRUE(input.Read(pos, 1, &temp));
-      pos += 1 + temp;
-    }
-    output->Write(pos, static_cast<uint32_t>(cipher_suite_), 2);
-    return CHANGE;
-  }
-
- private:
-  uint16_t cipher_suite_;
-};
-
 // Test that the client doesn't tolerate the server picking a different cipher
 // suite for resumption.
 TEST_P(TlsConnectStream, TestResumptionOverrideCipher) {
   ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
   server_->EnableSingleCipher(ChooseOneCipher(version_));
   Connect();
   SendReceive();
   CheckKeys(ssl_kea_ecdh, ssl_auth_rsa_sign);
--- a/gtests/ssl_gtest/tls_filter.cc
+++ b/gtests/ssl_gtest/tls_filter.cc
@@ -427,18 +427,17 @@ bool FindServerHelloExtensions(TlsParser
       return false;
     }
   }
   return true;
 }
 
 static bool FindHelloRetryExtensions(TlsParser* parser,
                                      const TlsVersioned& header) {
-  // TODO for -19 add cipher suite
-  if (!parser->Skip(2)) {  // version
+  if (!parser->Skip(4)) {  // version (2) + cipher suite (2)
     return false;
   }
   return true;
 }
 
 bool FindEncryptedExtensions(TlsParser* parser, const TlsVersioned& header) {
   return true;
 }
@@ -642,9 +641,30 @@ PacketFilter::Action TlsInspectorClientH
   if (header.handshake_type() == kTlsHandshakeClientHello) {
     *output = input;
     output->Write(0, version_, 2);
     return CHANGE;
   }
   return KEEP;
 }
 
+PacketFilter::Action SelectedCipherSuiteReplacer::FilterHandshake(
+    const HandshakeHeader& header, const DataBuffer& input,
+    DataBuffer* output) {
+  if (header.handshake_type() != kTlsHandshakeServerHello) {
+    return KEEP;
+  }
+
+  *output = input;
+  uint32_t temp = 0;
+  EXPECT_TRUE(input.Read(0, 2, &temp));
+  // Cipher suite is after version(2) and random(32).
+  size_t pos = 34;
+  if (temp < SSL_LIBRARY_VERSION_TLS_1_3) {
+    // In old versions, we have to skip a session_id too.
+    EXPECT_TRUE(input.Read(pos, 1, &temp));
+    pos += 1 + temp;
+  }
+  output->Write(pos, static_cast<uint32_t>(cipher_suite_), 2);
+  return CHANGE;
+}
+
 }  // namespace nss_test
--- a/gtests/ssl_gtest/tls_filter.h
+++ b/gtests/ssl_gtest/tls_filter.h
@@ -406,11 +406,24 @@ class TlsLastByteDamager : public TlsHan
     output->data()[output->len() - 1]++;
     return CHANGE;
   }
 
  private:
   uint8_t type_;
 };
 
+class SelectedCipherSuiteReplacer : public TlsHandshakeFilter {
+ public:
+  SelectedCipherSuiteReplacer(uint16_t suite) : cipher_suite_(suite) {}
+
+ protected:
+  PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
+                                       const DataBuffer& input,
+                                       DataBuffer* output) override;
+
+ private:
+  uint16_t cipher_suite_;
+};
+
 }  // namespace nss_test
 
 #endif
--- a/lib/ssl/ssl3con.c
+++ b/lib/ssl/ssl3con.c
@@ -6568,46 +6568,96 @@ done:
     if (buf.data)
         PORT_Free(buf.data);
     return rv;
 }
 
 /* Once a cipher suite has been selected, make sure that the necessary secondary
  * information is properly set. */
 SECStatus
-ssl3_SetCipherSuite(sslSocket *ss, ssl3CipherSuite chosenSuite,
-                    PRBool initHashes)
-{
-    ss->ssl3.hs.cipher_suite = chosenSuite;
-    ss->ssl3.hs.suite_def = ssl_LookupCipherSuiteDef(chosenSuite);
+ssl3_SetupCipherSuite(sslSocket *ss, PRBool initHashes)
+{
+    ss->ssl3.hs.suite_def = ssl_LookupCipherSuiteDef(ss->ssl3.hs.cipher_suite);
     if (!ss->ssl3.hs.suite_def) {
         PORT_Assert(0);
         PORT_SetError(SEC_ERROR_LIBRARY_FAILURE);
         return SECFailure;
     }
 
     ss->ssl3.hs.kea_def = &kea_defs[ss->ssl3.hs.suite_def->key_exchange_alg];
     ss->ssl3.hs.preliminaryInfo |= ssl_preinfo_cipher_suite;
 
     if (!initHashes) {
         return SECSuccess;
     }
     /* Now we've have a cipher suite, initialize the handshake hashes. */
     return ssl3_InitHandshakeHashes(ss);
 }
 
+SECStatus
+ssl_ClientConsumeCipherSuite(sslSocket *ss, SSL3ProtocolVersion version,
+                             PRUint8 **b, unsigned int *length)
+{
+    PRUint32 temp;
+    int i;
+    SECStatus rv;
+
+    /* Find the selected cipher suite in our list. */
+    rv = ssl3_ConsumeHandshakeNumber(ss, &temp, 2, b, length);
+    if (rv != SECSuccess) {
+        return SECFailure; /* alert has been sent */
+    }
+
+    i = ssl3_config_match_init(ss);
+    PORT_Assert(i > 0);
+    if (i <= 0) {
+        return SECFailure;
+    }
+    for (i = 0; i < ssl_V3_SUITES_IMPLEMENTED; i++) {
+        ssl3CipherSuiteCfg *suite = &ss->cipherSuites[i];
+        if (temp == suite->cipher_suite) {
+            SSLVersionRange vrange = { version, version };
+            if (!config_match(suite, ss->ssl3.policy, &vrange, ss)) {
+                /* config_match already checks whether the cipher suite is
+                 * acceptable for the version, but the check is repeated here
+                 * in order to give a more precise error code. */
+                if (!ssl3_CipherSuiteAllowedForVersionRange(temp, &vrange)) {
+                    PORT_SetError(SSL_ERROR_CIPHER_DISALLOWED_FOR_VERSION);
+                } else {
+                    PORT_SetError(SSL_ERROR_NO_CYPHER_OVERLAP);
+                }
+                return SECFailure;
+            }
+            break;
+        }
+    }
+    if (i >= ssl_V3_SUITES_IMPLEMENTED) {
+        PORT_SetError(SSL_ERROR_NO_CYPHER_OVERLAP);
+        return SECFailure;
+    }
+
+    /* Don't let the server change its mind. */
+    if (ss->ssl3.hs.helloRetry && temp != ss->ssl3.hs.cipher_suite) {
+        (void)SSL3_SendAlert(ss, alert_fatal, illegal_parameter);
+        PORT_SetError(SSL_ERROR_RX_MALFORMED_SERVER_HELLO);
+        return SECFailure;
+    }
+
+    ss->ssl3.hs.cipher_suite = (ssl3CipherSuite)temp;
+    return SECSuccess;
+}
+
 /* Called from ssl3_HandleHandshakeMessage() when it has deciphered a complete
  * ssl3 ServerHello message.
  * Caller must hold Handshake and RecvBuf locks.
  */
 static SECStatus
 ssl3_HandleServerHello(sslSocket *ss, PRUint8 *b, PRUint32 length)
 {
     PRUint32 temp;
-    PRBool suite_found = PR_FALSE;
     int i;
     int errCode = SSL_ERROR_RX_MALFORMED_SERVER_HELLO;
     SECStatus rv;
     SECItem sidBytes = { siBuffer, NULL, 0 };
     PRBool isTLS = PR_FALSE;
     SSL3AlertDescription desc = illegal_parameter;
 #ifndef TLS_1_3_DRAFT_VERSION
     SSL3ProtocolVersion downgradeCheckVersion;
@@ -6720,78 +6770,45 @@ ssl3_HandleServerHello(sslSocket *ss, PR
         }
         if (sidBytes.len > SSL3_SESSIONID_BYTES) {
             if (isTLS)
                 desc = decode_error;
             goto alert_loser; /* malformed. */
         }
     }
 
-    /* find selected cipher suite in our list. */
-    rv = ssl3_ConsumeHandshakeNumber(ss, &temp, 2, &b, &length);
-    if (rv != SECSuccess) {
-        goto loser; /* alert has been sent */
-    }
-    i = ssl3_config_match_init(ss);
-    PORT_Assert(i > 0);
-    if (i <= 0) {
+    rv = ssl_ClientConsumeCipherSuite(ss, ss->version, &b, &length);
+    if (rv != SECSuccess) {
         errCode = PORT_GetError();
-        goto loser;
-    }
-    for (i = 0; i < ssl_V3_SUITES_IMPLEMENTED; i++) {
-        ssl3CipherSuiteCfg *suite = &ss->cipherSuites[i];
-        if (temp == suite->cipher_suite) {
-            SSLVersionRange vrange = { ss->version, ss->version };
-            if (!config_match(suite, ss->ssl3.policy, &vrange, ss)) {
-                /* config_match already checks whether the cipher suite is
-                 * acceptable for the version, but the check is repeated here
-                 * in order to give a more precise error code. */
-                if (!ssl3_CipherSuiteAllowedForVersionRange(temp, &vrange)) {
-                    desc = handshake_failure;
-                    errCode = SSL_ERROR_CIPHER_DISALLOWED_FOR_VERSION;
-                    goto alert_loser;
-                }
-
-                break; /* failure */
-            }
-
-            suite_found = PR_TRUE;
-            break; /* success */
-        }
-    }
-    if (!suite_found) {
-        desc = handshake_failure;
-        errCode = SSL_ERROR_NO_CYPHER_OVERLAP;
         goto alert_loser;
     }
-
-    rv = ssl3_SetCipherSuite(ss, (ssl3CipherSuite)temp, PR_TRUE);
+    rv = ssl3_SetupCipherSuite(ss, PR_TRUE);
     if (rv != SECSuccess) {
         desc = internal_error;
         errCode = PORT_GetError();
-        goto alert_loser;
+        goto loser;
     }
 
     if (ss->version < SSL_LIBRARY_VERSION_TLS_1_3) {
+        PRBool found = PR_FALSE;
         /* find selected compression method in our list. */
         rv = ssl3_ConsumeHandshakeNumber(ss, &temp, 1, &b, &length);
         if (rv != SECSuccess) {
             goto loser; /* alert has been sent */
         }
-        suite_found = PR_FALSE;
         for (i = 0; i < ssl_compression_method_count; i++) {
             if (temp == ssl_compression_methods[i]) {
                 if (!ssl_CompressionEnabled(ss, ssl_compression_methods[i])) {
                     break; /* failure */
                 }
-                suite_found = PR_TRUE;
+                found = PR_TRUE;
                 break; /* success */
             }
         }
-        if (!suite_found) {
+        if (!found) {
             desc = handshake_failure;
             errCode = SSL_ERROR_NO_COMPRESSION_OVERLAP;
             goto alert_loser;
         }
         ss->ssl3.hs.compression = (SSLCompressionMethod)temp;
     } else {
         ss->ssl3.hs.compression = ssl_compression_null;
     }
@@ -8053,17 +8070,18 @@ ssl3_NegotiateCipherSuite(sslSocket *ss,
         ssl3CipherSuiteCfg *suite = &ss->cipherSuites[j];
         SSLVersionRange vrange = { ss->version, ss->version };
         if (!config_match(suite, ss->ssl3.policy, &vrange, ss)) {
             continue;
         }
         for (i = 0; i + 1 < suites->len; i += 2) {
             PRUint16 suite_i = (suites->data[i] << 8) | suites->data[i + 1];
             if (suite_i == suite->cipher_suite) {
-                return ssl3_SetCipherSuite(ss, suite_i, initHashes);
+                ss->ssl3.hs.cipher_suite = suite_i;
+                return ssl3_SetupCipherSuite(ss, initHashes);
             }
         }
     }
     return SECFailure;
 }
 
 /*
  * Call the SNI config hook.
@@ -8718,17 +8736,18 @@ ssl3_HandleClientHelloPart2(sslSocket *s
             if (!suite->enabled)
                 break;
 #endif
             /* Double check that the cached cipher suite is in the client's
              * list.  If it isn't, fall through and start a new session. */
             for (i = 0; i + 1 < suites->len; i += 2) {
                 PRUint16 suite_i = (suites->data[i] << 8) | suites->data[i + 1];
                 if (suite_i == suite->cipher_suite) {
-                    rv = ssl3_SetCipherSuite(ss, suite_i, PR_TRUE);
+                    ss->ssl3.hs.cipher_suite = suite_i;
+                    rv = ssl3_SetupCipherSuite(ss, PR_TRUE);
                     if (rv != SECSuccess) {
                         desc = internal_error;
                         errCode = PORT_GetError();
                         goto alert_loser;
                     }
 
                     /* Use the cached compression method. */
                     ss->ssl3.hs.compression =
@@ -9165,29 +9184,28 @@ ssl3_HandleV2ClientHello(sslSocket *ss, 
         errCode = PORT_GetError(); /* error code is already set. */
         goto alert_loser;
     }
 
     /* Select a cipher suite.
     **
     ** NOTE: This suite selection algorithm should be the same as the one in
     ** ssl3_HandleClientHello().
-    **
-    ** See the comments about export cipher suites in ssl3_HandleClientHello().
     */
     for (j = 0; j < ssl_V3_SUITES_IMPLEMENTED; j++) {
         ssl3CipherSuiteCfg *suite = &ss->cipherSuites[j];
         SSLVersionRange vrange = { ss->version, ss->version };
         if (!config_match(suite, ss->ssl3.policy, &vrange, ss)) {
             continue;
         }
         for (i = 0; i + 2 < suite_length; i += 3) {
             PRUint32 suite_i = (suites[i] << 16) | (suites[i + 1] << 8) | suites[i + 2];
             if (suite_i == suite->cipher_suite) {
-                rv = ssl3_SetCipherSuite(ss, suite_i, PR_TRUE);
+                ss->ssl3.hs.cipher_suite = suite_i;
+                rv = ssl3_SetupCipherSuite(ss, PR_TRUE);
                 if (rv != SECSuccess) {
                     desc = internal_error;
                     errCode = PORT_GetError();
                     goto alert_loser;
                 }
                 goto suite_found;
             }
         }
--- a/lib/ssl/sslimpl.h
+++ b/lib/ssl/sslimpl.h
@@ -1623,16 +1623,20 @@ extern SECStatus ssl3_HandleHandshakeMes
 extern void ssl3_DestroySSL3Info(sslSocket *ss);
 
 extern SECStatus ssl_ClientReadVersion(sslSocket *ss, PRUint8 **b,
                                        PRUint32 *length,
                                        SSL3ProtocolVersion *version);
 extern SECStatus ssl3_NegotiateVersion(sslSocket *ss,
                                        SSL3ProtocolVersion peerVersion,
                                        PRBool allowLargerPeerVersion);
+extern SECStatus ssl_ClientConsumeCipherSuite(sslSocket *ss,
+                                              SSL3ProtocolVersion version,
+                                              PRUint8 **b,
+                                              unsigned int *length);
 
 extern SECStatus ssl_GetPeerInfo(sslSocket *ss);
 
 /* ECDH functions */
 extern SECStatus ssl3_SendECDHClientKeyExchange(sslSocket *ss,
                                                 SECKEYPublicKey *svrPubKey);
 extern SECStatus ssl3_HandleECDHServerKeyExchange(sslSocket *ss,
                                                   PRUint8 *b, PRUint32 length);
@@ -1825,18 +1829,17 @@ SECStatus ssl_PickSignatureScheme(sslSoc
                                   SECKEYPrivateKey *privKey,
                                   const SSLSignatureScheme *peerSchemes,
                                   unsigned int peerSchemeCount,
                                   PRBool requireSha1);
 SECOidTag ssl3_HashTypeToOID(SSLHashType hashType);
 SSLHashType ssl_SignatureSchemeToHashType(SSLSignatureScheme scheme);
 KeyType ssl_SignatureSchemeToKeyType(SSLSignatureScheme scheme);
 
-SECStatus ssl3_SetCipherSuite(sslSocket *ss, ssl3CipherSuite chosenSuite,
-                              PRBool initHashes);
+SECStatus ssl3_SetupCipherSuite(sslSocket *ss, PRBool initHashes);
 
 /* Pull in TLS 1.3 functions */
 #include "tls13con.h"
 
 /********************** misc calls *********************/
 
 #ifdef DEBUG
 extern void ssl3_CheckCipherSuiteOrderConsistency();
--- a/lib/ssl/tls13con.c
+++ b/lib/ssl/tls13con.c
@@ -472,17 +472,18 @@ tls13_SetupClientHello(sslSocket *ss)
             FATAL_ERROR(ss, SEC_ERROR_LIBRARY_FAILURE, internal_error);
             SSL_AtomicIncrementLong(&ssl3stats->sch_sid_cache_not_ok);
             ss->sec.uncache(ss->sec.ci.sid);
             ssl_FreeSID(ss->sec.ci.sid);
             ss->sec.ci.sid = NULL;
             return SECFailure;
         }
 
-        rv = ssl3_SetCipherSuite(ss, ss->sec.ci.sid->u.ssl3.cipherSuite, PR_FALSE);
+        ss->ssl3.hs.cipher_suite = ss->sec.ci.sid->u.ssl3.cipherSuite;
+        rv = ssl3_SetupCipherSuite(ss, PR_FALSE);
         if (rv != SECSuccess) {
             FATAL_ERROR(ss, PORT_GetError(), internal_error);
             return SECFailure;
         }
 
         rv = tls13_ComputeEarlySecrets(ss);
         if (rv != SECSuccess) {
             FATAL_ERROR(ss, SEC_ERROR_LIBRARY_FAILURE, internal_error);
@@ -1513,32 +1514,39 @@ tls13_SendHelloRetryRequest(sslSocket *s
     if (ss->ssl3.hs.helloRetry) {
         FATAL_ERROR(ss, SSL_ERROR_BAD_2ND_CLIENT_HELLO, illegal_parameter);
         return SECFailure;
     }
 
     ssl_GetXmitBufLock(ss);
     rv = ssl3_AppendHandshakeHeader(ss, hello_retry_request,
                                     2 +     /* version */
+                                        2 + /* cipher suite */
                                         2 + /* extension length */
                                         2 + /* group extension id */
                                         2 + /* group extension length */
                                         2 /* group */);
     if (rv != SECSuccess) {
         FATAL_ERROR(ss, SEC_ERROR_LIBRARY_FAILURE, internal_error);
         goto loser;
     }
 
     rv = ssl3_AppendHandshakeNumber(
         ss, tls13_EncodeDraftVersion(ss->version), 2);
     if (rv != SECSuccess) {
         FATAL_ERROR(ss, SEC_ERROR_LIBRARY_FAILURE, internal_error);
         goto loser;
     }
 
+    rv = ssl3_AppendHandshakeNumber(ss, ss->ssl3.hs.cipher_suite, 2);
+    if (rv != SECSuccess) {
+        FATAL_ERROR(ss, SEC_ERROR_LIBRARY_FAILURE, internal_error);
+        goto loser;
+    }
+
     /* Length of extensions. */
     rv = ssl3_AppendHandshakeNumber(ss, 2 + 2 + 2, 2);
     if (rv != SECSuccess) {
         FATAL_ERROR(ss, SEC_ERROR_LIBRARY_FAILURE, internal_error);
         goto loser;
     }
 
     /* Key share extension - currently the only reason we send this. */
@@ -1739,16 +1747,21 @@ tls13_HandleHelloRetryRequest(sslSocket 
         return SECFailure; /* alert already sent */
     }
     if (version > ss->vrange.max || version < SSL_LIBRARY_VERSION_TLS_1_3) {
         FATAL_ERROR(ss, SSL_ERROR_RX_MALFORMED_HELLO_RETRY_REQUEST,
                     protocol_version);
         return SECFailure;
     }
 
+    rv = ssl_ClientConsumeCipherSuite(ss, version, &b, &length);
+    if (rv != SECSuccess) {
+        return SECFailure; /* error code already set */
+    }
+
     /* Extensions. */
     rv = ssl3_ConsumeHandshakeNumber(ss, &tmp, 2, &b, &length);
     if (rv != SECSuccess) {
         return SECFailure; /* error code already set */
     }
     /* Extensions must be non-empty and use the remainder of the message.
      * This means that a HelloRetryRequest cannot be a no-op: we must have an
      * extension, it must be one that we understand and recognize as being valid