Bug 1394956 - key_share after HelloRetryRequest can have multiple shares, r=ekr NSS_TLS13_DRAFT19_BRANCH
authorMartin Thomson <martin.thomson@gmail.com>
Wed, 30 Aug 2017 10:43:24 +1000
branchNSS_TLS13_DRAFT19_BRANCH
changeset 13608 3efb83875558adc1674dfa2ddba0a47f85979ed5
parent 13607 3ace64039e117f508f6f2951f269b5abc6d80509
child 13609 27bfdd0ee644c33d3445a844d17d9042e0210035
push id2389
push usermartin.thomson@gmail.com
push dateMon, 25 Sep 2017 04:22:23 +0000
reviewersekr
bugs1394956
Bug 1394956 - key_share after HelloRetryRequest can have multiple shares, r=ekr
gtests/ssl_gtest/ssl_hrr_unittest.cc
gtests/ssl_gtest/tls_connect.cc
gtests/ssl_gtest/tls_connect.h
lib/ssl/tls13con.c
lib/ssl/tls13exthandle.c
--- a/gtests/ssl_gtest/ssl_hrr_unittest.cc
+++ b/gtests/ssl_gtest/ssl_hrr_unittest.cc
@@ -182,16 +182,31 @@ TEST_P(TlsConnectTls13, RetryWithSameKey
   static const std::vector<SSLNamedGroup> groups = {ssl_grp_ec_secp384r1,
                                                     ssl_grp_ec_secp521r1};
   server_->ConfigNamedGroups(groups);
   ConnectExpectAlert(server_, kTlsAlertIllegalParameter);
   EXPECT_EQ(SSL_ERROR_BAD_2ND_CLIENT_HELLO, server_->error_code());
   EXPECT_EQ(SSL_ERROR_ILLEGAL_PARAMETER_ALERT, client_->error_code());
 }
 
