--- a/external_tests/ssl_gtest/databuffer.h
+++ b/external_tests/ssl_gtest/databuffer.h
@@ -59,43 +59,46 @@ class DataBuffer {
} else {
assert(len == 0);
data_ = nullptr;
len_ = 0;
}
}
// Write will do a new allocation and expand the size of the buffer if needed.
- void Write(size_t index, const uint8_t* val, size_t count) {
+ // Returns the offset of the end of the write.
+ size_t Write(size_t index, const uint8_t* val, size_t count) {
if (index + count > len_) {
size_t newlen = index + count;
uint8_t* tmp = new uint8_t[newlen]; // Always > 0.
memcpy(static_cast<void*>(tmp),
static_cast<const void*>(data_), len_);
if (index > len_) {
memset(static_cast<void*>(tmp + len_), 0, index - len_);
}
delete[] data_;
data_ = tmp;
len_ = newlen;
}
memcpy(static_cast<void*>(data_ + index),
static_cast<const void*>(val), count);
+ return index + count;
}
- void Write(size_t index, const DataBuffer& buf) {
- Write(index, buf.data(), buf.len());
+ size_t Write(size_t index, const DataBuffer& buf) {
+ return Write(index, buf.data(), buf.len());
}
// Write an integer, also performing host-to-network order conversion.
- void Write(size_t index, uint32_t val, size_t count) {
+ // Returns the offset of the end of the write.
+ size_t Write(size_t index, uint32_t val, size_t count) {
assert(count <= sizeof(uint32_t));
uint32_t nvalue = htonl(val);
auto* addr = reinterpret_cast<const uint8_t*>(&nvalue);
- Write(index, addr + sizeof(uint32_t) - count, count);
+ return Write(index, addr + sizeof(uint32_t) - count, count);
}
// This can't use the same trick as Write(), since we might be reading from a
// smaller data source.
bool Read(size_t index, size_t count, uint32_t* val) const {
assert(count < sizeof(uint32_t));
assert(val);
if ((index > len()) || (count > (len() - index))) {
--- a/external_tests/ssl_gtest/ssl_extension_unittest.cc
+++ b/external_tests/ssl_gtest/ssl_extension_unittest.cc
@@ -12,265 +12,264 @@
#include "tls_parser.h"
#include "tls_filter.h"
#include "tls_connect.h"
namespace nss_test {
class TlsExtensionFilter : public TlsHandshakeFilter {
protected:
- virtual bool FilterHandshake(uint16_t version, uint8_t handshake_type,
- const DataBuffer& input, DataBuffer* output) {
- if (handshake_type == kTlsHandshakeClientHello) {
+ virtual PacketFilter::Action FilterHandshake(
+ const HandshakeHeader& header,
+ const DataBuffer& input, DataBuffer* output) {
+ if (header.handshake_type() == kTlsHandshakeClientHello) {
TlsParser parser(input);
- if (!FindClientHelloExtensions(parser, version)) {
- return false;
+ if (!FindClientHelloExtensions(&parser, header)) {
+ return KEEP;
}
- return FilterExtensions(parser, input, output);
+ return FilterExtensions(&parser, input, output);
}
- if (handshake_type == kTlsHandshakeServerHello) {
+ if (header.handshake_type() == kTlsHandshakeServerHello) {
TlsParser parser(input);
- if (!FindServerHelloExtensions(parser, version)) {
- return false;
+ if (!FindServerHelloExtensions(&parser, header.version())) {
+ return KEEP;
}
- return FilterExtensions(parser, input, output);
+ return FilterExtensions(&parser, input, output);
}
- return false;
+ return KEEP;
}
- virtual bool FilterExtension(uint16_t extension_type,
- const DataBuffer& input, DataBuffer* output) = 0;
+ virtual PacketFilter::Action FilterExtension(uint16_t extension_type,
+ const DataBuffer& input,
+ DataBuffer* output) = 0;
public:
- static bool FindClientHelloExtensions(TlsParser& parser, uint16_t version) {
- if (!parser.Skip(2 + 32)) { // version + random
+ static bool FindClientHelloExtensions(TlsParser* parser, const Versioned& header) {
+ if (!parser->Skip(2 + 32)) { // version + random
return false;
}
- if (!parser.SkipVariable(1)) { // session ID
+ if (!parser->SkipVariable(1)) { // session ID
return false;
}
- if (IsDtls(version) && !parser.SkipVariable(1)) { // DTLS cookie
+ if (header.is_dtls() && !parser->SkipVariable(1)) { // DTLS cookie
return false;
}
- if (!parser.SkipVariable(2)) { // cipher suites
+ if (!parser->SkipVariable(2)) { // cipher suites
return false;
}
- if (!parser.SkipVariable(1)) { // compression methods
+ if (!parser->SkipVariable(1)) { // compression methods
return false;
}
return true;
}
- static bool FindServerHelloExtensions(TlsParser& parser, uint16_t version) {
- if (!parser.Skip(2 + 32)) { // version + random
+ static bool FindServerHelloExtensions(TlsParser* parser, uint16_t version) {
+ if (!parser->Skip(2 + 32)) { // version + random
return false;
}
- if (!parser.SkipVariable(1)) { // session ID
+ if (!parser->SkipVariable(1)) { // session ID
return false;
}
- if (!parser.Skip(2)) { // cipher suite
+ if (!parser->Skip(2)) { // cipher suite
return false;
}
if (NormalizeTlsVersion(version) <= SSL_LIBRARY_VERSION_TLS_1_2) {
- if (!parser.Skip(1)) { // compression method
+ if (!parser->Skip(1)) { // compression method
return false;
}
}
return true;
}
private:
- bool FilterExtensions(TlsParser& parser,
- const DataBuffer& input, DataBuffer* output) {
- size_t length_offset = parser.consumed();
+ PacketFilter::Action FilterExtensions(TlsParser* parser,
+ const DataBuffer& input,
+ DataBuffer* output) {
+ size_t length_offset = parser->consumed();
uint32_t all_extensions;
- if (!parser.Read(&all_extensions, 2)) {
- return false; // no extensions, odd but OK
+ if (!parser->Read(&all_extensions, 2)) {
+ return KEEP; // no extensions, odd but OK
}
- if (all_extensions != parser.remaining()) {
- return false; // malformed
+ if (all_extensions != parser->remaining()) {
+ return KEEP; // malformed
}
bool changed = false;
// Write out the start of the message.
output->Allocate(input.len());
- output->Write(0, input.data(), parser.consumed());
- size_t output_offset = parser.consumed();
+ size_t offset = output->Write(0, input.data(), parser->consumed());
- while (parser.remaining()) {
+ while (parser->remaining()) {
uint32_t extension_type;
- if (!parser.Read(&extension_type, 2)) {
- return false; // malformed
+ if (!parser->Read(&extension_type, 2)) {
+ return KEEP; // malformed
+ }
+
+ DataBuffer extension;
+ if (!parser->ReadVariable(&extension, 2)) {
+ return KEEP; // malformed
}
- // Copy extension type.
- output->Write(output_offset, extension_type, 2);
+ DataBuffer filtered;
+ PacketFilter::Action action = FilterExtension(extension_type, extension,
+ &filtered);
+ if (action == DROP) {
+ changed = true;
+ std::cerr << "extension drop: " << extension << std::endl;
+ continue;
+ }
- DataBuffer extension;
- if (!parser.ReadVariable(&extension, 2)) {
- return false; // malformed
+ const DataBuffer* source = &extension;
+ if (action == CHANGE) {
+ EXPECT_GT(0x10000, filtered.len());
+ changed = true;
+ std::cerr << "extension old: " << extension << std::endl;
+ std::cerr << "extension new: " << filtered << std::endl;
+ source = &filtered;
}
- output_offset = ApplyFilter(static_cast<uint16_t>(extension_type), extension,
- output, output_offset + 2, &changed);
+
+ // Write out extension.
+ offset = output->Write(offset, extension_type, 2);
+ offset = output->Write(offset, source->len(), 2);
+ offset = output->Write(offset, *source);
}
- output->Truncate(output_offset);
+ output->Truncate(offset);
if (changed) {
size_t newlen = output->len() - length_offset - 2;
+ EXPECT_GT(0x10000, newlen);
if (newlen >= 0x10000) {
- return false; // bad: size increased too much
+ return KEEP; // bad: size increased too much
}
output->Write(length_offset, newlen, 2);
+ return CHANGE;
}
- return changed;
- }
-
- size_t ApplyFilter(uint16_t extension_type, const DataBuffer& extension,
- DataBuffer* output, size_t offset, bool* changed) {
- const DataBuffer* source = &extension;
- DataBuffer filtered;
- if (FilterExtension(extension_type, extension, &filtered) &&
- filtered.len() < 0x10000) {
- *changed = true;
- std::cerr << "extension old: " << extension << std::endl;
- std::cerr << "extension new: " << filtered << std::endl;
- source = &filtered;
- }
-
- output->Write(offset, source->len(), 2);
- output->Write(offset + 2, *source);
- return offset + 2 + source->len();
+ return KEEP;
}
};
class TlsExtensionTruncator : public TlsExtensionFilter {
public:
TlsExtensionTruncator(uint16_t extension, size_t length)
: extension_(extension), length_(length) {}
- virtual bool FilterExtension(uint16_t extension_type,
- const DataBuffer& input, DataBuffer* output) {
+ virtual PacketFilter::Action FilterExtension(
+ uint16_t extension_type, const DataBuffer& input, DataBuffer* output) {
if (extension_type != extension_) {
- return false;
+ return KEEP;
}
if (input.len() <= length_) {
- return false;
+ return KEEP;
}
output->Assign(input.data(), length_);
- return true;
+ return CHANGE;
}
private:
uint16_t extension_;
size_t length_;
};
class TlsExtensionDamager : public TlsExtensionFilter {
public:
TlsExtensionDamager(uint16_t extension, size_t index)
: extension_(extension), index_(index) {}
- virtual bool FilterExtension(uint16_t extension_type,
- const DataBuffer& input, DataBuffer* output) {
+ virtual PacketFilter::Action FilterExtension(
+ uint16_t extension_type, const DataBuffer& input, DataBuffer* output) {
if (extension_type != extension_) {
- return false;
+ return KEEP;
}
*output = input;
output->data()[index_] += 73; // Increment selected for maximum damage
- return true;
+ return CHANGE;
}
private:
uint16_t extension_;
size_t index_;
};
class TlsExtensionReplacer : public TlsExtensionFilter {
public:
TlsExtensionReplacer(uint16_t extension, const DataBuffer& data)
: extension_(extension), data_(data) {}
- virtual bool FilterExtension(uint16_t extension_type,
- const DataBuffer& input, DataBuffer* output) {
+ virtual PacketFilter::Action FilterExtension(
+ uint16_t extension_type, const DataBuffer& input, DataBuffer* output) {
if (extension_type != extension_) {
- return false;
+ return KEEP;
}
*output = data_;
- return true;
+ return CHANGE;
}
private:
const uint16_t extension_;
const DataBuffer data_;
};
class TlsExtensionInjector : public TlsHandshakeFilter {
public:
TlsExtensionInjector(uint16_t ext, DataBuffer& data)
: extension_(ext), data_(data) {}
- virtual bool FilterHandshake(uint16_t version, uint8_t handshake_type,
- const DataBuffer& input, DataBuffer* output) {
+ virtual PacketFilter::Action FilterHandshake(
+ const HandshakeHeader& header,
+ const DataBuffer& input, DataBuffer* output) {
size_t offset;
- if (handshake_type == kTlsHandshakeClientHello) {
+ if (header.handshake_type() == kTlsHandshakeClientHello) {
TlsParser parser(input);
- if (!TlsExtensionFilter::FindClientHelloExtensions(parser, version)) {
- return false;
+ if (!TlsExtensionFilter::FindClientHelloExtensions(&parser, header)) {
+ return KEEP;
}
offset = parser.consumed();
- } else if (handshake_type == kTlsHandshakeServerHello) {
+ } else if (header.handshake_type() == kTlsHandshakeServerHello) {
TlsParser parser(input);
- if (!TlsExtensionFilter::FindServerHelloExtensions(parser, version)) {
- return false;
+ if (!TlsExtensionFilter::FindServerHelloExtensions(&parser, header.version())) {
+ return KEEP;
}
offset = parser.consumed();
} else {
- return false;
+ return KEEP;
}
*output = input;
- std::cerr << "Pre:" << input << std::endl;
- std::cerr << "Lof:" << offset << std::endl;
-
// Increase the size of the extensions.
uint16_t* len_addr = reinterpret_cast<uint16_t*>(output->data() + offset);
- std::cerr << "L-p:" << ntohs(*len_addr) << std::endl;
*len_addr = htons(ntohs(*len_addr) + data_.len() + 4);
- std::cerr << "L-i:" << ntohs(*len_addr) << std::endl;
-
// Insert the extension type and length.
DataBuffer type_length;
type_length.Allocate(4);
type_length.Write(0, extension_, 2);
type_length.Write(2, data_.len(), 2);
output->Splice(type_length, offset + 2);
// Insert the payload.
output->Splice(data_, offset + 6);
- std::cerr << "Aft:" << *output << std::endl;
- return true;
+ return CHANGE;
}
private:
const uint16_t extension_;
const DataBuffer data_;
};
class TlsExtensionCapture : public TlsExtensionFilter {
public:
TlsExtensionCapture(uint16_t ext)
: extension_(ext), data_() {}
- virtual bool FilterExtension(uint16_t extension_type,
- const DataBuffer& input, DataBuffer* output) {
+ virtual PacketFilter::Action FilterExtension(
+ uint16_t extension_type, const DataBuffer& input, DataBuffer* output) {
if (extension_type == extension_) {
data_.Assign(input);
}
- return false;
+ return KEEP;
}
const DataBuffer& extension() const { return data_; }
private:
const uint16_t extension_;
DataBuffer data_;
};
@@ -623,20 +622,24 @@ TEST_P(TlsExtensionTest12Plus, Signature
*/
// Helper class - stores signed certificate timestamps as provided
// by the relevant callbacks on the client.
class SignedCertificateTimestampsExtractor {
public:
SignedCertificateTimestampsExtractor(TlsAgent& client) {
client.SetAuthCertificateCallback(
- [&](TlsAgent& agent, PRBool checksig, PRBool isServer) {
+ [&](TlsAgent& agent, PRBool checksig, PRBool isServer) -> SECStatus {
const SECItem *scts = SSL_PeerSignedCertTimestamps(agent.ssl_fd());
- ASSERT_TRUE(scts);
+ EXPECT_TRUE(scts);
+ if (!scts) {
+ return SECFailure;
+ }
auth_timestamps_.reset(new DataBuffer(scts->data, scts->len));
+ return SECSuccess;
}
);
client.SetHandshakeCallback(
[&](TlsAgent& agent) {
const SECItem *scts = SSL_PeerSignedCertTimestamps(agent.ssl_fd());
ASSERT_TRUE(scts);
handshake_timestamps_.reset(new DataBuffer(scts->data, scts->len));
}
--- a/external_tests/ssl_gtest/ssl_loopback_unittest.cc
+++ b/external_tests/ssl_gtest/ssl_loopback_unittest.cc
@@ -3,16 +3,17 @@
/* This Source Code Form is subject to the terms of the Mozilla Public
* License, v. 2.0. If a copy of the MPL was not distributed with this file,
* You can obtain one at http://mozilla.org/MPL/2.0/. */
#include "ssl.h"
#include "sslerr.h"
#include "sslproto.h"
#include <memory>
+#include <functional>
extern "C" {
// This is not something that should make you happy.
#include "libssl_internals.h"
}
#include "tls_parser.h"
#include "tls_filter.h"
@@ -42,43 +43,45 @@ uint8_t kBogusClientKeyExchange[] = {
};
// When we see the ClientKeyExchange from |client|, increment the
// ClientHelloVersion on |server|.
class TlsInspectorClientHelloVersionChanger : public TlsHandshakeFilter {
public:
TlsInspectorClientHelloVersionChanger(TlsAgent* server) : server_(server) {}
- virtual bool FilterHandshake(uint16_t version, uint8_t handshake_type,
- const DataBuffer& input, DataBuffer* output) {
- if (handshake_type == kTlsHandshakeClientKeyExchange) {
+ virtual PacketFilter::Action FilterHandshake(
+ const HandshakeHeader& header,
+ const DataBuffer& input, DataBuffer* output) {
+ if (header.handshake_type() == kTlsHandshakeClientKeyExchange) {
EXPECT_EQ(
SECSuccess,
SSLInt_IncrementClientHandshakeVersion(server_->ssl_fd()));
}
- return false;
+ return KEEP;
}
private:
TlsAgent* server_;
};
// Set the version number in the ClientHello.
class TlsInspectorClientHelloVersionSetter : public TlsHandshakeFilter {
public:
TlsInspectorClientHelloVersionSetter(uint16_t version) : version_(version) {}
- virtual bool FilterHandshake(uint16_t version, uint8_t handshake_type,
- const DataBuffer& input, DataBuffer* output) {
- if (handshake_type == kTlsHandshakeClientHello) {
+ virtual PacketFilter::Action FilterHandshake(
+ const HandshakeHeader& header,
+ const DataBuffer& input, DataBuffer* output) {
+ if (header.handshake_type() == kTlsHandshakeClientHello) {
*output = input;
output->Write(0, version_, 2);
- return true;
+ return CHANGE;
}
- return false;
+ return KEEP;
}
private:
uint16_t version_;
};
class TlsServerKeyExchangeEcdhe {
public:
@@ -118,16 +121,17 @@ TEST_P(TlsConnectGeneric, ConnectEcdsa)
ResetEcdsa();
Connect();
CheckKeys(ssl_kea_ecdh, ssl_auth_ecdsa);
}
TEST_P(TlsConnectGeneric, ConnectFalseStart) {
client_->EnableFalseStart();
Connect();
+ SendReceive();
}
TEST_P(TlsConnectGeneric, ConnectResumed) {
ConfigureSessionCache(RESUME_SESSIONID, RESUME_SESSIONID);
Connect();
ResetRsa();
ExpectResumption(RESUME_SESSIONID);
@@ -805,16 +809,125 @@ TEST_F(TlsConnectTest, TestFallbackFromT
SSL_LIBRARY_VERSION_TLS_1_2);
server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1,
SSL_LIBRARY_VERSION_TLS_1_3);
ConnectExpectFail();
ASSERT_EQ(SSL_ERROR_RX_MALFORMED_SERVER_HELLO, client_->error_code());
}
#endif
+class BeforeFinished : public TlsRecordFilter {
+ private:
+ enum HandshakeState {
+ BEFORE_CCS,
+ AFTER_CCS,
+ DONE
+ };
+ typedef std::function<void(void)> VoidFunction;
+
+ public:
+ BeforeFinished(TlsAgent* client, TlsAgent* server,
+ VoidFunction before_ccs, VoidFunction before_finished)
+ : client_(client),
+ server_(server),
+ before_ccs_(before_ccs),
+ before_finished_(before_finished),
+ state_(BEFORE_CCS) {}
+
+ protected:
+ virtual PacketFilter::Action FilterRecord(
+ const RecordHeader& header, const DataBuffer& body, DataBuffer* out) {
+ switch (state_) {
+ case BEFORE_CCS:
+ // Awaken when we see the CCS.
+ if (header.content_type() == kTlsChangeCipherSpecType) {
+ before_ccs_();
+
+ // Write the CCS out as a separate write, so that we can make
+ // progress. Ordinarily, libssl sends the CCS and Finished together,
+ // but that means that they both get processed together.
+ DataBuffer ccs;
+ header.Write(&ccs, 0, body);
+ server_->SendDirect(ccs);
+ ForceRead();
+ state_ = AFTER_CCS;
+ // Request that the original record be dropped by the filter.
+ return DROP;
+ }
+ break;
+
+ case AFTER_CCS:
+ EXPECT_EQ(kTlsHandshakeType, header.content_type());
+ // This could check that data contains a Finished message, but it's
+ // encrypted, so that's too much extra work.
+
+ before_finished_();
+ state_ = DONE;
+ break;
+
+ case DONE:
+ break;
+ }
+ return KEEP;
+ }
+
+ private:
+ void ForceRead() {
+ // Read from the socket to get libssl to process the handshake messages that
+ // were sent from the server up until now.
+ uint8_t block[10];
+ int32_t rv = PR_Read(client_->ssl_fd(), block, sizeof(block));
+ // Expect a blocking error here, since the handshake shouldn't have completed.
+ EXPECT_GT(0, rv);
+ EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PR_GetError());
+ }
+
+ TlsAgent* client_;
+ TlsAgent* server_;
+ VoidFunction before_ccs_;
+ VoidFunction before_finished_;
+ HandshakeState state_;
+};
+
+// TODO Pre13
+TEST_P(TlsConnectGeneric, ClientWriteBetweenCCSAndFinishedWithFalseStart) {
+ client_->EnableFalseStart();
+ server_->SetPacketFilter(new BeforeFinished(client_, server_, [this]() {
+ EXPECT_TRUE(client_->can_falsestart_hook_called());
+ }, [this]() {
+ // Write something, which used to fail: bug 1235366.
+ client_->SendData(10);
+ }));
+
+ Connect();
+ server_->SendData(10);
+ Receive(10);
+}
+
+TEST_P(TlsConnectGeneric, AuthCompleteBeforeFinishedWithFalseStart) {
+ client_->EnableFalseStart();
+ client_->SetAuthCertificateCallback(
+ [](TlsAgent&, PRBool, PRBool) -> SECStatus {
+ return SECWouldBlock;
+ });
+ server_->SetPacketFilter(new BeforeFinished(client_, server_, []() {
+ // Do nothing before CCS
+ }, [this]() {
+ EXPECT_FALSE(client_->can_falsestart_hook_called());
+ // AuthComplete before Finished still enables false start.
+ EXPECT_EQ(SECSuccess, SSL_AuthCertificateComplete(client_->ssl_fd(), 0));
+ EXPECT_TRUE(client_->can_falsestart_hook_called());
+ client_->SendData(10);
+ }));
+
+ Connect();
+ server_->SendData(10);
+ Receive(10);
+}
+
INSTANTIATE_TEST_CASE_P(VariantsStream10, TlsConnectGeneric,
::testing::Combine(
TlsConnectTestBase::kTlsModesStream,
TlsConnectTestBase::kTlsV10));
INSTANTIATE_TEST_CASE_P(VariantsAll, TlsConnectGeneric,
::testing::Combine(
TlsConnectTestBase::kTlsModesAll,
TlsConnectTestBase::kTlsV11V12));
--- a/external_tests/ssl_gtest/ssl_skip_unittest.cc
+++ b/external_tests/ssl_gtest/ssl_skip_unittest.cc
@@ -24,62 +24,57 @@ class TlsHandshakeSkipFilter : public Tl
// A TLS record filter that skips handshake messages of the identified type.
TlsHandshakeSkipFilter(uint8_t handshake_type)
: handshake_type_(handshake_type),
skipped_(false) {}
protected:
// Takes a record; if it is a handshake record, it removes the first handshake
// message that is of handshake_type_ type.
- virtual bool FilterRecord(uint8_t content_type, uint16_t version,
- const DataBuffer& input, DataBuffer* output) {
- if (content_type != kTlsHandshakeType) {
- return false;
+ virtual PacketFilter::Action FilterRecord(
+ const RecordHeader& record_header,
+ const DataBuffer& input, DataBuffer* output) {
+
+ if (record_header.content_type() != kTlsHandshakeType) {
+ return KEEP;
}
size_t output_offset = 0U;
output->Allocate(input.len());
TlsParser parser(input);
while (parser.remaining()) {
size_t start = parser.consumed();
- uint8_t handshake_type;
- if (!parser.Read(&handshake_type)) {
- return false;
- }
- uint32_t length;
- if (!TlsHandshakeFilter::ReadLength(&parser, version, &length)) {
- return false;
+ TlsHandshakeFilter::HandshakeHeader header;
+ DataBuffer ignored;
+ if (!header.Parse(&parser, record_header, &ignored)) {
+ return KEEP;
}
- if (!parser.Skip(length)) {
- return false;
- }
-
- if (skipped_ || handshake_type != handshake_type_) {
+ if (skipped_ || header.handshake_type() != handshake_type_) {
size_t entire_length = parser.consumed() - start;
output->Write(output_offset, input.data() + start,
entire_length);
// DTLS sequence numbers need to be rewritten
- if (skipped_ && IsDtls(version)) {
+ if (skipped_ && header.is_dtls()) {
output->data()[start + 5] -= 1;
}
output_offset += entire_length;
} else {
std::cerr << "Dropping handshake: "
<< static_cast<unsigned>(handshake_type_) << std::endl;
// We only need to report that the output contains changed data if we
// drop a handshake message. But once we've skipped one message, we
// have to modify all subsequent handshake messages so that they include
// the correct DTLS sequence numbers.
skipped_ = true;
}
}
output->Truncate(output_offset);
- return skipped_;
+ return skipped_ ? CHANGE : KEEP;
}
private:
// The type of handshake message to drop.
uint8_t handshake_type_;
// Whether this filter has ever skipped a handshake message. Track this so
// that sequence numbers on DTLS handshake messages can be rewritten in
// subsequent calls.
--- a/external_tests/ssl_gtest/test_io.cc
+++ b/external_tests/ssl_gtest/test_io.cc
@@ -353,21 +353,32 @@ int32_t DummyPrSocket::Write(const void
if (!peer_) {
PR_SetError(PR_IO_ERROR, 0);
return -1;
}
DataBuffer packet(static_cast<const uint8_t*>(buf),
static_cast<size_t>(length));
DataBuffer filtered;
- if (filter_ && filter_->Filter(packet, &filtered)) {
- LOG("Filtered packet: " << filtered);
- peer_->PacketReceived(filtered);
- } else {
- peer_->PacketReceived(packet);
+ PacketFilter::Action action = PacketFilter::KEEP;
+ if (filter_) {
+ action = filter_->Filter(packet, &filtered);
+ }
+ switch (action) {
+ case PacketFilter::CHANGE:
+ LOG("Original packet: " << packet);
+ LOG("Filtered packet: " << filtered);
+ peer_->PacketReceived(filtered);
+ break;
+ case PacketFilter::DROP:
+ LOG("Droppped packet: " << packet);
+ break;
+ case PacketFilter::KEEP:
+ peer_->PacketReceived(packet);
+ break;
}
// libssl can't handle it if this reports something other than the length
// of what was passed in (or less, but we're not doing partial writes).
return static_cast<int32_t>(packet.len());
}
Poller *Poller::instance;
--- a/external_tests/ssl_gtest/test_io.h
+++ b/external_tests/ssl_gtest/test_io.h
@@ -20,40 +20,48 @@ namespace nss_test {
class DataBuffer;
class Packet;
class DummyPrSocket; // Fwd decl.
// Allow us to inspect a packet before it is written.
class PacketFilter {
public:
+ enum Action {
+ KEEP, // keep the original packet unmodified
+ CHANGE, // change the packet to a different value
+ DROP // drop the packet
+ };
+
virtual ~PacketFilter() {}
// The packet filter takes input and has the option of mutating it.
//
// A filter that modifies the data places the modified data in *output and
- // returns true. A filter that does not modify data returns false, in which
- // case the value in *output is ignored.
- virtual bool Filter(const DataBuffer& input, DataBuffer* output) = 0;
+ // returns CHANGE. A filter that does not modify data returns LEAVE, in which
+ // case the value in *output is ignored. A Filter can return DROP, in which
+ // case the packet is dropped (and *output is ignored).
+ virtual Action Filter(const DataBuffer& input, DataBuffer* output) = 0;
};
enum Mode { STREAM, DGRAM };
inline std::ostream& operator<<(std::ostream& os, Mode m) {
return os << ((m == STREAM) ? "TLS" : "DTLS");
}
class DummyPrSocket {
public:
~DummyPrSocket();
static PRFileDesc* CreateFD(const std::string& name,
Mode mode); // Returns an FD.
static DummyPrSocket* GetAdapter(PRFileDesc* fd);
+ DummyPrSocket* peer() const { return peer_; }
void SetPeer(DummyPrSocket* peer) { peer_ = peer; }
void SetPacketFilter(PacketFilter* filter) { filter_ = filter; }
// Drops peer, packet filter and any outstanding packets.
void Reset();
void PacketReceived(const DataBuffer& data);
int32_t Read(void* data, int32_t len);
int32_t Recv(void* buf, int32_t buflen);
--- a/external_tests/ssl_gtest/tls_agent.cc
+++ b/external_tests/ssl_gtest/tls_agent.cc
@@ -6,16 +6,17 @@
#include "tls_agent.h"
#include "pk11func.h"
#include "ssl.h"
#include "sslerr.h"
#include "sslproto.h"
#include "keyhi.h"
+#include "databuffer.h"
#define GTEST_HAS_RTTI 0
#include "gtest/gtest.h"
namespace nss_test {
const char* TlsAgent::states[] = {"INIT", "CONNECTING", "CONNECTED", "ERROR"};
@@ -496,19 +497,21 @@ void TlsAgent::Handshake() {
&TlsAgent::ReadableCallback);
return;
break;
// TODO(ekr@rtfm.com): needs special case for DTLS
case SSL_ERROR_RX_MALFORMED_HANDSHAKE:
default:
if (IS_SSL_ERROR(err)) {
- LOG("Handshake failed with SSL error " << err - SSL_ERROR_BASE);
+ LOG("Handshake failed with SSL error " << (err - SSL_ERROR_BASE)
+ << ": " << PORT_ErrorToString(err));
} else {
- LOG("Handshake failed with error " << err);
+ LOG("Handshake failed with error " << err
+ << ": " << PORT_ErrorToString(err));
}
error_code_ = err;
SetState(STATE_ERROR);
return;
}
}
void TlsAgent::PrepareForRenegotiate() {
@@ -519,16 +522,21 @@ void TlsAgent::PrepareForRenegotiate() {
void TlsAgent::StartRenegotiate() {
PrepareForRenegotiate();
SECStatus rv = SSL_ReHandshake(ssl_fd_, PR_TRUE);
EXPECT_EQ(SECSuccess, rv);
}
+void TlsAgent::SendDirect(const DataBuffer& buf) {
+ LOG("Send Direct " << buf);
+ adapter_->peer()->PacketReceived(buf);
+}
+
void TlsAgent::SendData(size_t bytes, size_t blocksize) {
uint8_t block[4096];
ASSERT_LT(blocksize, sizeof(block));
while(bytes) {
size_t tosend = std::min(blocksize, bytes);
@@ -543,37 +551,38 @@ void TlsAgent::SendData(size_t bytes, si
bytes -= tosend;
}
}
void TlsAgent::ReadBytes() {
uint8_t block[1024];
- LOG("Reading application data from socket");
-
int32_t rv = PR_Read(ssl_fd_, block, sizeof(block));
+ LOG("ReadBytes " << rv);
- int32_t err = PR_GetError();
- if (err != PR_WOULD_BLOCK_ERROR) {
- if (expected_read_error_) {
+ if (rv >= 0) {
+ size_t count = static_cast<size_t>(rv);
+ for (size_t i = 0; i < count; ++i) {
+ ASSERT_EQ(recv_ctr_ & 0xff, block[i]);
+ recv_ctr_++;
+ }
+ } else {
+ int32_t err = PR_GetError();
+ LOG("Read error " << err << ": " << PORT_ErrorToString(err));
+ if (err != PR_WOULD_BLOCK_ERROR && expected_read_error_) {
error_code_ = err;
- } else {
- ASSERT_LE(0, rv);
- size_t count = static_cast<size_t>(rv);
- LOG("Read " << count << " bytes");
- for (size_t i = 0; i < count; ++i) {
- ASSERT_EQ(recv_ctr_ & 0xff, block[i]);
- recv_ctr_++;
- }
}
}
- Poller::Instance()->Wait(READABLE_EVENT, adapter_, this,
- &TlsAgent::ReadableCallback);
+ // If closed, then don't bother waiting around.
+ if (rv) {
+ Poller::Instance()->Wait(READABLE_EVENT, adapter_, this,
+ &TlsAgent::ReadableCallback);
+ }
}
void TlsAgent::ResetSentBytes() {
send_ctr_ = 0;
}
void TlsAgent::ConfigureSessionCache(SessionResumptionMode mode) {
EXPECT_TRUE(EnsureTlsSetup());
--- a/external_tests/ssl_gtest/tls_agent.h
+++ b/external_tests/ssl_gtest/tls_agent.h
@@ -27,17 +27,17 @@ enum SessionResumptionMode {
RESUME_SESSIONID = 1,
RESUME_TICKET = 2,
RESUME_BOTH = RESUME_SESSIONID | RESUME_TICKET
};
class TlsAgent;
typedef
- std::function<void(TlsAgent& agent, PRBool checksig, PRBool isServer)>
+ std::function<SECStatus(TlsAgent& agent, PRBool checksig, PRBool isServer)>
AuthCertificateCallbackFunction;
typedef
std::function<void(TlsAgent& agent)>
HandshakeCallbackFunction;
class TlsAgent : public PollTarget {
public:
@@ -95,17 +95,20 @@ class TlsAgent : public PollTarget {
void SetSignatureAlgorithms(const SSLSignatureAndHashAlg* algorithms,
size_t count);
void EnableAlpn(const uint8_t* val, size_t len);
void CheckAlpn(SSLNextProtoState expected_state,
const std::string& expected) const;
void EnableSrtp();
void CheckSrtp() const;
void CheckErrorCode(int32_t expected) const;
+ // Send data on the socket, encrypting it.
void SendData(size_t bytes, size_t blocksize = 1024);
+ // Send data directly to the underlying socket, skipping the TLS layer.
+ void SendDirect(const DataBuffer& buf);
void ReadBytes();
void ResetSentBytes(); // Hack to test drops.
void EnableExtendedMasterSecret();
void CheckExtendedMasterSecret(bool expected);
void DisableRollbackDetection();
void EnableCompression();
void SetDowngradeCheckVersion(uint16_t version);
@@ -149,16 +152,18 @@ class TlsAgent : public PollTarget {
std::vector<uint8_t> session_id() const {
return std::vector<uint8_t>(info_.sessionID,
info_.sessionID + info_.sessionIDLength);
}
size_t received_bytes() const { return recv_ctr_; }
int32_t error_code() const { return error_code_; }
+ bool can_falsestart_hook_called() const { return can_falsestart_hook_called_; }
+
void SetHandshakeCallback(HandshakeCallbackFunction handshake_callback) {
handshake_callback_ = handshake_callback;
}
void SetAuthCertificateCallback(
AuthCertificateCallbackFunction auth_certificate_callback) {
auth_certificate_callback_ = auth_certificate_callback;
}
@@ -176,17 +181,17 @@ class TlsAgent : public PollTarget {
// Dummy auth certificate hook.
static SECStatus AuthCertificateHook(void* arg, PRFileDesc* fd,
PRBool checksig, PRBool isServer) {
TlsAgent* agent = reinterpret_cast<TlsAgent*>(arg);
agent->CheckPreliminaryInfo();
agent->auth_certificate_hook_called_ = true;
if (agent->auth_certificate_callback_) {
- agent->auth_certificate_callback_(*agent, checksig, isServer);
+ return agent->auth_certificate_callback_(*agent, checksig, isServer);
}
return SECSuccess;
}
// Client auth certificate hook.
static SECStatus ClientAuthenticated(void* arg, PRFileDesc* fd,
PRBool checksig, PRBool isServer) {
TlsAgent* agent = reinterpret_cast<TlsAgent*>(arg);
@@ -232,16 +237,17 @@ class TlsAgent : public PollTarget {
return SSL_SNI_CURRENT_CONFIG_IS_USED;
}
static SECStatus CanFalseStartCallback(PRFileDesc *fd, void *arg,
PRBool *canFalseStart) {
TlsAgent* agent = reinterpret_cast<TlsAgent*>(arg);
agent->CheckPreliminaryInfo();
EXPECT_TRUE(agent->falsestart_enabled_);
+ EXPECT_FALSE(agent->can_falsestart_hook_called_);
agent->can_falsestart_hook_called_ = true;
*canFalseStart = true;
return SECSuccess;
}
static void HandshakeCallback(PRFileDesc *fd, void *arg) {
TlsAgent* agent = reinterpret_cast<TlsAgent*>(arg);
agent->CheckPreliminaryInfo();
--- a/external_tests/ssl_gtest/tls_connect.cc
+++ b/external_tests/ssl_gtest/tls_connect.cc
@@ -276,22 +276,27 @@ void TlsConnectTestBase::EnableSrtp() {
void TlsConnectTestBase::CheckSrtp() const {
client_->CheckSrtp();
server_->CheckSrtp();
}
void TlsConnectTestBase::SendReceive() {
client_->SendData(50);
server_->SendData(50);
- WAIT_(client_->received_bytes() == 50U &&
- server_->received_bytes() == 50U, 2000);
- ASSERT_EQ(50U, client_->received_bytes());
- ASSERT_EQ(50U, server_->received_bytes());
+ Receive(50);
}
+void TlsConnectTestBase::Receive(size_t amount) {
+ WAIT_(client_->received_bytes() == amount &&
+ server_->received_bytes() == amount, 2000);
+ ASSERT_EQ(amount, client_->received_bytes());
+ ASSERT_EQ(amount, server_->received_bytes());
+}
+
+
void TlsConnectTestBase::ExpectExtendedMasterSecret(bool expected) {
expect_extended_master_secret_ = expected;
}
void TlsConnectTestBase::CheckExtendedMasterSecret() {
client_->CheckExtendedMasterSecret(expect_extended_master_secret_);
server_->CheckExtendedMasterSecret(expect_extended_master_secret_);
}
--- a/external_tests/ssl_gtest/tls_connect.h
+++ b/external_tests/ssl_gtest/tls_connect.h
@@ -67,16 +67,17 @@ class TlsConnectTestBase : public ::test
void DisableEcdheCiphers();
void EnableExtendedMasterSecret();
void ConfigureSessionCache(SessionResumptionMode client,
SessionResumptionMode server);
void EnableAlpn();
void EnableSrtp();
void CheckSrtp() const;
void SendReceive();
+ void Receive(size_t amount);
void ExpectExtendedMasterSecret(bool expected);
protected:
Mode mode_;
TlsAgent* client_;
TlsAgent* server_;
uint16_t version_;
SessionResumptionMode expected_resumption_mode_;
--- a/external_tests/ssl_gtest/tls_filter.cc
+++ b/external_tests/ssl_gtest/tls_filter.cc
@@ -2,243 +2,273 @@
/* vim: set ts=2 et sw=2 tw=80: */
/* This Source Code Form is subject to the terms of the Mozilla Public
* License, v. 2.0. If a copy of the MPL was not distributed with this file,
* You can obtain one at http://mozilla.org/MPL/2.0/. */
#include "tls_filter.h"
#include <iostream>
+#include "gtest_utils.h"
namespace nss_test {
-bool TlsRecordFilter::Filter(const DataBuffer& input, DataBuffer* output) {
+PacketFilter::Action TlsRecordFilter::Filter(const DataBuffer& input, DataBuffer* output) {
bool changed = false;
- size_t output_offset = 0U;
+ size_t offset = 0U;
output->Allocate(input.len());
TlsParser parser(input);
while (parser.remaining()) {
- size_t start = parser.consumed();
- uint8_t content_type;
- if (!parser.Read(&content_type)) {
- return false;
- }
- uint32_t version;
- if (!parser.Read(&version, 2)) {
- return false;
+ RecordHeader header;
+ DataBuffer record;
+ if (!header.Parse(&parser, &record)) {
+ return KEEP;
}
- if (IsDtls(version)) {
- if (!parser.Skip(8)) {
- return false;
- }
- }
- size_t header_len = parser.consumed() - start;
- output->Write(output_offset, input.data() + start, header_len);
-
- DataBuffer record;
- if (!parser.ReadVariable(&record, 2)) {
- return false;
+ DataBuffer filtered;
+ PacketFilter::Action action = FilterRecord(header, record, &filtered);
+ if (action == DROP) {
+ changed = true;
+ std::cerr << "record drop: " << record << std::endl;
+ continue; // don't copy this one
}
- // Move the offset in the output forward. ApplyFilter() returns the index
- // of the end of the record it wrote to the output, so we need to skip
- // over the content type and version for the value passed to it.
- output_offset = ApplyFilter(content_type, version, record, output,
- output_offset + header_len,
- &changed);
+ const DataBuffer* source = &record;
+ if (action == CHANGE) {
+ EXPECT_GT(0x10000, filtered.len());
+ changed = true;
+ std::cerr << "record old: " << record << std::endl;
+ std::cerr << "record new: " << filtered << std::endl;
+ source = &filtered;
+ }
+
+ offset = header.Write(output, offset, *source);
}
- output->Truncate(output_offset);
+ output->Truncate(offset);
// Record how many packets we actually touched.
if (changed) {
++count_;
+ return (offset == 0) ? DROP : CHANGE;
}
- return changed;
+ return KEEP;
}
-size_t TlsRecordFilter::ApplyFilter(uint8_t content_type, uint16_t version,
- const DataBuffer& record,
- DataBuffer* output,
- size_t offset, bool* changed) {
- const DataBuffer* source = &record;
- DataBuffer filtered;
- if (FilterRecord(content_type, version, record, &filtered) &&
- filtered.len() < 0x10000) {
- *changed = true;
- std::cerr << "record old: " << record << std::endl;
- std::cerr << "record new: " << filtered << std::endl;
- source = &filtered;
- }
-
- output->Write(offset, source->len(), 2);
- output->Write(offset + 2, *source);
- return offset + 2 + source->len();
-}
-
-bool TlsHandshakeFilter::FilterRecord(uint8_t content_type, uint16_t version,
- const DataBuffer& input,
- DataBuffer* output) {
- // Check that the first byte is as requested.
- if (content_type != kTlsHandshakeType) {
+bool TlsRecordFilter::RecordHeader::Parse(TlsParser* parser, DataBuffer* body) {
+ if (!parser->Read(&content_type_)) {
return false;
}
+ uint32_t version;
+ if (!parser->Read(&version, 2)) {
+ return false;
+ }
+ version_ = version;
+
+ sequence_number_ = 0;
+ if (IsDtls(version)) {
+ uint32_t tmp;
+ if (!parser->Read(&tmp, 4)) {
+ return false;
+ }
+ sequence_number_ = static_cast<uint64_t>(tmp) << 32;
+ if (!parser->Read(&tmp, 4)) {
+ return false;
+ }
+ sequence_number_ |= static_cast<uint64_t>(tmp);
+ }
+ return parser->ReadVariable(body, 2);
+}
+
+size_t TlsRecordFilter::RecordHeader::Write(
+ DataBuffer* buffer, size_t offset, const DataBuffer& body) const {
+ offset = buffer->Write(offset, content_type_, 1);
+ offset = buffer->Write(offset, version_, 2);
+ if (is_dtls()) {
+ // write epoch (2 octet), and seqnum (6 octet)
+ offset = buffer->Write(offset, sequence_number_ >> 32, 4);
+ offset = buffer->Write(offset, sequence_number_ & 0xffffffff, 4);
+ }
+ offset = buffer->Write(offset, body.len(), 2);
+ offset = buffer->Write(offset, body);
+ return offset;
+}
+
+PacketFilter::Action TlsHandshakeFilter::FilterRecord(
+ const RecordHeader& record_header, const DataBuffer& input,
+ DataBuffer* output) {
+ // Check that the first byte is as requested.
+ if (record_header.content_type() != kTlsHandshakeType) {
+ return KEEP;
+ }
+
bool changed = false;
- size_t output_offset = 0U;
+ size_t offset = 0U;
output->Allocate(input.len()); // Preallocate a little.
TlsParser parser(input);
while (parser.remaining()) {
- size_t start = parser.consumed();
- uint8_t handshake_type;
- if (!parser.Read(&handshake_type)) {
- return false; // malformed
+ HandshakeHeader header;
+ DataBuffer handshake;
+ if (!header.Parse(&parser, record_header, &handshake)) {
+ return KEEP;
}
- uint32_t length;
- if (!ReadLength(&parser, version, &length)) {
- return false;
+
+ DataBuffer filtered;
+ PacketFilter::Action action = FilterHandshake(header, handshake, &filtered);
+ if (action == DROP) {
+ changed = true;
+ std::cerr << "handshake drop: " << handshake << std::endl;
+ continue;
}
- size_t header_len = parser.consumed() - start;
- output->Write(output_offset, input.data() + start, header_len);
-
- DataBuffer handshake;
- if (!parser.Read(&handshake, length)) {
- return false;
+ const DataBuffer* source = &handshake;
+ if (action == CHANGE) {
+ EXPECT_GT(0x1000000, filtered.len());
+ changed = true;
+ std::cerr << "handshake old: " << handshake << std::endl;
+ std::cerr << "handshake new: " << filtered << std::endl;
+ source = &filtered;
}
- // Move the offset in the output forward. ApplyFilter() returns the index
- // of the end of the message it wrote to the output, so we need to identify
- // offsets from the start of the message for length and the handshake
- // message.
- output_offset = ApplyFilter(version, handshake_type, handshake,
- output, output_offset + 1,
- output_offset + header_len,
- &changed);
+ offset = header.Write(output, offset, *source);
}
- output->Truncate(output_offset);
- return changed;
+ output->Truncate(offset);
+ return changed ? (offset ? CHANGE : DROP) : KEEP;
}
-bool TlsHandshakeFilter::ReadLength(TlsParser* parser, uint16_t version, uint32_t *length) {
+bool TlsHandshakeFilter::HandshakeHeader::ReadLength(TlsParser* parser,
+ const RecordHeader& header,
+ uint32_t *length) {
if (!parser->Read(length, 3)) {
return false; // malformed
}
- if (!IsDtls(version)) {
+ if (!header.is_dtls()) {
return true; // nothing left to do
}
// Read and check DTLS parameters
- if (!parser->Skip(2)) { // sequence number
+ uint32_t message_seq_tmp;
+ if (!parser->Read(&message_seq_tmp, 2)) { // sequence number
return false;
}
+ message_seq_ = message_seq_tmp;
uint32_t fragment_offset;
if (!parser->Read(&fragment_offset, 3)) {
return false;
}
uint32_t fragment_length;
if (!parser->Read(&fragment_length, 3)) {
return false;
}
// All current tests where we are using this code don't fragment.
return (fragment_offset == 0 && fragment_length == *length);
}
-size_t TlsHandshakeFilter::ApplyFilter(
- uint16_t version, uint8_t handshake_type, const DataBuffer& handshake,
- DataBuffer* output, size_t length_offset, size_t value_offset,
- bool* changed) {
- const DataBuffer* source = &handshake;
- DataBuffer filtered;
- if (FilterHandshake(version, handshake_type, handshake, &filtered) &&
- filtered.len() < 0x1000000) {
- *changed = true;
- std::cerr << "handshake old: " << handshake << std::endl;
- std::cerr << "handshake new: " << filtered << std::endl;
- source = &filtered;
+bool TlsHandshakeFilter::HandshakeHeader::Parse(
+ TlsParser* parser, const RecordHeader& record_header,
+ DataBuffer* body) {
+
+ version_ = record_header.version();
+ if (!parser->Read(&handshake_type_)) {
+ return false; // malformed
+ }
+ uint32_t length;
+ if (!ReadLength(parser, record_header, &length)) {
+ return false;
}
- // Back up and overwrite the (two) length field(s): the handshake message
- // length and the DTLS fragment length.
- output->Write(length_offset, source->len(), 3);
- if (IsDtls(version)) {
- output->Write(length_offset + 8, source->len(), 3);
- }
- output->Write(value_offset, *source);
- return value_offset + source->len();
+ return parser->Read(body, length);
}
-bool TlsInspectorRecordHandshakeMessage::FilterHandshake(
- uint16_t version, uint8_t handshake_type,
+size_t TlsHandshakeFilter::HandshakeHeader::Write(
+ DataBuffer* buffer, size_t offset, const DataBuffer& body) const {
+ offset = buffer->Write(offset, handshake_type(), 1);
+ offset = buffer->Write(offset, body.len(), 3);
+ if (is_dtls()) {
+ offset = buffer->Write(offset, message_seq_, 2);
+ offset = buffer->Write(offset, 0U, 3); // fragment_offset
+ offset = buffer->Write(offset, body.len(), 3);
+ }
+ offset = buffer->Write(offset, body);
+ return offset;
+}
+
+PacketFilter::Action TlsInspectorRecordHandshakeMessage::FilterHandshake(
+ const HandshakeHeader& header,
const DataBuffer& input, DataBuffer* output) {
// Only do this once.
if (buffer_.len()) {
- return false;
+ return KEEP;
}
- if (handshake_type == handshake_type_) {
+ if (header.handshake_type() == handshake_type_) {
buffer_ = input;
}
- return false;
+ return KEEP;
}
-bool TlsInspectorReplaceHandshakeMessage::FilterHandshake(
- uint16_t version, uint8_t handshake_type,
+PacketFilter::Action TlsInspectorReplaceHandshakeMessage::FilterHandshake(
+ const HandshakeHeader& header,
const DataBuffer& input, DataBuffer* output) {
- if (handshake_type == handshake_type_) {
+ if (header.handshake_type() == handshake_type_) {
*output = buffer_;
- return true;
+ return CHANGE;
}
- return false;
+ return KEEP;
}
-bool TlsAlertRecorder::FilterRecord(uint8_t content_type, uint16_t version,
- const DataBuffer& input, DataBuffer* output) {
+PacketFilter::Action TlsAlertRecorder::FilterRecord(
+ const RecordHeader& header, const DataBuffer& input, DataBuffer* output) {
if (level_ == kTlsAlertFatal) { // already fatal
- return false;
+ return KEEP;
}
- if (content_type != kTlsAlertType) {
- return false;
+ if (header.content_type() != kTlsAlertType) {
+ return KEEP;
}
std::cerr << "Alert: " << input << std::endl;
TlsParser parser(input);
uint8_t lvl;
if (!parser.Read(&lvl)) {
- return false;
+ return KEEP;
}
if (lvl == kTlsAlertWarning) { // not strong enough
- return false;
+ return KEEP;
}
level_ = lvl;
(void)parser.Read(&description_);
- return false;
+ return KEEP;
}
ChainedPacketFilter::~ChainedPacketFilter() {
for (auto it = filters_.begin(); it != filters_.end(); ++it) {
delete *it;
}
}
-bool ChainedPacketFilter::Filter(const DataBuffer& input, DataBuffer* output) {
+PacketFilter::Action ChainedPacketFilter::Filter(const DataBuffer& input,
+ DataBuffer* output) {
DataBuffer in(input);
bool changed = false;
for (auto it = filters_.begin(); it != filters_.end(); ++it) {
- if ((*it)->Filter(in, output)) {
+ PacketFilter::Action action = (*it)->Filter(in, output);
+ if (action == DROP) {
+ return DROP;
+ }
+ if (action == CHANGE) {
in = *output;
changed = true;
}
}
- return changed;
+ return changed ? CHANGE : KEEP;
}
} // namespace nss_test
--- a/external_tests/ssl_gtest/tls_filter.h
+++ b/external_tests/ssl_gtest/tls_filter.h
@@ -15,94 +15,148 @@
namespace nss_test {
// Abstract filter that operates on entire (D)TLS records.
class TlsRecordFilter : public PacketFilter {
public:
TlsRecordFilter() : count_(0) {}
- virtual bool Filter(const DataBuffer& input, DataBuffer* output);
+ virtual PacketFilter::Action Filter(const DataBuffer& input,
+ DataBuffer* output);
// Report how many packets were altered by the filter.
size_t filtered_packets() const { return count_; }
+ class Versioned {
+ public:
+ Versioned() : version_(0) {}
+ bool is_dtls() const { return IsDtls(version_); }
+ uint16_t version() const { return version_; }
+
+ protected:
+ uint16_t version_;
+ };
+
+ class RecordHeader : public Versioned {
+ public:
+ RecordHeader()
+ : Versioned(), content_type_(0), sequence_number_(0) {}
+
+ uint8_t content_type() const { return content_type_; }
+ uint64_t sequence_number() const { return sequence_number_; }
+ size_t header_length() const { return is_dtls() ? 11 : 3; }
+
+ // Parse the header; return true if successful; body in an outparam if OK.
+ bool Parse(TlsParser* parser, DataBuffer* body);
+ // Write the header and body to a buffer at the given offset.
+ // Return the offset of the end of the write.
+ size_t Write(DataBuffer* buffer, size_t offset, const DataBuffer& body) const;
+
+ private:
+ uint8_t content_type_;
+ uint64_t sequence_number_;
+ };
+
protected:
- virtual bool FilterRecord(uint8_t content_type, uint16_t version,
- const DataBuffer& data, DataBuffer* changed) = 0;
+ // The record filter receives the record contentType, version and DTLS
+ // sequence number (which is zero for TLS), plus the existing record payload.
+ // It returns an action (KEEP, CHANGE, DROP). It writes to the `changed`
+ // outparam with the new record contents if it chooses to CHANGE the record.
+ virtual PacketFilter::Action FilterRecord(const RecordHeader& header,
+ const DataBuffer& data,
+ DataBuffer* changed) = 0;
+
private:
- size_t ApplyFilter(uint8_t content_type, uint16_t version,
- const DataBuffer& record, DataBuffer* output,
- size_t offset, bool* changed);
size_t count_;
};
// Abstract filter that operates on handshake messages rather than records.
// This assumes that the handshake messages are written in a block as entire
// records and that they don't span records or anything crazy like that.
class TlsHandshakeFilter : public TlsRecordFilter {
public:
TlsHandshakeFilter() {}
- // Reads the length from the record header.
- // This also reads the DTLS fragment information and checks it.
- static bool ReadLength(TlsParser* parser, uint16_t version, uint32_t *length);
+ class HandshakeHeader : public Versioned {
+ public:
+ HandshakeHeader()
+ : Versioned(), handshake_type_(0), message_seq_(0) {}
+
+ uint8_t handshake_type() const { return handshake_type_; }
+ bool Parse(TlsParser* parser, const RecordHeader& record_header,
+ DataBuffer* body);
+ size_t Write(DataBuffer* buffer, size_t offset,
+ const DataBuffer& body) const;
+
+ private:
+ // Reads the length from the record header.
+ // This also reads the DTLS fragment information and checks it.
+ bool ReadLength(TlsParser* parser, const RecordHeader& header,
+ uint32_t *length);
+
+ uint8_t handshake_type_;
+ uint16_t message_seq_;
+ // fragment_offset is always zero in these tests.
+ };
protected:
- virtual bool FilterRecord(uint8_t content_type, uint16_t version,
- const DataBuffer& input, DataBuffer* output);
- virtual bool FilterHandshake(uint16_t version, uint8_t handshake_type,
- const DataBuffer& input, DataBuffer* output) = 0;
+ virtual PacketFilter::Action FilterRecord(const RecordHeader& header,
+ const DataBuffer& input,
+ DataBuffer* output);
+ virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
+ const DataBuffer& input,
+ DataBuffer* output) = 0;
private:
- size_t ApplyFilter(uint16_t version, uint8_t handshake_type,
- const DataBuffer& record, DataBuffer* output,
- size_t length_offset, size_t value_offset, bool* changed);
};
// Make a copy of the first instance of a handshake message.
class TlsInspectorRecordHandshakeMessage : public TlsHandshakeFilter {
public:
TlsInspectorRecordHandshakeMessage(uint8_t handshake_type)
: handshake_type_(handshake_type), buffer_() {}
- virtual bool FilterHandshake(uint16_t version, uint8_t handshake_type,
- const DataBuffer& input, DataBuffer* output);
+ virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
+ const DataBuffer& input,
+ DataBuffer* output);
const DataBuffer& buffer() const { return buffer_; }
private:
uint8_t handshake_type_;
DataBuffer buffer_;
};
// Replace all instances of a handshake message.
class TlsInspectorReplaceHandshakeMessage : public TlsHandshakeFilter {
public:
TlsInspectorReplaceHandshakeMessage(uint8_t handshake_type,
const DataBuffer& replacement)
: handshake_type_(handshake_type), buffer_(replacement) {}
- virtual bool FilterHandshake(uint16_t version, uint8_t handshake_type,
- const DataBuffer& input, DataBuffer* output);
+ virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
+ const DataBuffer& input,
+ DataBuffer* output);
private:
uint8_t handshake_type_;
DataBuffer buffer_;
};
// Records an alert. If an alert has already been recorded, it won't save the
// new alert unless the old alert is a warning and the new one is fatal.
class TlsAlertRecorder : public TlsRecordFilter {
public:
TlsAlertRecorder() : level_(255), description_(255) {}
- virtual bool FilterRecord(uint8_t content_type, uint16_t version,
- const DataBuffer& input, DataBuffer* output);
+ virtual PacketFilter::Action FilterRecord(const RecordHeader& header,
+ const DataBuffer& input,
+ DataBuffer* output);
uint8_t level() const { return level_; }
uint8_t description() const { return description_; }
private:
uint8_t level_;
uint8_t description_;
};
@@ -110,17 +164,18 @@ class TlsAlertRecorder : public TlsRecor
// Runs multiple packet filters in series.
class ChainedPacketFilter : public PacketFilter {
public:
ChainedPacketFilter() {}
ChainedPacketFilter(const std::vector<PacketFilter*> filters)
: filters_(filters.begin(), filters.end()) {}
virtual ~ChainedPacketFilter();
- virtual bool Filter(const DataBuffer& input, DataBuffer* output);
+ virtual PacketFilter::Action Filter(const DataBuffer& input,
+ DataBuffer* output);
// Takes ownership of the filter.
void Add(PacketFilter* filter) {
filters_.push_back(filter);
}
private:
std::vector<PacketFilter*> filters_;