Bug 1350561 - Test that adding an extension is properly handled, r=ekr
authorMartin Thomson <martin.thomson@gmail.com>
Sat, 25 Mar 2017 09:36:08 -0500
changeset 13250 0985b120ab0247bdcbf340566ccbde69745cc22f
parent 13249 3bcb0acab9260df0e20fcd6a8335cc19dae2b933
child 13251 229ce7626a6a83f7c131a8d0303bf13c75710998
push id2118
push usermartin.thomson@gmail.com
push dateSat, 25 Mar 2017 18:33:57 +0000
reviewersekr
bugs1350561
Bug 1350561 - Test that adding an extension is properly handled, r=ekr Differential Revision: https://nss-review.dev.mozaws.net/D272
gtests/ssl_gtest/ssl_extension_unittest.cc
gtests/ssl_gtest/tls_filter.cc
gtests/ssl_gtest/tls_filter.h
lib/ssl/ssl3ext.c
lib/ssl/tls13con.c
lib/ssl/tls13exthandle.c
--- a/gtests/ssl_gtest/ssl_extension_unittest.cc
+++ b/gtests/ssl_gtest/ssl_extension_unittest.cc
@@ -64,32 +64,21 @@ class TlsExtensionDamager : public TlsEx
 class TlsExtensionInjector : public TlsHandshakeFilter {
  public:
   TlsExtensionInjector(uint16_t ext, DataBuffer& data)
       : extension_(ext), data_(data) {}
 
   virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
                                                const DataBuffer& input,
                                                DataBuffer* output) {
-    size_t offset;
-    if (header.handshake_type() == kTlsHandshakeClientHello) {
-      TlsParser parser(input);
-      if (!TlsExtensionFilter::FindClientHelloExtensions(&parser, header)) {
-        return KEEP;
-      }
-      offset = parser.consumed();
-    } else if (header.handshake_type() == kTlsHandshakeServerHello) {
-      TlsParser parser(input);
-      if (!TlsExtensionFilter::FindServerHelloExtensions(&parser)) {
-        return KEEP;
-      }
-      offset = parser.consumed();
-    } else {
+    TlsParser parser(input);
+    if (!TlsExtensionFilter::FindExtensions(&parser, header)) {
       return KEEP;
     }
+    size_t offset = parser.consumed();
 
     *output = input;
 
     // Increase the size of the extensions.
     uint16_t ext_len;
     memcpy(&ext_len, output->data() + offset, sizeof(ext_len));
     ext_len = htons(ntohs(ext_len) + data_.len() + 4);
     memcpy(output->data() + offset, &ext_len, sizeof(ext_len));
@@ -111,55 +100,71 @@ class TlsExtensionInjector : public TlsH
 
  private:
   const uint16_t extension_;
   const DataBuffer data_;
 };
 
 class TlsExtensionAppender : public TlsHandshakeFilter {
  public:
-  TlsExtensionAppender(uint16_t ext, DataBuffer& data)
-      : extension_(ext), data_(data) {}
+  TlsExtensionAppender(uint8_t handshake_type, uint16_t ext, DataBuffer& data)
+      : handshake_type_(handshake_type), extension_(ext), data_(data) {}
 
   virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
                                                const DataBuffer& input,
                                                DataBuffer* output) {
-    size_t offset;
-    TlsParser parser(input);
-    if (header.handshake_type() == kTlsHandshakeClientHello) {
-      if (!TlsExtensionFilter::FindClientHelloExtensions(&parser, header)) {
-        return KEEP;
-      }
-    } else if (header.handshake_type() == kTlsHandshakeServerHello) {
-      if (!TlsExtensionFilter::FindServerHelloExtensions(&parser)) {
-        return KEEP;
-      }
-    } else {
+    if (header.handshake_type() != handshake_type_) {
       return KEEP;
     }
-    offset = parser.consumed();
+
+    TlsParser parser(input);
+    if (!TlsExtensionFilter::FindExtensions(&parser, header)) {
+      return KEEP;
+    }
     *output = input;
 
-    uint32_t ext_len;
-    if (!parser.Read(&ext_len, 2)) {
-      ADD_FAILURE();
+    // Increase the length of the extensions block.
+    if (!UpdateLength(output, parser.consumed(), 2)) {
       return KEEP;
     }
 
-    ext_len += 4 + data_.len();
-    output->Write(offset, ext_len, 2);
+    // Extensions in Certificate are nested twice.  Increase the size of the
+    // certificate list.
+    if (header.handshake_type() == kTlsHandshakeCertificate) {
+      TlsParser p2(input);
+      if (!p2.SkipVariable(1)) {
+        ADD_FAILURE();
+        return KEEP;
+      }
+      if (!UpdateLength(output, p2.consumed(), 3)) {
+        return KEEP;
+      }
+    }
 
-    offset = output->len();
+    size_t offset = output->len();
     offset = output->Write(offset, extension_, 2);
     WriteVariable(output, offset, data_, 2);
 
     return CHANGE;
   }
 
  private:
+  bool UpdateLength(DataBuffer* output, size_t offset, size_t size) {
+    uint32_t len;
+    if (!output->Read(offset, size, &len)) {
+      ADD_FAILURE();
+      return false;
+    }
+
+    len += 4 + data_.len();
+    output->Write(offset, len, size);
+    return true;
+  }
+
+  const uint8_t handshake_type_;
   const uint16_t extension_;
   const DataBuffer data_;
 };
 
 class TlsExtensionTestBase : public TlsConnectTestBase {
  protected:
   TlsExtensionTestBase(Mode mode, uint16_t version)
       : TlsConnectTestBase(mode, version) {}
@@ -859,19 +864,19 @@ TEST_F(TlsExtensionTest13Stream, ResumeO
   server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO);
 }
 
 TEST_F(TlsExtensionTest13Stream, ResumePskExtensionNotLast) {
   SetupForResume();
 
   const uint8_t empty_buf[] = {0};
   DataBuffer empty(empty_buf, 0);
-  client_->SetPacketFilter(
-      // Inject an unused extension.
-      std::make_shared<TlsExtensionAppender>(0xffff, empty));
+  // Inject an unused extension after the PSK extension.
+  client_->SetPacketFilter(std::make_shared<TlsExtensionAppender>(
+      kTlsHandshakeClientHello, 0xffff, empty));
   ConnectExpectAlert(server_, kTlsAlertIllegalParameter);
   client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
   server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO);
 }
 
 TEST_F(TlsExtensionTest13Stream, ResumeNoKeModes) {
   SetupForResume();
 
@@ -980,22 +985,153 @@ TEST_P(TlsExtensionTest13, EmptyVersionL
   ConnectWithBogusVersionList(ext, sizeof(ext));
 }
 
 TEST_P(TlsExtensionTest13, OddVersionList) {
   static const uint8_t ext[] = {0x00, 0x01, 0x00};
   ConnectWithBogusVersionList(ext, sizeof(ext));
 }
 
+// TODO: this only tests extensions in server messages.  The client can extend
+// Certificate messages, which is not checked here.
+class TlsBogusExtensionTest
+    : public TlsConnectTestBase,
+      public ::testing::WithParamInterface<std::tuple<std::string, uint16_t>> {
+ public:
+  TlsBogusExtensionTest()
+      : TlsConnectTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) {}
+
+ protected:
+  virtual void ConnectAndFail(uint8_t message) = 0;
+
+  void AddFilter(uint8_t message, uint16_t extension) {
+    static uint8_t empty_buf[1] = {0};
+    DataBuffer empty(empty_buf, 0);
+    auto filter = std::make_shared<TlsExtensionAppender>(message, extension, empty);
+    if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) {
+      server_->SetTlsRecordFilter(filter);
+      filter->EnableDecryption();
+    } else {
+      server_->SetPacketFilter(filter);
+    }
+  }
+
+  void Run(uint8_t message, uint16_t extension = 0xff) {
+    EnsureTlsSetup();
+    AddFilter(message, extension);
+    ConnectAndFail(message);
+  }
+};
+
+class TlsBogusExtensionTestPre13 : public TlsBogusExtensionTest {
+ protected:
+  void ConnectAndFail(uint8_t) override {
+    ConnectExpectAlert(client_, kTlsAlertUnsupportedExtension);
+  }
+};
+
+class TlsBogusExtensionTest13 : public TlsBogusExtensionTest {
+ protected:
+  void ConnectAndFail(uint8_t message) override {
+    if (message == kTlsHandshakeHelloRetryRequest) {
+      ConnectExpectAlert(client_, kTlsAlertUnsupportedExtension);
+      return;
+    }
+
+    client_->StartConnect();
+    server_->StartConnect();
+    client_->Handshake();  // ClientHello
+    server_->Handshake();  // ServerHello
+
+    client_->ExpectSendAlert(kTlsAlertUnsupportedExtension);
+    client_->Handshake();
+    if (mode_ == STREAM) {
+      server_->ExpectSendAlert(kTlsAlertBadRecordMac);
+    }
+    server_->Handshake();
+  }
+};
+
+TEST_P(TlsBogusExtensionTestPre13, AddBogusExtensionServerHello) {
+  Run(kTlsHandshakeServerHello);
+}
+
+TEST_P(TlsBogusExtensionTest13, AddBogusExtensionServerHello) {
+  Run(kTlsHandshakeServerHello);
+}
+
+TEST_P(TlsBogusExtensionTest13, AddBogusExtensionEncryptedExtensions) {
+  Run(kTlsHandshakeEncryptedExtensions);
+}
+
+TEST_P(TlsBogusExtensionTest13, AddBogusExtensionCertificate) {
+  Run(kTlsHandshakeCertificate);
+}
+
+TEST_P(TlsBogusExtensionTest13, AddBogusExtensionCertificateRequest) {
+  server_->RequestClientAuth(false);
+  Run(kTlsHandshakeCertificateRequest);
+}
+
+TEST_P(TlsBogusExtensionTest13, AddBogusExtensionHelloRetryRequest) {
+  static const std::vector<SSLNamedGroup> groups = {ssl_grp_ec_secp384r1};
+  server_->ConfigNamedGroups(groups);
+
+  Run(kTlsHandshakeHelloRetryRequest);
+}
+
+TEST_P(TlsBogusExtensionTest13, AddVersionExtensionServerHello) {
+  Run(kTlsHandshakeServerHello, ssl_tls13_supported_versions_xtn);
+}
+
+TEST_P(TlsBogusExtensionTest13, AddVersionExtensionEncryptedExtensions) {
+  Run(kTlsHandshakeEncryptedExtensions, ssl_tls13_supported_versions_xtn);
+}
+
+TEST_P(TlsBogusExtensionTest13, AddVersionExtensionCertificate) {
+  Run(kTlsHandshakeCertificate, ssl_tls13_supported_versions_xtn);
+}
+
+TEST_P(TlsBogusExtensionTest13, AddVersionExtensionCertificateRequest) {
+  server_->RequestClientAuth(false);
+  Run(kTlsHandshakeCertificateRequest, ssl_tls13_supported_versions_xtn);
+}
+
+TEST_P(TlsBogusExtensionTest13, AddVersionExtensionHelloRetryRequest) {
+  static const std::vector<SSLNamedGroup> groups = {ssl_grp_ec_secp384r1};
+  server_->ConfigNamedGroups(groups);
+
+  Run(kTlsHandshakeHelloRetryRequest, ssl_tls13_supported_versions_xtn);
+}
+
+
+// NewSessionTicket allows unknown extensions AND it isn't protected by the
+// Finished.  So adding an unknown extension doesn't cause an error.
+TEST_P(TlsBogusExtensionTest13, AddBogusExtensionNewSessionTicket) {
+  ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+
+  AddFilter(kTlsHandshakeNewSessionTicket, 0xff);
+  Connect();
+  SendReceive();
+  CheckKeys();
+
+  Reset();
+  ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+  ExpectResumption(RESUME_TICKET);
+  Connect();
+  SendReceive();
+}
+
 INSTANTIATE_TEST_CASE_P(ExtensionStream, TlsExtensionTestGeneric,
                         ::testing::Combine(TlsConnectTestBase::kTlsModesStream,
                                            TlsConnectTestBase::kTlsVAll));
