Bug 1235366 - Test case for writing between receipt of CCS and Finished, r=ekr
authorMartin Thomson <martin.thomson@gmail.com>
Thu, 04 Feb 2016 05:54:40 +1100
changeset 11847 88c7afc06b2e679de8b913ca94103cb99a4fcde4
parent 11846 3430648027ec608791e3100c43af89e6beee8d6c
child 11848 7d389601644f8dbcfea3212de3ab8ca9c9fb54b1
push id960
push usermartin.thomson@gmail.com
push dateWed, 03 Feb 2016 19:15:17 +0000
reviewersekr
bugs1235366
Bug 1235366 - Test case for writing between receipt of CCS and Finished, r=ekr
external_tests/ssl_gtest/databuffer.h
external_tests/ssl_gtest/ssl_extension_unittest.cc
external_tests/ssl_gtest/ssl_loopback_unittest.cc
external_tests/ssl_gtest/ssl_skip_unittest.cc
external_tests/ssl_gtest/test_io.cc
external_tests/ssl_gtest/test_io.h
external_tests/ssl_gtest/tls_agent.cc
external_tests/ssl_gtest/tls_agent.h
external_tests/ssl_gtest/tls_connect.cc
external_tests/ssl_gtest/tls_connect.h
external_tests/ssl_gtest/tls_filter.cc
external_tests/ssl_gtest/tls_filter.h
--- a/external_tests/ssl_gtest/databuffer.h
+++ b/external_tests/ssl_gtest/databuffer.h
@@ -59,43 +59,46 @@ class DataBuffer {
     } else {
       assert(len == 0);
       data_ = nullptr;
       len_ = 0;
     }
   }
 
   // Write will do a new allocation and expand the size of the buffer if needed.