+// Here we modify the second ClientHello so that the client retries with the
+// same shares, even though the server wanted something else.
+TEST_P(TlsConnectTls13, RetryWithTwoShares) {
+  EnsureTlsSetup();
+  EXPECT_EQ(SECSuccess, SSL_SendAdditionalKeyShares(client_->ssl_fd(), 1));
+  client_->SetPacketFilter(std::make_shared<KeyShareReplayer>());
+
+  static const std::vector<SSLNamedGroup> groups = {ssl_grp_ec_secp384r1,
+                                                    ssl_grp_ec_secp521r1};
+  server_->ConfigNamedGroups(groups);
+  ConnectExpectAlert(server_, kTlsAlertIllegalParameter);
+  EXPECT_EQ(SSL_ERROR_BAD_2ND_CLIENT_HELLO, server_->error_code());
+  EXPECT_EQ(SSL_ERROR_ILLEGAL_PARAMETER_ALERT, client_->error_code());
+}
+
 TEST_P(TlsConnectTls13, RetryCallbackAccept) {
   EnsureTlsSetup();
 
   auto accept_hello = [](PRBool firstHello, const PRUint8* clientToken,
                          unsigned int clientTokenLen, PRUint8* appToken,
                          unsigned int* appTokenLen, unsigned int appTokenMax,
                          void* arg) {
     auto* called = reinterpret_cast<bool*>(arg);
@@ -366,40 +381,97 @@ TEST_P(TlsConnectTls13, RetryCallbackRet
   EXPECT_LT(0U, capture_hrr->buffer().len()) << "HelloRetryRequest expected";
   EXPECT_FALSE(capture_key_share->captured())
       << "no key_share extension expected";
 
   auto capture_cookie =
       std::make_shared<TlsExtensionCapture>(ssl_tls13_cookie_xtn);
   client_->SetPacketFilter(capture_cookie);
 
-  Connect();
+  Handshake();
+  CheckConnected();
   EXPECT_EQ(2U, cb_called);
   EXPECT_TRUE(capture_cookie->captured()) << "should have a cookie";
 }
 
+static size_t CountShares(const DataBuffer& key_share) {
+  size_t count = 0;
+  uint32_t len = 0;
+  size_t offset = 2;
+
+  EXPECT_TRUE(key_share.Read(0, 2, &len));
+  EXPECT_EQ(key_share.len() - 2, len);
+  while (offset < key_share.len()) {
+    offset += 2;  // Skip KeyShareEntry.group
+    EXPECT_TRUE(key_share.Read(offset, 2, &len));
+    offset += 2 + len;  // Skip KeyShareEntry.key_exchange
+    ++count;
+  }
+  return count;
+}
+
+TEST_P(TlsConnectTls13, RetryCallbackRetryWithAdditionalShares) {
+  EnsureTlsSetup();
+  EXPECT_EQ(SECSuccess, SSL_SendAdditionalKeyShares(client_->ssl_fd(), 1));
+
+  auto capture_server =
+      std::make_shared<TlsExtensionCapture>(ssl_tls13_key_share_xtn);
+  capture_server->SetHandshakeTypes({kTlsHandshakeHelloRetryRequest});
+  server_->SetPacketFilter(capture_server);
+
+  size_t cb_called = 0;
+  EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback(server_->ssl_fd(),
+                                                      RetryHello, &cb_called));
+
+  // Do the first message exchange.
+  StartConnect();
+  client_->Handshake();
+  server_->Handshake();
+
+  EXPECT_EQ(1U, cb_called) << "callback should be called once here";
+  EXPECT_FALSE(capture_server->captured())
+      << "no key_share extension expected from server";
+
+  auto capture_client_2nd =
+      std::make_shared<TlsExtensionCapture>(ssl_tls13_key_share_xtn);
+  client_->SetPacketFilter(capture_client_2nd);
+
+  Handshake();
+  CheckConnected();
+  EXPECT_EQ(2U, cb_called);
+  EXPECT_TRUE(capture_client_2nd->captured()) << "client should send key_share";
+  EXPECT_EQ(2U, CountShares(capture_client_2nd->extension()))
+      << "client should still send two shares";
+}
+
 // The callback should be run even if we have another reason to send
 // HelloRetryRequest.  In this case, the server sends HRR because the server
 // wants a P-384 key share and the client didn't offer one.
 TEST_P(TlsConnectTls13, RetryCallbackRetryWithGroupMismatch) {
   EnsureTlsSetup();
 
-  auto capture = std::make_shared<TlsExtensionCapture>(ssl_tls13_cookie_xtn);
-  capture->SetHandshakeTypes({kTlsHandshakeHelloRetryRequest});
-  server_->SetPacketFilter(capture);
+  auto capture_cookie =
+      std::make_shared<TlsExtensionCapture>(ssl_tls13_cookie_xtn);
+  capture_cookie->SetHandshakeTypes({kTlsHandshakeHelloRetryRequest});
+  auto capture_key_share =
+      std::make_shared<TlsExtensionCapture>(ssl_tls13_key_share_xtn);
+  capture_key_share->SetHandshakeTypes({kTlsHandshakeHelloRetryRequest});
+  server_->SetPacketFilter(std::make_shared<ChainedPacketFilter>(
+      ChainedPacketFilterInit{capture_cookie, capture_key_share}));
 
   static const std::vector<SSLNamedGroup> groups = {ssl_grp_ec_secp384r1};
   server_->ConfigNamedGroups(groups);
 
   size_t cb_called = 0;
   EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback(server_->ssl_fd(),
                                                       RetryHello, &cb_called));
   Connect();
   EXPECT_EQ(2U, cb_called);
-  EXPECT_TRUE(capture->captured()) << "cookie expected";
+  EXPECT_TRUE(capture_cookie->captured()) << "cookie expected";
+  EXPECT_TRUE(capture_key_share->captured()) << "key_share expected";
 }
 
 static const uint8_t kApplicationToken[] = {0x92, 0x44, 0x00};
 
 SSLHelloRetryRequestAction RetryHelloWithToken(
     PRBool firstHello, const PRUint8* clientToken, unsigned int clientTokenLen,
     PRUint8* appToken, unsigned int* appTokenLen, unsigned int appTokenMax,
     void* arg) {
@@ -731,16 +803,64 @@ TEST_P(TlsKeyExchange13, ConnectEcdhePre
   server_->ConfigNamedGroups(server_groups);
   EXPECT_EQ(SECSuccess, SSL_SendAdditionalKeyShares(client_->ssl_fd(), 1));
 
   Connect();
   CheckKeys();
   CheckKEXDetails(client_groups, client_groups);
 }
 
+// The callback should be run even if we have another reason to send
+// HelloRetryRequest.  In this case, the server sends HRR because the server
+// wants an X25519 key share and the client didn't offer one.
+TEST_P(TlsKeyExchange13,
+       RetryCallbackRetryWithGroupMismatchAndAdditionalShares) {
+  EnsureKeyShareSetup();
+
+  static const std::vector<SSLNamedGroup> client_groups = {
+      ssl_grp_ec_secp256r1, ssl_grp_ec_secp384r1, ssl_grp_ec_curve25519};
+  client_->ConfigNamedGroups(client_groups);
+  static const std::vector<SSLNamedGroup> server_groups = {
+      ssl_grp_ec_curve25519};
+  server_->ConfigNamedGroups(server_groups);
+  EXPECT_EQ(SECSuccess, SSL_SendAdditionalKeyShares(client_->ssl_fd(), 1));
+
+  auto capture_server =
+      std::make_shared<TlsExtensionCapture>(ssl_tls13_key_share_xtn);
+  capture_server->SetHandshakeTypes({kTlsHandshakeHelloRetryRequest});
+  server_->SetPacketFilter(std::make_shared<ChainedPacketFilter>(
+      ChainedPacketFilterInit{capture_hrr_, capture_server}));
+
+  size_t cb_called = 0;
+  EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback(server_->ssl_fd(),
+                                                      RetryHello, &cb_called));
+
+  // Do the first message exchange.
+  StartConnect();
+  client_->Handshake();
+  server_->Handshake();
+
+  EXPECT_EQ(1U, cb_called) << "callback should be called once here";
+  EXPECT_TRUE(capture_server->captured()) << "key_share extension expected";
+
+  uint32_t server_group = 0;
+  EXPECT_TRUE(capture_server->extension().Read(0, 2, &server_group));
+  EXPECT_EQ(ssl_grp_ec_curve25519, static_cast<SSLNamedGroup>(server_group));
+
+  Handshake();
+  CheckConnected();
+  EXPECT_EQ(2U, cb_called);
+  EXPECT_TRUE(shares_capture2_->captured()) << "client should send shares";
+
+  CheckKeys();
+  static const std::vector<SSLNamedGroup> client_shares(
+      client_groups.begin(), client_groups.begin() + 2);
+  CheckKEXDetails(client_groups, client_shares, server_groups[0]);
+}
+
 TEST_F(TlsConnectTest, Select12AfterHelloRetryRequest) {
   EnsureTlsSetup();
   client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2,
                            SSL_LIBRARY_VERSION_TLS_1_3);
   server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2,
                            SSL_LIBRARY_VERSION_TLS_1_3);
   static const std::vector<SSLNamedGroup> client_groups = {
       ssl_grp_ec_secp256r1, ssl_grp_ec_secp521r1};
