Bug 996237 - Unit testing cipher suite selection happy path. r=ekr
--- a/media/mtransport/test/transport_unittests.cpp
+++ b/media/mtransport/test/transport_unittests.cpp
@@ -11,16 +11,17 @@
#include <map>
#include "sigslot.h"
#include "logging.h"
#include "nspr.h"
#include "nss.h"
#include "ssl.h"
+#include "sslproto.h"
#include "nsThreadUtils.h"
#include "nsXPCOM.h"
#include "databuffer.h"
#include "dtlsidentity.h"
#include "nricectx.h"
#include "nricemediastream.h"
@@ -354,16 +355,24 @@ class TransportTestPeer : public sigslot
peer->fingerprint_len_);
ASSERT_TRUE(NS_SUCCEEDED(res));
mask <<= 1;
}
}
+ void SetupSrtp() {
+ // this mimics the setup we do elsewhere
+ std::vector<uint16_t> srtp_ciphers;
+ srtp_ciphers.push_back(SRTP_AES128_CM_HMAC_SHA1_80);
+ srtp_ciphers.push_back(SRTP_AES128_CM_HMAC_SHA1_32);
+
+ ASSERT_TRUE(NS_SUCCEEDED(dtls_->SetSrtpCiphers(srtp_ciphers)));
+ }
void ConnectSocket_s(TransportTestPeer *peer) {
nsresult res;
res = loopback_->Init();
ASSERT_EQ((nsresult)NS_OK, res);
loopback_->Connect(peer->loopback_);
@@ -529,16 +538,41 @@ class TransportTestPeer : public sigslot
}
bool failed() {
return state() == TransportLayer::TS_ERROR;
}
size_t received() { return received_; }
+ uint16_t cipherSuite() const {
+ nsresult rv;
+ uint16_t cipher;
+ RUN_ON_THREAD(test_utils->sts_target(),
+ WrapRunnableRet(dtls_, &TransportLayerDtls::GetCipherSuite,
+ &cipher, &rv));
+
+ if (NS_FAILED(rv)) {
+ return TLS_NULL_WITH_NULL_NULL; // i.e., not good
+ }
+ return cipher;
+ }
+
+ uint16_t srtpCipher() const {
+ nsresult rv;
+ uint16_t cipher;
+ RUN_ON_THREAD(test_utils->sts_target(),
+ WrapRunnableRet(dtls_, &TransportLayerDtls::GetSrtpCipher,
+ &cipher, &rv));
+ if (NS_FAILED(rv)) {
+ return 0; // the SRTP equivalent of TLS_NULL_WITH_NULL_NULL
+ }
+ return cipher;
+ }
+
private:
std::string name_;
nsCOMPtr<nsIEventTarget> target_;
size_t received_;
mozilla::RefPtr<TransportFlow> flow_;
TransportLayerLoopback *loopback_;
TransportLayerLogging *logging_;
TransportLayerLossy *lossy_;
@@ -580,16 +614,21 @@ class TransportTest : public ::testing::
nsresult rv;
target_ = do_GetService(NS_SOCKETTRANSPORTSERVICE_CONTRACTID, &rv);
ASSERT_TRUE(NS_SUCCEEDED(rv));
p1_ = new TransportTestPeer(target_, "P1");
p2_ = new TransportTestPeer(target_, "P2");
}
+ void SetupSrtp() {
+ p1_->SetupSrtp();
+ p2_->SetupSrtp();
+ }
+
void SetDtlsPeer(int digests = 1, unsigned int damage = 0) {
p1_->SetDtlsPeer(p2_, digests, damage);
p2_->SetDtlsPeer(p1_, digests, damage);
}
void SetDtlsAllowAll() {
p1_->SetDtlsAllowAll();
p2_->SetDtlsAllowAll();
@@ -655,18 +694,40 @@ class TransportTest : public ::testing::
TEST_F(TransportTest, TestNoDtlsVerificationSettings) {
ConnectSocketExpectFail();
}
TEST_F(TransportTest, TestConnect) {
SetDtlsPeer();
ConnectSocket();
+
+ // check that everything was negotiated properly
+ ASSERT_EQ(TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, p1_->cipherSuite());
+ ASSERT_EQ(TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, p2_->cipherSuite());
+
+ // no SRTP on this one
+ ASSERT_EQ(0, p1_->srtpCipher());
+ ASSERT_EQ(0, p2_->srtpCipher());
}
+TEST_F(TransportTest, TestConnectSrtp) {
+ SetupSrtp();
+ SetDtlsPeer();
+ ConnectSocket();
+
+ ASSERT_EQ(TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, p1_->cipherSuite());
+ ASSERT_EQ(TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, p2_->cipherSuite());
+
+ // SRTP is on
+ ASSERT_EQ(SRTP_AES128_CM_HMAC_SHA1_80, p1_->srtpCipher());
+ ASSERT_EQ(SRTP_AES128_CM_HMAC_SHA1_80, p2_->srtpCipher());
+}
+
+
TEST_F(TransportTest, TestConnectDestroyFlowsMainThread) {
SetDtlsPeer();
ConnectSocket();
DestroyPeerFlows();
}
TEST_F(TransportTest, TestConnectAllowAll) {
SetDtlsAllowAll();
--- a/media/mtransport/transportlayer.h
+++ b/media/mtransport/transportlayer.h
@@ -87,30 +87,30 @@ class TransportLayer : public sigslot::h
const std::string& flow_id() const {
return flow_id_;
}
protected:
virtual void WasInserted() {}
virtual void SetState(State state, const char *file, unsigned line);
// Check if we are on the right thread
- void CheckThread() {
+ void CheckThread() const {
NS_ABORT_IF_FALSE(CheckThreadInt(), "Wrong thread");
}
Mode mode_;
State state_;
std::string flow_id_;
TransportLayer *downward_; // The next layer in the stack
nsCOMPtr<nsIEventTarget> target_;
private:
DISALLOW_COPY_ASSIGN(TransportLayer);
- bool CheckThreadInt() {
+ bool CheckThreadInt() const {
bool on;
if (!target_) // OK if no thread set.
return true;
NS_ENSURE_SUCCESS(target_->IsOnCurrentThread(&on), false);
NS_ENSURE_TRUE(on, false);
--- a/media/mtransport/transportlayerdtls.cpp
+++ b/media/mtransport/transportlayerdtls.cpp
@@ -684,16 +684,35 @@ bool TransportLayerDtls::SetupCipherSuit
"Unable to disable suite: " << DisabledCiphers[i]);
return false;
}
}
}
return true;
}
+nsresult TransportLayerDtls::GetCipherSuite(uint16_t* cipherSuite) const {
+ CheckThread();
+ if (!cipherSuite) {
+ MOZ_MTLOG(ML_ERROR, LAYER_INFO << "GetCipherSuite passed a nullptr");
+ return NS_ERROR_NULL_POINTER;
+ }
+ if (state_ != TS_OPEN) {
+ return NS_ERROR_NOT_AVAILABLE;
+ }
+ SSLChannelInfo info;
+ SECStatus rv = SSL_GetChannelInfo(ssl_fd_, &info, sizeof(info));
+ if (rv != SECSuccess) {
+ MOZ_MTLOG(ML_NOTICE, LAYER_INFO << "GetCipherSuite can't get channel info");
+ return NS_ERROR_FAILURE;
+ }
+ *cipherSuite = info.cipherSuite;
+ return NS_OK;
+}
+
void TransportLayerDtls::StateChange(TransportLayer *layer, State state) {
if (state <= state_) {
MOZ_MTLOG(ML_ERROR, "Lower layer state is going backwards from ours");
TL_SET_STATE(TS_ERROR);
return;
}
switch (state) {
@@ -895,17 +914,17 @@ SECStatus TransportLayerDtls::GetClientA
nsresult TransportLayerDtls::SetSrtpCiphers(std::vector<uint16_t> ciphers) {
// TODO: We should check these
srtp_ciphers_ = ciphers;
return NS_OK;
}
-nsresult TransportLayerDtls::GetSrtpCipher(uint16_t *cipher) {
+nsresult TransportLayerDtls::GetSrtpCipher(uint16_t *cipher) const {
CheckThread();
SECStatus rv = SSL_GetSRTPCipher(ssl_fd_, cipher);
if (rv != SECSuccess) {
MOZ_MTLOG(ML_DEBUG, "No SRTP cipher negotiated");
return NS_ERROR_FAILURE;
}
return NS_OK;
--- a/media/mtransport/transportlayerdtls.h
+++ b/media/mtransport/transportlayerdtls.h
@@ -68,18 +68,20 @@ class TransportLayerDtls : public Transp
void SetIdentity(const RefPtr<DtlsIdentity>& identity) {
identity_ = identity;
}
nsresult SetVerificationAllowAll();
nsresult SetVerificationDigest(const std::string digest_algorithm,
const unsigned char *digest_value,
size_t digest_len);
+ nsresult GetCipherSuite(uint16_t* cipherSuite) const;
+
nsresult SetSrtpCiphers(std::vector<uint16_t> ciphers);
- nsresult GetSrtpCipher(uint16_t *cipher);
+ nsresult GetSrtpCipher(uint16_t *cipher) const;
nsresult ExportKeyingMaterial(const std::string& label,
bool use_context,
const std::string& context,
unsigned char *out,
unsigned int outlen);
const CERTCertificate *GetPeerCert() const {