Bug 1343036 - Early data size limit, r=franziskus,ekr
authorMartin Thomson <martin.thomson@gmail.com>
Mon, 06 Mar 2017 16:24:30 +1100
changeset 13227 d621b1e53054ea654f187990e959233c9f63b249
parent 13226 3dfc16480e837b69682f4d857c5e734f21eff3bb
child 13228 8e85df1972f9fa2c34bac40bd7d65ad2591c377f
push id2097
push usermartin.thomson@gmail.com
push dateSat, 18 Mar 2017 22:39:41 +0000
reviewersfranziskus, ekr
bugs1343036
Bug 1343036 - Early data size limit, r=franziskus,ekr
gtests/ssl_gtest/libssl_internals.c
gtests/ssl_gtest/libssl_internals.h
gtests/ssl_gtest/ssl_0rtt_unittest.cc
gtests/ssl_gtest/tls_connect.cc
lib/ssl/SSLerrs.h
lib/ssl/ssl3con.c
lib/ssl/ssl3exthandle.c
lib/ssl/sslerr.h
lib/ssl/sslimpl.h
lib/ssl/sslinfo.c
lib/ssl/sslsecur.c
lib/ssl/sslt.h
lib/ssl/tls13con.c
lib/ssl/tls13con.h
--- a/gtests/ssl_gtest/libssl_internals.c
+++ b/gtests/ssl_gtest/libssl_internals.c
@@ -28,25 +28,18 @@ SECStatus SSLInt_IncrementClientHandshak
 SECStatus SSLInt_UpdateSSLv2ClientRandom(PRFileDesc *fd, uint8_t *rnd,
                                          size_t rnd_len, uint8_t *msg,
                                          size_t msg_len) {
   sslSocket *ss = ssl_FindSocket(fd);
   if (!ss) {
     return SECFailure;
   }
 
-  SECStatus rv = ssl3_InitState(ss);
-  if (rv != SECSuccess) {
-    return rv;
-  }
-
-  rv = ssl3_RestartHandshakeHashes(ss);
-  if (rv != SECSuccess) {
-    return rv;
-  }
+  ssl3_InitState(ss);
+  ssl3_RestartHandshakeHashes(ss);
 
   // Ensure we don't overrun hs.client_random.
   rnd_len = PR_MIN(SSL3_RANDOM_LENGTH, rnd_len);
 
   // Zero the client_random struct.
   PORT_Memset(&ss->ssl3.hs.client_random, 0, SSL3_RANDOM_LENGTH);
 
   // Copy over the challenge bytes.
@@ -61,21 +54,21 @@ PRBool SSLInt_ExtensionNegotiated(PRFile
   sslSocket *ss = ssl_FindSocket(fd);
   return (PRBool)(ss && ssl3_ExtensionNegotiated(ss, ext));
 }
 
 void SSLInt_ClearSessionTicketKey() { ssl_ResetSessionTicketKeys(); }
 
 SECStatus SSLInt_SetMTU(PRFileDesc *fd, PRUint16 mtu) {
   sslSocket *ss = ssl_FindSocket(fd);
-  if (ss) {
-    ss->ssl3.mtu = mtu;
-    return SECSuccess;
+  if (!ss) {
+    return SECFailure;
   }
-  return SECFailure;
+  ss->ssl3.mtu = mtu;
+  return SECSuccess;
 }
 
 PRInt32 SSLInt_CountTls13CipherSpecs(PRFileDesc *fd) {
   PRCList *cur_p;
   PRInt32 ct = 0;
 
   sslSocket *ss = ssl_FindSocket(fd);
   if (!ss) {
@@ -189,17 +182,19 @@ SECStatus SSLInt_Set0RttAlpn(PRFileDesc 
   if (!ss) {
     return SECFailure;
   }
 
   ss->xtnData.nextProtoState = SSL_NEXT_PROTO_EARLY_VALUE;
   if (ss->xtnData.nextProto.data) {
     SECITEM_FreeItem(&ss->xtnData.nextProto, PR_FALSE);
   }
-  if (!SECITEM_AllocItem(NULL, &ss->xtnData.nextProto, len)) return SECFailure;
+  if (!SECITEM_AllocItem(NULL, &ss->xtnData.nextProto, len)) {
+    return SECFailure;
+  }
   PORT_Memcpy(ss->xtnData.nextProto.data, data, len);
 
   return SECSuccess;
 }
 
 PRBool SSLInt_HasCertWithAuthType(PRFileDesc *fd, SSLAuthType authType) {
   sslSocket *ss = ssl_FindSocket(fd);
   if (!ss) {
@@ -246,27 +241,29 @@ SECStatus SSLInt_AdvanceReadSeqNum(PRFil
   sslSocket *ss;
   ssl3CipherSpec *spec;
 
   ss = ssl_FindSocket(fd);
   if (!ss) {
     return SECFailure;
   }
   if (to >= (1ULL << 48)) {
+    PORT_SetError(SEC_ERROR_INVALID_ARGS);
     return SECFailure;
   }
   ssl_GetSpecWriteLock(ss);
   spec = ss->ssl3.crSpec;
   epoch = spec->read_seq_num >> 48;
   spec->read_seq_num = (epoch << 48) | to;
 
   /* For DTLS, we need to fix the record sequence number.  For this, we can just
    * scrub the entire structure on the assumption that the new sequence number
    * is far enough past the last received sequence number. */
   if (to <= spec->recvdRecords.right + DTLS_RECVD_RECORDS_WINDOW) {
+    PORT_SetError(SEC_ERROR_INVALID_ARGS);
     return SECFailure;
   }
   dtls_RecordSetRecvd(&spec->recvdRecords, to);
 
   ssl_ReleaseSpecWriteLock(ss);
   return SECSuccess;
 }
 
@@ -274,16 +271,17 @@ SECStatus SSLInt_AdvanceWriteSeqNum(PRFi
   PRUint64 epoch;
   sslSocket *ss;
 
   ss = ssl_FindSocket(fd);
   if (!ss) {
     return SECFailure;
   }
   if (to >= (1ULL << 48)) {
+    PORT_SetError(SEC_ERROR_INVALID_ARGS);
     return SECFailure;
   }
   ssl_GetSpecWriteLock(ss);
   epoch = ss->ssl3.cwSpec->write_seq_num >> 48;
   ss->ssl3.cwSpec->write_seq_num = (epoch << 48) | to;
   ssl_ReleaseSpecWriteLock(ss);
   return SECSuccess;
 }
@@ -359,15 +357,41 @@ SECStatus SSLInt_UsingShortHeaders(PRFil
   sslSocket *ss;
 
   ss = ssl_FindSocket(fd);
   if (!ss) {
     return SECFailure;
   }
 
   *result = ss->ssl3.hs.shortHeaders;
-
   return SECSuccess;
 }
 
 void SSLInt_SetTicketLifetime(uint32_t lifetime) {
   ssl_ticket_lifetime = lifetime;
 }
+
+void SSLInt_SetMaxEarlyDataSize(uint32_t size) {
+  ssl_max_early_data_size = size;
+}
+
+SECStatus SSLInt_SetSocketMaxEarlyDataSize(PRFileDesc *fd, uint32_t size) {
+  sslSocket *ss;
+
+  ss = ssl_FindSocket(fd);
+  if (!ss) {
+    return SECFailure;
+  }
+
+  /* This only works when resuming. */
+  if (!ss->statelessResume) {
+    PORT_SetError(SEC_INTERNAL_ONLY);
+    return SECFailure;
+  }
+
+  /* Modifying both specs allows this to be used on either peer. */
+  ssl_GetSpecWriteLock(ss);
+  ss->ssl3.crSpec->earlyDataRemaining = size;
+  ss->ssl3.cwSpec->earlyDataRemaining = size;
+  ssl_ReleaseSpecWriteLock(ss);
+
+  return SECSuccess;
+}
--- a/gtests/ssl_gtest/libssl_internals.h
+++ b/gtests/ssl_gtest/libssl_internals.h
@@ -45,10 +45,12 @@ SECStatus SSLInt_SetCipherSpecChangeFunc
                                          void *arg);
 PK11SymKey *SSLInt_CipherSpecToKey(PRBool isServer, ssl3CipherSpec *spec);
 SSLCipherAlgorithm SSLInt_CipherSpecToAlgorithm(PRBool isServer,
                                                 ssl3CipherSpec *spec);
 unsigned char *SSLInt_CipherSpecToIv(PRBool isServer, ssl3CipherSpec *spec);
 SECStatus SSLInt_EnableShortHeaders(PRFileDesc *fd);
 SECStatus SSLInt_UsingShortHeaders(PRFileDesc *fd, PRBool *result);
 void SSLInt_SetTicketLifetime(uint32_t lifetime);
+void SSLInt_SetMaxEarlyDataSize(uint32_t size);
+SECStatus SSLInt_SetSocketMaxEarlyDataSize(PRFileDesc *fd, uint32_t size);
 
 #endif  // ndef libssl_internals_h_
--- a/gtests/ssl_gtest/ssl_0rtt_unittest.cc
+++ b/gtests/ssl_gtest/ssl_0rtt_unittest.cc
@@ -277,9 +277,115 @@ TEST_P(TlsConnectTls13, TestTls13ZeroRtt
   if (mode_ == STREAM) {
     // The server sends an alert when receiving the early app data record.
     ASSERT_TRUE_WAIT(
         (server_->error_code() == SSL_ERROR_RX_UNEXPECTED_APPLICATION_DATA),
         2000);
   }
 }
 
+static void CheckEarlyDataLimit(const std::shared_ptr<TlsAgent>& agent,
+                                size_t expected_size) {
+  SSLPreliminaryChannelInfo preinfo;
+  SECStatus rv =
+      SSL_GetPreliminaryChannelInfo(agent->ssl_fd(), &preinfo, sizeof(preinfo));
+  EXPECT_EQ(SECSuccess, rv);
+  EXPECT_EQ(expected_size, static_cast<size_t>(preinfo.maxEarlyDataSize));
+}
+
+TEST_P(TlsConnectTls13, SendTooMuchEarlyData) {
+  const char* big_message = "0123456789abcdef";
+  const size_t short_size = strlen(big_message) - 1;
+  const PRInt32 short_length = static_cast<PRInt32>(short_size);
+  SSLInt_SetMaxEarlyDataSize(static_cast<PRUint32>(short_size));
+  SetupForZeroRtt();
+
+  client_->Set0RttEnabled(true);
+  server_->Set0RttEnabled(true);
+  ExpectResumption(RESUME_TICKET);
+  client_->SetExpectedAlertSentCount(1);
+  server_->SetExpectedAlertReceivedCount(1);
+
+  client_->Handshake();
+  CheckEarlyDataLimit(client_, short_size);
+
+  PRInt32 sent;
+  // Writing more than the limit will succeed in TLS, but fail in DTLS.
+  if (mode_ == STREAM) {
+    sent = PR_Write(client_->ssl_fd(), big_message,
+                    static_cast<PRInt32>(strlen(big_message)));
+  } else {
+    sent = PR_Write(client_->ssl_fd(), big_message,
+                    static_cast<PRInt32>(strlen(big_message)));
+    EXPECT_GE(0, sent);
+    EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError());
+
+    // Try an exact-sized write now.
+    sent = PR_Write(client_->ssl_fd(), big_message, short_length);
+  }
+  EXPECT_EQ(short_length, sent);
+
+  // Even a single octet write should now fail.
+  sent = PR_Write(client_->ssl_fd(), big_message, 1);
+  EXPECT_GE(0, sent);
+  EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError());
+
+  // Process the ClientHello and read 0-RTT.
+  server_->Handshake();
+  CheckEarlyDataLimit(server_, short_size);
+
+  std::vector<uint8_t> buf(short_size + 1);
+  PRInt32 read = PR_Read(server_->ssl_fd(), buf.data(), buf.capacity());
+  EXPECT_EQ(short_length, read);
+  EXPECT_EQ(0, memcmp(big_message, buf.data(), short_size));
+
+  // Second read fails.
+  read = PR_Read(server_->ssl_fd(), buf.data(), buf.capacity());
+  EXPECT_EQ(SECFailure, read);
+  EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError());
+
+  Handshake();
+  ExpectEarlyDataAccepted(true);
+  CheckConnected();
+  SendReceive();
+}
+
+TEST_P(TlsConnectTls13, ReceiveTooMuchEarlyData) {
+  const size_t limit = 5;
+  SSLInt_SetMaxEarlyDataSize(limit);
+  SetupForZeroRtt();
+
+  client_->Set0RttEnabled(true);
+  server_->Set0RttEnabled(true);
+  ExpectResumption(RESUME_TICKET);
+
+  client_->Handshake();  // Send ClientHello
+  CheckEarlyDataLimit(client_, limit);
+
+  // Lift the limit on the client.
+  EXPECT_EQ(SECSuccess,
+            SSLInt_SetSocketMaxEarlyDataSize(client_->ssl_fd(), 1000));
+
+  // Send message
+  const char* message = "0123456789abcdef";
+  const PRInt32 message_len = static_cast<PRInt32>(strlen(message));
+  EXPECT_EQ(message_len, PR_Write(client_->ssl_fd(), message, message_len));
+
+  server_->Handshake();  // Process ClientHello, send server flight.
+  server_->Handshake();  // Just to make sure that we don't read ahead.
+  CheckEarlyDataLimit(server_, limit);
+
+  // Attempt to read early data.
+  std::vector<uint8_t> buf(strlen(message) + 1);
+  EXPECT_GT(0, PR_Read(server_->ssl_fd(), buf.data(), buf.capacity()));
+  if (mode_ == STREAM) {
+    // This error isn't fatal for DTLS.
+    server_->CheckErrorCode(SSL_ERROR_TOO_MUCH_EARLY_DATA);
+  }
+
+  client_->Handshake();  // Process the handshake.
+  client_->Handshake();  // Process the alert.
+  if (mode_ == STREAM) {
+    client_->CheckErrorCode(SSL_ERROR_HANDSHAKE_UNEXPECTED_ALERT);
+  }
+}
+
 }  // namespace nss_test
--- a/gtests/ssl_gtest/tls_connect.cc
+++ b/gtests/ssl_gtest/tls_connect.cc
@@ -170,16 +170,17 @@ void TlsConnectTestBase::ClearServerCach
   SSLInt_ClearSessionTicketKey();
   SSL_ConfigServerSessionIDCache(1024, 0, 0, g_working_dir_path.c_str());
 }
 
 void TlsConnectTestBase::SetUp() {
   SSL_ConfigServerSessionIDCache(1024, 0, 0, g_working_dir_path.c_str());
   SSLInt_ClearSessionTicketKey();
   SSLInt_SetTicketLifetime(30);
+  SSLInt_SetMaxEarlyDataSize(1024);
   ClearStats();
   Init();
 }
 
 void TlsConnectTestBase::TearDown() {
   client_ = nullptr;
   server_ = nullptr;
 
--- a/lib/ssl/SSLerrs.h
+++ b/lib/ssl/SSLerrs.h
@@ -503,8 +503,11 @@ ER3(SSL_ERROR_MISSING_SIGNATURE_ALGORITH
 ER3(SSL_ERROR_MALFORMED_PSK_KEY_EXCHANGE_MODES, (SSL_ERROR_BASE + 158),
     "SSL received a malformed PSK key exchange modes extension.")
 
 ER3(SSL_ERROR_MISSING_PSK_KEY_EXCHANGE_MODES, (SSL_ERROR_BASE + 159),
     "SSL expected a PSK key exchange modes extension.")
 
 ER3(SSL_ERROR_DOWNGRADE_WITH_EARLY_DATA, (SSL_ERROR_BASE + 160),
     "SSL got a pre-TLS 1.3 version even though we sent early data.")
+
+ER3(SSL_ERROR_TOO_MUCH_EARLY_DATA, (SSL_ERROR_BASE + 161),
+    "SSL received more early data than permitted.")
--- a/lib/ssl/ssl3con.c
+++ b/lib/ssl/ssl3con.c
@@ -2721,20 +2721,17 @@ ssl3_SendRecord(sslSocket *ss,
     }
 
     if (ss->ssl3.initialized == PR_FALSE) {
         /* This can happen on a server if the very first incoming record
         ** looks like a defective ssl3 record (e.g. too long), and we're
         ** trying to send an alert.
         */
         PR_ASSERT(type == content_alert);
-        rv = ssl3_InitState(ss);
-        if (rv != SECSuccess) {
-            return SECFailure; /* ssl3_InitState has set the error code. */
-        }
+        ssl3_InitState(ss);
     }
 
     /* check for Token Presence */
     if (!ssl3_ClientAuthTokenPresent(ss->sec.ci.sid)) {
         PORT_SetError(SSL_ERROR_TOKEN_INSERTION_REMOVAL);
         return SECFailure;
     }
 
@@ -2930,16 +2927,17 @@ ssl3_SendApplicationData(sslSocket *ss, 
              * the middle of a large application data write.  (See
              * Bugzilla bug 127740, comment #1.)
              */
             ssl_ReleaseXmitBufLock(ss);
             PR_Sleep(PR_INTERVAL_NO_WAIT); /* PR_Yield(); */
             ssl_GetXmitBufLock(ss);
         }
         toSend = PR_MIN(len - totalSent, MAX_FRAGMENT_LENGTH);
+
         /*
          * Note that the 0 epoch is OK because flags will never require
          * its use, as guaranteed by the PORT_Assert above.
          */
         sent = ssl3_SendRecord(ss, NULL, content_application_data,
                                in + totalSent, toSend, flags);
         if (sent < 0) {
             if (totalSent > 0 && PR_GetError() == PR_WOULD_BLOCK_ERROR) {
@@ -4096,34 +4094,31 @@ ssl3_InitHandshakeHashes(sslSocket *ss)
             return SECFailure;
         }
         sslBuffer_Clear(&ss->ssl3.hs.messages);
     }
 
     return SECSuccess;
 }
 
-SECStatus
+void
 ssl3_RestartHandshakeHashes(sslSocket *ss)
 {
-    SECStatus rv = SECSuccess;
-
     SSL_TRC(30, ("%d: SSL3[%d]: reset handshake hashes",
                  SSL_GETPID(), ss->fd));
     ss->ssl3.hs.hashType = handshake_hash_unknown;
     ss->ssl3.hs.messages.len = 0;
     if (ss->ssl3.hs.md5) {
         PK11_DestroyContext(ss->ssl3.hs.md5, PR_TRUE);
         ss->ssl3.hs.md5 = NULL;
     }
     if (ss->ssl3.hs.sha) {
         PK11_DestroyContext(ss->ssl3.hs.sha, PR_TRUE);
         ss->ssl3.hs.sha = NULL;
     }
-    return rv;
 }
 
 /*
  * Handshake messages
  */
 /* Called from  ssl3_InitHandshakeHashes()
 **      ssl3_AppendHandshake()
 **      ssl3_HandleV2ClientHello()
@@ -5019,25 +5014,18 @@ ssl3_SendClientHello(sslSocket *ss, sslC
         return SECFailure;
     }
 
     /* If we are responding to a HelloRetryRequest, don't reinitialize. We need
      * to maintain the handshake hashes. */
     if (ss->ssl3.hs.helloRetry) {
         PORT_Assert(type == client_hello_retry);
     } else {
-        rv = ssl3_InitState(ss);
-        if (rv != SECSuccess) {
-            return rv; /* ssl3_InitState has set the error code. */
-        }
-
-        rv = ssl3_RestartHandshakeHashes(ss);
-        if (rv != SECSuccess) {
-            return rv;
-        }
+        ssl3_InitState(ss);
+        ssl3_RestartHandshakeHashes(ss);
     }
 
     /* These must be reset every handshake. */
     ss->ssl3.hs.sendingSCSV = PR_FALSE;
     ss->ssl3.hs.preliminaryInfo = 0;
     PORT_Assert(IS_DTLS(ss) || type != client_hello_retransmit);
     SECITEM_FreeItem(&ss->ssl3.hs.newSessionTicket.ticket, PR_FALSE);
     ss->ssl3.hs.receivedNewSessionTicket = PR_FALSE;
@@ -9159,26 +9147,18 @@ ssl3_HandleV2ClientHello(sslSocket *ss, 
 
     ssl3_ResetExtensionData(&ss->xtnData);
 
     version = (buffer[1] << 8) | buffer[2];
     if (version < SSL_LIBRARY_VERSION_3_0) {
         goto loser;
     }
 
-    rv = ssl3_InitState(ss);
-    if (rv != SECSuccess) {
-        ssl_ReleaseSSL3HandshakeLock(ss);
-        return rv; /* ssl3_InitState has set the error code. */
-    }
-    rv = ssl3_RestartHandshakeHashes(ss);
-    if (rv != SECSuccess) {
-        ssl_ReleaseSSL3HandshakeLock(ss);
-        return rv;
-    }
+    ssl3_InitState(ss);
+    ssl3_RestartHandshakeHashes(ss);
 
     if (ss->ssl3.hs.ws != wait_client_hello) {
         desc = unexpected_message;
         errCode = SSL_ERROR_RX_UNEXPECTED_CLIENT_HELLO;
         goto alert_loser;
     }
 
     total += suite_length = (buffer[3] << 8) | buffer[4];
@@ -11790,20 +11770,17 @@ ssl3_HandleHandshakeMessage(sslSocket *s
     hdr[0] = (PRUint8)ss->ssl3.hs.msg_type;
     hdr[1] = (PRUint8)(length >> 16);
     hdr[2] = (PRUint8)(length >> 8);
     hdr[3] = (PRUint8)(length);
 
     /* Start new handshake hashes when we start a new handshake.  Unless this is
      * TLS 1.3 and we sent a HelloRetryRequest. */
     if (ss->ssl3.hs.msg_type == client_hello && !ss->ssl3.hs.helloRetry) {
-        rv = ssl3_RestartHandshakeHashes(ss);
-        if (rv != SECSuccess) {
-            return rv;
-        }
+        ssl3_RestartHandshakeHashes(ss);
     }
     /* We should not include hello_request and hello_verify_request messages
      * in the handshake hashes */
     if ((ss->ssl3.hs.msg_type != hello_request) &&
         (ss->ssl3.hs.msg_type != hello_verify_request)) {
         rv = ssl3_UpdateHandshakeHashes(ss, (unsigned char *)hdr, 4);
         if (rv != SECSuccess)
             return rv; /* err code already set. */
@@ -12585,21 +12562,18 @@ ssl3_HandleRecord(sslSocket *ss, SSL3Cip
     SSL3ContentType rType;
     sslBuffer *plaintext;
     sslBuffer temp_buf = { NULL, 0, 0 };
     SSL3AlertDescription alert = internal_error;
     PORT_Assert(ss->opt.noLocks || ssl_HaveRecvBufLock(ss));
 
     if (!ss->ssl3.initialized) {
         ssl_GetSSL3HandshakeLock(ss);
-        rv = ssl3_InitState(ss);
+        ssl3_InitState(ss);
         ssl_ReleaseSSL3HandshakeLock(ss);
-        if (rv != SECSuccess) {
-            return rv; /* ssl3_InitState has set the error code. */
-        }
     }
 
     /* check for Token Presence */
     if (!ssl3_ClientAuthTokenPresent(ss->sec.ci.sid)) {
         PORT_SetError(SSL_ERROR_TOKEN_INSERTION_REMOVAL);
         return SECFailure;
     }
 
@@ -12908,26 +12882,24 @@ ssl3_InitCipherSpec(ssl3CipherSpec *spec
 }
 
 /* Called from: ssl3_SendRecord
 **      ssl3_SendClientHello()
 **      ssl3_HandleV2ClientHello()
 **      ssl3_HandleRecord()
 **
 ** This function should perhaps acquire and release the SpecWriteLock.
-**
-**
 */
-SECStatus
+void
 ssl3_InitState(sslSocket *ss)
 {
     PORT_Assert(ss->opt.noLocks || ssl_HaveSSL3HandshakeLock(ss));
 
     if (ss->ssl3.initialized)
-        return SECSuccess; /* Function should be idempotent */
+        return; /* Function should be idempotent */
 
     ss->ssl3.policy = SSL_ALLOWED;
 
     ssl_InitSecState(&ss->sec);
 
     ssl_GetSpecWriteLock(ss);
     ss->ssl3.crSpec = ss->ssl3.cwSpec = &ss->ssl3.specs[0];
     ss->ssl3.prSpec = ss->ssl3.pwSpec = &ss->ssl3.specs[1];
@@ -12972,17 +12944,16 @@ ssl3_InitState(sslSocket *ss)
     PORT_Memset(&ss->ssl3.hs.newSessionTicket, 0,
                 sizeof(ss->ssl3.hs.newSessionTicket));
 
     ss->ssl3.hs.zeroRttState = ssl_0rtt_none;
 
     ssl_FilterSupportedGroups(ss);
 
     ss->ssl3.initialized = PR_TRUE;
-    return SECSuccess;
 }
 
 /* record the export policy for this cipher suite */
 SECStatus
 ssl3_SetPolicy(ssl3CipherSuite which, int policy)
 {
     ssl3CipherSuiteCfg *suite;
 
--- a/lib/ssl/ssl3exthandle.c
+++ b/lib/ssl/ssl3exthandle.c
@@ -873,17 +873,17 @@ ssl3_ClientHandleStatusRequestXtn(const 
     }
 
     /* Keep track of negotiated extensions. */
     xtnData->negotiated[xtnData->numNegotiated++] = ex_type;
     return SECSuccess;
 }
 
 PRUint32 ssl_ticket_lifetime = 2 * 24 * 60 * 60; /* 2 days in seconds */
-#define TLS_EX_SESS_TICKET_VERSION (0x0103)
+#define TLS_EX_SESS_TICKET_VERSION (0x0104)
 
 /*
  * Called from ssl3_SendNewSessionTicket, tls13_SendNewSessionTicket
  */
 SECStatus
 ssl3_EncodeSessionTicket(sslSocket *ss,
                          const NewSessionTicket *ticket,
                          SECItem *ticket_data)
@@ -998,17 +998,18 @@ ssl3_EncodeSessionTicket(sslSocket *ss,
         + ms_item.len                          /* master_secret */
         + 1                                    /* client_auth_type */
         + cert_length                          /* cert */
         + 1                                    /* server name type */
         + srvNameLen                           /* name len + length field */
         + 1                                    /* extendedMasterSecretUsed */
         + sizeof(ticket->ticket_lifetime_hint) /* ticket lifetime hint */
         + sizeof(ticket->flags)                /* ticket flags */
-        + 1 + alpnSelection.len;               /* npn value + length field. */
+        + 1 + alpnSelection.len                /* npn value + length field. */
+        + 4;                                   /* maxEarlyData */
 #ifdef UNSAFE_FUZZER_MODE
     padding_length = 0;
 #else
     padding_length = AES_BLOCK_SIZE -
                      (ciphertext_length %
                       AES_BLOCK_SIZE);
 #endif
     ciphertext_length += padding_length;
@@ -1147,16 +1148,20 @@ ssl3_EncodeSessionTicket(sslSocket *ss,
     if (rv != SECSuccess)
         goto loser;
     if (alpnSelection.len) {
         rv = ssl3_AppendToItem(&plaintext, alpnSelection.data, alpnSelection.len);
         if (rv != SECSuccess)
             goto loser;
     }
 
+    rv = ssl3_AppendNumberToItem(&plaintext, ssl_max_early_data_size, 4);
+    if (rv != SECSuccess)
+        goto loser;
+
     PORT_Assert(plaintext.len == padding_length);
     for (i = 0; i < padding_length; i++)
         plaintext.data[i] = (unsigned char)padding_length;
 
     if (SECITEM_AllocItem(NULL, &ciphertext, ciphertext_length) == NULL) {
         rv = SECFailure;
         goto loser;
     }
@@ -1603,16 +1608,22 @@ ssl3_ProcessSessionTicketCommon(sslSocke
         rv = SECITEM_CopyItem(NULL, &parsed_session_ticket->alpnSelection,
                               &alpn_item);
         if (rv != SECSuccess)
             goto no_ticket;
         if (alpn_item.len >= 256)
             goto no_ticket;
     }
 
+    rv = ssl3_ExtConsumeHandshakeNumber(ss, &temp, 4, &buffer, &buffer_len);
+    if (rv != SECSuccess) {
+        goto no_ticket;
+    }
+    parsed_session_ticket->maxEarlyData = temp;
+
 #ifndef UNSAFE_FUZZER_MODE
     /* Done parsing.  Check that all bytes have been consumed. */
     if (buffer_len != padding_length) {
         goto no_ticket;
     }
 #endif
 
     /* Use the ticket if it has not expired, otherwise free the allocated
@@ -1637,16 +1648,18 @@ ssl3_ProcessSessionTicketCommon(sslSocke
         sid->keaType = parsed_session_ticket->keaType;
         sid->keaKeyBits = parsed_session_ticket->keaKeyBits;
         sid->namedCurve = parsed_session_ticket->namedCurve;
 
         if (SECITEM_CopyItem(NULL, &sid->u.ssl3.locked.sessionTicket.ticket,
                              &extension_data) != SECSuccess)
             goto no_ticket;
         sid->u.ssl3.locked.sessionTicket.flags = parsed_session_ticket->flags;
+        sid->u.ssl3.locked.sessionTicket.max_early_data_size =
+            parsed_session_ticket->maxEarlyData;
 
         if (parsed_session_ticket->ms_length >
             sizeof(sid->u.ssl3.keys.wrapped_master_secret))
             goto no_ticket;
         PORT_Memcpy(sid->u.ssl3.keys.wrapped_master_secret,
                     parsed_session_ticket->master_secret,
                     parsed_session_ticket->ms_length);
         sid->u.ssl3.keys.wrapped_master_secret_len =
--- a/lib/ssl/sslerr.h
+++ b/lib/ssl/sslerr.h
@@ -240,15 +240,16 @@ typedef enum {
     SSL_ERROR_TOO_MANY_RECORDS = (SSL_ERROR_BASE + 153),
     SSL_ERROR_RX_UNEXPECTED_HELLO_RETRY_REQUEST = (SSL_ERROR_BASE + 154),
     SSL_ERROR_RX_MALFORMED_HELLO_RETRY_REQUEST = (SSL_ERROR_BASE + 155),
     SSL_ERROR_BAD_2ND_CLIENT_HELLO = (SSL_ERROR_BASE + 156),
     SSL_ERROR_MISSING_SIGNATURE_ALGORITHMS_EXTENSION = (SSL_ERROR_BASE + 157),
     SSL_ERROR_MALFORMED_PSK_KEY_EXCHANGE_MODES = (SSL_ERROR_BASE + 158),
     SSL_ERROR_MISSING_PSK_KEY_EXCHANGE_MODES = (SSL_ERROR_BASE + 159),
     SSL_ERROR_DOWNGRADE_WITH_EARLY_DATA = (SSL_ERROR_BASE + 160),
+    SSL_ERROR_TOO_MUCH_EARLY_DATA = (SSL_ERROR_BASE + 161),
     SSL_ERROR_END_OF_LIST   /* let the c compiler determine the value of this. */
 } SSLErrorCodes;
 #endif /* NO_SECURITY_ERROR_ENUM */
 
 /* clang-format on */
 
 #endif /* __SSL_ERR_H_ */
--- a/lib/ssl/sslimpl.h
+++ b/lib/ssl/sslimpl.h
@@ -499,16 +499,19 @@ struct ssl3CipherSpecStr {
     sslSequenceNumber write_seq_num;
     sslSequenceNumber read_seq_num;
     SSL3ProtocolVersion version;
     ssl3KeyMaterial client;
     ssl3KeyMaterial server;
     SECItem msItem;
     DTLSEpoch epoch;
     DTLSRecvdRecords recvdRecords;
+    /* The number of 0-RTT bytes that can be sent or received in TLS 1.3. This
+     * will be zero for everything but 0-RTT. */
+    PRUint32 earlyDataRemaining;
 
     PRUint8 refCt;
     const char *phase;
 };
 
 typedef enum { never_cached,
                in_client_cache,
                in_server_cache,
@@ -1008,16 +1011,17 @@ typedef struct SessionTicketStr {
     SSL3Opaque master_secret[48];
     PRBool extendedMasterSecretUsed;
     ClientIdentity client_identity;
     SECItem peer_cert;
     PRUint32 timestamp;
     PRUint32 flags;
     SECItem srvName; /* negotiated server name */
     SECItem alpnSelection;
+    PRUint32 maxEarlyData;
 } SessionTicket;
 
 /*
  * SSL2 buffers used in SSL3.
  *     writeBuf in the SecurityInfo maintained by sslsecur.c is used
  *              to hold the data just about to be passed to the kernel
  *     sendBuf in the ConnectInfo maintained by sslcon.c is used
  *              to hold handshake messages as they are accumulated
@@ -1226,16 +1230,17 @@ struct sslSocketStr {
 };
 
 extern char ssl_debug;
 extern char ssl_trace;
 extern FILE *ssl_trace_iob;
 extern FILE *ssl_keylog_iob;
 extern PRUint32 ssl3_sid_timeout;
 extern PRUint32 ssl_ticket_lifetime;
+extern PRUint32 ssl_max_early_data_size;
 
 extern const char *const ssl3_cipherName[];
 
 extern sslSessionIDLookupFunc ssl_sid_lookup;
 extern sslSessionIDCacheFunc ssl_sid_cache;
 extern sslSessionIDUncacheFunc ssl_sid_uncache;
 
 extern const sslNamedGroupDef ssl_named_groups[];
@@ -1345,18 +1350,18 @@ extern SECStatus ssl_EnableNagleDelay(ss
 extern void ssl_FinishHandshake(sslSocket *ss);
 
 extern SECStatus ssl_CipherPolicySet(PRInt32 which, PRInt32 policy);
 
 extern SECStatus ssl_CipherPrefSetDefault(PRInt32 which, PRBool enabled);
 
 extern SECStatus ssl3_ConstrainRangeByPolicy(void);
 
-extern SECStatus ssl3_InitState(sslSocket *ss);
-extern SECStatus ssl3_RestartHandshakeHashes(sslSocket *ss);
+extern void ssl3_InitState(sslSocket *ss);
+extern void ssl3_RestartHandshakeHashes(sslSocket *ss);
 extern SECStatus ssl3_UpdateHandshakeHashes(sslSocket *ss,
                                             const unsigned char *b,
                                             unsigned int l);
 
 /* Returns PR_TRUE if we are still waiting for the server to complete its
  * response to our client second round. Once we've received the Finished from
  * the server then there is no need to check false start.
  */
--- a/lib/ssl/sslinfo.c
+++ b/lib/ssl/sslinfo.c
@@ -136,18 +136,29 @@ SSL_GetPreliminaryChannelInfo(PRFileDesc
 
     memset(&inf, 0, sizeof(inf));
     inf.length = PR_MIN(sizeof(inf), len);
 
     inf.valuesSet = ss->ssl3.hs.preliminaryInfo;
     inf.protocolVersion = ss->version;
     inf.cipherSuite = ss->ssl3.hs.cipher_suite;
     inf.canSendEarlyData = !ss->sec.isServer &&
-                           (ss->ssl3.hs.zeroRttState == ssl_0rtt_sent) &&
-                           !ss->firstHsDone;
+                           (ss->ssl3.hs.zeroRttState == ssl_0rtt_sent ||
+                            ss->ssl3.hs.zeroRttState == ssl_0rtt_accepted);
+    /* We shouldn't be able to send early data if the handshake is done. */
+    PORT_Assert(!ss->firstHsDone || !inf.canSendEarlyData);
+
+    if (ss->sec.ci.sid &&
+        (ss->ssl3.hs.zeroRttState == ssl_0rtt_sent ||
+         ss->ssl3.hs.zeroRttState == ssl_0rtt_accepted)) {
+        inf.maxEarlyDataSize =
+            ss->sec.ci.sid->u.ssl3.locked.sessionTicket.max_early_data_size;
+    } else {
+        inf.maxEarlyDataSize = 0;
+    }
 
     memcpy(info, &inf, inf.length);
     return SECSuccess;
 }
 
 /* name */
 #define CS_(x) x, #x
 #define CS(x) CS_(TLS_##x)
--- a/lib/ssl/sslsecur.c
+++ b/lib/ssl/sslsecur.c
@@ -879,16 +879,17 @@ ssl_SecureRead(sslSocket *ss, unsigned c
     return ssl_SecureRecv(ss, buf, len, 0);
 }
 
 /* Caller holds the SSL Socket's write lock. SSL_LOCK_WRITER(ss) */
 int
 ssl_SecureSend(sslSocket *ss, const unsigned char *buf, int len, int flags)
 {
     int rv = 0;
+    PRBool zeroRtt = PR_FALSE;
 
     SSL_TRC(2, ("%d: SSL[%d]: SecureSend: sending %d bytes",
                 SSL_GETPID(), ss->fd, len));
 
     if (ss->shutdownHow & ssl_SHUTDOWN_SEND) {
         PORT_SetError(PR_SOCKET_SHUTDOWN_ERROR);
         rv = PR_FAILURE;
         goto done;
@@ -918,65 +919,67 @@ ssl_SecureSend(sslSocket *ss, const unsi
         ss->writerThread = PR_GetCurrentThread();
 
     /* Check to see if we can write even though we're not finished.
      *
      * Case 1: False start
      * Case 2: TLS 1.3 0-RTT
      */
     if (!ss->firstHsDone) {
-        PRBool falseStart = PR_FALSE;
+        PRBool allowEarlySend = PR_FALSE;
+
         ssl_Get1stHandshakeLock(ss);
         if (ss->opt.enableFalseStart ||
             (ss->opt.enable0RttData && !ss->sec.isServer)) {
             ssl_GetSSL3HandshakeLock(ss);
             /* The client can sometimes send before the handshake is fully
              * complete. In TLS 1.2: false start; in TLS 1.3: 0-RTT. */
-            falseStart = ss->ssl3.hs.canFalseStart ||
-                         ss->ssl3.hs.zeroRttState == ssl_0rtt_sent ||
-                         ss->ssl3.hs.zeroRttState == ssl_0rtt_accepted;
+            zeroRtt = ss->ssl3.hs.zeroRttState == ssl_0rtt_sent ||
+                      ss->ssl3.hs.zeroRttState == ssl_0rtt_accepted;
+            allowEarlySend = ss->ssl3.hs.canFalseStart || zeroRtt;
             ssl_ReleaseSSL3HandshakeLock(ss);
         }
-        if (!falseStart && ss->handshake) {
+        if (!allowEarlySend && ss->handshake) {
             rv = ssl_Do1stHandshake(ss);
         }
         ssl_Release1stHandshakeLock(ss);
     }
     if (rv < 0) {
         ss->writerThread = NULL;
         goto done;
     }
 
+    if (zeroRtt) {
+        /* There's a limit to the number of early data octets we can send.
+         *
+         * Note that taking this lock doesn't prevent the cipher specs from
+         * being changed out between here and when records are ultimately
+         * encrypted.  The only effect of that is to occasionally do an
+         * unnecessary short write when data is identified as 0-RTT here but
+         * 1-RTT later.
+         */
+        ssl_GetSpecReadLock(ss);
+        len = tls13_LimitEarlyData(ss, content_application_data, len);
+        ssl_ReleaseSpecReadLock(ss);
+    }
+
     /* Check for zero length writes after we do housekeeping so we make forward
      * progress.
      */
     if (len == 0) {
         rv = 0;
         goto done;
     }
     PORT_Assert(buf != NULL);
     if (!buf) {
         PORT_SetError(PR_INVALID_ARGUMENT_ERROR);
         rv = PR_FAILURE;
         goto done;
     }
 
-    if (!ss->firstHsDone) {
-#ifdef DEBUG
-        ssl_GetSSL3HandshakeLock(ss);
-        PORT_Assert(!ss->sec.isServer &&
-                    (ss->ssl3.hs.canFalseStart ||
-                     ss->ssl3.hs.zeroRttState == ssl_0rtt_sent ||
-                     ss->ssl3.hs.zeroRttState == ssl_0rtt_accepted));
-        ssl_ReleaseSSL3HandshakeLock(ss);
-#endif
-        SSL_TRC(3, ("%d: SSL[%d]: SecureSend: sending data due to false start",
-                    SSL_GETPID(), ss->fd));
-    }
-
     ssl_GetXmitBufLock(ss);
     rv = ssl3_SendApplicationData(ss, buf, len, flags);
     ssl_ReleaseXmitBufLock(ss);
     ss->writerThread = NULL;
 done:
     if (rv < 0) {
         SSL_TRC(2, ("%d: SSL[%d]: SecureSend: returning %d count, error %d",
                     SSL_GETPID(), ss->fd, rv, PORT_GetError()));
--- a/lib/ssl/sslt.h
+++ b/lib/ssl/sslt.h
@@ -299,16 +299,25 @@ typedef struct SSLPreliminaryChannelInfo
     PRUint16 cipherSuite;
 
     /* The following fields were added in NSS 3.29. */
     /* |canSendEarlyData| is true when a 0-RTT is enabled. This can only be
      * true after sending the ClientHello and before the handshake completes.
      */
     PRBool canSendEarlyData;
 
+    /* The following fields were added in NSS 3.31. */
+    /* The number of early data octets that a client is permitted to send on
+     * this connection.  The value will be zero if the connection was not
+     * resumed or early data is not permitted.  For a client, this value only
+     * has meaning if |canSendEarlyData| is true.  For a server, this indicates
+     * the value that was advertised in the session ticket that was used to
+     * resume this session. */
+    PRUint32 maxEarlyDataSize;
+
     /* When adding new fields to this structure, please document the
      * NSS version in which they were added. */
 } SSLPreliminaryChannelInfo;
 
 typedef struct SSLCipherSuiteInfoStr {
     /* On return, SSL_GetCipherSuitelInfo sets |length| to the smaller of
      * the |len| argument and the length of the struct used by NSS.
      * Callers must ensure the application uses a version of NSS that
--- a/lib/ssl/tls13con.c
+++ b/lib/ssl/tls13con.c
@@ -2764,16 +2764,21 @@ tls13_SetCipherSpec(sslSocket *ss, Traff
     } else {
         /* The sequence number has the high 16 bits as the epoch. */
         spec->read_seq_num = spec->write_seq_num =
             (sslSequenceNumber)spec->epoch << 48;
 
         dtls_InitRecvdRecords(&spec->recvdRecords);
     }
 
+    if (type == TrafficKeyEarlyApplicationData) {
+        spec->earlyDataRemaining =
+            ss->sec.ci.sid->u.ssl3.locked.sessionTicket.max_early_data_size;
+    }
+
     /* Now that we've set almost everything up, finally cut over. */
     ssl_GetSpecWriteLock(ss);
     tls13_CipherSpecRelease(*specp); /* May delete old cipher. */
     *specp = spec;                   /* Overwrite. */
     ssl_ReleaseSpecWriteLock(ss);
 
     SSL_TRC(3, ("%d: TLS13[%d]: %s installed key for phase='%s'.%d dir=%s",
                 SSL_GETPID(), ss->fd, SSL_ROLE(ss),
@@ -3770,17 +3775,17 @@ tls13_SendClientSecondRound(sslSocket *s
  *   struct {
  *       uint32 ticket_lifetime;
  *       uint32 ticket_age_add;
  *       opaque ticket<1..2^16-1>;
  *       TicketExtension extensions<0..2^16-2>;
  *   } NewSessionTicket;
  */
 
-#define MAX_EARLY_DATA_SIZE (2 << 16) /* Arbitrary limit. */
+PRUint32 ssl_max_early_data_size = (2 << 16); /* Arbitrary limit. */
 
 SECStatus
 tls13_SendNewSessionTicket(sslSocket *ss)
 {
     PRUint16 message_length;
     SECItem ticket_data = { 0, NULL, 0 };
     SECStatus rv;
     NewSessionTicket ticket = { 0 };
@@ -3840,17 +3845,17 @@ tls13_SendNewSessionTicket(sslSocket *ss
         if (rv != SECSuccess)
             goto loser;
 
         /* Length */
         rv = ssl3_AppendHandshakeNumber(ss, 4, 2);
         if (rv != SECSuccess)
             goto loser;
 
-        rv = ssl3_AppendHandshakeNumber(ss, MAX_EARLY_DATA_SIZE, 4);
+        rv = ssl3_AppendHandshakeNumber(ss, ssl_max_early_data_size, 4);
         if (rv != SECSuccess)
             goto loser;
     }
 
     SECITEM_FreeItem(&ticket_data, PR_FALSE);
     return SECSuccess;
 
 loser:
@@ -4084,16 +4089,38 @@ tls13_FormatAdditionalData(PRUint8 *aad,
 {
     PRUint8 *ptr = aad;
 
     PORT_Assert(length == 8);
     ptr = ssl_EncodeUintX(seqNum, 8, ptr);
     PORT_Assert((ptr - aad) == length);
 }
 
+PRInt32
+tls13_LimitEarlyData(sslSocket *ss, SSL3ContentType type, PRInt32 toSend)
+{
+    PRInt32 reduced;
+
+    PORT_Assert(type == content_application_data);
+    PORT_Assert(ss->vrange.max >= SSL_LIBRARY_VERSION_TLS_1_3);
+    PORT_Assert(!ss->firstHsDone);
+    if (ss->ssl3.cwSpec->epoch != TrafficKeyEarlyApplicationData) {
+        return toSend;
+    }
+
+    if (IS_DTLS(ss) && toSend > ss->ssl3.cwSpec->earlyDataRemaining) {
+        /* Don't split application data records in DTLS. */
+        return 0;
+    }
+
+    reduced = PR_MIN(toSend, ss->ssl3.cwSpec->earlyDataRemaining);
+    ss->ssl3.cwSpec->earlyDataRemaining -= reduced;
+    return reduced;
+}
+
 SECStatus
 tls13_ProtectRecord(sslSocket *ss,
                     ssl3CipherSpec *cwSpec,
                     SSL3ContentType type,
                     const SSL3Opaque *pIn,
                     PRUint32 contentLen,
                     sslBuffer *wrBuf)
 {
@@ -4235,16 +4262,27 @@ tls13_UnprotectRecord(sslSocket *ss, SSL
         PORT_SetError(SSL_ERROR_BAD_BLOCK_PADDING);
         return SECFailure;
     }
 
     /* Record the type. */
     cText->type = plaintext->buf[plaintext->len - 1];
     --plaintext->len;
 
+    /* Check that we haven't received too much 0-RTT data. */
+    if (crSpec->epoch == TrafficKeyEarlyApplicationData &&
+        cText->type == content_application_data) {
+        if (plaintext->len > crSpec->earlyDataRemaining) {
+            *alert = unexpected_message;
+            PORT_SetError(SSL_ERROR_TOO_MUCH_EARLY_DATA);
+            return SECFailure;
+        }
+        crSpec->earlyDataRemaining -= plaintext->len;
+    }
+
     SSL_TRC(10,
             ("%d: TLS13[%d]: %s received record of length=%d type=%d",
              SSL_GETPID(), ss->fd, SSL_ROLE(ss),
              plaintext->len, cText->type));
 
     return SECSuccess;
 }
 
--- a/lib/ssl/tls13con.h
+++ b/lib/ssl/tls13con.h
@@ -40,16 +40,17 @@ SSLHashType tls13_GetHashForCipherSuite(
 SSLHashType tls13_GetHash(const sslSocket *ss);
 unsigned int tls13_GetHashSizeForHash(SSLHashType hash);
 unsigned int tls13_GetHashSize(const sslSocket *ss);
 CK_MECHANISM_TYPE tls13_GetHkdfMechanism(sslSocket *ss);
 void tls13_FatalError(sslSocket *ss, PRErrorCode prError,
                       SSL3AlertDescription desc);
 SECStatus tls13_SetupClientHello(sslSocket *ss);
 SECStatus tls13_MaybeDo0RTTHandshake(sslSocket *ss);
+PRInt32 tls13_LimitEarlyData(sslSocket *ss, SSL3ContentType type, PRInt32 toSend);
 PRBool tls13_AllowPskCipher(const sslSocket *ss,
                             const ssl3CipherSuiteDef *cipher_def);
 PRBool tls13_PskSuiteEnabled(sslSocket *ss);
 SECStatus tls13_ComputePskBinder(sslSocket *ss, PRBool sending,
                                  unsigned int prefixLength,
                                  PRUint8 *output, unsigned int *outputLen,
                                  unsigned int maxOutputLen);
 SECStatus tls13_HandleClientHelloPart2(sslSocket *ss,