--- a/gtests/ssl_gtest/tls_connect.cc
+++ b/gtests/ssl_gtest/tls_connect.cc
@@ -701,60 +701,66 @@ void TlsKeyExchangeTest::EnsureKeyShareS
 
 void TlsKeyExchangeTest::ConfigNamedGroups(
     const std::vector<SSLNamedGroup>& groups) {
   client_->ConfigNamedGroups(groups);
   server_->ConfigNamedGroups(groups);
 }
 
 std::vector<SSLNamedGroup> TlsKeyExchangeTest::GetGroupDetails(
-    const DataBuffer& ext) {
+    const std::shared_ptr<TlsExtensionCapture>& capture) {
+  EXPECT_TRUE(capture->captured());
+  const DataBuffer& ext = capture->extension();
+
   uint32_t tmp = 0;
   EXPECT_TRUE(ext.Read(0, 2, &tmp));
   EXPECT_EQ(ext.len() - 2, static_cast<size_t>(tmp));
   EXPECT_TRUE(ext.len() % 2 == 0);
+
   std::vector<SSLNamedGroup> groups;
   for (size_t i = 1; i < ext.len() / 2; i += 1) {
     EXPECT_TRUE(ext.Read(2 * i, 2, &tmp));
     groups.push_back(static_cast<SSLNamedGroup>(tmp));
   }
   return groups;
 }
 
 std::vector<SSLNamedGroup> TlsKeyExchangeTest::GetShareDetails(
-    const DataBuffer& ext) {
+    const std::shared_ptr<TlsExtensionCapture>& capture) {
+  EXPECT_TRUE(capture->captured());
+  const DataBuffer& ext = capture->extension();
+
   uint32_t tmp = 0;
   EXPECT_TRUE(ext.Read(0, 2, &tmp));
   EXPECT_EQ(ext.len() - 2, static_cast<size_t>(tmp));
+
   std::vector<SSLNamedGroup> shares;
   size_t i = 2;
   while (i < ext.len()) {
     EXPECT_TRUE(ext.Read(i, 2, &tmp));
     shares.push_back(static_cast<SSLNamedGroup>(tmp));
     EXPECT_TRUE(ext.Read(i + 2, 2, &tmp));
     i += 4 + tmp;
   }
   EXPECT_EQ(ext.len(), i);
   return shares;
 }
 
 void TlsKeyExchangeTest::CheckKEXDetails(
     const std::vector<SSLNamedGroup>& expected_groups,
     const std::vector<SSLNamedGroup>& expected_shares, bool expect_hrr) {
-  std::vector<SSLNamedGroup> groups =
-      GetGroupDetails(groups_capture_->extension());
+  std::vector<SSLNamedGroup> groups = GetGroupDetails(groups_capture_);
   EXPECT_EQ(expected_groups, groups);
 
   if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) {
     ASSERT_LT(0U, expected_shares.size());
-    std::vector<SSLNamedGroup> shares =
-        GetShareDetails(shares_capture_->extension());
+    std::vector<SSLNamedGroup> shares = GetShareDetails(shares_capture_);
     EXPECT_EQ(expected_shares, shares);
   } else {
-    EXPECT_EQ(0U, shares_capture_->extension().len());
+    EXPECT_FALSE(shares_capture_->captured());
   }
 
   EXPECT_EQ(expect_hrr, capture_hrr_->buffer().len() != 0);
 }
 
 void TlsKeyExchangeTest::CheckKEXDetails(
     const std::vector<SSLNamedGroup>& expected_groups,
     const std::vector<SSLNamedGroup>& expected_shares) {
@@ -766,13 +772,11 @@ void TlsKeyExchangeTest::CheckKEXDetails
     const std::vector<SSLNamedGroup>& expected_shares,
     SSLNamedGroup expected_share2) {
   CheckKEXDetails(expected_groups, expected_shares, true);
 
   for (auto it : expected_shares) {
     EXPECT_NE(expected_share2, it);
   }
   std::vector<SSLNamedGroup> expected_shares2 = {expected_share2};
-  std::vector<SSLNamedGroup> shares =
-      GetShareDetails(shares_capture2_->extension());
-  EXPECT_EQ(expected_shares2, shares);
+  EXPECT_EQ(expected_shares2, GetShareDetails(shares_capture2_));
 }
 }  // namespace nss_test
