Bug 1720235 SSL handling of signature algorithms ignores environmental invalid algorithms.
authorRobert Relyea <rrelyea@redhat.com>
Tue, 20 Jul 2021 13:02:54 -0700
changeset 15966 c71bb1bedf7d9d45b34cd483097d30c8a3f01602
parent 15965 8f41147c21926a70a1163f9b91b01b735ed621f1
child 15967 e9236397be133d5e3bcb4f162f0cdbc31f1e4153
push id4000
push userrrelyea@redhat.com
push dateFri, 23 Jul 2021 17:00:04 +0000
bugs1720235
Bug 1720235 SSL handling of signature algorithms ignores environmental invalid algorithms. Our QA is quite extensive on handling of alert corner cases. Our code that checks if a signature algorithm is supported ignores the role of policy. If SHA1 is turned off by policy, for instance, we only detect that late in the game. This shows up in our test cases as decrypt_alerts rather than illegal_parameter or handshake_error alerts. It also shows up in us apparently accepting a client auth request which only has invalid alerts. We also don't handle filtering out signature algorithms that are illegal in tls 13 mode. This patch not only fixes these issues, but also issues where we proposing signature algorithms in server mode that we don't support by policy. This patch includes: In gtests: 1) adding support for policy in ssl_gtests. Currently both the server an client will run with the same policy. The patch allows us to set policy on one and keeping the old policy on the other. 2) Update extension tests which failed in tls 1.3 because the patch now correctly rejects illegal tls 1.3 auth values. The test was updated to use a legal auth value in tls 1.3 (so we are correctly testing the format issue. 3) Update extension tests to handle the case where we try to use an illegal value for tls 1.3. 4) add tests to ssl_auth_unittests.cc to make sure we can properly connect even when several auth methods are turned off by policy (make sure we don't advertize them on the client side, and that the server doesn't select them when the client doesn't advertize them). 5) add tests to ssl_auth_unittests.cc to make sure we don't send empty client auth requests when the requester only sends invalid auth requests. patch itself: 1) The handling of policy checks for ssl schemes were scattered in various locations. I've consolidated them into a single function. That function now checks for NSS_ALG_USE_IN_ANY_SIGNATURE as if this is off by policy, we will fail if we try to use the algorithm in a signature in any case. NSS now supports policy on all signature algorithms, not just DSA, so we need to check the policy of all the algorithms. 2) to support the policy check on the signature algorithms, I added a new ssl_AuthTypeToOID, which also replaces our switch in checking if the SPKI matches our auth type. 3) ssl_SignatureSchemeValid now accepts an spkiOid of SEC_OID_UNKNOWN. To allow us to filter signature schemes based on version and policy restrictions before we try to select a certificate. This prevents us from sending empty client auth messages when we are presented with only invalid signature schemes. 4) We filter supported algorithms against policy early, preventing us from sending, or even setting invalid algorithms if they are turned off by policy. 5) ssl ConsumeSignatureScheme was handling alerts inconsistently. The Consume could send an allert in it's failure case, but the check of scheme validity wouldn't sent an alert. The collers were inconstent as well. Now ssl_ConsumeSignatureScheme always sends and alert on failure, and the callers do not. Differential Revision: https://phabricator.services.mozilla.com/D120392
gtests/ssl_gtest/nss_policy.h
gtests/ssl_gtest/ssl_auth_unittest.cc
gtests/ssl_gtest/ssl_extension_unittest.cc
gtests/ssl_gtest/tls_agent.cc
gtests/ssl_gtest/tls_agent.h
gtests/ssl_gtest/tls_connect.cc
lib/ssl/ssl3con.c
lib/ssl/sslimpl.h
new file mode 100644
--- /dev/null
+++ b/gtests/ssl_gtest/nss_policy.h
@@ -0,0 +1,78 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#ifndef nss_policy_h_
+#define nss_policy_h_
+
+#include "prtypes.h"
+#include "secoid.h"
+
+namespace nss_test {
+
+// container class to hold all a temp policy
+class NssPolicy {
+ public:
+  NssPolicy() : oid_(SEC_OID_UNKNOWN), set_(0), clear_(0) {}
+  NssPolicy(SECOidTag _oid, PRUint32 _set, PRUint32 _clear)
+      : oid_(_oid), set_(_set), clear_(_clear) {}
+  NssPolicy(const NssPolicy &p)
+      : oid_(p.oid_), set_(p.set_), clear_(p.clear_) {}
+  // clone the current policy for this oid
+  NssPolicy(SECOidTag _oid) : oid_(_oid), set_(0), clear_(0) {
+    NSS_GetAlgorithmPolicy(_oid, &set_);
+    clear_ = ~set_;
+  }
+  SECOidTag oid(void) const { return oid_; }
+  PRUint32 set(void) const { return set_; }
+  PRUint32 clear(void) const { return clear_; }
+  operator bool() const { return oid_ != SEC_OID_UNKNOWN; }
+
+ private:
+  SECOidTag oid_;
+  PRUint32 set_;
+  PRUint32 clear_;
+};
+
+// set the policy indicated in NssPolicy and restor the old policy
+// when we go out of scope
+class NssManagePolicy {
+ public:
+  NssManagePolicy(const NssPolicy &p) : policy_(p), current_(~(PRUint32)0) {
+    if (p) {
+      (void)NSS_GetAlgorithmPolicy(p.oid(), &current_);
+      (void)NSS_SetAlgorithmPolicy(p.oid(), p.set(), p.clear());
+    }
+  }
+  ~NssManagePolicy() {
+    if (policy_) {
+      (void)NSS_SetAlgorithmPolicy(policy_.oid(), current_, ~current_);
+    }
+  }
+
+ private:
+  NssPolicy policy_;
+  PRUint32 current_;
+};
+
+// wrapping PRFileDesc this way ensures that tests that attempt to access
+// PRFileDesc always correctly apply
+// the policy that was bound to that socket with TlsAgent::SetPolicy().
+class NssManagedFileDesc {
+ public:
+  NssManagedFileDesc(PRFileDesc *fd, const NssPolicy &policy)
+      : fd_(fd), managed_policy_(policy) {}
+  PRFileDesc *get(void) const { return fd_; }
+  operator PRFileDesc *() const { return fd_; }
+  bool operator==(PRFileDesc *fd) const { return fd_ == fd; }
+
+ private:
+  PRFileDesc *fd_;
+  NssManagePolicy managed_policy_;
+};
+
+}  // namespace nss_test
+
+#endif
--- a/gtests/ssl_gtest/ssl_auth_unittest.cc
+++ b/gtests/ssl_gtest/ssl_auth_unittest.cc
@@ -1793,16 +1793,175 @@ TEST_P(TlsSignatureSchemeConfiguration, 
   Reset(certificate_);
   EnsureTlsSetup();
   client_->SetSignatureSchemes(&signature_scheme_, 1);
   server_->SetSignatureSchemes(&signature_scheme_, 1);
   Connect();
   CheckKeys(ssl_kea_ecdh, ssl_grp_ec_curve25519, auth_type_, signature_scheme_);
 }
 