-INSTANTIATE_TEST_CASE_P(ExtensionDatagram, TlsExtensionTestGeneric,
-                        ::testing::Combine(TlsConnectTestBase::kTlsModesAll,
-                                           TlsConnectTestBase::kTlsV11Plus));
+INSTANTIATE_TEST_CASE_P(
+    ExtensionDatagram, TlsExtensionTestGeneric,
+    ::testing::Combine(TlsConnectTestBase::kTlsModesDatagram,
+                       TlsConnectTestBase::kTlsV11Plus));
 INSTANTIATE_TEST_CASE_P(ExtensionDatagramOnly, TlsExtensionTestDtls,
                         TlsConnectTestBase::kTlsV11Plus);
 
 INSTANTIATE_TEST_CASE_P(ExtensionTls12Plus, TlsExtensionTest12Plus,
                         ::testing::Combine(TlsConnectTestBase::kTlsModesAll,
                                            TlsConnectTestBase::kTlsV12Plus));
 
 INSTANTIATE_TEST_CASE_P(ExtensionPre13Stream, TlsExtensionTestPre13,
@@ -1003,9 +1139,21 @@ INSTANTIATE_TEST_CASE_P(ExtensionPre13St
                                            TlsConnectTestBase::kTlsV10ToV12));
 INSTANTIATE_TEST_CASE_P(ExtensionPre13Datagram, TlsExtensionTestPre13,
                         ::testing::Combine(TlsConnectTestBase::kTlsModesAll,
                                            TlsConnectTestBase::kTlsV11V12));
 
 INSTANTIATE_TEST_CASE_P(ExtensionTls13, TlsExtensionTest13,
                         TlsConnectTestBase::kTlsModesAll);
 