--- a/gtests/ssl_gtest/tls_connect.h
+++ b/gtests/ssl_gtest/tls_connect.h
@@ -259,18 +259,20 @@ class TlsKeyExchangeTest : public TlsCon
  protected:
   std::shared_ptr<TlsExtensionCapture> groups_capture_;
   std::shared_ptr<TlsExtensionCapture> shares_capture_;
   std::shared_ptr<TlsExtensionCapture> shares_capture2_;
   std::shared_ptr<TlsInspectorRecordHandshakeMessage> capture_hrr_;
 
   void EnsureKeyShareSetup();
   void ConfigNamedGroups(const std::vector<SSLNamedGroup>& groups);
-  std::vector<SSLNamedGroup> GetGroupDetails(const DataBuffer& ext);
-  std::vector<SSLNamedGroup> GetShareDetails(const DataBuffer& ext);
+  std::vector<SSLNamedGroup> GetGroupDetails(
+      const std::shared_ptr<TlsExtensionCapture>& capture);
+  std::vector<SSLNamedGroup> GetShareDetails(
+      const std::shared_ptr<TlsExtensionCapture>& capture);
   void CheckKEXDetails(const std::vector<SSLNamedGroup>& expectedGroups,
                        const std::vector<SSLNamedGroup>& expectedShares);
   void CheckKEXDetails(const std::vector<SSLNamedGroup>& expectedGroups,
                        const std::vector<SSLNamedGroup>& expectedShares,
                        SSLNamedGroup expectedShare2);
 
  private:
   void CheckKEXDetails(const std::vector<SSLNamedGroup>& expectedGroups,
--- a/lib/ssl/tls13con.c
+++ b/lib/ssl/tls13con.c
@@ -1428,20 +1428,31 @@ tls13_HandleClientHelloPart2(sslSocket *
                         illegal_parameter);
             goto loser;
         }
         if (!clientShare) {
             FATAL_ERROR(ss, SSL_ERROR_BAD_2ND_CLIENT_HELLO,
                         illegal_parameter);
             goto loser;
         }
