Bug 1485883 - Handle use_srtp extension in gecko, r=drno
authorMartin Thomson <martin.thomson@gmail.com>
Thu, 23 Aug 2018 16:03:00 +1000
changeset 435178 2f345fa70dc0556eff561511bd61be23f366113b
parent 435177 709a197ccc2e5959d14df8e049ce60da3c727f5a
child 435179 9d2183e99373aeff0d5ff976d7d7e23505641bc9
push id34598
push userdluca@mozilla.com
push dateFri, 07 Sep 2018 16:36:02 +0000
treeherdermozilla-central@b8905df54d01 [default view] [failures only]
perfherder[talos] [build metrics] [platform microbench] (compared to previous push)
reviewersdrno
bugs1485883
milestone64.0a1
first release with
nightly linux32
nightly linux64
nightly mac
nightly win32
nightly win64
last release without
nightly linux32
nightly linux64
nightly mac
nightly win32
nightly win64
Bug 1485883 - Handle use_srtp extension in gecko, r=drno Summary: This implements the SRTP extension in TransportLayerDtls. My hope is that we can expunge the SRTP code from NSS in a few releases. Reviewers: drno Subscribers: ekr Tags: #secure-revision Bug #: 1485883 Differential Revision: https://phabricator.services.mozilla.com/D4188 MozReview-Commit-ID: Cwjrn9wsCQr
media/mtransport/test/transport_unittests.cpp
media/mtransport/transportlayerdtls.cpp
media/mtransport/transportlayerdtls.h
--- a/media/mtransport/test/transport_unittests.cpp
+++ b/media/mtransport/test/transport_unittests.cpp
@@ -5,25 +5,27 @@
  * You can obtain one at http://mozilla.org/MPL/2.0/. */
 
 // Original author: ekr@rtfm.com
 
 #include <iostream>
 #include <string>
 #include <map>
 #include <algorithm>
+#include <functional>
 
 #include "mozilla/UniquePtr.h"
 
 #include "sigslot.h"
 
 #include "logging.h"
 #include "nspr.h"
 #include "nss.h"
 #include "ssl.h"
+#include "sslexp.h"
 #include "sslproto.h"
 
 #include "nsThreadUtils.h"
 #include "nsXPCOM.h"
 
 #include "mediapacket.h"
 #include "dtlsidentity.h"
 #include "nricectxhandler.h"
@@ -444,17 +446,16 @@ class TransportTestPeer : public sigslot
         dtls_(new TransportLayerDtls()),
         identity_(DtlsIdentity::Generate()),
         ice_ctx_(NrIceCtxHandler::Create(name)),
         streams_(), candidates_(),
         peer_(nullptr),
         gathering_complete_(false),
         enabled_cipersuites_(),
         disabled_cipersuites_(),
