Bug 996237 - Unit testing cipher suite selection happy path. r=ekr
☠☠ backed out by abc2e9abe681 ☠ ☠
authorMartin Thomson <martin.thomson@gmail.com>
Thu, 31 Jul 2014 10:47:00 -0400
changeset 197507 9a13eb125a47339e3b07b88870cf246c597b0509
parent 197506 351434e9157b1ca48eef943c62d2c1a9fdd96a6a
child 197508 e5a36d268d11f629426824a7321e5702d8426f2e
push id1
push userroot
push dateMon, 20 Oct 2014 17:29:22 +0000
reviewersekr
bugs996237
milestone34.0a1
Bug 996237 - Unit testing cipher suite selection happy path. r=ekr
media/mtransport/test/transport_unittests.cpp
media/mtransport/transportlayer.h
media/mtransport/transportlayerdtls.cpp
media/mtransport/transportlayerdtls.h
--- a/media/mtransport/test/transport_unittests.cpp
+++ b/media/mtransport/test/transport_unittests.cpp
@@ -11,16 +11,17 @@
 #include <map>
 
 #include "sigslot.h"
 
 #include "logging.h"
 #include "nspr.h"
 #include "nss.h"
 #include "ssl.h"
+#include "sslproto.h"
 
 #include "nsThreadUtils.h"
 #include "nsXPCOM.h"
 
 #include "databuffer.h"
 #include "dtlsidentity.h"
 #include "nricectx.h"
 #include "nricemediastream.h"
@@ -354,16 +355,24 @@ class TransportTestPeer : public sigslot
           peer->fingerprint_len_);
 
       ASSERT_TRUE(NS_SUCCEEDED(res));
 
       mask <<= 1;
     }
   }
 
+  void SetupSrtp() {
+    // this mimics the setup we do elsewhere
+    std::vector<uint16_t> srtp_ciphers;
+    srtp_ciphers.push_back(SRTP_AES128_CM_HMAC_SHA1_80);
+    srtp_ciphers.push_back(SRTP_AES128_CM_HMAC_SHA1_32);
+
+    ASSERT_TRUE(NS_SUCCEEDED(dtls_->SetSrtpCiphers(srtp_ciphers)));
+  }
 
   void ConnectSocket_s(TransportTestPeer *peer) {
     nsresult res;
     res = loopback_->Init();
     ASSERT_EQ((nsresult)NS_OK, res);
 
     loopback_->Connect(peer->loopback_);
 
@@ -529,16 +538,41 @@ class TransportTestPeer : public sigslot
   }
 
   bool failed() {
     return state() == TransportLayer::TS_ERROR;
   }
 
   size_t received() { return received_; }
 
+  uint16_t cipherSuite() const {
+    nsresult rv;
+    uint16_t cipher;
+    RUN_ON_THREAD(test_utils->sts_target(),
+                  WrapRunnableRet(dtls_, &TransportLayerDtls::GetCipherSuite,
+                                  &cipher, &rv));
+
+    if (NS_FAILED(rv)) {
+      return TLS_NULL_WITH_NULL_NULL; // i.e., not good
+    }
+    return cipher;
+  }
+
+  uint16_t srtpCipher() const {
+    nsresult rv;
+    uint16_t cipher;
+    RUN_ON_THREAD(test_utils->sts_target(),
+                  WrapRunnableRet(dtls_, &TransportLayerDtls::GetSrtpCipher,
+                                  &cipher, &rv));
+    if (NS_FAILED(rv)) {
+      return 0; // the SRTP equivalent of TLS_NULL_WITH_NULL_NULL
+    }
+    return cipher;
+  }
+
  private:
   std::string name_;
   nsCOMPtr<nsIEventTarget> target_;
   size_t received_;
     mozilla::RefPtr<TransportFlow> flow_;
   TransportLayerLoopback *loopback_;
   TransportLayerLogging *logging_;
   TransportLayerLossy *lossy_;
@@ -580,16 +614,21 @@ class TransportTest : public ::testing::
     nsresult rv;
     target_ = do_GetService(NS_SOCKETTRANSPORTSERVICE_CONTRACTID, &rv);
     ASSERT_TRUE(NS_SUCCEEDED(rv));
 
     p1_ = new TransportTestPeer(target_, "P1");
     p2_ = new TransportTestPeer(target_, "P2");
   }
 
