Bug 1396487 - Refactor 1/n-1 record splitting code, r=ekr NSS_TLS13_DRAFT19_BRANCH
authorMartin Thomson <martin.thomson@gmail.com>
Tue, 15 Aug 2017 22:37:23 +1000
branchNSS_TLS13_DRAFT19_BRANCH
changeset 13609 27bfdd0ee644c33d3445a844d17d9042e0210035
parent 13608 3efb83875558adc1674dfa2ddba0a47f85979ed5
child 13610 7039fffea93782512241df32268edf6aa20438ba
push id2390
push usermartin.thomson@gmail.com
push dateTue, 26 Sep 2017 06:20:53 +0000
reviewersekr
bugs1396487, 97269, 97268
Bug 1396487 - Refactor 1/n-1 record splitting code, r=ekr It turns out that something changed a while back and we started splitting far more than is needed. The original design split into 1/n-1/n/n/n, but now we split 1/n-1/1/n-1/1/n-1 for large writes. That's inefficient and the code is unnecessarily complex in order to support it. This splits just once for each write, but it splits 1/n/n/n/n/remainder, unlike the original design, which you can see here: https://src.chromium.org/viewvc/chrome/trunk/src/net/third_party/nss/ssl/ssl3con.c?r1=97269&r2=97268&pathrev=97269 Also, because ssl3_SendApplicationData is the only place that needs to care about this, and it was preventing tests from actually testing this, I moved the splitting there instead.
gtests/ssl_gtest/ssl_loopback_unittest.cc
lib/ssl/ssl3con.c
--- a/gtests/ssl_gtest/ssl_loopback_unittest.cc
+++ b/gtests/ssl_gtest/ssl_loopback_unittest.cc
@@ -332,17 +332,17 @@ TEST_P(TlsConnectGeneric, ConnectWithCom
   EXPECT_EQ(client_->version() < SSL_LIBRARY_VERSION_TLS_1_3 &&
                 variant_ != ssl_variant_datagram,
             client_->is_compressed());
   SendReceive();
 }
 
 TEST_P(TlsConnectDatagram, TestDtlsHolddownExpiry) {
   Connect();
-  std::cerr << "Expiring holddown timer\n";
+  std::cerr << "Expiring holddown timer" << std::endl;
   SSLInt_ForceRtTimerExpiry(client_->ssl_fd());
   SSLInt_ForceRtTimerExpiry(server_->ssl_fd());
   SendReceive();
   if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) {
     // One for send, one for receive.
     EXPECT_EQ(2, SSLInt_CountTls13CipherSpecs(client_->ssl_fd()));
   }
 }
@@ -350,17 +350,17 @@ TEST_P(TlsConnectDatagram, TestDtlsHoldd
 class TlsPreCCSHeaderInjector : public TlsRecordFilter {
  public:
   TlsPreCCSHeaderInjector() {}
   virtual PacketFilter::Action FilterRecord(
       const TlsRecordHeader& record_header, const DataBuffer& input,
       size_t* offset, DataBuffer* output) override {
     if (record_header.content_type() != kTlsChangeCipherSpecType) return KEEP;
 
-    std::cerr << "Injecting Finished header before CCS\n";
+    std::cerr << "Injecting Finished header before CCS" << std::endl;
     const uint8_t hhdr[] = {kTlsHandshakeFinished, 0x00, 0x00, 0x0c};
     DataBuffer hhdr_buf(hhdr, sizeof(hhdr));
     TlsRecordHeader nhdr(record_header.version(), kTlsHandshakeType, 0);
     *offset = nhdr.Write(output, *offset, hhdr_buf);
     *offset = record_header.Write(output, *offset, input);
     return CHANGE;
   }
 };
@@ -460,16 +460,37 @@ TEST_F(TlsConnectStreamTls13, BothAltHan
   Connect();
   ASSERT_EQ(kTlsAltHandshakeType, header_filter->header(0)->content_type());
   ASSERT_EQ(kTlsHandshakeType, header_filter->header(1)->content_type());
   uint32_t ver;
   ASSERT_TRUE(sh_filter->buffer().Read(0, 2, &ver));
   ASSERT_EQ((uint32_t)(0x7a00 | TLS_1_3_DRAFT_VERSION), ver);
 }
 
+static size_t ExpectedCbcLen(size_t in, size_t hmac = 20, size_t block = 16) {
+  // MAC-then-Encrypt expansion formula:
+  return ((in + hmac + (block - 1)) / 16) * 16;
+}
+
+TEST_F(TlsConnectTest, OneNRecordSplitting) {
+  ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_0);
+  EnsureTlsSetup();
+  ConnectWithCipherSuite(TLS_RSA_WITH_AES_128_CBC_SHA);
+  auto records = std::make_shared<TlsRecordRecorder>();
+  server_->SetPacketFilter(records);
+  // This should be split into 1, 16384 and 20.
+  DataBuffer big_buffer;
+  big_buffer.Allocate(1 + 16384 + 20);
+  server_->SendBuffer(big_buffer);
+  ASSERT_EQ(3U, records->count());
+  EXPECT_EQ(ExpectedCbcLen(1), records->record(0).buffer.len());
+  EXPECT_EQ(ExpectedCbcLen(16384), records->record(1).buffer.len());
+  EXPECT_EQ(ExpectedCbcLen(20), records->record(2).buffer.len());
+}
+
 INSTANTIATE_TEST_CASE_P(
     GenericStream, TlsConnectGeneric,
     ::testing::Combine(TlsConnectTestBase::kTlsVariantsStream,
                        TlsConnectTestBase::kTlsVAll));
 INSTANTIATE_TEST_CASE_P(
     GenericDatagram, TlsConnectGeneric,
     ::testing::Combine(TlsConnectTestBase::kTlsVariantsDatagram,
                        TlsConnectTestBase::kTlsV11Plus));
