gtests/ssl_gtest/tls_agent.h
author Martin Thomson <martin.thomson@gmail.com>
Mon, 09 Jul 2018 12:24:13 +1000
changeset 14455 ee357b00f2e6c44589dcd406097357888d59ef06
parent 14355 aa6678175aade961a3290e1edad69bf9b8548998
child 14907 403437c461fdd08f7a3a9dc7eba3c66e8c0c5ab9
permissions -rw-r--r--
Bug 1483129 - TLS 1.3 RFC version, r=ekr This retains the ability to negotiate draft versions of DTLS 1.3, but uses the final RFC version for TLS 1.3. This also refactors the handling of the downgrade sentinel. As we've discovered - to our dismay - some MitM boxes forward handshake messages when they shouldn't. This could result in triggering the downgrade sentinel. I've done two things here: - The server always sets the sentinel. It reduces the assumed version if it only supports a draft version though on the basis that the client might expect the full version. - The client has a new option SSL_ENABLE_HELLO_DOWNGRADE_CHECK which is disabled by default. The client will reject a handshake that appears to be a downgrade only when this is explicitly enabled. The client will allow an apparent downgrade to TLS 1.2 if it is running a draft version of TLS 1.3. The allowance for a draft version is now only effective for DTLS 1.3. Tests for version downgrade have been updated and enabled. These were rotten in a few ways, but nothing dramatic.

/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
/* vim: set ts=2 et sw=2 tw=80: */
/* This Source Code Form is subject to the terms of the Mozilla Public
 * License, v. 2.0. If a copy of the MPL was not distributed with this file,
 * You can obtain one at http://mozilla.org/MPL/2.0/. */

#ifndef tls_agent_h_
#define tls_agent_h_

#include "prio.h"
#include "ssl.h"

#include <functional>
#include <iostream>

#include "test_io.h"

#define GTEST_HAS_RTTI 0
#include "gtest/gtest.h"
#include "scoped_ptrs.h"

extern bool g_ssl_gtest_verbose;

namespace nss_test {

#define LOG(msg) std::cerr << role_str() << ": " << msg << std::endl
#define LOGV(msg)                      \
  do {                                 \
    if (g_ssl_gtest_verbose) LOG(msg); \
  } while (false)

enum SessionResumptionMode {
  RESUME_NONE = 0,
  RESUME_SESSIONID = 1,
  RESUME_TICKET = 2,
  RESUME_BOTH = RESUME_SESSIONID | RESUME_TICKET
};

class PacketFilter;
class TlsAgent;
class TlsCipherSpec;
struct TlsRecord;

const extern std::vector<SSLNamedGroup> kAllDHEGroups;
const extern std::vector<SSLNamedGroup> kECDHEGroups;
const extern std::vector<SSLNamedGroup> kFFDHEGroups;
const extern std::vector<SSLNamedGroup> kFasterDHEGroups;

// These functions are called from callbacks.  They use bare pointers because
// TlsAgent sets up the callback and it doesn't know who owns it.
typedef std::function<SECStatus(TlsAgent* agent, bool checksig, bool isServer)>
    AuthCertificateCallbackFunction;

typedef std::function<void(TlsAgent* agent)> HandshakeCallbackFunction;

typedef std::function<int32_t(TlsAgent* agent, const SECItem* srvNameArr,
                              PRUint32 srvNameArrSize)>
    SniCallbackFunction;

class TlsAgent : public PollTarget {
 public:
  enum Role { CLIENT, SERVER };
  enum State { STATE_INIT, STATE_CONNECTING, STATE_CONNECTED, STATE_ERROR };

  static const std::string kClient;     // the client key is sign only
  static const std::string kRsa2048;    // bigger sign and encrypt for either
  static const std::string kRsa8192;    // biggest sign and encrypt for either
  static const std::string kServerRsa;  // both sign and encrypt
  static const std::string kServerRsaSign;
  static const std::string kServerRsaPss;
  static const std::string kServerRsaDecrypt;
  static const std::string kServerEcdsa256;
  static const std::string kServerEcdsa384;
  static const std::string kServerEcdsa521;
  static const std::string kServerEcdhEcdsa;
  static const std::string kServerEcdhRsa;
  static const std::string kServerDsa;

  TlsAgent(const std::string& name, Role role, SSLProtocolVariant variant);
  virtual ~TlsAgent();