-  void Write(size_t index, const uint8_t* val, size_t count) {
+  // Returns the offset of the end of the write.
+  size_t Write(size_t index, const uint8_t* val, size_t count) {
     if (index + count > len_) {
       size_t newlen = index + count;
       uint8_t* tmp = new uint8_t[newlen]; // Always > 0.
       memcpy(static_cast<void*>(tmp),
              static_cast<const void*>(data_), len_);
       if (index > len_) {
         memset(static_cast<void*>(tmp + len_), 0, index - len_);
       }
       delete[] data_;
       data_ = tmp;
       len_ = newlen;
     }
     memcpy(static_cast<void*>(data_ + index),
            static_cast<const void*>(val), count);
+    return index + count;
   }
 
-  void Write(size_t index, const DataBuffer& buf) {
-    Write(index, buf.data(), buf.len());
+  size_t Write(size_t index, const DataBuffer& buf) {
+    return Write(index, buf.data(), buf.len());
   }
 
   // Write an integer, also performing host-to-network order conversion.
-  void Write(size_t index, uint32_t val, size_t count) {
+  // Returns the offset of the end of the write.
+  size_t Write(size_t index, uint32_t val, size_t count) {
     assert(count <= sizeof(uint32_t));
     uint32_t nvalue = htonl(val);
     auto* addr = reinterpret_cast<const uint8_t*>(&nvalue);
-    Write(index, addr + sizeof(uint32_t) - count, count);
+    return Write(index, addr + sizeof(uint32_t) - count, count);
   }
 
   // This can't use the same trick as Write(), since we might be reading from a
   // smaller data source.
   bool Read(size_t index, size_t count, uint32_t* val) const {
     assert(count < sizeof(uint32_t));
     assert(val);
     if ((index > len()) || (count > (len() - index))) {
--- a/external_tests/ssl_gtest/ssl_extension_unittest.cc
+++ b/external_tests/ssl_gtest/ssl_extension_unittest.cc
@@ -12,265 +12,264 @@
 #include "tls_parser.h"
 #include "tls_filter.h"
 #include "tls_connect.h"
 
 namespace nss_test {
 
 class TlsExtensionFilter : public TlsHandshakeFilter {
  protected:
-  virtual bool FilterHandshake(uint16_t version, uint8_t handshake_type,
-                               const DataBuffer& input, DataBuffer* output) {
-    if (handshake_type == kTlsHandshakeClientHello) {
+  virtual PacketFilter::Action FilterHandshake(
+      const HandshakeHeader& header,
+      const DataBuffer& input, DataBuffer* output) {
+    if (header.handshake_type() == kTlsHandshakeClientHello) {
       TlsParser parser(input);
-      if (!FindClientHelloExtensions(parser, version)) {
-        return false;
+      if (!FindClientHelloExtensions(&parser, header)) {
+        return KEEP;
       }
-      return FilterExtensions(parser, input, output);
+      return FilterExtensions(&parser, input, output);
     }
-    if (handshake_type == kTlsHandshakeServerHello) {
+    if (header.handshake_type() == kTlsHandshakeServerHello) {
       TlsParser parser(input);
-      if (!FindServerHelloExtensions(parser, version)) {
-        return false;
+      if (!FindServerHelloExtensions(&parser, header.version())) {
+        return KEEP;
       }
-      return FilterExtensions(parser, input, output);
+      return FilterExtensions(&parser, input, output);
     }
-    return false;
+    return KEEP;
   }
 
-  virtual bool FilterExtension(uint16_t extension_type,
-                               const DataBuffer& input, DataBuffer* output) = 0;
+  virtual PacketFilter::Action FilterExtension(uint16_t extension_type,
+                                               const DataBuffer& input,
+                                               DataBuffer* output) = 0;
 
  public:
-  static bool FindClientHelloExtensions(TlsParser& parser, uint16_t version) {
-    if (!parser.Skip(2 + 32)) { // version + random
+  static bool FindClientHelloExtensions(TlsParser* parser, const Versioned& header) {
+    if (!parser->Skip(2 + 32)) { // version + random
       return false;
     }
-    if (!parser.SkipVariable(1)) { // session ID
+    if (!parser->SkipVariable(1)) { // session ID
       return false;
     }
-    if (IsDtls(version) && !parser.SkipVariable(1)) { // DTLS cookie
+    if (header.is_dtls() && !parser->SkipVariable(1)) { // DTLS cookie
       return false;
     }
-    if (!parser.SkipVariable(2)) { // cipher suites
+    if (!parser->SkipVariable(2)) { // cipher suites
       return false;
     }
-    if (!parser.SkipVariable(1)) { // compression methods
+    if (!parser->SkipVariable(1)) { // compression methods
       return false;
     }
     return true;
   }
 
-  static bool FindServerHelloExtensions(TlsParser& parser, uint16_t version) {
-    if (!parser.Skip(2 + 32)) { // version + random
+  static bool FindServerHelloExtensions(TlsParser* parser, uint16_t version) {
+    if (!parser->Skip(2 + 32)) { // version + random
       return false;
     }
-    if (!parser.SkipVariable(1)) { // session ID
+    if (!parser->SkipVariable(1)) { // session ID
       return false;
     }
-    if (!parser.Skip(2)) { // cipher suite
+    if (!parser->Skip(2)) { // cipher suite
       return false;
     }
     if (NormalizeTlsVersion(version) <= SSL_LIBRARY_VERSION_TLS_1_2) {
-      if (!parser.Skip(1)) { // compression method
+      if (!parser->Skip(1)) { // compression method
         return false;
       }
     }
     return true;
   }
 
  private:
-  bool FilterExtensions(TlsParser& parser,
-                        const DataBuffer& input, DataBuffer* output) {
-    size_t length_offset = parser.consumed();
+  PacketFilter::Action FilterExtensions(TlsParser* parser,
+                                        const DataBuffer& input,
+                                        DataBuffer* output) {
+    size_t length_offset = parser->consumed();
     uint32_t all_extensions;
-    if (!parser.Read(&all_extensions, 2)) {
-      return false; // no extensions, odd but OK
+    if (!parser->Read(&all_extensions, 2)) {
+      return KEEP; // no extensions, odd but OK
     }
-    if (all_extensions != parser.remaining()) {
-      return false; // malformed
+    if (all_extensions != parser->remaining()) {
+      return KEEP; // malformed
     }
 
     bool changed = false;
 
     // Write out the start of the message.
     output->Allocate(input.len());
-    output->Write(0, input.data(), parser.consumed());
-    size_t output_offset = parser.consumed();
+    size_t offset = output->Write(0, input.data(), parser->consumed());
 
-    while (parser.remaining()) {
+    while (parser->remaining()) {
       uint32_t extension_type;
-      if (!parser.Read(&extension_type, 2)) {
-        return false; // malformed
+      if (!parser->Read(&extension_type, 2)) {
+        return KEEP; // malformed
+      }
+
+      DataBuffer extension;
+      if (!parser->ReadVariable(&extension, 2)) {
+        return KEEP; // malformed
       }
 
-      // Copy extension type.
-      output->Write(output_offset, extension_type, 2);
+      DataBuffer filtered;
+      PacketFilter::Action action = FilterExtension(extension_type, extension,
+                                                    &filtered);
+      if (action == DROP) {
+        changed = true;
+        std::cerr << "extension drop: " << extension << std::endl;
+        continue;
+      }
 
-      DataBuffer extension;
-      if (!parser.ReadVariable(&extension, 2)) {
-        return false; // malformed
+      const DataBuffer* source = &extension;
+      if (action == CHANGE) {
+        EXPECT_GT(0x10000, filtered.len());
+        changed = true;
+        std::cerr << "extension old: " << extension << std::endl;
+        std::cerr << "extension new: " << filtered << std::endl;
+        source = &filtered;
       }
-      output_offset = ApplyFilter(static_cast<uint16_t>(extension_type), extension,
-                                  output, output_offset + 2, &changed);
+
+      // Write out extension.
+      offset = output->Write(offset, extension_type, 2);
+      offset = output->Write(offset, source->len(), 2);
+      offset = output->Write(offset, *source);
     }
-    output->Truncate(output_offset);
+    output->Truncate(offset);
 
     if (changed) {
       size_t newlen = output->len() - length_offset - 2;
+      EXPECT_GT(0x10000, newlen);
       if (newlen >= 0x10000) {
-        return false; // bad: size increased too much
+        return KEEP; // bad: size increased too much
       }
       output->Write(length_offset, newlen, 2);
+      return CHANGE;
     }
-    return changed;
-  }
-
-  size_t ApplyFilter(uint16_t extension_type, const DataBuffer& extension,
-                     DataBuffer* output, size_t offset, bool* changed) {
-    const DataBuffer* source = &extension;
-    DataBuffer filtered;
-    if (FilterExtension(extension_type, extension, &filtered) &&
-        filtered.len() < 0x10000) {
-      *changed = true;
-      std::cerr << "extension old: " << extension << std::endl;
-      std::cerr << "extension new: " << filtered << std::endl;
-      source = &filtered;
-    }
-
-    output->Write(offset, source->len(), 2);
-    output->Write(offset + 2, *source);
-    return offset + 2 + source->len();
+    return KEEP;
   }
 };
 
 class TlsExtensionTruncator : public TlsExtensionFilter {
  public:
   TlsExtensionTruncator(uint16_t extension, size_t length)
       : extension_(extension), length_(length) {}
-  virtual bool FilterExtension(uint16_t extension_type,
-                               const DataBuffer& input, DataBuffer* output) {
+  virtual PacketFilter::Action FilterExtension(
+      uint16_t extension_type, const DataBuffer& input, DataBuffer* output) {
     if (extension_type != extension_) {
-      return false;
+      return KEEP;
     }
     if (input.len() <= length_) {
-      return false;
+      return KEEP;
     }
 
     output->Assign(input.data(), length_);
-    return true;
+    return CHANGE;
   }
  private:
     uint16_t extension_;
     size_t length_;
 };
 
 class TlsExtensionDamager : public TlsExtensionFilter {
  public:
   TlsExtensionDamager(uint16_t extension, size_t index)
       : extension_(extension), index_(index) {}
-  virtual bool FilterExtension(uint16_t extension_type,
-                               const DataBuffer& input, DataBuffer* output) {
+  virtual PacketFilter::Action FilterExtension(
+      uint16_t extension_type, const DataBuffer& input, DataBuffer* output) {
     if (extension_type != extension_) {
-      return false;
+      return KEEP;
     }
 
     *output = input;
     output->data()[index_] += 73; // Increment selected for maximum damage
-    return true;
+    return CHANGE;
   }
  private:
   uint16_t extension_;
   size_t index_;
 };
 
 class TlsExtensionReplacer : public TlsExtensionFilter {
  public:
   TlsExtensionReplacer(uint16_t extension, const DataBuffer& data)
       : extension_(extension), data_(data) {}
-  virtual bool FilterExtension(uint16_t extension_type,
-                               const DataBuffer& input, DataBuffer* output) {
+  virtual PacketFilter::Action FilterExtension(
+      uint16_t extension_type, const DataBuffer& input, DataBuffer* output) {
     if (extension_type != extension_) {
-      return false;
+      return KEEP;
     }
 
     *output = data_;
-    return true;
+    return CHANGE;
   }
  private:
   const uint16_t extension_;
   const DataBuffer data_;
 };
 
 class TlsExtensionInjector : public TlsHandshakeFilter {
  public:
   TlsExtensionInjector(uint16_t ext, DataBuffer& data)
       : extension_(ext), data_(data) {}
 
-  virtual bool FilterHandshake(uint16_t version, uint8_t handshake_type,
-                               const DataBuffer& input, DataBuffer* output) {
+  virtual PacketFilter::Action FilterHandshake(
+      const HandshakeHeader& header,
+      const DataBuffer& input, DataBuffer* output) {
     size_t offset;
-    if (handshake_type == kTlsHandshakeClientHello) {
+    if (header.handshake_type() == kTlsHandshakeClientHello) {
       TlsParser parser(input);
-      if (!TlsExtensionFilter::FindClientHelloExtensions(parser, version)) {
-        return false;
+      if (!TlsExtensionFilter::FindClientHelloExtensions(&parser, header)) {
+        return KEEP;
       }
       offset = parser.consumed();
-    } else if (handshake_type == kTlsHandshakeServerHello) {
+    } else if (header.handshake_type() == kTlsHandshakeServerHello) {
       TlsParser parser(input);
-      if (!TlsExtensionFilter::FindServerHelloExtensions(parser, version)) {
-        return false;
+      if (!TlsExtensionFilter::FindServerHelloExtensions(&parser, header.version())) {
+        return KEEP;
       }
       offset = parser.consumed();
     } else {
-      return false;
+      return KEEP;
     }
 
     *output = input;
 
-    std::cerr << "Pre:" << input << std::endl;
-    std::cerr << "Lof:" << offset << std::endl;
-
     // Increase the size of the extensions.
     uint16_t* len_addr = reinterpret_cast<uint16_t*>(output->data() + offset);
-    std::cerr << "L-p:" << ntohs(*len_addr) << std::endl;
     *len_addr = htons(ntohs(*len_addr) + data_.len() + 4);
-    std::cerr << "L-i:" << ntohs(*len_addr) << std::endl;
-
 
     // Insert the extension type and length.
     DataBuffer type_length;
     type_length.Allocate(4);
     type_length.Write(0, extension_, 2);
     type_length.Write(2, data_.len(), 2);
     output->Splice(type_length, offset + 2);
 
     // Insert the payload.
     output->Splice(data_, offset + 6);
 
-    std::cerr << "Aft:" << *output << std::endl;
-    return true;
+    return CHANGE;
   }
 
  private:
   const uint16_t extension_;
   const DataBuffer data_;
 };
 
 class TlsExtensionCapture : public TlsExtensionFilter {
  public:
   TlsExtensionCapture(uint16_t ext)
       : extension_(ext), data_() {}
 
-  virtual bool FilterExtension(uint16_t extension_type,
-                               const DataBuffer& input, DataBuffer* output) {
+  virtual PacketFilter::Action FilterExtension(
+      uint16_t extension_type, const DataBuffer& input, DataBuffer* output) {
     if (extension_type == extension_) {
       data_.Assign(input);
     }
-    return false;
+    return KEEP;
   }
 
   const DataBuffer& extension() const { return data_; }
 
  private:
   const uint16_t extension_;
   DataBuffer data_;
 };
@@ -623,20 +622,24 @@ TEST_P(TlsExtensionTest12Plus, Signature
  */
 
 // Helper class - stores signed certificate timestamps as provided
 // by the relevant callbacks on the client.
 class SignedCertificateTimestampsExtractor {
  public:
   SignedCertificateTimestampsExtractor(TlsAgent& client) {
     client.SetAuthCertificateCallback(
-      [&](TlsAgent& agent, PRBool checksig, PRBool isServer) {
+      [&](TlsAgent& agent, PRBool checksig, PRBool isServer) -> SECStatus {
         const SECItem *scts = SSL_PeerSignedCertTimestamps(agent.ssl_fd());
-        ASSERT_TRUE(scts);
+        EXPECT_TRUE(scts);
+        if (!scts) {
+          return SECFailure;
+        }
         auth_timestamps_.reset(new DataBuffer(scts->data, scts->len));
+        return SECSuccess;
       }
     );
     client.SetHandshakeCallback(
       [&](TlsAgent& agent) {
         const SECItem *scts = SSL_PeerSignedCertTimestamps(agent.ssl_fd());
         ASSERT_TRUE(scts);
         handshake_timestamps_.reset(new DataBuffer(scts->data, scts->len));
       }
--- a/external_tests/ssl_gtest/ssl_loopback_unittest.cc
+++ b/external_tests/ssl_gtest/ssl_loopback_unittest.cc
@@ -3,16 +3,17 @@
 /* This Source Code Form is subject to the terms of the Mozilla Public
  * License, v. 2.0. If a copy of the MPL was not distributed with this file,
  * You can obtain one at http://mozilla.org/MPL/2.0/. */
 
 #include "ssl.h"
 #include "sslerr.h"
 #include "sslproto.h"
 #include <memory>
+#include <functional>
 
 extern "C" {
 // This is not something that should make you happy.
 #include "libssl_internals.h"
 }
 
 #include "tls_parser.h"
 #include "tls_filter.h"
@@ -42,43 +43,45 @@ uint8_t kBogusClientKeyExchange[] = {
 };
 
 // When we see the ClientKeyExchange from |client|, increment the
 // ClientHelloVersion on |server|.
 class TlsInspectorClientHelloVersionChanger : public TlsHandshakeFilter {
  public:
   TlsInspectorClientHelloVersionChanger(TlsAgent* server) : server_(server) {}
 
-  virtual bool FilterHandshake(uint16_t version, uint8_t handshake_type,
-                               const DataBuffer& input, DataBuffer* output) {
-    if (handshake_type == kTlsHandshakeClientKeyExchange) {
+  virtual PacketFilter::Action FilterHandshake(
+      const HandshakeHeader& header,
+      const DataBuffer& input, DataBuffer* output) {
+    if (header.handshake_type() == kTlsHandshakeClientKeyExchange) {
       EXPECT_EQ(
           SECSuccess,
           SSLInt_IncrementClientHandshakeVersion(server_->ssl_fd()));
     }
-    return false;
+    return KEEP;
   }
 
  private:
   TlsAgent* server_;
 };
 
 // Set the version number in the ClientHello.
 class TlsInspectorClientHelloVersionSetter : public TlsHandshakeFilter {
  public:
   TlsInspectorClientHelloVersionSetter(uint16_t version) : version_(version) {}
 
-  virtual bool FilterHandshake(uint16_t version, uint8_t handshake_type,
-                               const DataBuffer& input, DataBuffer* output) {
-    if (handshake_type == kTlsHandshakeClientHello) {
+  virtual PacketFilter::Action FilterHandshake(
+      const HandshakeHeader& header,
+      const DataBuffer& input, DataBuffer* output) {
+    if (header.handshake_type() == kTlsHandshakeClientHello) {
       *output = input;
       output->Write(0, version_, 2);
-      return true;
+      return CHANGE;
     }
-    return false;
+    return KEEP;
   }
 
  private:
   uint16_t version_;
 };
 
 class TlsServerKeyExchangeEcdhe {
  public:
@@ -118,16 +121,17 @@ TEST_P(TlsConnectGeneric, ConnectEcdsa) 
   ResetEcdsa();
   Connect();
   CheckKeys(ssl_kea_ecdh, ssl_auth_ecdsa);
 }
 
 TEST_P(TlsConnectGeneric, ConnectFalseStart) {
   client_->EnableFalseStart();
   Connect();
+  SendReceive();
 }
 
 TEST_P(TlsConnectGeneric, ConnectResumed) {
   ConfigureSessionCache(RESUME_SESSIONID, RESUME_SESSIONID);
   Connect();
 
   ResetRsa();
   ExpectResumption(RESUME_SESSIONID);
@@ -805,16 +809,125 @@ TEST_F(TlsConnectTest, TestFallbackFromT
                            SSL_LIBRARY_VERSION_TLS_1_2);
   server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1,
                            SSL_LIBRARY_VERSION_TLS_1_3);
   ConnectExpectFail();
   ASSERT_EQ(SSL_ERROR_RX_MALFORMED_SERVER_HELLO, client_->error_code());
 }
 #endif
 
+class BeforeFinished : public TlsRecordFilter {
+ private:
+  enum HandshakeState {
+    BEFORE_CCS,
+    AFTER_CCS,
+    DONE
+  };
+  typedef std::function<void(void)> VoidFunction;
+
+ public:
+  BeforeFinished(TlsAgent* client, TlsAgent* server,
+                 VoidFunction before_ccs, VoidFunction before_finished)
+      : client_(client),
+        server_(server),
+        before_ccs_(before_ccs),
+        before_finished_(before_finished),
+        state_(BEFORE_CCS) {}
+
+ protected:
+  virtual PacketFilter::Action FilterRecord(
+      const RecordHeader& header, const DataBuffer& body, DataBuffer* out) {
+    switch (state_) {
+      case BEFORE_CCS:
+        // Awaken when we see the CCS.
+        if (header.content_type() == kTlsChangeCipherSpecType) {
+          before_ccs_();
+
+          // Write the CCS out as a separate write, so that we can make
+          // progress. Ordinarily, libssl sends the CCS and Finished together,
+          // but that means that they both get processed together.
+          DataBuffer ccs;
+          header.Write(&ccs, 0, body);
+          server_->SendDirect(ccs);
+          ForceRead();
+          state_ = AFTER_CCS;
+          // Request that the original record be dropped by the filter.
+          return DROP;
+        }
+        break;
+
+      case AFTER_CCS:
+        EXPECT_EQ(kTlsHandshakeType, header.content_type());
+        // This could check that data contains a Finished message, but it's
+        // encrypted, so that's too much extra work.
+
+        before_finished_();
+        state_ = DONE;
+        break;
+
+      case DONE:
+        break;
+    }
+    return KEEP;
+  }
+
+ private:
+  void ForceRead() {
+    // Read from the socket to get libssl to process the handshake messages that
+    // were sent from the server up until now.
+    uint8_t block[10];
+    int32_t rv = PR_Read(client_->ssl_fd(), block, sizeof(block));
+    // Expect a blocking error here, since the handshake shouldn't have completed.
+    EXPECT_GT(0, rv);
+    EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PR_GetError());
+  }
+
+  TlsAgent* client_;
+  TlsAgent* server_;
+  VoidFunction before_ccs_;
+  VoidFunction before_finished_;
+  HandshakeState state_;
+};
+
+// TODO Pre13
+TEST_P(TlsConnectGeneric, ClientWriteBetweenCCSAndFinishedWithFalseStart) {
+  client_->EnableFalseStart();
+  server_->SetPacketFilter(new BeforeFinished(client_, server_, [this]() {
+        EXPECT_TRUE(client_->can_falsestart_hook_called());
+      }, [this]() {
+        // Write something, which used to fail: bug 1235366.
+        client_->SendData(10);
+      }));
+
+  Connect();
+  server_->SendData(10);
+  Receive(10);
+}
+
+TEST_P(TlsConnectGeneric, AuthCompleteBeforeFinishedWithFalseStart) {
+  client_->EnableFalseStart();
+  client_->SetAuthCertificateCallback(
+      [](TlsAgent&, PRBool, PRBool) -> SECStatus {
+        return SECWouldBlock;
+      });
+  server_->SetPacketFilter(new BeforeFinished(client_, server_, []() {
+        // Do nothing before CCS
+      }, [this]() {
+        EXPECT_FALSE(client_->can_falsestart_hook_called());
+        // AuthComplete before Finished still enables false start.
+        EXPECT_EQ(SECSuccess, SSL_AuthCertificateComplete(client_->ssl_fd(), 0));
+        EXPECT_TRUE(client_->can_falsestart_hook_called());
+        client_->SendData(10);
+      }));
+
+  Connect();
+  server_->SendData(10);
+  Receive(10);
+}
+
 INSTANTIATE_TEST_CASE_P(VariantsStream10, TlsConnectGeneric,
                         ::testing::Combine(
                           TlsConnectTestBase::kTlsModesStream,
                           TlsConnectTestBase::kTlsV10));
 INSTANTIATE_TEST_CASE_P(VariantsAll, TlsConnectGeneric,
                         ::testing::Combine(
                           TlsConnectTestBase::kTlsModesAll,
                           TlsConnectTestBase::kTlsV11V12));
--- a/external_tests/ssl_gtest/ssl_skip_unittest.cc
+++ b/external_tests/ssl_gtest/ssl_skip_unittest.cc
@@ -24,62 +24,57 @@ class TlsHandshakeSkipFilter : public Tl
   // A TLS record filter that skips handshake messages of the identified type.
   TlsHandshakeSkipFilter(uint8_t handshake_type)
       : handshake_type_(handshake_type),
         skipped_(false) {}
 
  protected:
   // Takes a record; if it is a handshake record, it removes the first handshake
   // message that is of handshake_type_ type.
-  virtual bool FilterRecord(uint8_t content_type, uint16_t version,
-                            const DataBuffer& input, DataBuffer* output) {
-    if (content_type != kTlsHandshakeType) {
-      return false;
+  virtual PacketFilter::Action FilterRecord(
+      const RecordHeader& record_header,
+      const DataBuffer& input, DataBuffer* output) {
+
+    if (record_header.content_type() != kTlsHandshakeType) {
+      return KEEP;
     }
 
     size_t output_offset = 0U;
     output->Allocate(input.len());
 
     TlsParser parser(input);
     while (parser.remaining()) {
       size_t start = parser.consumed();
-      uint8_t handshake_type;
-      if (!parser.Read(&handshake_type)) {
-        return false;
-      }
-      uint32_t length;
-      if (!TlsHandshakeFilter::ReadLength(&parser, version, &length)) {
-        return false;
+      TlsHandshakeFilter::HandshakeHeader header;
+      DataBuffer ignored;
+      if (!header.Parse(&parser, record_header, &ignored)) {
+        return KEEP;
       }
 
-      if (!parser.Skip(length)) {
-        return false;
-      }
-
-      if (skipped_ || handshake_type != handshake_type_) {
+      if (skipped_ || header.handshake_type() != handshake_type_) {
         size_t entire_length = parser.consumed() - start;
         output->Write(output_offset, input.data() + start,
                       entire_length);
         // DTLS sequence numbers need to be rewritten
-        if (skipped_ && IsDtls(version)) {
+        if (skipped_ && header.is_dtls()) {
           output->data()[start + 5] -= 1;
         }
         output_offset += entire_length;
       } else {
         std::cerr << "Dropping handshake: "
                   << static_cast<unsigned>(handshake_type_) << std::endl;
         // We only need to report that the output contains changed data if we
         // drop a handshake message.  But once we've skipped one message, we
         // have to modify all subsequent handshake messages so that they include
         // the correct DTLS sequence numbers.
         skipped_ = true;
       }
     }
     output->Truncate(output_offset);
-    return skipped_;
+    return skipped_ ? CHANGE : KEEP;
   }
 
  private:
   // The type of handshake message to drop.
   uint8_t handshake_type_;
   // Whether this filter has ever skipped a handshake message.  Track this so
   // that sequence numbers on DTLS handshake messages can be rewritten in
   // subsequent calls.
--- a/external_tests/ssl_gtest/test_io.cc
+++ b/external_tests/ssl_gtest/test_io.cc
@@ -353,21 +353,32 @@ int32_t DummyPrSocket::Write(const void 
   if (!peer_) {
     PR_SetError(PR_IO_ERROR, 0);
     return -1;
   }
 
   DataBuffer packet(static_cast<const uint8_t*>(buf),
                     static_cast<size_t>(length));
   DataBuffer filtered;
-  if (filter_ && filter_->Filter(packet, &filtered)) {
-    LOG("Filtered packet: " << filtered);
-    peer_->PacketReceived(filtered);
-  } else {
-    peer_->PacketReceived(packet);
+  PacketFilter::Action action = PacketFilter::KEEP;
+  if (filter_) {
+    action = filter_->Filter(packet, &filtered);
+  }
+  switch (action) {
+    case PacketFilter::CHANGE:
+      LOG("Original packet: " << packet);
+      LOG("Filtered packet: " << filtered);
+      peer_->PacketReceived(filtered);
+      break;
+    case PacketFilter::DROP:
+      LOG("Droppped packet: " << packet);
+      break;
+    case PacketFilter::KEEP:
+      peer_->PacketReceived(packet);
+      break;
   }
   // libssl can't handle it if this reports something other than the length
   // of what was passed in (or less, but we're not doing partial writes).
   return static_cast<int32_t>(packet.len());
 }
 
 Poller *Poller::instance;
 
--- a/external_tests/ssl_gtest/test_io.h
+++ b/external_tests/ssl_gtest/test_io.h
@@ -20,40 +20,48 @@ namespace nss_test {
 
 class DataBuffer;
 class Packet;
 class DummyPrSocket;  // Fwd decl.
 
 // Allow us to inspect a packet before it is written.
 class PacketFilter {
  public:
+  enum Action {
+    KEEP,   // keep the original packet unmodified
+    CHANGE, // change the packet to a different value
+    DROP    // drop the packet
+  };
+
   virtual ~PacketFilter() {}
 
   // The packet filter takes input and has the option of mutating it.
   //
   // A filter that modifies the data places the modified data in *output and
-  // returns true.  A filter that does not modify data returns false, in which
-  // case the value in *output is ignored.
-  virtual bool Filter(const DataBuffer& input, DataBuffer* output) = 0;
+  // returns CHANGE.  A filter that does not modify data returns LEAVE, in which
+  // case the value in *output is ignored.  A Filter can return DROP, in which
+  // case the packet is dropped (and *output is ignored).
+  virtual Action Filter(const DataBuffer& input, DataBuffer* output) = 0;
 };
 
 enum Mode { STREAM, DGRAM };
 
 inline std::ostream& operator<<(std::ostream& os, Mode m) {
   return os << ((m == STREAM) ? "TLS" : "DTLS");
 }
 
 class DummyPrSocket {
  public:
   ~DummyPrSocket();
 
   static PRFileDesc* CreateFD(const std::string& name,
                               Mode mode);  // Returns an FD.
   static DummyPrSocket* GetAdapter(PRFileDesc* fd);
 
+  DummyPrSocket* peer() const { return peer_; }
   void SetPeer(DummyPrSocket* peer) { peer_ = peer; }
   void SetPacketFilter(PacketFilter* filter) { filter_ = filter; }
   // Drops peer, packet filter and any outstanding packets.
   void Reset();
 
   void PacketReceived(const DataBuffer& data);
   int32_t Read(void* data, int32_t len);
   int32_t Recv(void* buf, int32_t buflen);
--- a/external_tests/ssl_gtest/tls_agent.cc
+++ b/external_tests/ssl_gtest/tls_agent.cc
@@ -6,16 +6,17 @@
 
 #include "tls_agent.h"
 
 #include "pk11func.h"
 #include "ssl.h"
 #include "sslerr.h"
 #include "sslproto.h"
 #include "keyhi.h"
+#include "databuffer.h"
 
 #define GTEST_HAS_RTTI 0
 #include "gtest/gtest.h"
 
 namespace nss_test {
 
 
 const char* TlsAgent::states[] = {"INIT", "CONNECTING", "CONNECTED", "ERROR"};
@@ -496,19 +497,21 @@ void TlsAgent::Handshake() {
                                &TlsAgent::ReadableCallback);
       return;
       break;
 
       // TODO(ekr@rtfm.com): needs special case for DTLS
     case SSL_ERROR_RX_MALFORMED_HANDSHAKE:
     default:
       if (IS_SSL_ERROR(err)) {
-        LOG("Handshake failed with SSL error " << err - SSL_ERROR_BASE);
+        LOG("Handshake failed with SSL error " << (err - SSL_ERROR_BASE)
+            << ": " << PORT_ErrorToString(err));
       } else {
-        LOG("Handshake failed with error " << err);
+        LOG("Handshake failed with error " << err
+            << ": " << PORT_ErrorToString(err));
       }
       error_code_ = err;
       SetState(STATE_ERROR);
       return;
   }
 }
 
 void TlsAgent::PrepareForRenegotiate() {
@@ -519,16 +522,21 @@ void TlsAgent::PrepareForRenegotiate() {
 
 void TlsAgent::StartRenegotiate() {
   PrepareForRenegotiate();
 
   SECStatus rv = SSL_ReHandshake(ssl_fd_, PR_TRUE);
   EXPECT_EQ(SECSuccess, rv);
 }
 
+void TlsAgent::SendDirect(const DataBuffer& buf) {
+  LOG("Send Direct " << buf);
+  adapter_->peer()->PacketReceived(buf);
+}
+
 void TlsAgent::SendData(size_t bytes, size_t blocksize) {
   uint8_t block[4096];
 
   ASSERT_LT(blocksize, sizeof(block));
 
   while(bytes) {
     size_t tosend = std::min(blocksize, bytes);
 
@@ -543,37 +551,38 @@ void TlsAgent::SendData(size_t bytes, si
 
     bytes -= tosend;
   }
 }
 
 void TlsAgent::ReadBytes() {
   uint8_t block[1024];
 
-  LOG("Reading application data from socket");
-
   int32_t rv = PR_Read(ssl_fd_, block, sizeof(block));
+  LOG("ReadBytes " << rv);
 
-  int32_t err = PR_GetError();
-  if (err != PR_WOULD_BLOCK_ERROR) {
-    if (expected_read_error_) {
+  if (rv >= 0) {
+    size_t count = static_cast<size_t>(rv);
+    for (size_t i = 0; i < count; ++i) {
+      ASSERT_EQ(recv_ctr_ & 0xff, block[i]);
+      recv_ctr_++;
+    }
+  } else {
+    int32_t err = PR_GetError();
+    LOG("Read error " << err << ": " << PORT_ErrorToString(err));
+    if (err != PR_WOULD_BLOCK_ERROR && expected_read_error_) {
       error_code_ = err;
-    } else {
-      ASSERT_LE(0, rv);
-      size_t count = static_cast<size_t>(rv);
-      LOG("Read " << count << " bytes");
-      for (size_t i = 0; i < count; ++i) {
-        ASSERT_EQ(recv_ctr_ & 0xff, block[i]);
-        recv_ctr_++;
-      }
     }
   }
 
-  Poller::Instance()->Wait(READABLE_EVENT, adapter_, this,
-                           &TlsAgent::ReadableCallback);
+  // If closed, then don't bother waiting around.
+  if (rv) {
+    Poller::Instance()->Wait(READABLE_EVENT, adapter_, this,
+                             &TlsAgent::ReadableCallback);
+  }
 }
 
 void TlsAgent::ResetSentBytes() {
   send_ctr_ = 0;
 }
 
 void TlsAgent::ConfigureSessionCache(SessionResumptionMode mode) {
   EXPECT_TRUE(EnsureTlsSetup());
--- a/external_tests/ssl_gtest/tls_agent.h
+++ b/external_tests/ssl_gtest/tls_agent.h
@@ -27,17 +27,17 @@ enum SessionResumptionMode {
   RESUME_SESSIONID = 1,
   RESUME_TICKET = 2,
   RESUME_BOTH = RESUME_SESSIONID | RESUME_TICKET
 };
 
 class TlsAgent;
 
 typedef
-  std::function<void(TlsAgent& agent, PRBool checksig, PRBool isServer)>
+  std::function<SECStatus(TlsAgent& agent, PRBool checksig, PRBool isServer)>
   AuthCertificateCallbackFunction;
 
 typedef
   std::function<void(TlsAgent& agent)>
   HandshakeCallbackFunction;
 
 class TlsAgent : public PollTarget {
  public:
@@ -95,17 +95,20 @@ class TlsAgent : public PollTarget {
   void SetSignatureAlgorithms(const SSLSignatureAndHashAlg* algorithms,
                               size_t count);
   void EnableAlpn(const uint8_t* val, size_t len);
   void CheckAlpn(SSLNextProtoState expected_state,
                  const std::string& expected) const;
   void EnableSrtp();
   void CheckSrtp() const;
   void CheckErrorCode(int32_t expected) const;
+  // Send data on the socket, encrypting it.
   void SendData(size_t bytes, size_t blocksize = 1024);
+  // Send data directly to the underlying socket, skipping the TLS layer.
+  void SendDirect(const DataBuffer& buf);
   void ReadBytes();
   void ResetSentBytes(); // Hack to test drops.
   void EnableExtendedMasterSecret();
   void CheckExtendedMasterSecret(bool expected);
   void DisableRollbackDetection();
   void EnableCompression();
   void SetDowngradeCheckVersion(uint16_t version);
 
@@ -149,16 +152,18 @@ class TlsAgent : public PollTarget {
   std::vector<uint8_t> session_id() const {
     return std::vector<uint8_t>(info_.sessionID,
                                 info_.sessionID + info_.sessionIDLength);
   }
 
   size_t received_bytes() const { return recv_ctr_; }
   int32_t error_code() const { return error_code_; }
 
+  bool can_falsestart_hook_called() const { return can_falsestart_hook_called_; }
+
   void SetHandshakeCallback(HandshakeCallbackFunction handshake_callback) {
     handshake_callback_ = handshake_callback;
   }
 
   void SetAuthCertificateCallback(
       AuthCertificateCallbackFunction auth_certificate_callback) {
     auth_certificate_callback_ = auth_certificate_callback;
   }
@@ -176,17 +181,17 @@ class TlsAgent : public PollTarget {
 
   // Dummy auth certificate hook.
   static SECStatus AuthCertificateHook(void* arg, PRFileDesc* fd,
                                        PRBool checksig, PRBool isServer) {
     TlsAgent* agent = reinterpret_cast<TlsAgent*>(arg);
     agent->CheckPreliminaryInfo();
     agent->auth_certificate_hook_called_ = true;
     if (agent->auth_certificate_callback_) {
-      agent->auth_certificate_callback_(*agent, checksig, isServer);
+      return agent->auth_certificate_callback_(*agent, checksig, isServer);
     }
     return SECSuccess;
   }
 
   // Client auth certificate hook.
   static SECStatus ClientAuthenticated(void* arg, PRFileDesc* fd,
                                        PRBool checksig, PRBool isServer) {
     TlsAgent* agent = reinterpret_cast<TlsAgent*>(arg);
@@ -232,16 +237,17 @@ class TlsAgent : public PollTarget {
     return SSL_SNI_CURRENT_CONFIG_IS_USED;
   }
 
   static SECStatus CanFalseStartCallback(PRFileDesc *fd, void *arg,
                                          PRBool *canFalseStart) {
     TlsAgent* agent = reinterpret_cast<TlsAgent*>(arg);
     agent->CheckPreliminaryInfo();
     EXPECT_TRUE(agent->falsestart_enabled_);
+    EXPECT_FALSE(agent->can_falsestart_hook_called_);
     agent->can_falsestart_hook_called_ = true;
     *canFalseStart = true;
     return SECSuccess;
   }
 
   static void HandshakeCallback(PRFileDesc *fd, void *arg) {
     TlsAgent* agent = reinterpret_cast<TlsAgent*>(arg);
     agent->CheckPreliminaryInfo();
--- a/external_tests/ssl_gtest/tls_connect.cc
+++ b/external_tests/ssl_gtest/tls_connect.cc
@@ -276,22 +276,27 @@ void TlsConnectTestBase::EnableSrtp() {
 void TlsConnectTestBase::CheckSrtp() const {
   client_->CheckSrtp();
   server_->CheckSrtp();
 }
 
 void TlsConnectTestBase::SendReceive() {
   client_->SendData(50);
   server_->SendData(50);
-  WAIT_(client_->received_bytes() == 50U &&
-        server_->received_bytes() == 50U, 2000);
-  ASSERT_EQ(50U, client_->received_bytes());
-  ASSERT_EQ(50U, server_->received_bytes());
+  Receive(50);
 }
 
+void TlsConnectTestBase::Receive(size_t amount) {
+  WAIT_(client_->received_bytes() == amount &&
+        server_->received_bytes() == amount, 2000);
+  ASSERT_EQ(amount, client_->received_bytes());
+  ASSERT_EQ(amount, server_->received_bytes());
+}
+
+
 void TlsConnectTestBase::ExpectExtendedMasterSecret(bool expected) {
   expect_extended_master_secret_ = expected;
 }
 
 void TlsConnectTestBase::CheckExtendedMasterSecret() {
   client_->CheckExtendedMasterSecret(expect_extended_master_secret_);
   server_->CheckExtendedMasterSecret(expect_extended_master_secret_);
 }
--- a/external_tests/ssl_gtest/tls_connect.h
+++ b/external_tests/ssl_gtest/tls_connect.h
@@ -67,16 +67,17 @@ class TlsConnectTestBase : public ::test
   void DisableEcdheCiphers();
   void EnableExtendedMasterSecret();
   void ConfigureSessionCache(SessionResumptionMode client,
                              SessionResumptionMode server);
   void EnableAlpn();
   void EnableSrtp();
   void CheckSrtp() const;
   void SendReceive();
+  void Receive(size_t amount);
   void ExpectExtendedMasterSecret(bool expected);
 
  protected:
   Mode mode_;
   TlsAgent* client_;
   TlsAgent* server_;
   uint16_t version_;
   SessionResumptionMode expected_resumption_mode_;
--- a/external_tests/ssl_gtest/tls_filter.cc
+++ b/external_tests/ssl_gtest/tls_filter.cc
@@ -2,243 +2,273 @@
 /* vim: set ts=2 et sw=2 tw=80: */
 /* This Source Code Form is subject to the terms of the Mozilla Public
  * License, v. 2.0. If a copy of the MPL was not distributed with this file,
  * You can obtain one at http://mozilla.org/MPL/2.0/. */
 
 #include "tls_filter.h"
 
 #include <iostream>
+#include "gtest_utils.h"
 
 namespace nss_test {
 
-bool TlsRecordFilter::Filter(const DataBuffer& input, DataBuffer* output) {
+PacketFilter::Action TlsRecordFilter::Filter(const DataBuffer& input, DataBuffer* output) {
   bool changed = false;
-  size_t output_offset = 0U;
+  size_t offset = 0U;
   output->Allocate(input.len());
 
   TlsParser parser(input);
   while (parser.remaining()) {
-    size_t start = parser.consumed();
-    uint8_t content_type;
-    if (!parser.Read(&content_type)) {
-      return false;
-    }
-    uint32_t version;
-    if (!parser.Read(&version, 2)) {
-      return false;
+    RecordHeader header;
+    DataBuffer record;
+    if (!header.Parse(&parser, &record)) {
+      return KEEP;
     }
 
-    if (IsDtls(version)) {
-      if (!parser.Skip(8)) {
-        return false;
-      }
-    }
-    size_t header_len = parser.consumed() - start;
-    output->Write(output_offset, input.data() + start, header_len);
-
-    DataBuffer record;
-    if (!parser.ReadVariable(&record, 2)) {
-      return false;
+    DataBuffer filtered;
+    PacketFilter::Action action = FilterRecord(header, record, &filtered);
+    if (action == DROP) {
+      changed = true;
+      std::cerr << "record drop: " << record << std::endl;
+      continue; // don't copy this one
     }
 
-    // Move the offset in the output forward.  ApplyFilter() returns the index
-    // of the end of the record it wrote to the output, so we need to skip
-    // over the content type and version for the value passed to it.
-    output_offset = ApplyFilter(content_type, version, record, output,
-                                output_offset + header_len,
-                                &changed);
+    const DataBuffer* source = &record;
+    if (action == CHANGE) {
+      EXPECT_GT(0x10000, filtered.len());
+      changed = true;
+      std::cerr << "record old: " << record << std::endl;
+      std::cerr << "record new: " << filtered << std::endl;
+      source = &filtered;
+    }
+
+    offset = header.Write(output, offset, *source);
   }
-  output->Truncate(output_offset);
+  output->Truncate(offset);
 
   // Record how many packets we actually touched.
   if (changed) {
     ++count_;
+    return (offset == 0) ? DROP : CHANGE;
   }
 
-  return changed;
+  return KEEP;
 }
 
-size_t TlsRecordFilter::ApplyFilter(uint8_t content_type, uint16_t version,
-                                    const DataBuffer& record,
-                                    DataBuffer* output,
-                                    size_t offset, bool* changed) {
-  const DataBuffer* source = &record;
-  DataBuffer filtered;
-  if (FilterRecord(content_type, version, record, &filtered) &&
-      filtered.len() < 0x10000) {
-    *changed = true;
-    std::cerr << "record old: " << record << std::endl;
-    std::cerr << "record new: " << filtered << std::endl;
-    source = &filtered;
-  }
-
-  output->Write(offset, source->len(), 2);
-  output->Write(offset + 2, *source);
-  return offset + 2 + source->len();
-}
-
-bool TlsHandshakeFilter::FilterRecord(uint8_t content_type, uint16_t version,
-                                      const DataBuffer& input,
-                                      DataBuffer* output) {
-  // Check that the first byte is as requested.
-  if (content_type != kTlsHandshakeType) {
+bool TlsRecordFilter::RecordHeader::Parse(TlsParser* parser, DataBuffer* body) {
+  if (!parser->Read(&content_type_)) {
     return false;
   }
 
+  uint32_t version;
+  if (!parser->Read(&version, 2)) {
+    return false;
+  }
+  version_ = version;
+
+  sequence_number_ = 0;
+  if (IsDtls(version)) {
+    uint32_t tmp;
+    if (!parser->Read(&tmp, 4)) {
+      return false;
+    }
+    sequence_number_ = static_cast<uint64_t>(tmp) << 32;
+    if (!parser->Read(&tmp, 4)) {
+      return false;
+    }
+    sequence_number_ |= static_cast<uint64_t>(tmp);
+  }
+  return parser->ReadVariable(body, 2);
+}
+
+size_t TlsRecordFilter::RecordHeader::Write(
+    DataBuffer* buffer, size_t offset, const DataBuffer& body) const {
+  offset = buffer->Write(offset, content_type_, 1);
+  offset = buffer->Write(offset, version_, 2);
+  if (is_dtls()) {
+    // write epoch (2 octet), and seqnum (6 octet)
+    offset = buffer->Write(offset, sequence_number_ >> 32, 4);
+    offset = buffer->Write(offset, sequence_number_ & 0xffffffff, 4);
+  }
+  offset = buffer->Write(offset, body.len(), 2);
+  offset = buffer->Write(offset, body);
+  return offset;
+}
+
+PacketFilter::Action TlsHandshakeFilter::FilterRecord(
+    const RecordHeader& record_header, const DataBuffer& input,
+    DataBuffer* output) {
+  // Check that the first byte is as requested.
+  if (record_header.content_type() != kTlsHandshakeType) {
+    return KEEP;
+  }
+
   bool changed = false;
-  size_t output_offset = 0U;
+  size_t offset = 0U;
   output->Allocate(input.len()); // Preallocate a little.
 
   TlsParser parser(input);
   while (parser.remaining()) {
-    size_t start = parser.consumed();
-    uint8_t handshake_type;
-    if (!parser.Read(&handshake_type)) {
-      return false; // malformed
+    HandshakeHeader header;
+    DataBuffer handshake;
+    if (!header.Parse(&parser, record_header, &handshake)) {
+      return KEEP;
     }
-    uint32_t length;
-    if (!ReadLength(&parser, version, &length)) {
-      return false;
+
+    DataBuffer filtered;
+    PacketFilter::Action action = FilterHandshake(header, handshake, &filtered);
+    if (action == DROP) {
+      changed = true;
+      std::cerr << "handshake drop: " << handshake << std::endl;
+      continue;
     }
 
-    size_t header_len = parser.consumed() - start;
-    output->Write(output_offset, input.data() + start, header_len);
-
-    DataBuffer handshake;
-    if (!parser.Read(&handshake, length)) {
-      return false;
+    const DataBuffer* source = &handshake;
+    if (action == CHANGE) {
+      EXPECT_GT(0x1000000, filtered.len());
+      changed = true;
+      std::cerr << "handshake old: " << handshake << std::endl;
+      std::cerr << "handshake new: " << filtered << std::endl;
+      source = &filtered;
     }
 
-    // Move the offset in the output forward.  ApplyFilter() returns the index
-    // of the end of the message it wrote to the output, so we need to identify
-    // offsets from the start of the message for length and the handshake
-    // message.
-    output_offset = ApplyFilter(version, handshake_type, handshake,
-                                output, output_offset + 1,
-                                output_offset + header_len,
-                                &changed);
+    offset = header.Write(output, offset, *source);
   }
-  output->Truncate(output_offset);
-  return changed;
+  output->Truncate(offset);
+  return changed ? (offset ? CHANGE : DROP) : KEEP;
 }
 
-bool TlsHandshakeFilter::ReadLength(TlsParser* parser, uint16_t version, uint32_t *length) {
+bool TlsHandshakeFilter::HandshakeHeader::ReadLength(TlsParser* parser,
+                                                     const RecordHeader& header,
+                                                     uint32_t *length) {
   if (!parser->Read(length, 3)) {
     return false; // malformed
   }
 
-  if (!IsDtls(version)) {
+  if (!header.is_dtls()) {
     return true; // nothing left to do
   }
 
   // Read and check DTLS parameters
-  if (!parser->Skip(2)) { // sequence number
+  uint32_t message_seq_tmp;
+  if (!parser->Read(&message_seq_tmp, 2)) { // sequence number
     return false;
   }
+  message_seq_ = message_seq_tmp;
 
   uint32_t fragment_offset;
   if (!parser->Read(&fragment_offset, 3)) {
     return false;
   }
 
   uint32_t fragment_length;
   if (!parser->Read(&fragment_length, 3)) {
     return false;
   }
 
   // All current tests where we are using this code don't fragment.
   return (fragment_offset == 0 && fragment_length == *length);
 }
 
-size_t TlsHandshakeFilter::ApplyFilter(
-    uint16_t version, uint8_t handshake_type, const DataBuffer& handshake,
-    DataBuffer* output, size_t length_offset, size_t value_offset,
-    bool* changed) {
-  const DataBuffer* source = &handshake;
-  DataBuffer filtered;
-  if (FilterHandshake(version, handshake_type, handshake, &filtered) &&
-      filtered.len() < 0x1000000) {
-    *changed = true;
-    std::cerr << "handshake old: " << handshake << std::endl;
-    std::cerr << "handshake new: " << filtered << std::endl;
-    source = &filtered;
+bool TlsHandshakeFilter::HandshakeHeader::Parse(
+    TlsParser* parser, const RecordHeader& record_header,
+    DataBuffer* body) {
+
+  version_ = record_header.version();
+  if (!parser->Read(&handshake_type_)) {
+    return false; // malformed
+  }
+  uint32_t length;
+  if (!ReadLength(parser, record_header, &length)) {
+    return false;
   }
 
-  // Back up and overwrite the (two) length field(s): the handshake message
-  // length and the DTLS fragment length.
-  output->Write(length_offset, source->len(), 3);
-  if (IsDtls(version)) {
-    output->Write(length_offset + 8, source->len(), 3);
-  }
-  output->Write(value_offset, *source);
-  return value_offset + source->len();
+  return parser->Read(body, length);
 }
 
-bool TlsInspectorRecordHandshakeMessage::FilterHandshake(
-    uint16_t version, uint8_t handshake_type,
+size_t TlsHandshakeFilter::HandshakeHeader::Write(
+    DataBuffer* buffer, size_t offset, const DataBuffer& body) const {
+    offset = buffer->Write(offset, handshake_type(), 1);
+    offset = buffer->Write(offset, body.len(), 3);
+    if (is_dtls()) {
+      offset = buffer->Write(offset, message_seq_, 2);
+      offset = buffer->Write(offset, 0U, 3); // fragment_offset
+      offset = buffer->Write(offset, body.len(), 3);
+    }
+    offset = buffer->Write(offset, body);
+    return offset;
+}
+
+PacketFilter::Action TlsInspectorRecordHandshakeMessage::FilterHandshake(
+    const HandshakeHeader& header,
     const DataBuffer& input, DataBuffer* output) {
   // Only do this once.
   if (buffer_.len()) {
-    return false;
+    return KEEP;
   }
 
-  if (handshake_type == handshake_type_) {
+  if (header.handshake_type() == handshake_type_) {
     buffer_ = input;
   }
-  return false;
+  return KEEP;
 }
 
 
-bool TlsInspectorReplaceHandshakeMessage::FilterHandshake(
-    uint16_t version, uint8_t handshake_type,
+PacketFilter::Action TlsInspectorReplaceHandshakeMessage::FilterHandshake(
+    const HandshakeHeader& header,
     const DataBuffer& input, DataBuffer* output) {
-  if (handshake_type == handshake_type_) {
+  if (header.handshake_type() == handshake_type_) {
     *output = buffer_;
-    return true;
+    return CHANGE;
   }
 
-  return false;
+  return KEEP;
 }
 
-bool TlsAlertRecorder::FilterRecord(uint8_t content_type, uint16_t version,
-                                    const DataBuffer& input, DataBuffer* output) {
+PacketFilter::Action TlsAlertRecorder::FilterRecord(
+    const RecordHeader& header, const DataBuffer& input, DataBuffer* output) {
   if (level_ == kTlsAlertFatal) { // already fatal
-    return false;
+    return KEEP;
   }
-  if (content_type != kTlsAlertType) {
-    return false;
+  if (header.content_type() != kTlsAlertType) {
+    return KEEP;
   }
 
   std::cerr << "Alert: " << input << std::endl;
 
   TlsParser parser(input);
   uint8_t lvl;
   if (!parser.Read(&lvl)) {
-    return false;
+    return KEEP;
   }
   if (lvl == kTlsAlertWarning) { // not strong enough
-    return false;
+    return KEEP;
   }
   level_ = lvl;
   (void)parser.Read(&description_);
-  return false;
+  return KEEP;
 }
 
 ChainedPacketFilter::~ChainedPacketFilter() {
   for (auto it = filters_.begin(); it != filters_.end(); ++it) {
     delete *it;
   }
 }
 
-bool ChainedPacketFilter::Filter(const DataBuffer& input, DataBuffer* output) {
+PacketFilter::Action ChainedPacketFilter::Filter(const DataBuffer& input,
+                                                 DataBuffer* output) {
   DataBuffer in(input);
   bool changed = false;
   for (auto it = filters_.begin(); it != filters_.end(); ++it) {
-    if ((*it)->Filter(in, output)) {
+    PacketFilter::Action action = (*it)->Filter(in, output);
+    if (action == DROP) {
+      return DROP;
+    }
+    if (action == CHANGE) {
       in = *output;
       changed = true;
     }
   }
-  return changed;
+  return changed ? CHANGE : KEEP;
 }
 
 }  // namespace nss_test
--- a/external_tests/ssl_gtest/tls_filter.h
+++ b/external_tests/ssl_gtest/tls_filter.h
@@ -15,94 +15,148 @@
 
 namespace nss_test {
 
 // Abstract filter that operates on entire (D)TLS records.
 class TlsRecordFilter : public PacketFilter {
  public:
   TlsRecordFilter() : count_(0) {}
 
-  virtual bool Filter(const DataBuffer& input, DataBuffer* output);
+  virtual PacketFilter::Action Filter(const DataBuffer& input,
+                                      DataBuffer* output);
 
   // Report how many packets were altered by the filter.
   size_t filtered_packets() const { return count_; }
 
+  class Versioned {
+   public:
+    Versioned() : version_(0) {}
+    bool is_dtls() const { return IsDtls(version_); }
+    uint16_t version() const { return version_; }
+
+   protected:
+    uint16_t version_;
+  };
+
+  class RecordHeader : public Versioned {
+   public:
+    RecordHeader()
+        : Versioned(), content_type_(0), sequence_number_(0) {}
+
+    uint8_t content_type() const { return content_type_; }
+    uint64_t sequence_number() const { return sequence_number_; }
+    size_t header_length() const { return is_dtls() ? 11 : 3; }
+
+    // Parse the header; return true if successful; body in an outparam if OK.
+    bool Parse(TlsParser* parser, DataBuffer* body);
+    // Write the header and body to a buffer at the given offset.
+    // Return the offset of the end of the write.
+    size_t Write(DataBuffer* buffer, size_t offset, const DataBuffer& body) const;
+
+   private:
+    uint8_t content_type_;
+    uint64_t sequence_number_;
+  };
+
  protected:
-  virtual bool FilterRecord(uint8_t content_type, uint16_t version,
-                            const DataBuffer& data, DataBuffer* changed) = 0;
+  // The record filter receives the record contentType, version and DTLS
+  // sequence number (which is zero for TLS), plus the existing record payload.
+  // It returns an action (KEEP, CHANGE, DROP).  It writes to the `changed`
+  // outparam with the new record contents if it chooses to CHANGE the record.
+  virtual PacketFilter::Action FilterRecord(const RecordHeader& header,
+                                            const DataBuffer& data,
+                                            DataBuffer* changed) = 0;
+
  private:
-  size_t ApplyFilter(uint8_t content_type, uint16_t version,
-                     const DataBuffer& record, DataBuffer* output,
-                     size_t offset, bool* changed);
 
   size_t count_;
 };
 
 // Abstract filter that operates on handshake messages rather than records.
 // This assumes that the handshake messages are written in a block as entire
 // records and that they don't span records or anything crazy like that.
 class TlsHandshakeFilter : public TlsRecordFilter {
  public:
   TlsHandshakeFilter() {}
 
-  // Reads the length from the record header.
-  // This also reads the DTLS fragment information and checks it.
-  static bool ReadLength(TlsParser* parser, uint16_t version, uint32_t *length);
+  class HandshakeHeader : public Versioned {
+   public:
+    HandshakeHeader()
+        : Versioned(), handshake_type_(0), message_seq_(0) {}
+
+    uint8_t handshake_type() const { return handshake_type_; }
+    bool Parse(TlsParser* parser, const RecordHeader& record_header,
+               DataBuffer* body);
+    size_t Write(DataBuffer* buffer, size_t offset,
+                 const DataBuffer& body) const;
+
+   private:
+    // Reads the length from the record header.
+    // This also reads the DTLS fragment information and checks it.
+    bool ReadLength(TlsParser* parser, const RecordHeader& header,
+                    uint32_t *length);
+
+    uint8_t handshake_type_;
+    uint16_t message_seq_;
+    // fragment_offset is always zero in these tests.
+  };
 
  protected:
-  virtual bool FilterRecord(uint8_t content_type, uint16_t version,
-                            const DataBuffer& input, DataBuffer* output);
-  virtual bool FilterHandshake(uint16_t version, uint8_t handshake_type,
-                               const DataBuffer& input, DataBuffer* output) = 0;
+  virtual PacketFilter::Action FilterRecord(const RecordHeader& header,
+                                            const DataBuffer& input,
+                                            DataBuffer* output);
+  virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
+                                               const DataBuffer& input,
+                                               DataBuffer* output) = 0;
 
  private:
-  size_t ApplyFilter(uint16_t version, uint8_t handshake_type,
-                     const DataBuffer& record, DataBuffer* output,
-                     size_t length_offset, size_t value_offset, bool* changed);
 };
 
 // Make a copy of the first instance of a handshake message.
 class TlsInspectorRecordHandshakeMessage : public TlsHandshakeFilter {
  public:
   TlsInspectorRecordHandshakeMessage(uint8_t handshake_type)
       : handshake_type_(handshake_type), buffer_() {}
 
-  virtual bool FilterHandshake(uint16_t version, uint8_t handshake_type,
-                               const DataBuffer& input, DataBuffer* output);
+  virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
+                                               const DataBuffer& input,
+                                               DataBuffer* output);
 
   const DataBuffer& buffer() const { return buffer_; }
 
  private:
   uint8_t handshake_type_;
   DataBuffer buffer_;
 };
 
 // Replace all instances of a handshake message.
 class TlsInspectorReplaceHandshakeMessage : public TlsHandshakeFilter {
  public:
   TlsInspectorReplaceHandshakeMessage(uint8_t handshake_type,
                                       const DataBuffer& replacement)
       : handshake_type_(handshake_type), buffer_(replacement) {}
 
-  virtual bool FilterHandshake(uint16_t version, uint8_t handshake_type,
-                               const DataBuffer& input, DataBuffer* output);
+  virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
+                                               const DataBuffer& input,
+                                               DataBuffer* output);
 
  private:
   uint8_t handshake_type_;
   DataBuffer buffer_;
 };
 
 // Records an alert.  If an alert has already been recorded, it won't save the
 // new alert unless the old alert is a warning and the new one is fatal.
 class TlsAlertRecorder : public TlsRecordFilter {
  public:
   TlsAlertRecorder() : level_(255), description_(255) {}
 
-  virtual bool FilterRecord(uint8_t content_type, uint16_t version,
-                            const DataBuffer& input, DataBuffer* output);
+  virtual PacketFilter::Action FilterRecord(const RecordHeader& header,
+                                            const DataBuffer& input,
+                                            DataBuffer* output);
 
   uint8_t level() const { return level_; }
   uint8_t description() const { return description_; }
 
  private:
   uint8_t level_;
   uint8_t description_;
 };
@@ -110,17 +164,18 @@ class TlsAlertRecorder : public TlsRecor
 // Runs multiple packet filters in series.
 class ChainedPacketFilter : public PacketFilter {
  public:
   ChainedPacketFilter() {}
   ChainedPacketFilter(const std::vector<PacketFilter*> filters)
       : filters_(filters.begin(), filters.end()) {}
   virtual ~ChainedPacketFilter();
 
-  virtual bool Filter(const DataBuffer& input, DataBuffer* output);
+  virtual PacketFilter::Action Filter(const DataBuffer& input,
+                                      DataBuffer* output);
 
   // Takes ownership of the filter.
   void Add(PacketFilter* filter) {
     filters_.push_back(filter);
   }
 
  private:
   std::vector<PacketFilter*> filters_;