-        if (previousGroup && clientShare->group != previousGroup) {
-            FATAL_ERROR(ss, SSL_ERROR_BAD_2ND_CLIENT_HELLO,
-                        illegal_parameter);
-            goto loser;
+
+        /* If we requested a new key share, check that the client provided just
+         * one of the right type. */
+        if (previousGroup) {
+            if (PR_PREV_LINK(&ss->xtnData.remoteKeyShares) !=
+                PR_NEXT_LINK(&ss->xtnData.remoteKeyShares)) {
+                FATAL_ERROR(ss, SSL_ERROR_BAD_2ND_CLIENT_HELLO,
+                            illegal_parameter);
+                goto loser;
+            }
+            if (clientShare->group != previousGroup) {
+                FATAL_ERROR(ss, SSL_ERROR_BAD_2ND_CLIENT_HELLO,
+                            illegal_parameter);
+                goto loser;
+            }
         }
     }
 
     rv = tls13_MaybeSendHelloRetry(ss, requestedGroup, &hrr);
     if (rv != SECSuccess) {
         goto loser;
     }
     if (hrr) {
--- a/lib/ssl/tls13exthandle.c
+++ b/lib/ssl/tls13exthandle.c
@@ -316,26 +316,16 @@ tls13_ServerHandleKeyShareXtn(const sslS
     }
 
     while (data->len) {
         rv = tls13_HandleKeyShareEntry(ss, xtnData, data);
         if (rv != SECSuccess)
             goto loser;
     }
 
-    /* Check that the client only offered one share if this is
-     * after HRR. */
-    if (ss->ssl3.hs.helloRetry) {
-        if (PR_PREV_LINK(&xtnData->remoteKeyShares) !=
-            PR_NEXT_LINK(&xtnData->remoteKeyShares)) {
-            PORT_SetError(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO);
-            goto loser;
-        }
-    }
-
     return SECSuccess;
 
 loser:
     tls13_DestroyKeyShares(&xtnData->remoteKeyShares);
     return SECFailure;
 }
 
 SECStatus