  void SetPeer(std::shared_ptr<TlsAgent>& peer) {
    adapter_->SetPeer(peer->adapter_);
  }

  void SetFilter(std::shared_ptr<PacketFilter> filter) {
    adapter_->SetPacketFilter(filter);
  }
  void ClearFilter() { adapter_->SetPacketFilter(nullptr); }

  void StartConnect(PRFileDesc* model = nullptr);
  void CheckKEA(SSLKEAType kea_type, SSLNamedGroup group,
                size_t kea_size = 0) const;
  void CheckOriginalKEA(SSLNamedGroup kea_group) const;
  void CheckAuthType(SSLAuthType auth_type,
                     SSLSignatureScheme sig_scheme) const;

  void DisableAllCiphers();
  void EnableCiphersByAuthType(SSLAuthType authType);
  void EnableCiphersByKeyExchange(SSLKEAType kea);
  void EnableGroupsByKeyExchange(SSLKEAType kea);
  void EnableGroupsByAuthType(SSLAuthType authType);
  void EnableSingleCipher(uint16_t cipher);

  void Handshake();
  // Marks the internal state as CONNECTING in anticipation of renegotiation.
  void PrepareForRenegotiate();
  // Prepares for renegotiation, then actually triggers it.
  void StartRenegotiate();
  static bool LoadCertificate(const std::string& name,
                              ScopedCERTCertificate* cert,
                              ScopedSECKEYPrivateKey* priv);
  bool ConfigServerCert(const std::string& name, bool updateKeyBits = false,
                        const SSLExtraServerCertData* serverCertData = nullptr);
  bool ConfigServerCertWithChain(const std::string& name);
  bool EnsureTlsSetup(PRFileDesc* modelSocket = nullptr);

  void SetupClientAuth();
  void RequestClientAuth(bool requireAuth);

  void SetOption(int32_t option, int value);
  void ConfigureSessionCache(SessionResumptionMode mode);
  void Set0RttEnabled(bool en);
  void SetFallbackSCSVEnabled(bool en);
  void SetVersionRange(uint16_t minver, uint16_t maxver);
  void GetVersionRange(uint16_t* minver, uint16_t* maxver);
  void CheckPreliminaryInfo();
  void ResetPreliminaryInfo();
  void SetExpectedVersion(uint16_t version);
  void SetServerKeyBits(uint16_t bits);
  void ExpectReadWriteError();
  void EnableFalseStart();
  void ExpectResumption();
  void SkipVersionChecks();
  void SetSignatureSchemes(const SSLSignatureScheme* schemes, size_t count);
  void EnableAlpn(const uint8_t* val, size_t len);
  void CheckAlpn(SSLNextProtoState expected_state,
                 const std::string& expected = "") const;
  void EnableSrtp();
  void CheckSrtp() const;
  void CheckErrorCode(int32_t expected) const;
  void WaitForErrorCode(int32_t expected, uint32_t delay) const;
  // Send data on the socket, encrypting it.
  void SendData(size_t bytes, size_t blocksize = 1024);
  void SendBuffer(const DataBuffer& buf);
  bool SendEncryptedRecord(const std::shared_ptr<TlsCipherSpec>& spec,
                           uint64_t seq, uint8_t ct, const DataBuffer& buf);
  // Send data directly to the underlying socket, skipping the TLS layer.
  void SendDirect(const DataBuffer& buf);
  void SendRecordDirect(const TlsRecord& record);
  void ReadBytes(size_t max = 16384U);
  void ResetSentBytes();  // Hack to test drops.
  void EnableExtendedMasterSecret();
  void CheckExtendedMasterSecret(bool expected);
  void CheckEarlyDataAccepted(bool expected);
  void SetDowngradeCheckVersion(uint16_t version);
  void CheckSecretsDestroyed();
  void ConfigNamedGroups(const std::vector<SSLNamedGroup>& groups);
  void DisableECDHEServerKeyReuse();
  bool GetPeerChainLength(size_t* count);
  void CheckCipherSuite(uint16_t cipher_suite);
  void SetResumptionTokenCallback();
  bool MaybeSetResumptionToken();
  void SetResumptionToken(const std::vector<uint8_t>& resumption_token) {
    resumption_token_ = resumption_token;
  }
  const std::vector<uint8_t>& GetResumptionToken() const {
    return resumption_token_;
  }
  void GetTokenInfo(ScopedSSLResumptionTokenInfo& token) {
    SECStatus rv = SSL_GetResumptionTokenInfo(
        resumption_token_.data(), resumption_token_.size(), token.get(),
        sizeof(SSLResumptionTokenInfo));
    ASSERT_EQ(SECSuccess, rv);
  }
  void SetResumptionCallbackCalled() { resumption_callback_called_ = true; }
  bool resumption_callback_called() const {
    return resumption_callback_called_;
  }