--- a/lib/ssl/ssl3con.c
+++ b/lib/ssl/ssl3con.c
@@ -2652,103 +2652,48 @@ ssl_ProtectRecord(sslSocket *ss, ssl3Cip
         return SECFailure;
     }
     ++cwSpec->write_seq_num;
 
     return SECSuccess;
 }
 
 SECStatus
-ssl_ProtectRecordMaybeSplit(sslSocket *ss, ssl3CipherSpec *cwSpec,
-                            SSL3ContentType type, PRBool capRecordVersion,
-                            const PRUint8 *pIn, unsigned int nIn,
-                            unsigned int *written)
+ssl_ProtectNextRecord(sslSocket *ss, ssl3CipherSpec *spec,
+                      SSL3ContentType type, PRBool capRecordVersion,
+                      const PRUint8 *pIn, unsigned int nIn,
+                      unsigned int *written)
 {
     sslBuffer *wrBuf = &ss->sec.writeBuf;
-    unsigned int contentLen = PR_MIN(nIn, MAX_FRAGMENT_LENGTH);
+    unsigned int contentLen;
     unsigned int spaceNeeded;
-    unsigned int numRecords;
-    SECStatus rv;
-
-    if (nIn > 1 && ss->opt.cbcRandomIV &&
-        ss->ssl3.cwSpec->version < SSL_LIBRARY_VERSION_TLS_1_1 &&
-        type == content_application_data &&
-        ss->ssl3.cwSpec->cipher_def->type == type_block /* CBC mode */) {
-        /* We will split the first byte of the record into its own record,
-         * as explained in the documentation for SSL_CBC_RANDOM_IV in ssl.h
-         */
-        numRecords = 2;
-    } else {
-        numRecords = 1;
-    }
-
-    spaceNeeded = contentLen + (numRecords * SSL3_BUFFER_FUDGE);
-    if (ss->ssl3.cwSpec->version >= SSL_LIBRARY_VERSION_TLS_1_1 &&
-        ss->ssl3.cwSpec->cipher_def->type == type_block) {
-        spaceNeeded += ss->ssl3.cwSpec->cipher_def->iv_size;
+    SECStatus rv;
+
+    contentLen = PR_MIN(nIn, MAX_FRAGMENT_LENGTH);
+    spaceNeeded = contentLen + SSL3_BUFFER_FUDGE;
+    if (spec->version >= SSL_LIBRARY_VERSION_TLS_1_1 &&
+        spec->cipher_def->type == type_block) {
+        spaceNeeded += spec->cipher_def->iv_size;
     }
     if (spaceNeeded > SSL_BUFFER_SPACE(wrBuf)) {
         rv = sslBuffer_Grow(wrBuf, spaceNeeded);
         if (rv != SECSuccess) {
-            SSL_DBG(("%d: SSL3[%d]: expand write buffer to %d bytes",
+            SSL_DBG(("%d: SSL3[%d]: failed to expand write buffer to %d",
                      SSL_GETPID(), ss->fd, spaceNeeded));
             return SECFailure;
         }
     }
 
-    if (numRecords == 2) {
-        rv = ssl_ProtectRecord(ss, ss->ssl3.cwSpec, capRecordVersion, type,
-                               pIn, 1, wrBuf);
-        if (rv != SECSuccess) {
-            return SECFailure;
-        }
-
-        PRINT_BUF(50, (ss, "send (encrypted) record data [1/2]:",
-                       SSL_BUFFER_BASE(wrBuf), SSL_BUFFER_LEN(wrBuf)));
-
-        {
-            sslBuffer secondRecord = SSL_BUFFER_FIXED(SSL_BUFFER_NEXT(wrBuf),
-                                                      SSL_BUFFER_SPACE(wrBuf));
-
-            rv = ssl_ProtectRecord(ss, ss->ssl3.cwSpec, capRecordVersion, type,
-                                   pIn + 1, contentLen - 1, &secondRecord);
-            if (rv != SECSuccess) {
-                return SECFailure;
-            }
-            PRINT_BUF(50, (ss, "send (encrypted) record data [2/2]:",
-                           SSL_BUFFER_BASE(&secondRecord),
-                           SSL_BUFFER_LEN(&secondRecord)));
-            rv = sslBuffer_Skip(wrBuf, SSL_BUFFER_LEN(&secondRecord), NULL);
-            if (rv != SECSuccess) {
-                return SECFailure;
-            }
-        }
-    } else {
-        ssl3CipherSpec *spec;
-
-        if (cwSpec) {
-            /* cwSpec can only be set for retransmissions of DTLS handshake
-             * messages. */
-            PORT_Assert(IS_DTLS(ss) &&
-                        (type == content_handshake ||
-                         type == content_change_cipher_spec));
-            spec = cwSpec;
-        } else {
-            spec = ss->ssl3.cwSpec;
-        }
-
-        rv = ssl_ProtectRecord(ss, spec, !IS_DTLS(ss) && capRecordVersion,
-                               type, pIn, contentLen, wrBuf);
-        if (rv != SECSuccess) {
-            return SECFailure;
-        }
-        PRINT_BUF(50, (ss, "send (encrypted) record data:",
-                       SSL_BUFFER_BASE(wrBuf), SSL_BUFFER_LEN(wrBuf)));
-    }
-
+    rv = ssl_ProtectRecord(ss, spec, capRecordVersion,
+                           type, pIn, contentLen, wrBuf);
+    if (rv != SECSuccess) {
+        return SECFailure;
+    }
+    PRINT_BUF(50, (ss, "send (encrypted) record data:",
+                   SSL_BUFFER_BASE(wrBuf), SSL_BUFFER_LEN(wrBuf)));
     *written = contentLen;
     return SECSuccess;
 }
 
 /* Process the plain text before sending it.
  * Returns the number of bytes of plaintext that were successfully sent
  *  plus the number of bytes of plaintext that were copied into the
  *  output (write) buffer.
@@ -2783,16 +2728,17 @@ PRInt32
 ssl3_SendRecord(sslSocket *ss,
                 ssl3CipherSpec *cwSpec, /* non-NULL for DTLS retransmits */
                 SSL3ContentType type,
                 const PRUint8 *pIn, /* input buffer */
                 PRInt32 nIn,        /* bytes of input */
                 PRInt32 flags)
 {
     sslBuffer *wrBuf = &ss->sec.writeBuf;
+    ssl3CipherSpec *spec;
     SECStatus rv;
     PRInt32 totalSent = 0;
     PRBool capRecordVersion;
 
     SSL_TRC(3, ("%d: SSL3[%d] SendRecord type: %s nIn=%d",
                 SSL_GETPID(), ss->fd, ssl3_DecodeContentType(type),
                 nIn));
     PRINT_BUF(50, (ss, "Send record (plain text)", pIn, nIn));
@@ -2826,29 +2772,45 @@ ssl3_SendRecord(sslSocket *ss,
     }
 
     /* check for Token Presence */
     if (!ssl3_ClientAuthTokenPresent(ss->sec.ci.sid)) {
         PORT_SetError(SSL_ERROR_TOKEN_INSERTION_REMOVAL);
         return SECFailure;
     }
 
+    if (cwSpec) {
+        /* cwSpec can only be set for retransmissions of the DTLS handshake. */
+        PORT_Assert(IS_DTLS(ss) &&
+                    (type == content_handshake ||
+                     type == content_change_cipher_spec));
+        spec = cwSpec;
+    } else {
+        spec = ss->ssl3.cwSpec;
+    }
+
     while (nIn > 0) {
-        unsigned int contentLen = 0;
-
-        ssl_GetSpecReadLock(ss); /********************************/
-        rv = ssl_ProtectRecordMaybeSplit(ss, cwSpec, type, capRecordVersion,
-                                         pIn, nIn, &contentLen);
-        ssl_ReleaseSpecReadLock(ss); /************************************/
+        unsigned int written = 0;
+
+        ssl_GetSpecReadLock(ss);
+        rv = ssl_ProtectNextRecord(ss, spec, type, capRecordVersion,
+                                   pIn, nIn, &written);
+        ssl_ReleaseSpecReadLock(ss);
         if (rv != SECSuccess) {
             return SECFailure;
         }
 
-        pIn += contentLen;
-        nIn -= contentLen;
+        PORT_Assert(written > 0);
+        /* DTLS should not fragment non-application data here. */
+        if (IS_DTLS(ss) && type != content_application_data) {
+            PORT_Assert(written == nIn);
+        }
+
+        pIn += written;
+        nIn -= written;
         PORT_Assert(nIn >= 0);
 
         /* If there's still some previously saved ciphertext,
          * or the caller doesn't want us to send the data yet,
          * then add all our new ciphertext to the amount previously saved.
          */
         if ((ss->pendingBuf.len > 0) ||
             (flags & ssl_SEND_FLAG_FORCE_INTO_BUFFER)) {
@@ -2901,32 +2863,33 @@ ssl3_SendRecord(sslSocket *ss,
                                        SSL_BUFFER_LEN(wrBuf) - sent);
                 if (rv != SECSuccess) {
                     /* presumably a memory error, SEC_ERROR_NO_MEMORY */
                     return SECFailure;
                 }
             }
         }
         wrBuf->len = 0;
-        totalSent += contentLen;
+        totalSent += written;
     }
     return totalSent;
 }
 
 #define SSL3_PENDING_HIGH_WATER 1024
 
 /* Attempt to send the content of "in" in an SSL application_data record.
  * Returns "len" or SECFailure,   never SECWouldBlock, nor SECSuccess.
  */
 int
 ssl3_SendApplicationData(sslSocket *ss, const unsigned char *in,
                          PRInt32 len, PRInt32 flags)
 {
     PRInt32 totalSent = 0;
     PRInt32 discarded = 0;
+    PRBool splitNeeded = PR_FALSE;
 
     PORT_Assert(ss->opt.noLocks || ssl_HaveXmitBufLock(ss));
     /* These flags for internal use only */
     PORT_Assert(!(flags & ssl_SEND_FLAG_NO_RETRANSMIT));
     if (len < 0 || !in) {
         PORT_SetError(PR_INVALID_ARGUMENT_ERROR);
         return SECFailure;
     }
@@ -2943,31 +2906,47 @@ ssl3_SendApplicationData(sslSocket *ss, 
         if (in[0] != (unsigned char)(ss->appDataBuffered)) {
             PORT_SetError(PR_INVALID_ARGUMENT_ERROR);
             return SECFailure;
         }
         in++;
         len--;
         discarded = 1;
     }
+
+    /* We will split the first byte of the record into its own record, as
+     * explained in the documentation for SSL_CBC_RANDOM_IV in ssl.h.
+     */
+    if (len > 1 && ss->opt.cbcRandomIV &&
+        ss->version < SSL_LIBRARY_VERSION_TLS_1_1 &&
+        ss->ssl3.cwSpec->cipher_def->type == type_block /* CBC */) {
+        splitNeeded = PR_TRUE;
+    }
+
     while (len > totalSent) {
         PRInt32 sent, toSend;
 
         if (totalSent > 0) {
             /*
              * The thread yield is intended to give the reader thread a
              * chance to get some cycles while the writer thread is in
              * the middle of a large application data write.  (See
              * Bugzilla bug 127740, comment #1.)
              */
             ssl_ReleaseXmitBufLock(ss);
             PR_Sleep(PR_INTERVAL_NO_WAIT); /* PR_Yield(); */
             ssl_GetXmitBufLock(ss);
         }
-        toSend = PR_MIN(len - totalSent, MAX_FRAGMENT_LENGTH);
+
+        if (splitNeeded) {
+            toSend = 1;
+            splitNeeded = PR_FALSE;
+        } else {
+            toSend = PR_MIN(len - totalSent, MAX_FRAGMENT_LENGTH);
+        }
 
         /*
          * Note that the 0 epoch is OK because flags will never require
          * its use, as guaranteed by the PORT_Assert above.
          */
         sent = ssl3_SendRecord(ss, NULL, content_application_data,
                                in + totalSent, toSend, flags);
         if (sent < 0) {