hg commit -u "Endi S. Dewata <edewata@redhat.com>" -m "Bug 956866, SSL alert callback so client can log warnings such as no_renegotiation, r=ekr"
authorKai Engert <kaie@kuix.de>
Mon, 13 Mar 2017 07:49:50 +0100
changeset 13202 33be0a381a479ad1dd754ce370c8443e24235f1c
parent 13201 09c491ef3b4104d17c9fa68fb86bc4b1f19512d6
child 13204 97406323124d75e17bd4a88f98337580f9f1aa7f
push id2074
push userkaie@kuix.de
push dateMon, 13 Mar 2017 06:49:27 +0000
reviewersekr
bugs956866
hg commit -u "Endi S. Dewata <edewata@redhat.com>" -m "Bug 956866, SSL alert callback so client can log warnings such as no_renegotiation, r=ekr"
gtests/ssl_gtest/ssl_0rtt_unittest.cc
gtests/ssl_gtest/ssl_exporter_unittest.cc
gtests/ssl_gtest/ssl_extension_unittest.cc
gtests/ssl_gtest/ssl_version_unittest.cc
gtests/ssl_gtest/tls_agent.cc
gtests/ssl_gtest/tls_agent.h
gtests/ssl_gtest/tls_connect.cc
lib/ssl/ssl.def
lib/ssl/ssl.h
lib/ssl/ssl3con.c
lib/ssl/sslimpl.h
lib/ssl/sslsecur.c
lib/ssl/sslsock.c
--- a/gtests/ssl_gtest/ssl_0rtt_unittest.cc
+++ b/gtests/ssl_gtest/ssl_0rtt_unittest.cc
@@ -19,16 +19,18 @@ extern "C" {
 #include "tls_connect.h"
 #include "tls_filter.h"
 #include "tls_parser.h"
 
 namespace nss_test {
 
 TEST_P(TlsConnectTls13, ZeroRtt) {
   SetupForZeroRtt();
+  client_->SetExpectedAlertSentCount(1);
+  server_->SetExpectedAlertReceivedCount(1);
   client_->Set0RttEnabled(true);
   server_->Set0RttEnabled(true);
   ExpectResumption(RESUME_TICKET);
   ZeroRttSendReceive(true, true);
   Handshake();
   ExpectEarlyDataAccepted(true);
   CheckConnected();
   SendReceive();
@@ -98,16 +100,18 @@ TEST_P(TlsConnectTls13, ZeroRttServerOnl
   SendReceive();
   CheckKeys();
 }
 
 TEST_P(TlsConnectTls13, TestTls13ZeroRttAlpn) {
   EnableAlpn();
   SetupForZeroRtt();
   EnableAlpn();
+  client_->SetExpectedAlertSentCount(1);
+  server_->SetExpectedAlertReceivedCount(1);
   client_->Set0RttEnabled(true);
   server_->Set0RttEnabled(true);
   ExpectResumption(RESUME_TICKET);
   ExpectEarlyDataAccepted(true);
   ZeroRttSendReceive(true, true, [this]() {
     client_->CheckAlpn(SSL_NEXT_PROTO_EARLY_VALUE, "a");
     return true;
   });
--- a/gtests/ssl_gtest/ssl_exporter_unittest.cc
+++ b/gtests/ssl_gtest/ssl_exporter_unittest.cc
@@ -86,16 +86,18 @@ int32_t RegularExporterShouldFail(TlsAge
                             strlen(kExporterLabel), PR_TRUE, kExporterContext,
                             sizeof(kExporterContext), val, sizeof(val)))
       << "regular exporter should fail";
   return 0;
 }
 
 TEST_P(TlsConnectTls13, EarlyExporter) {
   SetupForZeroRtt();
+  client_->SetExpectedAlertSentCount(1);
+  server_->SetExpectedAlertReceivedCount(1);
   client_->Set0RttEnabled(true);
   server_->Set0RttEnabled(true);
   ExpectResumption(RESUME_TICKET);
 
   client_->Handshake();  // Send ClientHello.
   uint8_t client_value[10] = {0};
   RegularExporterShouldFail(client_.get(), nullptr, 0);
   EXPECT_EQ(SECSuccess,
--- a/gtests/ssl_gtest/ssl_extension_unittest.cc
+++ b/gtests/ssl_gtest/ssl_extension_unittest.cc
@@ -162,33 +162,75 @@ class TlsExtensionAppender : public TlsH
 class TlsExtensionTestBase : public TlsConnectTestBase {
  protected:
   TlsExtensionTestBase(Mode mode, uint16_t version)
       : TlsConnectTestBase(mode, version) {}
   TlsExtensionTestBase(const std::string& mode, uint16_t version)
       : TlsConnectTestBase(mode, version) {}
 
   void ClientHelloErrorTest(std::shared_ptr<PacketFilter> filter,
-                            uint8_t alert = kTlsAlertDecodeError) {
+                            uint8_t desc = kTlsAlertDecodeError) {
+    SSLAlert alert;
+
     auto alert_recorder = std::make_shared<TlsAlertRecorder>();
     server_->SetPacketFilter(alert_recorder);
     client_->SetPacketFilter(filter);
     ConnectExpectFail();
+
     EXPECT_EQ(kTlsAlertFatal, alert_recorder->level());
-    EXPECT_EQ(alert, alert_recorder->description());
+    EXPECT_EQ(desc, alert_recorder->description());
+
+    // verify no alerts received by the server
+    EXPECT_EQ(0U, server_->alert_received_count());
+
+    // verify the alert sent by the server
+    EXPECT_EQ(1U, server_->alert_sent_count());
+    EXPECT_TRUE(server_->GetLastAlertSent(&alert));
+    EXPECT_EQ(kTlsAlertFatal, alert.level);
+    EXPECT_EQ(desc, alert.description);
+
+    // verify the alert received by the client
+    EXPECT_EQ(1U, client_->alert_received_count());
+    EXPECT_TRUE(client_->GetLastAlertReceived(&alert));
+    EXPECT_EQ(kTlsAlertFatal, alert.level);
+    EXPECT_EQ(desc, alert.description);
+
+    // verify no alerts sent by the client
+    EXPECT_EQ(0U, client_->alert_sent_count());
   }
 
   void ServerHelloErrorTest(std::shared_ptr<PacketFilter> filter,
-                            uint8_t alert = kTlsAlertDecodeError) {
+                            uint8_t desc = kTlsAlertDecodeError) {
+    SSLAlert alert;
+
     auto alert_recorder = std::make_shared<TlsAlertRecorder>();
     client_->SetPacketFilter(alert_recorder);
     server_->SetPacketFilter(filter);
     ConnectExpectFail();
+
     EXPECT_EQ(kTlsAlertFatal, alert_recorder->level());
-    EXPECT_EQ(alert, alert_recorder->description());
+    EXPECT_EQ(desc, alert_recorder->description());
+
+    // verify no alerts received by the client
+    EXPECT_EQ(0U, client_->alert_received_count());
+
+    // verify the alert sent by the client
+    EXPECT_EQ(1U, client_->alert_sent_count());
+    EXPECT_TRUE(client_->GetLastAlertSent(&alert));
+    EXPECT_EQ(kTlsAlertFatal, alert.level);
+    EXPECT_EQ(desc, alert.description);
+
+    // verify the alert received by the server
+    EXPECT_EQ(1U, server_->alert_received_count());
+    EXPECT_TRUE(server_->GetLastAlertReceived(&alert));
+    EXPECT_EQ(kTlsAlertFatal, alert.level);
+    EXPECT_EQ(desc, alert.description);
+
+    // verify no alerts sent by the server
+    EXPECT_EQ(0U, server_->alert_sent_count());
   }
 
   static void InitSimpleSni(DataBuffer* extension) {
     const char* name = "host.name";
     const size_t namelen = PL_strlen(name);
     extension->Allocate(namelen + 5);
     extension->Write(0, namelen + 3, 2);
     extension->Write(2, static_cast<uint32_t>(0), 1);  // 0 == hostname
--- a/gtests/ssl_gtest/ssl_version_unittest.cc
+++ b/gtests/ssl_gtest/ssl_version_unittest.cc
@@ -236,16 +236,17 @@ TEST_F(TlsConnectTest, Tls13RejectsRehan
   Connect();
   SECStatus rv = SSL_ReHandshake(server_->ssl_fd(), PR_TRUE);
   EXPECT_EQ(SECFailure, rv);
   EXPECT_EQ(SSL_ERROR_RENEGOTIATION_NOT_ALLOWED, PORT_GetError());
 }
 
 TEST_P(TlsConnectGeneric, AlertBeforeServerHello) {
   EnsureTlsSetup();
+  client_->SetExpectedAlertReceivedCount(1);
   client_->StartConnect();
   server_->StartConnect();
   client_->Handshake();  // Send ClientHello.
   static const uint8_t kWarningAlert[] = {kTlsAlertWarning,
                                           kTlsAlertUnrecognizedName};
   DataBuffer alert;
   TlsAgentTestBase::MakeRecord(mode_, kTlsAlertType,
                                SSL_LIBRARY_VERSION_TLS_1_0, kWarningAlert,
--- a/gtests/ssl_gtest/tls_agent.cc
+++ b/gtests/ssl_gtest/tls_agent.cc
@@ -55,16 +55,22 @@ TlsAgent::TlsAgent(const std::string& na
       falsestart_enabled_(false),
       expected_version_(0),
       expected_cipher_suite_(0),
       expect_resumption_(false),
       expect_client_auth_(false),
       can_falsestart_hook_called_(false),
       sni_hook_called_(false),
       auth_certificate_hook_called_(false),
+      alert_received_count_(0),
+      expected_alert_received_count_(0),
+      last_alert_received_({0, 0}),
+      alert_sent_count_(0),
+      expected_alert_sent_count_(0),
+      last_alert_sent_({0, 0}),
       handshake_callback_called_(false),
       error_code_(0),
       send_ctr_(0),
       recv_ctr_(0),
       expect_readwrite_error_(false),
       handshake_callback_(),
       auth_certificate_callback_(),
       sni_callback_(),
@@ -170,16 +176,24 @@ bool TlsAgent::EnsureTlsSetup(PRFileDesc
     EXPECT_EQ(SECSuccess, rv);
     if (rv != SECSuccess) return false;
   }
 
   rv = SSL_AuthCertificateHook(ssl_fd(), AuthCertificateHook, this);
   EXPECT_EQ(SECSuccess, rv);
   if (rv != SECSuccess) return false;
 
+  rv = SSL_AlertReceivedCallback(ssl_fd(), AlertReceivedCallback, this);
+  EXPECT_EQ(SECSuccess, rv);
+  if (rv != SECSuccess) return false;
+
+  rv = SSL_AlertSentCallback(ssl_fd(), AlertSentCallback, this);
+  EXPECT_EQ(SECSuccess, rv);
+  if (rv != SECSuccess) return false;
+
   rv = SSL_HandshakeCallback(ssl_fd(), HandshakeCallback, this);
   EXPECT_EQ(SECSuccess, rv);
   if (rv != SECSuccess) return false;
 
   return true;
 }
 
 void TlsAgent::SetupClientAuth() {
@@ -584,16 +598,21 @@ void TlsAgent::CheckSrtp() const {
 
 void TlsAgent::CheckErrorCode(int32_t expected) const {
   EXPECT_EQ(STATE_ERROR, state_);
   EXPECT_EQ(expected, error_code_)
       << "Got error code " << PORT_ErrorToName(error_code_) << " expecting "
       << PORT_ErrorToName(expected) << std::endl;
 }
 
+void TlsAgent::CheckAlerts() const {
+  EXPECT_EQ(expected_alert_received_count_, alert_received_count_);
+  EXPECT_EQ(expected_alert_sent_count_, alert_sent_count_);
+}
+
 void TlsAgent::WaitForErrorCode(int32_t expected, uint32_t delay) const {
   ASSERT_EQ(0, error_code_);
   WAIT_(error_code_ != 0, delay);
   EXPECT_EQ(expected, error_code_)
       << "Got error code " << PORT_ErrorToName(error_code_) << " expecting "
       << PORT_ErrorToName(expected) << std::endl;
 }
 
--- a/gtests/ssl_gtest/tls_agent.h
+++ b/gtests/ssl_gtest/tls_agent.h
@@ -139,16 +139,17 @@ class TlsAgent : public PollTarget {
   void ExpectShortHeaders();
   void SetSignatureSchemes(const SSLSignatureScheme* schemes, size_t count);
   void EnableAlpn(const uint8_t* val, size_t len);
   void CheckAlpn(SSLNextProtoState expected_state,
                  const std::string& expected = "") const;
   void EnableSrtp();
   void CheckSrtp() const;
   void CheckErrorCode(int32_t expected) const;
+  void CheckAlerts() const;
   void WaitForErrorCode(int32_t expected, uint32_t delay) const;
   // Send data on the socket, encrypting it.
   void SendData(size_t bytes, size_t blocksize = 1024);
   void SendBuffer(const DataBuffer& buf);
   // Send data directly to the underlying socket, skipping the TLS layer.
   void SendDirect(const DataBuffer& buf);
   void ReadBytes();
   void ResetSentBytes();  // Hack to test drops.
@@ -239,16 +240,44 @@ class TlsAgent : public PollTarget {
       AuthCertificateCallbackFunction auth_certificate_callback) {
     auth_certificate_callback_ = auth_certificate_callback;
   }
 
   void SetSniCallback(SniCallbackFunction sni_callback) {
     sni_callback_ = sni_callback;
   }
 
+  size_t alert_received_count() const { return alert_received_count_; }
+
+  void SetExpectedAlertReceivedCount(size_t count) {
+    expected_alert_received_count_ = count;
+  }
+
+  bool GetLastAlertReceived(SSLAlert* alert) const {
+    if (!alert_received_count_) {
+      return false;
+    }
+    *alert = last_alert_received_;
+    return true;
+  }
+
+  size_t alert_sent_count() const { return alert_sent_count_; }
+
+  void SetExpectedAlertSentCount(size_t count) {
+    expected_alert_sent_count_ = count;
+  }
+
+  bool GetLastAlertSent(SSLAlert* alert) const {
+    if (!alert_sent_count_) {
+      return false;
+    }
+    *alert = last_alert_sent_;
+    return true;
+  }
+
  private:
   const static char* states[];
 
   void SetState(State state);
 
   // Dummy auth certificate hook.
   static SECStatus AuthCertificateHook(void* arg, PRFileDesc* fd,
                                        PRBool checksig, PRBool isServer) {
@@ -320,16 +349,40 @@ class TlsAgent : public PollTarget {
     agent->CheckPreliminaryInfo();
     EXPECT_TRUE(agent->falsestart_enabled_);
     EXPECT_FALSE(agent->can_falsestart_hook_called_);
     agent->can_falsestart_hook_called_ = true;
     *canFalseStart = true;
     return SECSuccess;
   }
 
+  static void AlertReceivedCallback(const PRFileDesc* fd, void* arg,
+                                    const SSLAlert* alert) {
+    TlsAgent* agent = reinterpret_cast<TlsAgent*>(arg);
+
+    std::cerr << agent->role_str()
+              << ": Alert received: level=" << static_cast<int>(alert->level)
+              << " desc=" << static_cast<int>(alert->description) << std::endl;
+
+    ++agent->alert_received_count_;
+    agent->last_alert_received_ = *alert;
+  }
+
+  static void AlertSentCallback(const PRFileDesc* fd, void* arg,
+                                const SSLAlert* alert) {
+    TlsAgent* agent = reinterpret_cast<TlsAgent*>(arg);
+
+    std::cerr << agent->role_str()
+              << ": Alert sent: level=" << static_cast<int>(alert->level)
+              << " desc=" << static_cast<int>(alert->description) << std::endl;
+
+    ++agent->alert_sent_count_;
+    agent->last_alert_sent_ = *alert;
+  }
+
   static void HandshakeCallback(PRFileDesc* fd, void* arg) {
     TlsAgent* agent = reinterpret_cast<TlsAgent*>(arg);
     agent->handshake_callback_called_ = true;
     agent->Connected();
     if (agent->handshake_callback_) {
       agent->handshake_callback_(agent);
     }
   }
@@ -351,16 +404,22 @@ class TlsAgent : public PollTarget {
   bool falsestart_enabled_;
   uint16_t expected_version_;
   uint16_t expected_cipher_suite_;
   bool expect_resumption_;
   bool expect_client_auth_;
   bool can_falsestart_hook_called_;
   bool sni_hook_called_;
   bool auth_certificate_hook_called_;
+  size_t alert_received_count_;
+  size_t expected_alert_received_count_;
+  SSLAlert last_alert_received_;
+  size_t alert_sent_count_;
+  size_t expected_alert_sent_count_;
+  SSLAlert last_alert_sent_;
   bool handshake_callback_called_;
   SSLChannelInfo info_;
   SSLCipherSuiteInfo csinfo_;
   SSLVersionRange vrange_;
   PRErrorCode error_code_;
   size_t send_ctr_;
   size_t recv_ctr_;
   bool expect_readwrite_error_;
--- a/gtests/ssl_gtest/tls_connect.cc
+++ b/gtests/ssl_gtest/tls_connect.cc
@@ -295,16 +295,19 @@ void TlsConnectTestBase::CheckConnected(
     session_ids_.push_back(sid_c1);
   }
 
   CheckExtendedMasterSecret();
   CheckEarlyDataAccepted();
   CheckResumption(expected_resumption_mode_);
   client_->CheckSecretsDestroyed();
   server_->CheckSecretsDestroyed();
+
+  client_->CheckAlerts();
+  server_->CheckAlerts();
 }
 
 void TlsConnectTestBase::CheckKeys(SSLKEAType kea_type, SSLNamedGroup kea_group,
                                    SSLAuthType auth_type,
                                    SSLSignatureScheme sig_scheme) const {
   client_->CheckKEA(kea_type, kea_group);
   server_->CheckKEA(kea_type, kea_group);
   client_->CheckAuthType(auth_type, sig_scheme);
--- a/lib/ssl/ssl.def
+++ b/lib/ssl/ssl.def
@@ -222,8 +222,15 @@ SSL_SignatureSchemePrefGet;
 ;+*;
 ;+};
 ;+NSS_3.30 {    # NSS 3.30 release
 ;+    global:
 SSL_SetSessionTicketKeyPair;
 ;+    local:
 ;+*;
 ;+};
+;+NSS_3.30.0.1 { # Additional symbols for NSS 3.30 release
+;+    global:
+SSL_AlertReceivedCallback;
+SSL_AlertSentCallback;
+;+    local:
+;+*;
+;+};
--- a/lib/ssl/ssl.h
+++ b/lib/ssl/ssl.h
@@ -815,16 +815,35 @@ SSL_IMPORT PRFileDesc *SSL_ReconfigFD(PR
 /*
  * Set the client side argument for SSL to retrieve PKCS #11 pin.
  *  fd - the file descriptor for the connection in question
  *  a - pkcs11 application specific data
  */
 SSL_IMPORT SECStatus SSL_SetPKCS11PinArg(PRFileDesc *fd, void *a);
 
 /*
+** These are callbacks for dealing with SSL alerts.
+ */
+
+typedef PRUint8 SSLAlertLevel;
+typedef PRUint8 SSLAlertDescription;
+
+typedef struct {
+    SSLAlertLevel level;
+    SSLAlertDescription description;
+} SSLAlert;
+
+typedef void(PR_CALLBACK *SSLAlertCallback)(const PRFileDesc *fd, void *arg,
+                                            const SSLAlert *alert);
+
+SSL_IMPORT SECStatus SSL_AlertReceivedCallback(PRFileDesc *fd, SSLAlertCallback cb,
+                                               void *arg);
+SSL_IMPORT SECStatus SSL_AlertSentCallback(PRFileDesc *fd, SSLAlertCallback cb,
+                                           void *arg);
+/*
 ** This is a callback for dealing with server certs that are not authenticated
 ** by the client.  The client app can decide that it actually likes the
 ** cert by some external means and restart the connection.
 **
 ** The bad cert hook must return SECSuccess to override the result of the
 ** authenticate certificate hook, SECFailure if the certificate should still be
 ** considered invalid, or SECWouldBlock if the application will authenticate
 ** the certificate asynchronously. SECWouldBlock is only supported for
--- a/lib/ssl/ssl3con.c
+++ b/lib/ssl/ssl3con.c
@@ -3149,16 +3149,20 @@ SSL3_SendAlert(sslSocket *ss, SSL3AlertL
     }
     if (level == alert_fatal) {
         ss->ssl3.fatalAlertSent = PR_TRUE;
     }
     ssl_ReleaseXmitBufLock(ss);
     if (needHsLock) {
         ssl_ReleaseSSL3HandshakeLock(ss);
     }
+    if (rv == SECSuccess && ss->alertSentCallback) {
+        SSLAlert alert = { level, desc };
+        ss->alertSentCallback(ss->fd, ss->alertSentCallbackArg, &alert);
+    }
     return rv; /* error set by ssl3_FlushHandshake or ssl3_SendRecord */
 }
 
 /*
  * Send illegal_parameter alert.  Set generic error number.
  */
 static SECStatus
 ssl3_IllegalParameter(sslSocket *ss)
@@ -3261,16 +3265,21 @@ ssl3_HandleAlert(sslSocket *ss, sslBuffe
         return SECFailure;
     }
     level = (SSL3AlertLevel)buf->buf[0];
     desc = (SSL3AlertDescription)buf->buf[1];
     buf->len = 0;
     SSL_TRC(5, ("%d: SSL3[%d] received alert, level = %d, description = %d",
                 SSL_GETPID(), ss->fd, level, desc));
 
+    if (ss->alertReceivedCallback) {
+        SSLAlert alert = { level, desc };
+        ss->alertReceivedCallback(ss->fd, ss->alertReceivedCallbackArg, &alert);
+    }
+
     switch (desc) {
         case close_notify:
             ss->recvdCloseNotify = 1;
             error = SSL_ERROR_CLOSE_NOTIFY_ALERT;
             break;
         case unexpected_message:
             error = SSL_ERROR_HANDSHAKE_UNEXPECTED_ALERT;
             break;
--- a/lib/ssl/sslimpl.h
+++ b/lib/ssl/sslimpl.h
@@ -1129,16 +1129,20 @@ struct sslSocketStr {
 
     /* Callbacks */
     SSLAuthCertificate authCertificate;
     void *authCertificateArg;
     SSLGetClientAuthData getClientAuthData;
     void *getClientAuthDataArg;
     SSLSNISocketConfig sniSocketConfig;
     void *sniSocketConfigArg;
+    SSLAlertCallback alertReceivedCallback;
+    void *alertReceivedCallbackArg;
+    SSLAlertCallback alertSentCallback;
+    void *alertSentCallbackArg;
     SSLBadCertHandler handleBadCert;
     void *badCertArg;
     SSLHandshakeCallback handshakeCallback;
     void *handshakeCallbackData;
     SSLCanFalseStartCallback canFalseStartCallback;
     void *canFalseStartCallbackData;
     void *pkcs11PinArg;
     SSLNextProtoCallback nextProtoCallback;
--- a/lib/ssl/sslsecur.c
+++ b/lib/ssl/sslsecur.c
@@ -989,16 +989,52 @@ done:
 
 int
 ssl_SecureWrite(sslSocket *ss, const unsigned char *buf, int len)
 {
     return ssl_SecureSend(ss, buf, len, 0);
 }
 
 SECStatus
+SSL_AlertReceivedCallback(PRFileDesc *fd, SSLAlertCallback cb, void *arg)
+{
+    sslSocket *ss;
+
+    ss = ssl_FindSocket(fd);
+    if (!ss) {
+        SSL_DBG(("%d: SSL[%d]: unable to find socket in SSL_AlertReceivedCallback",
+                 SSL_GETPID(), fd));
+        return SECFailure;
+    }
+
+    ss->alertReceivedCallback = cb;
+    ss->alertReceivedCallbackArg = arg;
+
+    return SECSuccess;
+}
+
+SECStatus
+SSL_AlertSentCallback(PRFileDesc *fd, SSLAlertCallback cb, void *arg)
+{
+    sslSocket *ss;
+
+    ss = ssl_FindSocket(fd);
+    if (!ss) {
+        SSL_DBG(("%d: SSL[%d]: unable to find socket in SSL_AlertSentCallback",
+                 SSL_GETPID(), fd));
+        return SECFailure;
+    }
+
+    ss->alertSentCallback = cb;
+    ss->alertSentCallbackArg = arg;
+
+    return SECSuccess;
+}
+
+SECStatus
 SSL_BadCertHook(PRFileDesc *fd, SSLBadCertHandler f, void *arg)
 {
     sslSocket *ss;
 
     ss = ssl_FindSocket(fd);
     if (!ss) {
         SSL_DBG(("%d: SSL[%d]: bad socket in SSLBadCertHook",
                  SSL_GETPID(), fd));
--- a/lib/ssl/sslsock.c
+++ b/lib/ssl/sslsock.c
@@ -325,16 +325,20 @@ ssl_DupSocket(sslSocket *os)
          * XXX We should detect this, and not just march on with NULL pointers.
          */
         ss->authCertificate = os->authCertificate;
         ss->authCertificateArg = os->authCertificateArg;
         ss->getClientAuthData = os->getClientAuthData;
         ss->getClientAuthDataArg = os->getClientAuthDataArg;
         ss->sniSocketConfig = os->sniSocketConfig;
         ss->sniSocketConfigArg = os->sniSocketConfigArg;
+        ss->alertReceivedCallback = os->alertReceivedCallback;
+        ss->alertReceivedCallbackArg = os->alertReceivedCallbackArg;
+        ss->alertSentCallback = os->alertSentCallback;
+        ss->alertSentCallbackArg = os->alertSentCallbackArg;
         ss->handleBadCert = os->handleBadCert;
         ss->badCertArg = os->badCertArg;
         ss->handshakeCallback = os->handshakeCallback;
         ss->handshakeCallbackData = os->handshakeCallbackData;
         ss->canFalseStartCallback = os->canFalseStartCallback;
         ss->canFalseStartCallbackData = os->canFalseStartCallbackData;
         ss->pkcs11PinArg = os->pkcs11PinArg;
         ss->nextProtoCallback = os->nextProtoCallback;
@@ -2143,16 +2147,24 @@ SSL_ReconfigFD(PRFileDesc *model, PRFile
     if (sm->getClientAuthData)
         ss->getClientAuthData = sm->getClientAuthData;
     if (sm->getClientAuthDataArg)
         ss->getClientAuthDataArg = sm->getClientAuthDataArg;
     if (sm->sniSocketConfig)
         ss->sniSocketConfig = sm->sniSocketConfig;
     if (sm->sniSocketConfigArg)
         ss->sniSocketConfigArg = sm->sniSocketConfigArg;
+    if (ss->alertReceivedCallback) {
+        ss->alertReceivedCallback = sm->alertReceivedCallback;
+        ss->alertReceivedCallbackArg = sm->alertReceivedCallbackArg;
+    }
+    if (ss->alertSentCallback) {
+        ss->alertSentCallback = sm->alertSentCallback;
+        ss->alertSentCallbackArg = sm->alertSentCallbackArg;
+    }
     if (sm->handleBadCert)
         ss->handleBadCert = sm->handleBadCert;
     if (sm->badCertArg)
         ss->badCertArg = sm->badCertArg;
     if (sm->handshakeCallback)
         ss->handshakeCallback = sm->handshakeCallback;
     if (sm->handshakeCallbackData)
         ss->handshakeCallbackData = sm->handshakeCallbackData;
@@ -3685,16 +3697,20 @@ ssl_NewSocket(PRBool makeLocks, SSLProto
     ss->dbHandle = CERT_GetDefaultCertDB();
 
     /* Provide default implementation of hooks */
     ss->authCertificate = SSL_AuthCertificate;
     ss->authCertificateArg = (void *)ss->dbHandle;
     ss->sniSocketConfig = NULL;
     ss->sniSocketConfigArg = NULL;
     ss->getClientAuthData = NULL;
+    ss->alertReceivedCallback = NULL;
+    ss->alertReceivedCallbackArg = NULL;
+    ss->alertSentCallback = NULL;
+    ss->alertSentCallbackArg = NULL;
     ss->handleBadCert = NULL;
     ss->badCertArg = NULL;
     ss->pkcs11PinArg = NULL;
 
     ssl_ChooseOps(ss);
     ssl3_InitSocketPolicy(ss);
     for (i = 0; i < SSL_NAMED_GROUP_COUNT; ++i) {
         ss->namedGroupPreferences[i] = &ssl_named_groups[i];