  const std::string& name() const { return name_; }

  Role role() const { return role_; }
  std::string role_str() const { return role_ == SERVER ? "server" : "client"; }

  SSLProtocolVariant variant() const { return variant_; }

  State state() const { return state_; }

  const CERTCertificate* peer_cert() const {
    return SSL_PeerCertificate(ssl_fd_.get());
  }

  const char* state_str() const { return state_str(state()); }

  static const char* state_str(State state) { return states[state]; }

  PRFileDesc* ssl_fd() const { return ssl_fd_.get(); }
  std::shared_ptr<DummyPrSocket>& adapter() { return adapter_; }

  bool is_compressed() const {
    return info_.compressionMethod != ssl_compression_null;
  }
  uint16_t server_key_bits() const { return server_key_bits_; }
  uint16_t min_version() const { return vrange_.min; }
  uint16_t max_version() const { return vrange_.max; }
  uint16_t version() const {
    EXPECT_EQ(STATE_CONNECTED, state_);
    return info_.protocolVersion;
  }

  bool cipher_suite(uint16_t* suite) const {
    if (state_ != STATE_CONNECTED) return false;

    *suite = info_.cipherSuite;
    return true;
  }

  std::string cipher_suite_name() const {
    if (state_ != STATE_CONNECTED) return "UNKNOWN";

    return csinfo_.cipherSuiteName;
  }

  std::vector<uint8_t> session_id() const {
    return std::vector<uint8_t>(info_.sessionID,
                                info_.sessionID + info_.sessionIDLength);
  }

  bool auth_type(SSLAuthType* a) const {
    if (state_ != STATE_CONNECTED) return false;

    *a = info_.authType;
    return true;
  }

  bool kea_type(SSLKEAType* k) const {
    if (state_ != STATE_CONNECTED) return false;

    *k = info_.keaType;
    return true;
  }

  size_t received_bytes() const { return recv_ctr_; }
  PRErrorCode error_code() const { return error_code_; }

  bool can_falsestart_hook_called() const {
    return can_falsestart_hook_called_;
  }

  void SetHandshakeCallback(HandshakeCallbackFunction handshake_callback) {
    handshake_callback_ = handshake_callback;
  }

  void SetAuthCertificateCallback(
      AuthCertificateCallbackFunction auth_certificate_callback) {
    auth_certificate_callback_ = auth_certificate_callback;
  }

  void SetSniCallback(SniCallbackFunction sni_callback) {
    sni_callback_ = sni_callback;
  }

  void ExpectReceiveAlert(uint8_t alert, uint8_t level = 0);
  void ExpectSendAlert(uint8_t alert, uint8_t level = 0);

  std::string alpn_value_to_use_ = "";

 private:
  const static char* states[];

  void SetState(State state);
  void ValidateCipherSpecs();

  // Dummy auth certificate hook.
  static SECStatus AuthCertificateHook(void* arg, PRFileDesc* fd,
                                       PRBool checksig, PRBool isServer) {
    TlsAgent* agent = reinterpret_cast<TlsAgent*>(arg);
    agent->CheckPreliminaryInfo();
    agent->auth_certificate_hook_called_ = true;
    if (agent->auth_certificate_callback_) {
      return agent->auth_certificate_callback_(agent, checksig ? true : false,
                                               isServer ? true : false);
    }
    return SECSuccess;
  }

  // Client auth certificate hook.
  static SECStatus ClientAuthenticated(void* arg, PRFileDesc* fd,
                                       PRBool checksig, PRBool isServer) {
    TlsAgent* agent = reinterpret_cast<TlsAgent*>(arg);
    EXPECT_TRUE(agent->expect_client_auth_);
    EXPECT_EQ(PR_TRUE, isServer);
    if (agent->auth_certificate_callback_) {
      return agent->auth_certificate_callback_(agent, checksig ? true : false,
                                               isServer ? true : false);
    }
    return SECSuccess;
  }