+  void SetupSrtp() {
+    p1_->SetupSrtp();
+    p2_->SetupSrtp();
+  }
+
   void SetDtlsPeer(int digests = 1, unsigned int damage = 0) {
     p1_->SetDtlsPeer(p2_, digests, damage);
     p2_->SetDtlsPeer(p1_, digests, damage);
   }
 
   void SetDtlsAllowAll() {
     p1_->SetDtlsAllowAll();
     p2_->SetDtlsAllowAll();
@@ -655,18 +694,40 @@ class TransportTest : public ::testing::
 
 TEST_F(TransportTest, TestNoDtlsVerificationSettings) {
   ConnectSocketExpectFail();
 }
 
 TEST_F(TransportTest, TestConnect) {
   SetDtlsPeer();
   ConnectSocket();
+
+  // check that everything was negotiated properly
+  ASSERT_EQ(TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, p1_->cipherSuite());
+  ASSERT_EQ(TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, p2_->cipherSuite());
+
+  // no SRTP on this one
+  ASSERT_EQ(0, p1_->srtpCipher());
+  ASSERT_EQ(0, p2_->srtpCipher());
 }
 
+TEST_F(TransportTest, TestConnectSrtp) {
+  SetupSrtp();
+  SetDtlsPeer();
+  ConnectSocket();
+
+  ASSERT_EQ(TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, p1_->cipherSuite());
+  ASSERT_EQ(TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, p2_->cipherSuite());
+
+  // SRTP is on
+  ASSERT_EQ(SRTP_AES128_CM_HMAC_SHA1_80, p1_->srtpCipher());
+  ASSERT_EQ(SRTP_AES128_CM_HMAC_SHA1_80, p2_->srtpCipher());
+}
+
+
 TEST_F(TransportTest, TestConnectDestroyFlowsMainThread) {
   SetDtlsPeer();
   ConnectSocket();
   DestroyPeerFlows();
 }
 
 TEST_F(TransportTest, TestConnectAllowAll) {
   SetDtlsAllowAll();
--- a/media/mtransport/transportlayer.h
+++ b/media/mtransport/transportlayer.h
@@ -87,30 +87,30 @@ class TransportLayer : public sigslot::h
   const std::string& flow_id() const {
     return flow_id_;
   }
 
  protected:
   virtual void WasInserted() {}
   virtual void SetState(State state, const char *file, unsigned line);
   // Check if we are on the right thread
-  void CheckThread() {
+  void CheckThread() const {
     NS_ABORT_IF_FALSE(CheckThreadInt(), "Wrong thread");
   }
 
   Mode mode_;
   State state_;
   std::string flow_id_;
   TransportLayer *downward_; // The next layer in the stack
   nsCOMPtr<nsIEventTarget> target_;
 
  private:
   DISALLOW_COPY_ASSIGN(TransportLayer);
 
-  bool CheckThreadInt() {
+  bool CheckThreadInt() const {
     bool on;
 
     if (!target_)  // OK if no thread set.
       return true;
 
     NS_ENSURE_SUCCESS(target_->IsOnCurrentThread(&on), false);
     NS_ENSURE_TRUE(on, false);
 
--- a/media/mtransport/transportlayerdtls.cpp
+++ b/media/mtransport/transportlayerdtls.cpp
@@ -684,16 +684,35 @@ bool TransportLayerDtls::SetupCipherSuit
                   "Unable to disable suite: " << DisabledCiphers[i]);
         return false;
       }
     }
   }
   return true;
 }
 
+nsresult TransportLayerDtls::GetCipherSuite(uint16_t* cipherSuite) const {
+  CheckThread();
+  if (!cipherSuite) {
+    MOZ_MTLOG(ML_ERROR, LAYER_INFO << "GetCipherSuite passed a nullptr");
+    return NS_ERROR_NULL_POINTER;
+  }
+  if (state_ != TS_OPEN) {
+    return NS_ERROR_NOT_AVAILABLE;
+  }
+  SSLChannelInfo info;
+  SECStatus rv = SSL_GetChannelInfo(ssl_fd_, &info, sizeof(info));
+  if (rv != SECSuccess) {
+    MOZ_MTLOG(ML_NOTICE, LAYER_INFO << "GetCipherSuite can't get channel info");
+    return NS_ERROR_FAILURE;
+  }
+  *cipherSuite = info.cipherSuite;
+  return NS_OK;
+}
+
 void TransportLayerDtls::StateChange(TransportLayer *layer, State state) {
   if (state <= state_) {
     MOZ_MTLOG(ML_ERROR, "Lower layer state is going backwards from ours");
     TL_SET_STATE(TS_ERROR);
     return;
   }
 
   switch (state) {
@@ -895,17 +914,17 @@ SECStatus TransportLayerDtls::GetClientA
 
 nsresult TransportLayerDtls::SetSrtpCiphers(std::vector<uint16_t> ciphers) {
   // TODO: We should check these
   srtp_ciphers_ = ciphers;
 
   return NS_OK;
 }
 
-nsresult TransportLayerDtls::GetSrtpCipher(uint16_t *cipher) {
+nsresult TransportLayerDtls::GetSrtpCipher(uint16_t *cipher) const {
   CheckThread();
   SECStatus rv = SSL_GetSRTPCipher(ssl_fd_, cipher);
   if (rv != SECSuccess) {
     MOZ_MTLOG(ML_DEBUG, "No SRTP cipher negotiated");
     return NS_ERROR_FAILURE;
   }
 
   return NS_OK;
--- a/media/mtransport/transportlayerdtls.h
+++ b/media/mtransport/transportlayerdtls.h
@@ -68,18 +68,20 @@ class TransportLayerDtls : public Transp
   void SetIdentity(const RefPtr<DtlsIdentity>& identity) {
     identity_ = identity;
   }
   nsresult SetVerificationAllowAll();
   nsresult SetVerificationDigest(const std::string digest_algorithm,
                                  const unsigned char *digest_value,
                                  size_t digest_len);
 
+  nsresult GetCipherSuite(uint16_t* cipherSuite) const;
+
   nsresult SetSrtpCiphers(std::vector<uint16_t> ciphers);
-  nsresult GetSrtpCipher(uint16_t *cipher);
+  nsresult GetSrtpCipher(uint16_t *cipher) const;
 
   nsresult ExportKeyingMaterial(const std::string& label,
                                 bool use_context,
                                 const std::string& context,
                                 unsigned char *out,
                                 unsigned int outlen);
 
   const CERTCertificate *GetPeerCert() const {