-        reuse_dhe_key_(false),
         test_utils_(utils) {
     std::vector<NrIceStunServer> stun_servers;
     UniquePtr<NrIceStunServer> server(NrIceStunServer::Create(
         std::string((char *)"stun.services.mozilla.com"), 3478));
     stun_servers.push_back(*server);
     EXPECT_TRUE(NS_SUCCEEDED(ice_ctx_->ctx()->SetStunServers(stun_servers)));
 
     dtls_->SetIdentity(identity_);
@@ -565,25 +566,18 @@ class TransportTestPeer : public sigslot
     flow_->PushLayer(loopback_);
     flow_->PushLayer(logging_);
     flow_->PushLayer(lossy_);
     flow_->PushLayer(dtls_);
 
     if (dtls_->state() != TransportLayer::TS_ERROR) {
       // Don't execute these blocks if DTLS didn't initialize.
       TweakCiphers(dtls_->internal_fd());
-      if (reuse_dhe_key_) {
-        // TransportLayerDtls automatically sets this pref to false
-        // so set it back for test.
-        // This is pretty gross. Dig directly into the NSS FD. The problem
-        // is that we are testing a feature which TransaportLayerDtls doesn't
-        // expose.
-        SECStatus rv = SSL_OptionSet(dtls_->internal_fd(),
-                                     SSL_REUSE_SERVER_ECDHE_KEY, PR_TRUE);
-        ASSERT_EQ(SECSuccess, rv);
+      if (post_setup_) {
+        post_setup_(dtls_->internal_fd());
       }
     }
 
     dtls_->SignalPacketReceived.connect(this, &TransportTestPeer::PacketReceived);
   }
 
   void TweakCiphers(PRFileDesc* fd) {
     for (unsigned short& enabled_cipersuite : enabled_cipersuites_) {
@@ -769,18 +763,18 @@ class TransportTestPeer : public sigslot
   }
 
   void SetCipherSuiteChanges(const std::vector<uint16_t>& enableThese,
                              const std::vector<uint16_t>& disableThese) {
     disabled_cipersuites_ = disableThese;
     enabled_cipersuites_ = enableThese;
   }
 
-  void SetReuseECDHEKey() {
-    reuse_dhe_key_ = true;
+  void SetPostSetup(const std::function<void(PRFileDesc*)>& setup) {
+    post_setup_ = std::move(setup);
   }
 
   TransportLayer::State state() {
     TransportLayer::State tstate;
 
     RUN_ON_THREAD(test_utils_->sts_target(),
                   WrapRunnableRet(&tstate, dtls_, &TransportLayer::state));
 
@@ -841,18 +835,18 @@ class TransportTestPeer : public sigslot
   std::vector<RefPtr<NrIceMediaStream> > streams_;
   std::map<std::string, std::vector<std::string> > candidates_;
   TransportTestPeer *peer_;
   bool gathering_complete_;
   unsigned char fingerprint_[TransportLayerDtls::kMaxDigestLength];
   size_t fingerprint_len_;
   std::vector<uint16_t> enabled_cipersuites_;
   std::vector<uint16_t> disabled_cipersuites_;
-  bool reuse_dhe_key_;
   MtransportTestUtils* test_utils_;
+  std::function<void(PRFileDesc* fd)> post_setup_ = nullptr;
 };
 
 
 class TransportTest : public MtransportTest {
  public:
   TransportTest() {
     fds_[0] = nullptr;
     fds_[1] = nullptr;
@@ -1158,32 +1152,43 @@ TEST_F(TransportTest, TestConnectVerifyN
 
   // Now compare these two to see if they are the same.
   ASSERT_FALSE((dhe1.public_key_.len() == dhe2.public_key_.len()) &&
                (!memcmp(dhe1.public_key_.data(), dhe2.public_key_.data(),
                         dhe1.public_key_.len())));
 }
 
 TEST_F(TransportTest, TestConnectVerifyReusedECDHE) {
+
+  auto set_reuse_ecdhe_key = [](PRFileDesc* fd) {
+    // TransportLayerDtls automatically sets this pref to false
+    // so set it back for test.
+    // This is pretty gross. Dig directly into the NSS FD. The problem
+    // is that we are testing a feature which TransaportLayerDtls doesn't
+    // expose.
+    SECStatus rv = SSL_OptionSet(fd, SSL_REUSE_SERVER_ECDHE_KEY, PR_TRUE);
+    ASSERT_EQ(SECSuccess, rv);
+  };
+
   SetDtlsPeer();
   DtlsInspectorRecordHandshakeMessage *i1 = new
     DtlsInspectorRecordHandshakeMessage(kTlsHandshakeServerKeyExchange);
   p1_->SetInspector(i1);
-  p1_->SetReuseECDHEKey();
+  p1_->SetPostSetup(set_reuse_ecdhe_key);
   ConnectSocket();
   TlsServerKeyExchangeECDHE dhe1;
   ASSERT_TRUE(dhe1.Parse(i1->buffer().data(), i1->buffer().len()));
 
   Reset();
   SetDtlsPeer();
   DtlsInspectorRecordHandshakeMessage *i2 = new
     DtlsInspectorRecordHandshakeMessage(kTlsHandshakeServerKeyExchange);
 
   p1_->SetInspector(i2);
-  p1_->SetReuseECDHEKey();
+  p1_->SetPostSetup(set_reuse_ecdhe_key);
 
   ConnectSocket();
   TlsServerKeyExchangeECDHE dhe2;
   ASSERT_TRUE(dhe2.Parse(i2->buffer().data(), i2->buffer().len()));
 
   // Now compare these two to see if they are the same.
   ASSERT_EQ(dhe1.public_key_.len(), dhe2.public_key_.len());
   ASSERT_TRUE(!memcmp(dhe1.public_key_.data(), dhe2.public_key_.data(),
@@ -1295,22 +1300,122 @@ TEST_F(TransportTest, TestSrtpMismatch) 
   std::vector<uint16_t> setA;
   setA.push_back(SRTP_AES128_CM_HMAC_SHA1_80);
   std::vector<uint16_t> setB;
   setB.push_back(SRTP_AES128_CM_HMAC_SHA1_32);
 
   p1_->SetSrtpCiphers(setA);
   p2_->SetSrtpCiphers(setB);
   SetDtlsPeer();
-  ConnectSocket();
+  ConnectSocketExpectFail();
 
   ASSERT_EQ(0, p1_->srtpCipher());
   ASSERT_EQ(0, p2_->srtpCipher());
 }
 
+static SECStatus NoopXtnHandler(PRFileDesc* fd, SSLHandshakeType message,
+                                const uint8_t* data, unsigned int len,
+                                SSLAlertDescription* alert, void* arg) {
+  return SECSuccess;
+}
+
+static PRBool WriteFixedXtn(PRFileDesc* fd, SSLHandshakeType message,
+                            uint8_t* data, unsigned int* len,
+                            unsigned int max_len, void* arg) {
+  // When we enable TLS 1.3, change ssl_hs_server_hello here to
+  // ssl_hs_encrypted_extensions.  At the same time, add a test that writes to
+  // ssl_hs_server_hello, which should fail.
+  if (message != ssl_hs_client_hello && message != ssl_hs_server_hello) {
+    return false;
+  }
+
+  auto v = reinterpret_cast<std::vector<uint8_t>*>(arg);
+  memcpy(data, &((*v)[0]), v->size());
+  *len = v->size();
+  return true;
+}
+
+// Note that |value| needs to be readable after this function returns.
+static void InstallBadSrtpExtensionWriter(TransportTestPeer* peer,
+                                          std::vector<uint8_t>* value) {
+  peer->SetPostSetup([value](PRFileDesc* fd) {
+      // Override the handler that is installed by the DTLS setup.
+      SECStatus rv = SSL_InstallExtensionHooks(
+          fd, ssl_use_srtp_xtn, WriteFixedXtn, value, NoopXtnHandler, nullptr);
+      ASSERT_EQ(SECSuccess, rv);
+    });
+}
+
+TEST_F(TransportTest, TestSrtpErrorServerSendsTwoSrtpCiphers) {
+  // Server (p1_) sends an extension with two values, and empty MKI.
+  std::vector<uint8_t> xtn = { 0x04, 0x00, 0x01, 0x00, 0x02, 0x00 };
+  InstallBadSrtpExtensionWriter(p1_, &xtn);
+  SetupSrtp();
+  SetDtlsPeer();
+  ConnectSocketExpectFail();
+}
+
+TEST_F(TransportTest, TestSrtpErrorServerSendsTwoMki) {
+  // Server (p1_) sends an MKI.
+  std::vector<uint8_t> xtn = { 0x02, 0x00, 0x01, 0x01, 0x00 };
+  InstallBadSrtpExtensionWriter(p1_, &xtn);
+  SetupSrtp();
+  SetDtlsPeer();
+  ConnectSocketExpectFail();
+}
+
+TEST_F(TransportTest, TestSrtpErrorServerSendsUnknownValue) {
+  std::vector<uint8_t> xtn = { 0x02, 0x9a, 0xf1, 0x00 };
+  InstallBadSrtpExtensionWriter(p1_, &xtn);
+  SetupSrtp();
+  SetDtlsPeer();
+  ConnectSocketExpectFail();
+}
+
+TEST_F(TransportTest, TestSrtpErrorServerSendsOverflow) {
+  std::vector<uint8_t> xtn = { 0x32, 0x00, 0x01, 0x00 };
+  InstallBadSrtpExtensionWriter(p1_, &xtn);
+  SetupSrtp();
+  SetDtlsPeer();
+  ConnectSocketExpectFail();
+}
+
+TEST_F(TransportTest, TestSrtpErrorServerSendsUnevenList) {
+  std::vector<uint8_t> xtn = { 0x01, 0x00, 0x00 };
+  InstallBadSrtpExtensionWriter(p1_, &xtn);
+  SetupSrtp();
+  SetDtlsPeer();
+  ConnectSocketExpectFail();
+}
+
+TEST_F(TransportTest, TestSrtpErrorClientSendsUnevenList) {
+  std::vector<uint8_t> xtn = { 0x01, 0x00, 0x00 };
+  InstallBadSrtpExtensionWriter(p2_, &xtn);
+  SetupSrtp();
+  SetDtlsPeer();
+  ConnectSocketExpectFail();
+}
+
+TEST_F(TransportTest, OnlyServerSendsSrtpXtn) {
+  p1_->SetupSrtp();
+  SetDtlsPeer();
+  ConnectSocketExpectState(TransportLayer::TS_ERROR,
+                           TransportLayer::TS_CLOSED);
+}
+
+TEST_F(TransportTest, OnlyClientSendsSrtpXtn) {
+  p2_->SetupSrtp();
+  SetDtlsPeer();
+  // This means that the server won't semd the extension as well.  The server
+  // (p1) thinks that everything is OK.  The client (p2) notices the problem
+  // after connecting and aborts.
+  ConnectSocketExpectState(TransportLayer::TS_CLOSED,
+                           TransportLayer::TS_ERROR);
+}
+
 // NSS doesn't support DHE suites on the server end.
 // This checks to see if we barf when that's the only option available.
 TEST_F(TransportTest, TestDheOnlyFails) {
   SetDtlsPeer();
 
   // p2_ is the client
   // setting this on p1_ (the server) causes NSS to assert
   ConfigureOneCipher(p2_, TLS_DHE_RSA_WITH_AES_128_CBC_SHA);
--- a/media/mtransport/transportlayerdtls.cpp
+++ b/media/mtransport/transportlayerdtls.cpp
@@ -19,18 +19,18 @@
 #include "mozilla/UniquePtr.h"
 #include "mozilla/Unused.h"
 #include "nsCOMPtr.h"
 #include "nsComponentManagerUtils.h"
 #include "nsComponentManagerUtils.h"
 #include "nsIEventTarget.h"
 #include "nsNetCID.h"
 #include "nsServiceManagerUtils.h"
-#include "ssl.h"
 #include "sslerr.h"
+#include "sslexp.h"
 #include "sslproto.h"
 #include "transportflow.h"
 
 
 namespace mozilla {
 
 MOZ_MTLOG_MODULE("mtransport")
 
@@ -739,26 +739,26 @@ static const uint32_t DisabledCiphers[] 
   TLS_ECDHE_RSA_WITH_NULL_SHA,
   TLS_ECDH_ECDSA_WITH_NULL_SHA,
   TLS_ECDH_RSA_WITH_NULL_SHA,
   TLS_RSA_WITH_NULL_SHA,
   TLS_RSA_WITH_NULL_SHA256,
   TLS_RSA_WITH_NULL_MD5,
 };
 
-bool TransportLayerDtls::SetupCipherSuites(UniquePRFileDesc& ssl_fd) const {
+bool TransportLayerDtls::SetupCipherSuites(UniquePRFileDesc& ssl_fd) {
   SECStatus rv;
 
   // Set the SRTP ciphers
-  if (!srtp_ciphers_.empty()) {
-    // Note: std::vector is guaranteed to contiguous
-    rv = SSL_SetSRTPCiphers(ssl_fd.get(), &srtp_ciphers_[0],
-                            srtp_ciphers_.size());
+  if (!enabled_srtp_ciphers_.empty()) {
+    rv = SSL_InstallExtensionHooks(ssl_fd.get(), ssl_use_srtp_xtn,
+                                   TransportLayerDtls::WriteSrtpXtn, this,
+                                   TransportLayerDtls::HandleSrtpXtn, this);
     if (rv != SECSuccess) {
-      MOZ_MTLOG(ML_ERROR, "Couldn't set SRTP cipher suite");
+      MOZ_MTLOG(ML_ERROR, LAYER_INFO << "unable to set SRTP extension handler");
       return false;
     }
   }
 
   for (const auto& cipher : EnabledCiphers) {
     MOZ_MTLOG(ML_DEBUG, LAYER_INFO << "Enabling: " << cipher);
     rv = SSL_CipherPrefSet(ssl_fd.get(), cipher, PR_TRUE);
     if (rv != SECSuccess) {
@@ -878,16 +878,22 @@ void TransportLayerDtls::Handshake() {
     if (!CheckAlpn()) {
       // Despite connecting, the connection doesn't have a valid ALPN label.
       // Forcibly close the connection so that the peer isn't left hanging
       // (assuming the close_notify isn't dropped).
       ssl_fd_ = nullptr;
       TL_SET_STATE(TS_ERROR);
       return;
     }
+    if (!enabled_srtp_ciphers_.empty() && srtp_cipher_ == 0) {
+      // We enabled SRTP, but got no cipher, this should have failed.
+      ssl_fd_ = nullptr;
+      TL_SET_STATE(TS_ERROR);
+      return;
+    }
 
     TL_SET_STATE(TS_OPEN);
 
     RecordCipherTelemetry();
   } else {
     int32_t err = PR_GetError();
     switch(err) {
       case SSL_ERROR_RX_MALFORMED_HANDSHAKE:
@@ -1136,35 +1142,217 @@ SECStatus TransportLayerDtls::GetClientA
     *pRetCert = nullptr;
     PR_SetError(PR_OUT_OF_MEMORY_ERROR, 0);
     return SECFailure;
   }
 
   return SECSuccess;
 }
 
-nsresult TransportLayerDtls::SetSrtpCiphers(std::vector<uint16_t> ciphers) {
-  // TODO: We should check these
-  srtp_ciphers_ = ciphers;
-
+nsresult TransportLayerDtls::SetSrtpCiphers(const std::vector<uint16_t>& ciphers) {
+  enabled_srtp_ciphers_ = std::move(ciphers);
   return NS_OK;
 }
 
 nsresult TransportLayerDtls::GetSrtpCipher(uint16_t *cipher) const {
   CheckThread();
-  if (state_ != TS_OPEN) {
+  if (srtp_cipher_ == 0) {
     return NS_ERROR_NOT_AVAILABLE;
   }
-  SECStatus rv = SSL_GetSRTPCipher(ssl_fd_.get(), cipher);
+  *cipher = srtp_cipher_;
+  return NS_OK;
+}
+
+static uint8_t* WriteUint16(uint8_t* cursor, uint16_t v) {
+  *cursor++ = v >> 8;
+  *cursor++ = v & 0xff;
+  return cursor;
+}
+
+static SSLHandshakeType SrtpXtnServerMessage(PRFileDesc* fd) {
+  SSLPreliminaryChannelInfo preinfo;
+  SECStatus rv = SSL_GetPreliminaryChannelInfo(fd, &preinfo, sizeof(preinfo));
   if (rv != SECSuccess) {
-    MOZ_MTLOG(ML_DEBUG, "No SRTP cipher negotiated");
-    return NS_ERROR_FAILURE;
+    MOZ_ASSERT(false, "Can't get version info");
+    return ssl_hs_client_hello;
+  }
+  return (preinfo.protocolVersion >= SSL_LIBRARY_VERSION_TLS_1_3)
+      ? ssl_hs_encrypted_extensions
+      : ssl_hs_server_hello;
+}
+
+/* static */ PRBool TransportLayerDtls::WriteSrtpXtn(
+    PRFileDesc* fd, SSLHandshakeType message, uint8_t* data,
+    unsigned int* len, unsigned int max_len, void* arg) {
+  auto self = reinterpret_cast<TransportLayerDtls*>(arg);
+
+  // ClientHello: send all supported versions.
+  if (message == ssl_hs_client_hello) {
+    MOZ_ASSERT(self->role_ == CLIENT);
+    MOZ_ASSERT(self->enabled_srtp_ciphers_.size(), "Haven't enabled SRTP");
+    // We will take 2 octets for each cipher, plus a 2 octet length and 1 octet
+    // for the length of the empty MKI.
+    if (max_len < self->enabled_srtp_ciphers_.size() * 2 + 3) {
+      MOZ_ASSERT(false, "Not enough space to send SRTP extension");
+      return false;
+    }
+    uint8_t* cursor = WriteUint16(data, self->enabled_srtp_ciphers_.size() * 2);
+    for (auto cs : self->enabled_srtp_ciphers_) {
+      cursor = WriteUint16(cursor, cs);
+    }
+    *cursor++ = 0; // MKI is empty
+    *len = cursor - data;
+    return true;
+  }
+
+  if (message == SrtpXtnServerMessage(fd)) {
+    MOZ_ASSERT(self->role_ == SERVER);
+    if (!self->srtp_cipher_) {
+      // Not negotiated. Definitely bad, but the connection can fail later.
+      return false;
+    }
+    if (max_len < 5) {
+      MOZ_ASSERT(false, "Not enough space to send SRTP extension");
+      return false;
+    }
+
+    uint8_t* cursor = WriteUint16(data, 2); // Length = 2.
+    cursor = WriteUint16(cursor, self->srtp_cipher_);
+    *cursor++ = 0; // No MKI
+    *len = cursor - data;
+    return true;
+  }
+
+  return false;
+}
+
+class TlsParser {
+ public:
+  TlsParser(const uint8_t* data, size_t len)
+      : cursor_(data), remaining_(len) {}
+
+  bool error() const { return error_; }
+  size_t remaining() const { return remaining_; }
+
+  template<typename T,
+           class = typename std::enable_if<std::is_unsigned<T>::value>::type>
+  void Read(T* v, size_t sz = sizeof(T)) {
+    MOZ_ASSERT(sz <= sizeof(T), "Type is too small to hold the value requested");
+    if (remaining_ < sz) {
+      error_ = true;
+      return;
+    }
+
+    T result = 0;
+    for (size_t i = 0; i < sz; ++i) {
+      result = (result << 8) | *cursor_++;
+      remaining_--;
+    }
+    *v = result;
   }
 
-  return NS_OK;
+  template<typename T,
+           class = typename std::enable_if<std::is_unsigned<T>::value>::type>
+  void ReadVector(std::vector<T>* v, size_t w) {
+    MOZ_ASSERT(v->empty(), "vector needs to be empty");
+
+    uint32_t len;
+    Read(&len, w);
+    if (error_ || len % sizeof(T) != 0 || len > remaining_) {
+      error_ = true;
+      return;
+    }
+
+    size_t count = len / sizeof(T);
+    v->reserve(count);
+    for (T i = 0; !error_ && i < count; ++i) {
+      T item;
+      Read(&item);
+      if (!error_) {
+        v->push_back(item);
+      }
+    }
+  }
+
+  void Skip(size_t n) {
+    if (remaining_ < n) {
+      error_ = true;
+    } else {
+      cursor_ += n;
+      remaining_ -= n;
+    }
+  }
+
+  size_t SkipVector(size_t w) {
+    uint32_t len = 0;
+    Read(&len, w);
+    Skip(len);
+    return len;
+  }
+
+ private:
+  const uint8_t* cursor_;
+  size_t remaining_;
+  bool error_ = false;
+};
+
+/* static */ SECStatus TransportLayerDtls::HandleSrtpXtn(
+    PRFileDesc* fd, SSLHandshakeType message, const uint8_t* data,
+    unsigned int len, SSLAlertDescription* alert, void* arg) {
+  static const uint8_t kTlsAlertHandshakeFailure = 40;
+  static const uint8_t kTlsAlertIllegalParameter = 47;
+  static const uint8_t kTlsAlertDecodeError = 50;
+  static const uint8_t kTlsAlertUnsupportedExtension = 110;
+
+  auto self = reinterpret_cast<TransportLayerDtls*>(arg);
+
+  // Parse the extension.
+  TlsParser parser(data, len);
+  std::vector<uint16_t> advertised;
+  parser.ReadVector(&advertised, 2);
+  size_t mki_len = parser.SkipVector(1);
+  if (parser.error() || parser.remaining() > 0) {
+    *alert = kTlsAlertDecodeError;
+    return SECFailure;
+  }
+
+  if (message == ssl_hs_client_hello) {
+    MOZ_ASSERT(self->role_ == SERVER);
+    if (self->enabled_srtp_ciphers_.empty()) {
+      // We don't have SRTP enabled, which is probably bad, but no sense in
+      // having the handshake fail at this point, let the client decide if this
+      // is a problem.
+      return SECSuccess;
+    }
+
+    for (auto supported : self->enabled_srtp_ciphers_) {
+      auto it = std::find(advertised.begin(), advertised.end(), supported);
+      if (it != advertised.end()) {
+        self->srtp_cipher_ = supported;
+        return SECSuccess;
+      }
+    }
+
+    // No common cipher.
+    *alert = kTlsAlertHandshakeFailure;
+    return SECFailure;
+  }
+
+  if (message == SrtpXtnServerMessage(fd)) {
+    MOZ_ASSERT(self->role_ == CLIENT);
+    if (advertised.size() != 1 || mki_len > 0) {
+      *alert = kTlsAlertIllegalParameter;
+      return SECFailure;
+    }
+    self->srtp_cipher_ = advertised[0];
+    return SECSuccess;
+  }
+
+  *alert = kTlsAlertUnsupportedExtension;
+  return SECFailure;
 }
 
 nsresult TransportLayerDtls::ExportKeyingMaterial(const std::string& label,
                                                   bool use_context,
                                                   const std::string& context,
                                                   unsigned char *out,
                                                   unsigned int outlen) {
   CheckThread();
--- a/media/mtransport/transportlayerdtls.h
+++ b/media/mtransport/transportlayerdtls.h
@@ -19,16 +19,17 @@
 #include "mozilla/TimeStamp.h"
 #include "nsCOMPtr.h"
 #include "nsIEventTarget.h"
 #include "nsITimer.h"
 #include "ScopedNSSTypes.h"
 #include "m_cpp_utils.h"
 #include "dtlsidentity.h"
 #include "transportlayer.h"
+#include "ssl.h"
 
 namespace mozilla {
 
 struct Packet;
 
 class TransportLayerNSPRAdapter {
  public:
   explicit TransportLayerNSPRAdapter(TransportLayer *output) :
@@ -46,22 +47,17 @@ class TransportLayerNSPRAdapter {
 
   TransportLayer *output_;
   std::queue<MediaPacket *> input_;
   bool enabled_;
 };
 
 class TransportLayerDtls final : public TransportLayer {
  public:
-  TransportLayerDtls() :
-      role_(CLIENT),
-      verification_mode_(VERIFY_UNSET),
-      ssl_fd_(nullptr),
-      auth_hook_called_(false),
-      cert_ok_(false) {}
+  TransportLayerDtls() = default;
 
   virtual ~TransportLayerDtls();
 
   enum Role { CLIENT, SERVER};
   enum Verification { VERIFY_UNSET, VERIFY_ALLOW_ALL, VERIFY_DIGEST};
   const static size_t kMaxDigestLength = HASH_LENGTH_MAX;
 
   // DTLS-specific operations
@@ -77,17 +73,17 @@ class TransportLayerDtls final : public 
 
   nsresult SetVerificationAllowAll();
   nsresult SetVerificationDigest(const std::string digest_algorithm,
                                  const unsigned char *digest_value,
                                  size_t digest_len);
 
   nsresult GetCipherSuite(uint16_t* cipherSuite) const;
 
-  nsresult SetSrtpCiphers(std::vector<uint16_t> ciphers);
+  nsresult SetSrtpCiphers(const std::vector<uint16_t>& ciphers);
   nsresult GetSrtpCipher(uint16_t *cipher) const;
 
   nsresult ExportKeyingMaterial(const std::string& label,
                                 bool use_context,
                                 const std::string& context,
                                 unsigned char *out,
                                 unsigned int outlen);
 
@@ -131,17 +127,17 @@ class TransportLayerDtls final : public 
 
    private:
     ~VerificationDigest() {}
     DISALLOW_COPY_ASSIGN(VerificationDigest);
   };
 
 
   bool Setup();
-  bool SetupCipherSuites(UniquePRFileDesc& ssl_fd) const;
+  bool SetupCipherSuites(UniquePRFileDesc& ssl_fd);
   bool SetupAlpn(UniquePRFileDesc& ssl_fd) const;
   void GetDecryptedPackets();
   void Handshake();
 
   bool CheckAlpn();
 
   static SECStatus GetClientAuthDataHook(void *arg, PRFileDesc *fd,
                                          CERTDistNames *caNames,
@@ -158,36 +154,45 @@ class TransportLayerDtls final : public 
   static void TimerCallback(nsITimer *timer, void *arg);
 
   SECStatus CheckDigest(const RefPtr<VerificationDigest>& digest,
                         UniqueCERTCertificate& cert) const;
 
   void RecordHandshakeCompletionTelemetry(TransportLayer::State endState);
   void RecordCipherTelemetry();
 
+  static PRBool WriteSrtpXtn(PRFileDesc* fd, SSLHandshakeType message,
+                             uint8_t* data, unsigned int* len,
+                             unsigned int max_len, void* arg);
+
+  static SECStatus HandleSrtpXtn(PRFileDesc* fd, SSLHandshakeType message,
+                                 const uint8_t* data, unsigned int len,
+                                 SSLAlertDescription* alert, void* arg);
+
   RefPtr<DtlsIdentity> identity_;
   // What ALPN identifiers are permitted.
   std::set<std::string> alpn_allowed_;
   // What ALPN identifier is used if ALPN is not supported.
   // The empty string indicates that ALPN is required.
   std::string alpn_default_;
   // What ALPN string was negotiated.
   std::string alpn_;
-  std::vector<uint16_t> srtp_ciphers_;
+  std::vector<uint16_t> enabled_srtp_ciphers_;
+  uint16_t srtp_cipher_ = 0;
 
-  Role role_;
-  Verification verification_mode_;
+  Role role_ = CLIENT;
+  Verification verification_mode_ = VERIFY_UNSET;
   std::vector<RefPtr<VerificationDigest> > digests_;
 
   // Must delete nspr_io_adapter after ssl_fd_ b/c ssl_fd_ causes an alert
   // (ssl_fd_ contains an un-owning pointer to nspr_io_adapter_)
-  UniquePtr<TransportLayerNSPRAdapter> nspr_io_adapter_;
-  UniquePRFileDesc ssl_fd_;
+  UniquePtr<TransportLayerNSPRAdapter> nspr_io_adapter_ = nullptr;
+  UniquePRFileDesc ssl_fd_ = nullptr;
 
-  nsCOMPtr<nsITimer> timer_;
-  bool auth_hook_called_;
-  bool cert_ok_;
+  nsCOMPtr<nsITimer> timer_ = nullptr;
+  bool auth_hook_called_ = false;
+  bool cert_ok_ = false;
   TimeStamp handshake_started_;
 };
 
 
 }  // close namespace
 #endif