  static SECStatus GetClientAuthDataHook(void* self, PRFileDesc* fd,
                                         CERTDistNames* caNames,
                                         CERTCertificate** cert,
                                         SECKEYPrivateKey** privKey);

  static void ReadableCallback(PollTarget* self, Event event) {
    TlsAgent* agent = static_cast<TlsAgent*>(self);
    if (event == TIMER_EVENT) {
      agent->timer_handle_ = nullptr;
    }
    agent->ReadableCallback_int();
  }

  void ReadableCallback_int() {
    LOGV("Readable");
    switch (state_) {
      case STATE_CONNECTING:
        Handshake();
        break;
      case STATE_CONNECTED:
        ReadBytes();
        break;
      default:
        break;
    }
  }

  static PRInt32 SniHook(PRFileDesc* fd, const SECItem* srvNameArr,
                         PRUint32 srvNameArrSize, void* arg) {
    TlsAgent* agent = reinterpret_cast<TlsAgent*>(arg);
    agent->CheckPreliminaryInfo();
    agent->sni_hook_called_ = true;
    EXPECT_EQ(1UL, srvNameArrSize);
    if (agent->sni_callback_) {
      return agent->sni_callback_(agent, srvNameArr, srvNameArrSize);
    }
    return 0;  // First configuration.
  }

  static SECStatus CanFalseStartCallback(PRFileDesc* fd, void* arg,
                                         PRBool* canFalseStart) {
    TlsAgent* agent = reinterpret_cast<TlsAgent*>(arg);
    agent->CheckPreliminaryInfo();
    EXPECT_TRUE(agent->falsestart_enabled_);
    EXPECT_FALSE(agent->can_falsestart_hook_called_);
    agent->can_falsestart_hook_called_ = true;
    *canFalseStart = true;
    return SECSuccess;
  }

  void CheckAlert(bool sent, const SSLAlert* alert);

  static void AlertReceivedCallback(const PRFileDesc* fd, void* arg,
                                    const SSLAlert* alert) {
    reinterpret_cast<TlsAgent*>(arg)->CheckAlert(false, alert);
  }

  static void AlertSentCallback(const PRFileDesc* fd, void* arg,
                                const SSLAlert* alert) {
    reinterpret_cast<TlsAgent*>(arg)->CheckAlert(true, alert);
  }

  static void HandshakeCallback(PRFileDesc* fd, void* arg) {
    TlsAgent* agent = reinterpret_cast<TlsAgent*>(arg);
    agent->handshake_callback_called_ = true;
    agent->Connected();
    if (agent->handshake_callback_) {
      agent->handshake_callback_(agent);
    }
  }

  void DisableLameGroups();
  void ConfigStrongECGroups(bool en);
  void ConfigAllDHGroups(bool en);
  void CheckCallbacks() const;
  void Connected();

  const std::string name_;
  SSLProtocolVariant variant_;
  Role role_;
  uint16_t server_key_bits_;
  std::shared_ptr<DummyPrSocket> adapter_;
  ScopedPRFileDesc ssl_fd_;
  State state_;
  std::shared_ptr<Poller::Timer> timer_handle_;
  bool falsestart_enabled_;
  uint16_t expected_version_;
  uint16_t expected_cipher_suite_;
  bool expect_resumption_;
  bool expect_client_auth_;
  bool can_falsestart_hook_called_;
  bool sni_hook_called_;
  bool auth_certificate_hook_called_;
  uint8_t expected_received_alert_;
  uint8_t expected_received_alert_level_;
  uint8_t expected_sent_alert_;
  uint8_t expected_sent_alert_level_;
  bool handshake_callback_called_;
  bool resumption_callback_called_;
  SSLChannelInfo info_;
  SSLCipherSuiteInfo csinfo_;
  SSLVersionRange vrange_;
  PRErrorCode error_code_;
  size_t send_ctr_;
  size_t recv_ctr_;
  bool expect_readwrite_error_;
  HandshakeCallbackFunction handshake_callback_;
  AuthCertificateCallbackFunction auth_certificate_callback_;
  SniCallbackFunction sni_callback_;
  bool skip_version_checks_;
  std::vector<uint8_t> resumption_token_;
};

inline std::ostream& operator<<(std::ostream& stream,
                                const TlsAgent::State& state) {
  return stream << TlsAgent::state_str(state);
}

class TlsAgentTestBase : public ::testing::Test {
 public:
  static ::testing::internal::ParamGenerator<std::string> kTlsRolesAll;