-}  // namespace nspr_test
+INSTANTIATE_TEST_CASE_P(BogusExtensionStream, TlsBogusExtensionTestPre13,
+                        ::testing::Combine(TlsConnectTestBase::kTlsModesStream,
+                                           TlsConnectTestBase::kTlsV10ToV12));
+INSTANTIATE_TEST_CASE_P(
+    BogusExtensionDatagram, TlsBogusExtensionTestPre13,
+    ::testing::Combine(TlsConnectTestBase::kTlsModesDatagram,
+                       TlsConnectTestBase::kTlsV11V12));
+
+INSTANTIATE_TEST_CASE_P(BogusExtension13, TlsBogusExtensionTest13,
+                        ::testing::Combine(TlsConnectTestBase::kTlsModesAll,
+                                           TlsConnectTestBase::kTlsV13));
+
+}  // namespace nss_test
--- a/gtests/ssl_gtest/tls_filter.cc
+++ b/gtests/ssl_gtest/tls_filter.cc
@@ -381,38 +381,17 @@ PacketFilter::Action ChainedPacketFilter
     if (action == CHANGE) {
       in = *output;
       changed = true;
     }
   }
   return changed ? CHANGE : KEEP;
 }
 
-PacketFilter::Action TlsExtensionFilter::FilterHandshake(
-    const HandshakeHeader& header, const DataBuffer& input,
-    DataBuffer* output) {
-  if (header.handshake_type() == kTlsHandshakeClientHello) {
-    TlsParser parser(input);
-    if (!FindClientHelloExtensions(&parser, header)) {
-      return KEEP;
-    }
-    return FilterExtensions(&parser, input, output);
-  }
-  if (header.handshake_type() == kTlsHandshakeServerHello) {
-    TlsParser parser(input);
-    if (!FindServerHelloExtensions(&parser)) {
-      return KEEP;
-    }
-    return FilterExtensions(&parser, input, output);
-  }
-  return KEEP;
-}
-
-bool TlsExtensionFilter::FindClientHelloExtensions(TlsParser* parser,
-                                                   const TlsVersioned& header) {
+bool FindClientHelloExtensions(TlsParser* parser, const TlsVersioned& header) {
   if (!parser->Skip(2 + 32)) {  // version + random
     return false;
   }
   if (!parser->SkipVariable(1)) {  // session ID
     return false;
   }
   if (header.is_dtls() && !parser->SkipVariable(1)) {  // DTLS cookie
     return false;
@@ -421,17 +400,17 @@ bool TlsExtensionFilter::FindClientHello
     return false;
   }
   if (!parser->SkipVariable(1)) {  // compression methods
     return false;
   }
   return true;
 }
 
-bool TlsExtensionFilter::FindServerHelloExtensions(TlsParser* parser) {
+bool FindServerHelloExtensions(TlsParser* parser, const TlsVersioned& header) {
   uint32_t vtmp;
   if (!parser->Read(&vtmp, 2)) {
     return false;
   }
   uint16_t version = static_cast<uint16_t>(vtmp);
   if (!parser->Skip(32)) {  // random
     return false;
   }
@@ -446,16 +425,102 @@ bool TlsExtensionFilter::FindServerHello
   if (NormalizeTlsVersion(version) <= SSL_LIBRARY_VERSION_TLS_1_2) {
     if (!parser->Skip(1)) {  // compression method
       return false;
     }
   }
   return true;
 }
 
+static bool FindHelloRetryExtensions(TlsParser* parser,
+                                     const TlsVersioned& header) {
+  // TODO for -19 add cipher suite
+  if (!parser->Skip(2)) {  // version
+    return false;
+  }
+  return true;
+}
+
+bool FindEncryptedExtensions(TlsParser* parser, const TlsVersioned& header) {
+  return true;
+}
+
+static bool FindCertReqExtensions(TlsParser* parser,
+                                  const TlsVersioned& header) {
+  if (!parser->SkipVariable(1)) {  // request context
+    return false;
+  }
+  // TODO remove the next two for -19
+  if (!parser->SkipVariable(2)) {  // signature_algorithms
+    return false;
+  }
+  if (!parser->SkipVariable(2)) {  // certificate_authorities
+    return false;
+  }
+  return true;
+}
+
+// Only look at the EE cert for this one.
+static bool FindCertificateExtensions(TlsParser* parser,
+                                      const TlsVersioned& header) {
+  if (!parser->SkipVariable(1)) {  // request context
+    return false;
+  }
+  if (!parser->Skip(3)) {  // length of certificate list
+    return false;
+  }
+  if (!parser->SkipVariable(3)) {  // ASN1Cert
+    return false;
+  }
+  return true;
+}
+
+static bool FindNewSessionTicketExtensions(TlsParser* parser,
+                                           const TlsVersioned& header) {
+  if (!parser->Skip(8)) {  // lifetime, age add
+    return false;
+  }
+  if (!parser->SkipVariable(2)) {  // ticket
+    return false;
+  }
+  return true;
+}
+
+static const std::map<uint16_t, TlsExtensionFinder> kExtensionFinders = {
+    {kTlsHandshakeClientHello, FindClientHelloExtensions},
+    {kTlsHandshakeServerHello, FindServerHelloExtensions},
+    {kTlsHandshakeHelloRetryRequest, FindHelloRetryExtensions},
+    {kTlsHandshakeEncryptedExtensions, FindEncryptedExtensions},
+    {kTlsHandshakeCertificateRequest, FindCertReqExtensions},
+    {kTlsHandshakeCertificate, FindCertificateExtensions},
+    {kTlsHandshakeNewSessionTicket, FindNewSessionTicketExtensions}};
+
+bool TlsExtensionFilter::FindExtensions(TlsParser* parser,
+                                        const HandshakeHeader& header) {
+  auto it = kExtensionFinders.find(header.handshake_type());
+  if (it == kExtensionFinders.end()) {
+    return false;
+  }
+  return (it->second)(parser, header);
+}
+
+PacketFilter::Action TlsExtensionFilter::FilterHandshake(
+    const HandshakeHeader& header, const DataBuffer& input,
+    DataBuffer* output) {
+  if (handshake_types_.count(header.handshake_type()) == 0) {
+    return KEEP;
+  }
+
+  TlsParser parser(input);
+  if (!FindExtensions(&parser, header)) {
+    return KEEP;
+  }
+  return FilterExtensions(&parser, input, output);
+}
+
 PacketFilter::Action TlsExtensionFilter::FilterExtensions(
     TlsParser* parser, const DataBuffer& input, DataBuffer* output) {
   size_t length_offset = parser->consumed();
   uint32_t all_extensions;
   if (!parser->Read(&all_extensions, 2)) {
     return KEEP;  // no extensions, odd but OK
   }
   if (all_extensions != parser->remaining()) {
--- a/gtests/ssl_gtest/tls_filter.h
+++ b/gtests/ssl_gtest/tls_filter.h
@@ -4,16 +4,17 @@
  * License, v. 2.0. If a copy of the MPL was not distributed with this file,
  * You can obtain one at http://mozilla.org/MPL/2.0/. */
 
 #ifndef tls_filter_h_
 #define tls_filter_h_
 
 #include <functional>
 #include <memory>
+#include <set>
 #include <vector>
 
 #include "test_io.h"
 #include "tls_parser.h"
 #include "tls_protect.h"
 
 extern "C" {
 #include "libssl_internals.h"
@@ -245,35 +246,46 @@ class ChainedPacketFilter : public Packe
 
   // Takes ownership of the filter.
   void Add(std::shared_ptr<PacketFilter> filter) { filters_.push_back(filter); }
 
  private:
   std::vector<std::shared_ptr<PacketFilter>> filters_;
 };
 
+typedef std::function<bool(TlsParser* parser, const TlsVersioned& header)>
+    TlsExtensionFinder;
+
 class TlsExtensionFilter : public TlsHandshakeFilter {
+ public:
+  TlsExtensionFilter() : handshake_types_() {
+    handshake_types_.insert(kTlsHandshakeClientHello);
+    handshake_types_.insert(kTlsHandshakeServerHello);
+  }
+
+  TlsExtensionFilter(const std::set<uint8_t>& types)
+      : handshake_types_(types) {}
+
+  static bool FindExtensions(TlsParser* parser, const HandshakeHeader& header);
+
  protected:
   PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
                                        const DataBuffer& input,
                                        DataBuffer* output) override;
 
   virtual PacketFilter::Action FilterExtension(uint16_t extension_type,
                                                const DataBuffer& input,
                                                DataBuffer* output) = 0;
 
- public:
-  static bool FindClientHelloExtensions(TlsParser* parser,
-                                        const TlsVersioned& header);
-  static bool FindServerHelloExtensions(TlsParser* parser);
-
  private:
   PacketFilter::Action FilterExtensions(TlsParser* parser,
                                         const DataBuffer& input,
                                         DataBuffer* output);
+
+  std::set<uint8_t> handshake_types_;
 };
 
 class TlsExtensionCapture : public TlsExtensionFilter {
  public:
   TlsExtensionCapture(uint16_t ext, bool last = false)
       : extension_(ext), captured_(false), last_(last), data_() {}
 
   const DataBuffer& extension() const { return data_; }
--- a/lib/ssl/ssl3ext.c
+++ b/lib/ssl/ssl3ext.c
@@ -82,16 +82,20 @@ static const ssl3ExtensionHandler newSes
 
 /* This table is used by the client to handle server certificates in TLS 1.3 */
 static const ssl3ExtensionHandler serverCertificateHandlers[] = {
     { ssl_signed_cert_timestamp_xtn, &ssl3_ClientHandleSignedCertTimestampXtn },
     { ssl_cert_status_xtn, &ssl3_ClientHandleStatusRequestXtn },
     { -1, NULL }
 };
 
+static const ssl3ExtensionHandler certificateRequestHandlers[] = {
+    { -1, NULL }
+};
+
 /* Tables of functions to format TLS hello extensions, one function per
  * extension.
  * These static tables are for the formatting of client hello extensions.
  * The server's table of hello senders is dynamic, in the socket struct,
  * and sender functions are registered there.
  * NB: the order of these extensions can have an impact on compatibility. Some
  * servers (e.g. Tomcat) will terminate the connection if the last extension in
  * the client hello is empty (for example, the extended master secret
@@ -244,17 +248,20 @@ ssl3_FindExtension(sslSocket *ss, SSLExt
  * In TLS >= 1.3, the client checks that extensions appear in the
  * right phase.
  */
 SECStatus
 ssl3_HandleParsedExtensions(sslSocket *ss,
                             SSL3HandshakeType handshakeMessage)
 {
     const ssl3ExtensionHandler *handlers;
-    PRBool isTLS13 = ss->version >= SSL_LIBRARY_VERSION_TLS_1_3;
+    /* HelloRetryRequest doesn't set ss->version. It might be safe to
+     * do so, but we weren't entirely sure. TODO(ekr@rtfm.com). */
+    PRBool isTLS13 = (ss->version >= SSL_LIBRARY_VERSION_TLS_1_3) ||
+            (handshakeMessage == hello_retry_request);
     PRCList *cursor;
 
     switch (handshakeMessage) {
         case client_hello:
             handlers = clientHelloHandlers;
             break;
         case new_session_ticket:
             PORT_Assert(ss->version >= SSL_LIBRARY_VERSION_TLS_1_3);
@@ -272,16 +279,20 @@ ssl3_HandleParsedExtensions(sslSocket *s
             } else {
                 handlers = serverHelloHandlersSSL3;
             }
             break;
         case certificate:
             PORT_Assert(!ss->sec.isServer);
             handlers = serverCertificateHandlers;
             break;
+        case certificate_request:
+            PORT_Assert(!ss->sec.isServer);
+            handlers = certificateRequestHandlers;
+            break;
         default:
             PORT_SetError(SEC_ERROR_LIBRARY_FAILURE);
             PORT_Assert(0);
             return SECFailure;
     }
 
     for (cursor = PR_NEXT_LINK(&ss->ssl3.hs.remoteExtensions);
          cursor != &ss->ssl3.hs.remoteExtensions;
--- a/lib/ssl/tls13con.c
+++ b/lib/ssl/tls13con.c
@@ -1752,17 +1752,17 @@ tls13_HandleHelloRetryRequest(sslSocket 
 
 static SECStatus
 tls13_HandleCertificateRequest(sslSocket *ss, SSL3Opaque *b, PRUint32 length)
 {
     SECStatus rv;
     TLS13CertificateRequest *certRequest = NULL;
     SECItem context = { siBuffer, NULL, 0 };
     PLArenaPool *arena;
-    PRUint32 extensionsLength;
+    SECItem extensionsData = { siBuffer, NULL, 0 };
 
     SSL_TRC(3, ("%d: TLS13[%d]: handle certificate_request sequence",
                 SSL_GETPID(), ss->fd));
 
     PORT_Assert(ss->opt.noLocks || ssl_HaveRecvBufLock(ss));
     PORT_Assert(ss->opt.noLocks || ssl_HaveSSL3HandshakeLock(ss));
 
     /* Client */
@@ -1810,24 +1810,26 @@ tls13_HandleCertificateRequest(sslSocket
         goto loser;
     }
 
     rv = ssl3_ParseCertificateRequestCAs(ss, &b, &length, arena,
                                          &certRequest->ca_list);
     if (rv != SECSuccess)
         goto loser; /* alert already sent */
 
-    /* Verify that the extensions length is correct. */
-    rv = ssl3_ConsumeHandshakeNumber(ss, &extensionsLength, 2, &b, &length);
+    /* Verify that the extensions are sane. */
+    rv = ssl3_ConsumeHandshakeVariable(ss, &extensionsData, 2, &b, &length);
     if (rv != SECSuccess) {
-        goto loser; /* alert already sent */
-    }
-    if (extensionsLength != length) {
-        FATAL_ERROR(ss, SSL_ERROR_RX_MALFORMED_CERT_REQUEST,
-                    illegal_parameter);
+        goto loser;
+    }
+
+    /* Process all the extensions (note: currently a no-op). */
+    rv = ssl3_HandleExtensions(ss, &extensionsData.data, &extensionsData.len,
+                               certificate_request);
+    if (rv != SECSuccess) {
         goto loser;
     }
 
     rv = SECITEM_CopyItem(arena, &certRequest->context, &context);
     if (rv != SECSuccess)
         goto loser;
 
     TLS13_SET_HS_STATE(ss, wait_server_cert);
@@ -4029,17 +4031,18 @@ tls13_ExtensionAllowed(PRUint16 extensio
 {
     unsigned int i;
 
     PORT_Assert((message == client_hello) ||
                 (message == server_hello) ||
                 (message == hello_retry_request) ||
                 (message == encrypted_extensions) ||
                 (message == new_session_ticket) ||
-                (message == certificate));
+                (message == certificate) ||
+                (message == certificate_request));
 
     for (i = 0; i < PR_ARRAY_SIZE(KnownExtensions); i++) {
         if (KnownExtensions[i].ex_value == extension)
             break;
     }
     if (i == PR_ARRAY_SIZE(KnownExtensions)) {
         /* We have never heard of this extension which is OK
          * in client_hello and new_session_ticket. */
--- a/lib/ssl/tls13exthandle.c
+++ b/lib/ssl/tls13exthandle.c
@@ -915,16 +915,19 @@ tls13_ClientSendSupportedVersionsXtn(con
             return -1;
 
         for (version = ss->vrange.max; version >= ss->vrange.min; --version) {
             rv = ssl3_ExtAppendHandshakeNumber(
                 ss, tls13_EncodeDraftVersion(version), 2);
             if (rv != SECSuccess)
                 return -1;
         }
+
+        xtnData->advertised[xtnData->numAdvertised++] =
+            ssl_tls13_supported_versions_xtn;
     }
 
     return extensions_len;
 }
 
 /*
  *    struct {
  *        opaque cookie<1..2^16-1>;