+class Tls12CertificateRequestReplacer : public TlsHandshakeFilter {
+ public:
+  Tls12CertificateRequestReplacer(const std::shared_ptr<TlsAgent>& a,
+                                  SSLSignatureScheme scheme)
+      : TlsHandshakeFilter(a, {kTlsHandshakeCertificateRequest}),
+        scheme_(scheme) {}
+
+  virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
+                                               const DataBuffer& input,
+                                               DataBuffer* output) {
+    uint32_t offset = 0;
+
+    if (header.handshake_type() != ssl_hs_certificate_request) {
+      return KEEP;
+    }
+
+    *output = input;
+
+    uint32_t types_len = 0;
+    if (!output->Read(offset, 1, &types_len)) {
+      ADD_FAILURE();
+      return KEEP;
+    }
+    offset += 1 + types_len;
+    uint32_t scheme_len = 0;
+    if (!output->Read(offset, 2, &scheme_len)) {
+      ADD_FAILURE();
+      return KEEP;
+    }
+    DataBuffer schemes;
+    schemes.Write(0, 2, 2);
+    schemes.Write(2, scheme_, 2);
+    output->Write(offset, 2, schemes.len());
+    output->Splice(schemes, offset + 2, scheme_len);
+
+    return CHANGE;
+  }
+
+ private:
+  SSLSignatureScheme scheme_;
+};
+
+//
+// Test how policy interacts with client auth connections
+//
+
+// TLS/DTLS version algorithm policy
+typedef std::tuple<SSLProtocolVariant, uint16_t, SECOidTag, PRUint32>
+    PolicySignatureSchemeProfile;
+
+// Only TLS 1.2 handles client auth schemes inside
+// the certificate request packet, so our failure tests for
+// those kinds of connections only occur here.
+class TlsConnectAuthWithPolicyTls12
+    : public TlsConnectTestBase,
+      public ::testing::WithParamInterface<PolicySignatureSchemeProfile> {
+ public:
+  TlsConnectAuthWithPolicyTls12()
+      : TlsConnectTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) {
+    alg_ = std::get<2>(GetParam());
+    policy_ = std::get<3>(GetParam());
+    // use the algorithm to select which single scheme to deploy
+    // We use these schemes to force servers sending schemes the client
+    // didn't advertise to make sure the client will still filter these
+    // by policy and detect that no valid schemes were presented, rather
+    // than sending an empty client auth message.
+    switch (alg_) {
+      case SEC_OID_SHA256:
+      case SEC_OID_PKCS1_RSA_PSS_SIGNATURE:
+        scheme_ = ssl_sig_rsa_pss_pss_sha256;
+        break;
+      case SEC_OID_PKCS1_RSA_ENCRYPTION:
+        scheme_ = ssl_sig_rsa_pkcs1_sha256;
+        break;
+      case SEC_OID_ANSIX962_EC_PUBLIC_KEY:
+        scheme_ = ssl_sig_ecdsa_secp256r1_sha256;
+        break;
+      default:
+        ADD_FAILURE() << "need to update algorithm table in "
+                         "TlsConnectAuthWithPolicyTls12";
+        scheme_ = ssl_sig_none;
+        break;
+    }
+  }
+
+ protected:
+  SECOidTag alg_;
+  PRUint32 policy_;
+  SSLSignatureScheme scheme_;
+};
+
+// Only TLS 1.2 and greater looks at schemes extensions on client auth
+class TlsConnectAuthWithPolicyTls12Plus
+    : public TlsConnectTestBase,
+      public ::testing::WithParamInterface<PolicySignatureSchemeProfile> {
+ public:
+  TlsConnectAuthWithPolicyTls12Plus()
+      : TlsConnectTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) {
+    alg_ = std::get<2>(GetParam());
+    policy_ = std::get<3>(GetParam());
+  }
+
+ protected:
+  SECOidTag alg_;
+  PRUint32 policy_;
+};
+
+// make sure we can turn single algorithms off by policy an still connect
+// this is basically testing that we are properly filtering our schemes
+// by policy before communicating them to the server, and that the
+// server is respecting our choices
+TEST_P(TlsConnectAuthWithPolicyTls12Plus, PolicySuccessTest) {
+  // in TLS 1.3, RSA PKCS1 is restricted. If we are also
+  // restricting RSA PSS by policy, we can't use the default
+  // RSA certificate as the server cert, switch to ECDSA
+  if ((version_ >= SSL_LIBRARY_VERSION_TLS_1_3) &&
+      (alg_ == SEC_OID_PKCS1_RSA_PSS_SIGNATURE)) {
+    Reset(TlsAgent::kServerEcdsa256);
+  }
+  client_->SetPolicy(alg_, 0, policy_);  // Disable policy for client
+  client_->SetupClientAuth();
+  server_->RequestClientAuth(false);
+  Connect();
+}
+
+// make sure we fail if the server ignores our policy preference and
+// requests client auth with a scheme we don't support
+TEST_P(TlsConnectAuthWithPolicyTls12, PolicyFailureTest) {
+  client_->SetPolicy(alg_, 0, policy_);
+  client_->SetupClientAuth();
+  server_->RequestClientAuth(false);
+  MakeTlsFilter<Tls12CertificateRequestReplacer>(server_, scheme_);
+  ConnectExpectAlert(client_, kTlsAlertHandshakeFailure);
+  client_->CheckErrorCode(SSL_ERROR_UNSUPPORTED_SIGNATURE_ALGORITHM);
+  server_->CheckErrorCode(SSL_ERROR_HANDSHAKE_FAILURE_ALERT);
+}
+
+INSTANTIATE_TEST_SUITE_P(
+    SignaturesWithPolicyFail, TlsConnectAuthWithPolicyTls12,
+    ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll,
+                       TlsConnectTestBase::kTlsV12,
+                       ::testing::Values(SEC_OID_SHA256,
+                                         SEC_OID_PKCS1_RSA_PSS_SIGNATURE,
+                                         SEC_OID_PKCS1_RSA_ENCRYPTION,
+                                         SEC_OID_ANSIX962_EC_PUBLIC_KEY),
+                       ::testing::Values(NSS_USE_ALG_IN_SSL_KX,
+                                         NSS_USE_ALG_IN_ANY_SIGNATURE)));
+
+INSTANTIATE_TEST_SUITE_P(
+    SignaturesWithPolicySuccess, TlsConnectAuthWithPolicyTls12Plus,
+    ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll,
+                       TlsConnectTestBase::kTlsV12Plus,
+                       ::testing::Values(SEC_OID_SHA256,
+                                         SEC_OID_PKCS1_RSA_PSS_SIGNATURE,
+                                         SEC_OID_PKCS1_RSA_ENCRYPTION,
+                                         SEC_OID_ANSIX962_EC_PUBLIC_KEY),
+                       ::testing::Values(NSS_USE_ALG_IN_SSL_KX,
+                                         NSS_USE_ALG_IN_ANY_SIGNATURE)));
+
 INSTANTIATE_TEST_SUITE_P(
     SignatureSchemeRsa, TlsSignatureSchemeConfiguration,
     ::testing::Combine(
         TlsConnectTestBase::kTlsVariantsAll, TlsConnectTestBase::kTlsV12,
         ::testing::Values(TlsAgent::kServerRsaSign),
         ::testing::Values(ssl_auth_rsa_sign),
         ::testing::Values(ssl_sig_rsa_pkcs1_sha256, ssl_sig_rsa_pkcs1_sha384,
                           ssl_sig_rsa_pkcs1_sha512, ssl_sig_rsa_pss_rsae_sha256,
--- a/gtests/ssl_gtest/ssl_extension_unittest.cc
+++ b/gtests/ssl_gtest/ssl_extension_unittest.cc
@@ -425,17 +425,20 @@ TEST_P(TlsExtensionTestDtls, SrtpOdd) {
 TEST_P(TlsExtensionTest12Plus, SignatureAlgorithmsBadLength) {
   const uint8_t val[] = {0x00};
   DataBuffer extension(val, sizeof(val));
   ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
       client_, ssl_signature_algorithms_xtn, extension));
 }
 
 TEST_P(TlsExtensionTest12Plus, SignatureAlgorithmsTrailingData) {
-  const uint8_t val[] = {0x00, 0x02, 0x04, 0x01, 0x00};  // sha-256, rsa
+  // make sure the test uses an algorithm that is legal for
+  // tls 1.3 (or tls 1.3 will throw a handshake failure alert
+  // instead of a decode error alert)
+  const uint8_t val[] = {0x00, 0x02, 0x08, 0x09, 0x00};  // sha-256, rsa-pss-pss
   DataBuffer extension(val, sizeof(val));
   ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
       client_, ssl_signature_algorithms_xtn, extension));
 }
 
 TEST_P(TlsExtensionTest12Plus, SignatureAlgorithmsEmpty) {
   const uint8_t val[] = {0x00, 0x00};
   DataBuffer extension(val, sizeof(val));
@@ -1126,16 +1129,25 @@ TEST_P(TlsExtensionTest13, EmptyVersionL
   ConnectWithBogusVersionList(kExt, sizeof(kExt));
 }
 
 TEST_P(TlsExtensionTest13, OddVersionList) {
   static const uint8_t kExt[] = {0x00, 0x01, 0x00};
   ConnectWithBogusVersionList(kExt, sizeof(kExt));
 }
 
+TEST_P(TlsExtensionTest13, SignatureAlgorithmsInvalidTls13) {
+  // testing the case where we ask for a invalid parameter for tls13
+  const uint8_t val[] = {0x00, 0x02, 0x04, 0x01};  // sha-256, rsa-pkcs1
+  DataBuffer extension(val, sizeof(val));
+  ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
+                           client_, ssl_signature_algorithms_xtn, extension),
+                       kTlsAlertHandshakeFailure);
+}
+
 // Use the stream version number for TLS 1.3 (0x0304) in DTLS.
 TEST_F(TlsConnectDatagram13, TlsVersionInDtls) {
   static const uint8_t kExt[] = {0x02, 0x03, 0x04};
 
   DataBuffer versions_buf(kExt, sizeof(kExt));
   MakeTlsFilter<TlsExtensionReplacer>(client_, ssl_tls13_supported_versions_xtn,
                                       versions_buf);
   ConnectExpectAlert(server_, kTlsAlertProtocolVersion);
--- a/gtests/ssl_gtest/tls_agent.cc
+++ b/gtests/ssl_gtest/tls_agent.cc
@@ -88,17 +88,18 @@ TlsAgent::TlsAgent(const std::string& nm
       error_code_(0),
       send_ctr_(0),
       recv_ctr_(0),
       expect_readwrite_error_(false),
       handshake_callback_(),
       auth_certificate_callback_(),
       sni_callback_(),
       skip_version_checks_(false),
-      resumption_token_() {
+      resumption_token_(),
+      policy_() {
   memset(&info_, 0, sizeof(info_));
   memset(&csinfo_, 0, sizeof(csinfo_));
   SECStatus rv = SSL_VersionRangeGetDefault(variant_, &vrange_);
   EXPECT_EQ(SECSuccess, rv);
 }
 
 TlsAgent::~TlsAgent() {
   if (timer_handle_) {
@@ -222,16 +223,17 @@ bool TlsAgent::ConfigServerCert(const st
   rv = SSL_ConfigServerCert(ssl_fd(), cert.get(), priv.get(), serverCertData,
                             serverCertData ? sizeof(*serverCertData) : 0);
   return rv == SECSuccess;
 }
 
 bool TlsAgent::EnsureTlsSetup(PRFileDesc* modelSocket) {
   // Don't set up twice
   if (ssl_fd_) return true;
+  NssManagePolicy policyManage(policy_);
 
   ScopedPRFileDesc dummy_fd(adapter_->CreateFD());
   EXPECT_NE(nullptr, dummy_fd);
   if (!dummy_fd) {
     return false;
   }
   if (adapter_->variant() == ssl_variant_stream) {
     ssl_fd_.reset(SSL_ImportFD(modelSocket, dummy_fd.get()));
@@ -314,17 +316,17 @@ bool TlsAgent::MaybeSetResumptionToken()
       if (expect_psk_ == ssl_psk_resume) return false;
     }
   }
 
   return true;
 }
 
 void TlsAgent::SetAntiReplayContext(ScopedSSLAntiReplayContext& ctx) {
-  EXPECT_EQ(SECSuccess, SSL_SetAntiReplayContext(ssl_fd_.get(), ctx.get()));
+  EXPECT_EQ(SECSuccess, SSL_SetAntiReplayContext(ssl_fd(), ctx.get()));
 }
 
 void TlsAgent::SetupClientAuth() {
   EXPECT_TRUE(EnsureTlsSetup());
   ASSERT_EQ(CLIENT, role_);
 
   EXPECT_EQ(SECSuccess,
             SSL_GetClientAuthDataHook(ssl_fd(), GetClientAuthDataHook,
@@ -864,18 +866,18 @@ void TlsAgent::CheckCallbacks() const {
 }
 
 void TlsAgent::ResetPreliminaryInfo() {
   expected_version_ = 0;
   expected_cipher_suite_ = 0;
 }
 
 void TlsAgent::UpdatePreliminaryChannelInfo() {
-  SECStatus rv = SSL_GetPreliminaryChannelInfo(ssl_fd_.get(), &pre_info_,
-                                               sizeof(pre_info_));
+  SECStatus rv =
+      SSL_GetPreliminaryChannelInfo(ssl_fd(), &pre_info_, sizeof(pre_info_));
   EXPECT_EQ(SECSuccess, rv);
   EXPECT_EQ(sizeof(pre_info_), pre_info_.length);
 }
 
 void TlsAgent::ValidateCipherSpecs() {
   PRInt32 cipherSpecs = SSLInt_CountCipherSpecs(ssl_fd());
   // We use one ciphersuite in each direction.
   PRInt32 expected = 2;
--- a/gtests/ssl_gtest/tls_agent.h
+++ b/gtests/ssl_gtest/tls_agent.h
@@ -9,16 +9,17 @@
 
 #include "prio.h"
 #include "ssl.h"
 #include "sslproto.h"
 
 #include <functional>
 #include <iostream>
 
+#include "nss_policy.h"
 #include "test_io.h"
 
 #define GTEST_HAS_RTTI 0
 #include "gtest/gtest.h"
 #include "nss_scoped_ptrs.h"
 #include "scoped_ptrs_ssl.h"
 
 extern bool g_ssl_gtest_verbose;
@@ -226,17 +227,19 @@ class TlsAgent : public PollTarget {
   const CERTCertificate* peer_cert() const {
     return SSL_PeerCertificate(ssl_fd_.get());
   }
 
   const char* state_str() const { return state_str(state()); }
 
   static const char* state_str(State state) { return states[state]; }
 
-  PRFileDesc* ssl_fd() const { return ssl_fd_.get(); }
+  NssManagedFileDesc ssl_fd() const {
+    return NssManagedFileDesc(ssl_fd_.get(), policy_);
+  }
   std::shared_ptr<DummyPrSocket>& adapter() { return adapter_; }
 
   const SSLChannelInfo& info() const {
     EXPECT_EQ(STATE_CONNECTED, state_);
     return info_;
   }
 
   const SSLPreliminaryChannelInfo& pre_info() const { return pre_info_; }
@@ -302,16 +305,20 @@ class TlsAgent : public PollTarget {
   void SetSniCallback(SniCallbackFunction sni_callback) {
     sni_callback_ = sni_callback;
   }
 
   void ExpectReceiveAlert(uint8_t alert, uint8_t level = 0);
   void ExpectSendAlert(uint8_t alert, uint8_t level = 0);
 
   std::string alpn_value_to_use_ = "";
+  // set the given policy before this agent runs
+  void SetPolicy(SECOidTag oid, PRUint32 set, PRUint32 clear) {
+    policy_ = NssPolicy(oid, set, clear);
+  }
 
  private:
   const static char* states[];
 
   void SetState(State state);
   void ValidateCipherSpecs();
 
   // Dummy auth certificate hook.
@@ -448,16 +455,17 @@ class TlsAgent : public PollTarget {
   size_t send_ctr_;
   size_t recv_ctr_;
   bool expect_readwrite_error_;
   HandshakeCallbackFunction handshake_callback_;
   AuthCertificateCallbackFunction auth_certificate_callback_;
   SniCallbackFunction sni_callback_;
   bool skip_version_checks_;
   std::vector<uint8_t> resumption_token_;
+  NssPolicy policy_;
 };
 
 inline std::ostream& operator<<(std::ostream& stream,
                                 const TlsAgent::State& state) {
   return stream << TlsAgent::state_str(state);
 }
 
 class TlsAgentTestBase : public ::testing::Test {
--- a/gtests/ssl_gtest/tls_connect.cc
+++ b/gtests/ssl_gtest/tls_connect.cc
@@ -374,20 +374,20 @@ void TlsConnectTestBase::ExpectResumptio
     client_->ExpectResumption();
     server_->ExpectResumption();
     expected_resumptions_ = num_resumptions;
   }
   EXPECT_EQ(expected_resumptions_ == 0, expected == RESUME_NONE);
 }
 
 void TlsConnectTestBase::EnsureTlsSetup() {
-  EXPECT_TRUE(server_->EnsureTlsSetup(server_model_ ? server_model_->ssl_fd()
-                                                    : nullptr));
-  EXPECT_TRUE(client_->EnsureTlsSetup(client_model_ ? client_model_->ssl_fd()
-                                                    : nullptr));
+  EXPECT_TRUE(server_->EnsureTlsSetup(
+      server_model_ ? server_model_->ssl_fd().get() : nullptr));
+  EXPECT_TRUE(client_->EnsureTlsSetup(
+      client_model_ ? client_model_->ssl_fd().get() : nullptr));
   server_->SetAntiReplayContext(anti_replay_);
   EXPECT_EQ(SECSuccess, SSL_SetTimeFunc(client_->ssl_fd(),
                                         TlsConnectTestBase::TimeFunc, &now_));
   EXPECT_EQ(SECSuccess, SSL_SetTimeFunc(server_->ssl_fd(),
                                         TlsConnectTestBase::TimeFunc, &now_));
 }
 
 void TlsConnectTestBase::Handshake() {
--- a/lib/ssl/ssl3con.c
+++ b/lib/ssl/ssl3con.c
@@ -68,16 +68,18 @@ static CK_MECHANISM_TYPE ssl3_GetHashMec
 static CK_MECHANISM_TYPE ssl3_GetMgfMechanismByHashType(SSLHashType hash);
 PRBool ssl_IsRsaPssSignatureScheme(SSLSignatureScheme scheme);
 PRBool ssl_IsRsaeSignatureScheme(SSLSignatureScheme scheme);
 PRBool ssl_IsRsaPkcs1SignatureScheme(SSLSignatureScheme scheme);
 PRBool ssl_IsDsaSignatureScheme(SSLSignatureScheme scheme);
 static SECStatus ssl3_UpdateDefaultHandshakeHashes(sslSocket *ss,
                                                    const unsigned char *b,
                                                    unsigned int l);
+const PRUint32 kSSLSigSchemePolicy =
+    NSS_USE_ALG_IN_SSL_KX | NSS_USE_ALG_IN_ANY_SIGNATURE;
 
 const PRUint8 ssl_hello_retry_random[] = {
     0xCF, 0x21, 0xAD, 0x74, 0xE5, 0x9A, 0x61, 0x11,
     0xBE, 0x1D, 0x8C, 0x02, 0x1E, 0x65, 0xB8, 0x91,
     0xC2, 0xA2, 0x11, 0x16, 0x7A, 0xBB, 0x8C, 0x5E,
     0x07, 0x9E, 0x09, 0xE2, 0xC8, 0xA8, 0x33, 0x9C
 };
 PR_STATIC_ASSERT(PR_ARRAY_SIZE(ssl_hello_retry_random) == SSL3_RANDOM_LENGTH);
@@ -779,16 +781,51 @@ ssl_HasCert(const sslSocket *ss, PRUint1
         return PR_TRUE;
     }
     if (authType == ssl_auth_rsa_sign) {
         return ssl_HasCert(ss, maxVersion, ssl_auth_rsa_pss);
     }
     return PR_FALSE;
 }
 
+/* return true if the scheme is allowed by policy, This prevents
+ * failures later when our actual signatures are rejected by
+ * policy by either ssl code, or lower level NSS code */
+static PRBool
+ssl_SchemePolicyOK(SSLSignatureScheme scheme, PRUint32 require)
+{
+    /* Hash policy. */
+    PRUint32 policy;
+    SECOidTag hashOID = ssl3_HashTypeToOID(ssl_SignatureSchemeToHashType(scheme));
+    SECOidTag sigOID;
+
+    /* policy bits needed to enable a SignatureScheme */
+    SECStatus rv = NSS_GetAlgorithmPolicy(hashOID, &policy);
+    if (rv == SECSuccess &&
+        (policy & require) != require) {
+        return PR_FALSE;
+    }
+
+    /* ssl_SignatureSchemeToAuthType reports rsa for rsa_pss_rsae, but we
+     * actually implement pss signatures when we sign, so just use RSA_PSS
+     * for all RSA PSS Siganture schemes */
+    if (ssl_IsRsaPssSignatureScheme(scheme)) {
+        sigOID = SEC_OID_PKCS1_RSA_PSS_SIGNATURE;
+    } else {
+        sigOID = ssl3_AuthTypeToOID(ssl_SignatureSchemeToAuthType(scheme));
+    }
+    /* Signature Policy. */
+    rv = NSS_GetAlgorithmPolicy(sigOID, &policy);
+    if (rv == SECSuccess &&
+        (policy & require) != require) {
+        return PR_FALSE;
+    }
+    return PR_TRUE;
+}
+
 /* Check that a signature scheme is accepted.
  * Both by policy and by having a token that supports it. */
 static PRBool
 ssl_SignatureSchemeAccepted(PRUint16 minVersion,
                             SSLSignatureScheme scheme,
                             PRBool forCert)
 {
     /* Disable RSA-PSS schemes if there are no tokens to verify them. */
@@ -804,33 +841,19 @@ ssl_SignatureSchemeAccepted(PRUint16 min
         if (minVersion >= SSL_LIBRARY_VERSION_TLS_1_3) {
             return PR_FALSE;
         }
     } else if (ssl_IsDsaSignatureScheme(scheme)) {
         /* DSA: not in TLS 1.3, and check policy. */
         if (minVersion >= SSL_LIBRARY_VERSION_TLS_1_3) {
             return PR_FALSE;
         }
-        PRUint32 dsaPolicy;
-        SECStatus rv = NSS_GetAlgorithmPolicy(SEC_OID_ANSIX9_DSA_SIGNATURE,
-                                              &dsaPolicy);
-        if (rv == SECSuccess && (dsaPolicy & NSS_USE_ALG_IN_SSL_KX) == 0) {
-            return PR_FALSE;
-        }
-    }
-
-    /* Hash policy. */
-    PRUint32 hashPolicy;
-    SSLHashType hashType = ssl_SignatureSchemeToHashType(scheme);
-    SECOidTag hashOID = ssl3_HashTypeToOID(hashType);
-    SECStatus rv = NSS_GetAlgorithmPolicy(hashOID, &hashPolicy);
-    if (rv == SECSuccess && (hashPolicy & NSS_USE_ALG_IN_SSL_KX) == 0) {
-        return PR_FALSE;
-    }
-    return PR_TRUE;
+    }
+
+    return ssl_SchemePolicyOK(scheme, kSSLSigSchemePolicy);
 }
 
 static SECStatus
 ssl_CheckSignatureSchemes(sslSocket *ss)
 {
     if (ss->vrange.max < SSL_LIBRARY_VERSION_TLS_1_2) {
         return SECSuccess;
     }
@@ -4232,16 +4255,36 @@ ssl3_HashTypeToOID(SSLHashType hashType)
         case ssl_hash_sha512:
             return SEC_OID_SHA512;
         default:
             break;
     }
     return SEC_OID_UNKNOWN;
 }
 
+SECOidTag
+ssl3_AuthTypeToOID(SSLAuthType authType)
+{
+    switch (authType) {
+        case ssl_auth_rsa_sign:
+            return SEC_OID_PKCS1_RSA_ENCRYPTION;
+        case ssl_auth_rsa_pss:
+            return SEC_OID_PKCS1_RSA_PSS_SIGNATURE;
+        case ssl_auth_ecdsa:
+            return SEC_OID_ANSIX962_EC_PUBLIC_KEY;
+        case ssl_auth_dsa:
+            return SEC_OID_ANSIX9_DSA_SIGNATURE;
+        default:
+            break;
+    }
+    /* shouldn't ever get there */
+    PORT_Assert(0);
+    return SEC_OID_UNKNOWN;
+}
+
 SSLHashType
 ssl_SignatureSchemeToHashType(SSLSignatureScheme scheme)
 {
     switch (scheme) {
         case ssl_sig_rsa_pkcs1_sha1:
         case ssl_sig_dsa_sha1:
         case ssl_sig_ecdsa_sha1:
             return ssl_hash_sha1;
@@ -4272,59 +4315,41 @@ ssl_SignatureSchemeToHashType(SSLSignatu
     }
     PORT_Assert(0);
     return ssl_hash_none;
 }
 
 static PRBool
 ssl_SignatureSchemeMatchesSpkiOid(SSLSignatureScheme scheme, SECOidTag spkiOid)
 {
-    switch (scheme) {
-        case ssl_sig_rsa_pkcs1_sha256:
-        case ssl_sig_rsa_pkcs1_sha384:
-        case ssl_sig_rsa_pkcs1_sha512:
-        case ssl_sig_rsa_pkcs1_sha1:
-        case ssl_sig_rsa_pss_rsae_sha256:
-        case ssl_sig_rsa_pss_rsae_sha384:
-        case ssl_sig_rsa_pss_rsae_sha512:
-        case ssl_sig_rsa_pkcs1_sha1md5:
-            return (spkiOid == SEC_OID_X500_RSA_ENCRYPTION) ||
-                   (spkiOid == SEC_OID_PKCS1_RSA_ENCRYPTION);
-        case ssl_sig_rsa_pss_pss_sha256:
-        case ssl_sig_rsa_pss_pss_sha384:
-        case ssl_sig_rsa_pss_pss_sha512:
-            return spkiOid == SEC_OID_PKCS1_RSA_PSS_SIGNATURE;
-        case ssl_sig_ecdsa_secp256r1_sha256:
-        case ssl_sig_ecdsa_secp384r1_sha384:
-        case ssl_sig_ecdsa_secp521r1_sha512:
-        case ssl_sig_ecdsa_sha1:
-            return spkiOid == SEC_OID_ANSIX962_EC_PUBLIC_KEY;
-        case ssl_sig_dsa_sha256:
-        case ssl_sig_dsa_sha384:
-        case ssl_sig_dsa_sha512:
-        case ssl_sig_dsa_sha1:
-            return spkiOid == SEC_OID_ANSIX9_DSA_SIGNATURE;
-        case ssl_sig_none:
-        case ssl_sig_ed25519:
-        case ssl_sig_ed448:
-            break;
-    }
-    PORT_Assert(0);
+    SECOidTag authOid = ssl3_AuthTypeToOID(ssl_SignatureSchemeToAuthType(scheme));
+
+    if (spkiOid == authOid) {
+        return PR_TRUE;
+    }
+    if ((authOid == SEC_OID_PKCS1_RSA_ENCRYPTION) &&
+        (spkiOid == SEC_OID_X500_RSA_ENCRYPTION)) {
+        return PR_TRUE;
+    }
     return PR_FALSE;
 }
 
 /* Validate that the signature scheme works for the given key type. */
 PRBool
 ssl_SignatureSchemeValid(SSLSignatureScheme scheme, SECOidTag spkiOid,
                          PRBool isTls13)
 {
     if (!ssl_IsSupportedSignatureScheme(scheme)) {
         return PR_FALSE;
     }
-    if (!ssl_SignatureSchemeMatchesSpkiOid(scheme, spkiOid)) {
+    /* if we are purposefully passed SEC_OID_UNKNOWN, it means
+     * we not checking the scheme against a potential key, so skip
+     * the call */
+    if ((spkiOid != SEC_OID_UNKNOWN) &&
+        !ssl_SignatureSchemeMatchesSpkiOid(scheme, spkiOid)) {
         return PR_FALSE;
     }
     if (isTls13) {
         if (ssl_SignatureSchemeToHashType(scheme) == ssl_hash_sha1) {
             return PR_FALSE;
         }
         if (ssl_IsRsaPkcs1SignatureScheme(scheme)) {
             return PR_FALSE;
@@ -4512,17 +4537,18 @@ ssl_CheckSignatureSchemeConsistency(sslS
     if (!isTLS13 && !ss->sec.isServer) {
         if (!ssl_SignatureKeyMatchesSpkiOid(ss->ssl3.hs.kea_def, spkiOid)) {
             PORT_SetError(SSL_ERROR_INCORRECT_SIGNATURE_ALGORITHM);
             return SECFailure;
         }
     }
 
     /* Verify that the signature scheme matches the signing key. */
-    if (!ssl_SignatureSchemeValid(scheme, spkiOid, isTLS13)) {
+    if ((spkiOid == SEC_OID_UNKNOWN) ||
+        !ssl_SignatureSchemeValid(scheme, spkiOid, isTLS13)) {
         PORT_SetError(SSL_ERROR_INCORRECT_SIGNATURE_ALGORITHM);
         return SECFailure;
     }
 
     if (!ssl_SignatureSchemeEnabled(ss, scheme)) {
         PORT_SetError(SSL_ERROR_UNSUPPORTED_SIGNATURE_ALGORITHM);
         return SECFailure;
     }
@@ -4547,17 +4573,18 @@ ssl_IsSupportedSignatureScheme(SSLSignat
         case ssl_sig_ecdsa_secp256r1_sha256:
         case ssl_sig_ecdsa_secp384r1_sha384:
         case ssl_sig_ecdsa_secp521r1_sha512:
         case ssl_sig_dsa_sha1:
         case ssl_sig_dsa_sha256:
         case ssl_sig_dsa_sha384:
         case ssl_sig_dsa_sha512:
         case ssl_sig_ecdsa_sha1:
-            return PR_TRUE;
+            return ssl_SchemePolicyOK(scheme, kSSLSigSchemePolicy);
+            break;
 
         case ssl_sig_rsa_pkcs1_sha1md5:
         case ssl_sig_none:
         case ssl_sig_ed25519:
         case ssl_sig_ed448:
             return PR_FALSE;
     }
     return PR_FALSE;
@@ -4672,19 +4699,20 @@ SECStatus
 ssl_ConsumeSignatureScheme(sslSocket *ss, PRUint8 **b,
                            PRUint32 *length, SSLSignatureScheme *out)
 {
     PRUint32 tmp;
     SECStatus rv;
 
     rv = ssl3_ConsumeHandshakeNumber(ss, &tmp, 2, b, length);
     if (rv != SECSuccess) {
-        return SECFailure; /* Error code set already. */
+        return SECFailure; /* Alert sent, Error code set already. */
     }
     if (!ssl_IsSupportedSignatureScheme((SSLSignatureScheme)tmp)) {
+        SSL3_SendAlert(ss, alert_fatal, illegal_parameter);
         PORT_SetError(SSL_ERROR_UNSUPPORTED_SIGNATURE_ALGORITHM);
         return SECFailure;
     }
     *out = (SSLSignatureScheme)tmp;
     return SECSuccess;
 }
 
 /**************************************************************************
@@ -6431,40 +6459,30 @@ ssl3_SendClientKeyExchange(sslSocket *ss
 PRBool
 ssl_CanUseSignatureScheme(SSLSignatureScheme scheme,
                           const SSLSignatureScheme *peerSchemes,
                           unsigned int peerSchemeCount,
                           PRBool requireSha1,
                           PRBool slotDoesPss)
 {
     SSLHashType hashType;
-    SECOidTag hashOID;
-    PRUint32 policy;
     unsigned int i;
 
     /* Skip RSA-PSS schemes when the certificate's private key slot does
      * not support this signature mechanism. */
     if (ssl_IsRsaPssSignatureScheme(scheme) && !slotDoesPss) {
         return PR_FALSE;
     }
 
-    if (ssl_IsDsaSignatureScheme(scheme) &&
-        (NSS_GetAlgorithmPolicy(SEC_OID_ANSIX9_DSA_SIGNATURE, &policy) ==
-         SECSuccess) &&
-        !(policy & NSS_USE_ALG_IN_SSL_KX)) {
-        return PR_FALSE;
-    }
-
     hashType = ssl_SignatureSchemeToHashType(scheme);
     if (requireSha1 && (hashType != ssl_hash_sha1)) {
         return PR_FALSE;
     }
-    hashOID = ssl3_HashTypeToOID(hashType);
-    if ((NSS_GetAlgorithmPolicy(hashOID, &policy) == SECSuccess) &&
-        !(policy & NSS_USE_ALG_IN_SSL_KX)) {
+
+    if (!ssl_SchemePolicyOK(scheme, kSSLSigSchemePolicy)) {
         return PR_FALSE;
     }
 
     for (i = 0; i < peerSchemeCount; i++) {
         if (peerSchemes[i] == scheme) {
             return PR_TRUE;
         }
     }
@@ -6528,16 +6546,19 @@ ssl_PickSignatureScheme(sslSocket *ss,
             PORT_SetError(SSL_ERROR_UNSUPPORTED_SIGNATURE_ALGORITHM);
             return SECFailure;
         }
         ss->ssl3.hs.signatureScheme = scheme;
         return SECSuccess;
     }
 
     spkiOid = SECOID_GetAlgorithmTag(&cert->subjectPublicKeyInfo.algorithm);
+    if (spkiOid == SEC_OID_UNKNOWN) {
+        return SECFailure;
+    }
 
     /* Now we have to search based on the key type. Go through our preferred
      * schemes in order and find the first that can be used. */
     for (i = 0; i < ss->ssl3.signatureSchemeCount; ++i) {
         scheme = ss->ssl3.signatureSchemes[i];
 
         if (ssl_SignatureSchemeValid(scheme, spkiOid, isTLS13) &&
             ssl_CanUseSignatureScheme(scheme, peerSchemes, peerSchemeCount,
@@ -7425,17 +7446,17 @@ ssl_HandleDHServerKeyExchange(sslSocket 
     if (!ssl_IsValidDHEShare(&dh_p, &dh_Ys)) {
         errCode = SSL_ERROR_RX_MALFORMED_DHE_KEY_SHARE;
         goto alert_loser;
     }
 
     if (ss->version >= SSL_LIBRARY_VERSION_TLS_1_2) {
         rv = ssl_ConsumeSignatureScheme(ss, &b, &length, &sigScheme);
         if (rv != SECSuccess) {
-            goto alert_loser; /* malformed or unsupported. */
+            goto loser; /* alert already sent */
         }
         rv = ssl_CheckSignatureSchemeConsistency(
             ss, sigScheme, &ss->sec.peerCert->subjectPublicKeyInfo);
         if (rv != SECSuccess) {
             goto alert_loser;
         }
         hashAlg = ssl_SignatureSchemeToHashType(sigScheme);
     } else {
@@ -7695,17 +7716,19 @@ ssl_ParseSignatureSchemes(const sslSocke
     for (; numRemaining && numSupported < MAX_SIGNATURE_SCHEMES; --numRemaining) {
         PRUint32 tmp;
         rv = ssl3_ExtConsumeHandshakeNumber(ss, &tmp, 2, &buf.data, &buf.len);
         if (rv != SECSuccess) {
             PORT_Assert(0);
             PORT_SetError(SEC_ERROR_LIBRARY_FAILURE);
             return SECFailure;
         }
-        if (ssl_IsSupportedSignatureScheme((SSLSignatureScheme)tmp)) {
+        if (ssl_SignatureSchemeValid((SSLSignatureScheme)tmp, SEC_OID_UNKNOWN,
+                                     (PRBool)ss->version >= SSL_LIBRARY_VERSION_TLS_1_3)) {
+            ;
             schemes[numSupported++] = (SSLSignatureScheme)tmp;
         }
     }
 
     if (!numSupported) {
         if (!arena) {
             PORT_Free(schemes);
         }
@@ -10281,17 +10304,20 @@ ssl3_HandleCertificateVerify(sslSocket *
 
     /* TLS 1.3 is handled by tls13_HandleCertificateVerify */
     PORT_Assert(ss->ssl3.prSpec->version <= SSL_LIBRARY_VERSION_TLS_1_2);
 
     if (ss->ssl3.prSpec->version == SSL_LIBRARY_VERSION_TLS_1_2) {
         PORT_Assert(ss->ssl3.hs.hashType == handshake_hash_record);
         rv = ssl_ConsumeSignatureScheme(ss, &b, &length, &sigScheme);
         if (rv != SECSuccess) {
-            goto loser; /* malformed or unsupported. */
+            if (PORT_GetError() == SSL_ERROR_UNSUPPORTED_SIGNATURE_ALGORITHM) {
+                errCode = SSL_ERROR_UNSUPPORTED_SIGNATURE_ALGORITHM;
+            }
+            goto loser; /* alert already sent */
         }
         rv = ssl_CheckSignatureSchemeConsistency(
             ss, sigScheme, &ss->sec.peerCert->subjectPublicKeyInfo);
         if (rv != SECSuccess) {
             errCode = PORT_GetError();
             desc = illegal_parameter;
             goto alert_loser;
         }
--- a/lib/ssl/sslimpl.h
+++ b/lib/ssl/sslimpl.h
@@ -1778,16 +1778,17 @@ SECStatus ssl_PrivateKeySupportsRsaPss(S
 SECStatus ssl_PickSignatureScheme(sslSocket *ss,
                                   CERTCertificate *cert,
                                   SECKEYPublicKey *pubKey,
                                   SECKEYPrivateKey *privKey,
                                   const SSLSignatureScheme *peerSchemes,
                                   unsigned int peerSchemeCount,
                                   PRBool requireSha1);
 SECOidTag ssl3_HashTypeToOID(SSLHashType hashType);
+SECOidTag ssl3_AuthTypeToOID(SSLAuthType hashType);
 SSLHashType ssl_SignatureSchemeToHashType(SSLSignatureScheme scheme);
 SSLAuthType ssl_SignatureSchemeToAuthType(SSLSignatureScheme scheme);
 
 SECStatus ssl3_SetupCipherSuite(sslSocket *ss, PRBool initHashes);
 SECStatus ssl_InsertRecordHeader(const sslSocket *ss, ssl3CipherSpec *cwSpec,
                                  SSLContentType contentType, sslBuffer *wrBuf,
                                  PRBool *needsLength);
 PRBool ssl_SignatureSchemeValid(SSLSignatureScheme scheme, SECOidTag spkiOid,