  TlsAgentTestBase(TlsAgent::Role role, SSLProtocolVariant variant,
                   uint16_t version = 0)
      : agent_(nullptr),
        role_(role),
        variant_(variant),
        version_(version),
        sink_adapter_(new DummyPrSocket("sink", variant)) {}
  virtual ~TlsAgentTestBase() {}

  void SetUp();
  void TearDown();

  void ExpectAlert(uint8_t alert);

  static void MakeRecord(SSLProtocolVariant variant, uint8_t type,
                         uint16_t version, const uint8_t* buf, size_t len,
                         DataBuffer* out, uint64_t seq_num = 0);
  void MakeRecord(uint8_t type, uint16_t version, const uint8_t* buf,
                  size_t len, DataBuffer* out, uint64_t seq_num = 0) const;
  void MakeHandshakeMessage(uint8_t hs_type, const uint8_t* data, size_t hs_len,
                            DataBuffer* out, uint64_t seq_num = 0) const;
  void MakeHandshakeMessageFragment(uint8_t hs_type, const uint8_t* data,
                                    size_t hs_len, DataBuffer* out,
                                    uint64_t seq_num, uint32_t fragment_offset,
                                    uint32_t fragment_length) const;
  DataBuffer MakeCannedTls13ServerHello();
  static void MakeTrivialHandshakeRecord(uint8_t hs_type, size_t hs_len,
                                         DataBuffer* out);
  static inline TlsAgent::Role ToRole(const std::string& str) {
    return str == "CLIENT" ? TlsAgent::CLIENT : TlsAgent::SERVER;
  }

  void Init(const std::string& server_name = TlsAgent::kServerRsa);
  void Reset(const std::string& server_name = TlsAgent::kServerRsa);

 protected:
  void EnsureInit();
  void ProcessMessage(const DataBuffer& buffer, TlsAgent::State expected_state,
                      int32_t error_code = 0);

  std::shared_ptr<TlsAgent> agent_;
  TlsAgent::Role role_;
  SSLProtocolVariant variant_;
  uint16_t version_;
  // This adapter is here just to accept packets from this agent.
  std::shared_ptr<DummyPrSocket> sink_adapter_;
};

class TlsAgentTest
    : public TlsAgentTestBase,
      public ::testing::WithParamInterface<
          std::tuple<std::string, SSLProtocolVariant, uint16_t>> {
 public:
  TlsAgentTest()
      : TlsAgentTestBase(ToRole(std::get<0>(GetParam())),
                         std::get<1>(GetParam()), std::get<2>(GetParam())) {}
};

class TlsAgentTestClient : public TlsAgentTestBase,
                           public ::testing::WithParamInterface<
                               std::tuple<SSLProtocolVariant, uint16_t>> {
 public:
  TlsAgentTestClient()
      : TlsAgentTestBase(TlsAgent::CLIENT, std::get<0>(GetParam()),
                         std::get<1>(GetParam())) {}
};

class TlsAgentTestClient13 : public TlsAgentTestClient {};

class TlsAgentStreamTestClient : public TlsAgentTestBase {
 public:
  TlsAgentStreamTestClient()
      : TlsAgentTestBase(TlsAgent::CLIENT, ssl_variant_stream) {}
};

class TlsAgentStreamTestServer : public TlsAgentTestBase {
 public:
  TlsAgentStreamTestServer()
      : TlsAgentTestBase(TlsAgent::SERVER, ssl_variant_stream) {}
};

class TlsAgentDgramTestClient : public TlsAgentTestBase {
 public:
  TlsAgentDgramTestClient()
      : TlsAgentTestBase(TlsAgent::CLIENT, ssl_variant_datagram) {}
};

inline bool operator==(const SSLVersionRange& vr1, const SSLVersionRange& vr2) {
  return vr1.min == vr2.min && vr1.max == vr2.max;
}

}  // namespace nss_test

#endif