security/patches/bug-542832-ssl-restart-tstclnt-4.patch
author Brian Smith <bsmith@mozilla.com>
Thu, 01 Dec 2011 14:33:37 -0800
changeset 81102 0ef53633ccc715a5705788ae22c8a844ba0bd88d
permissions -rw-r--r--
Bug 698552: Add SSL_RestartAfterAuthCertificate to mozilla-central's copy of NSS_3_13_2_BETA1, r=kaie, r=honzab

Index: mozilla/security/nss/cmd/tstclnt/tstclnt.c
===================================================================
RCS file: /cvsroot/mozilla/security/nss/cmd/tstclnt/tstclnt.c,v
retrieving revision 1.64
diff -u -8 -p -r1.64 tstclnt.c
--- mozilla/security/nss/cmd/tstclnt/tstclnt.c	6 Oct 2011 22:42:33 -0000	1.64
+++ mozilla/security/nss/cmd/tstclnt/tstclnt.c	16 Nov 2011 08:24:12 -0000
@@ -212,16 +212,18 @@ static void Usage(const char *progName)
                     "-n nickname");
     fprintf(stderr, 
             "%-20s Bypass PKCS11 layer for SSL encryption and MACing.\n", "-B");
     fprintf(stderr, "%-20s Disable SSL v2.\n", "-2");
     fprintf(stderr, "%-20s Disable SSL v3.\n", "-3");
     fprintf(stderr, "%-20s Disable TLS (SSL v3.1).\n", "-T");
     fprintf(stderr, "%-20s Prints only payload data. Skips HTTP header.\n", "-S");
     fprintf(stderr, "%-20s Client speaks first. \n", "-f");
+    fprintf(stderr, "%-20s Use synchronous certificate validation "
+                    "(required for SSL2)\n", "-O");
     fprintf(stderr, "%-20s Override bad server cert. Make it OK.\n", "-o");
     fprintf(stderr, "%-20s Disable SSL socket locking.\n", "-s");
     fprintf(stderr, "%-20s Verbose progress reporting.\n", "-v");
     fprintf(stderr, "%-20s Use export policy.\n", "-x");
     fprintf(stderr, "%-20s Ping the server and then exit.\n", "-q");
     fprintf(stderr, "%-20s Renegotiate N times (resuming session if N>1).\n", "-r N");
     fprintf(stderr, "%-20s Enable the session ticket extension.\n", "-u");
     fprintf(stderr, "%-20s Enable compression.\n", "-z");
@@ -288,30 +290,54 @@ disableAllSSLCiphers(void)
 	    fprintf(stderr,
 	            "SSL_CipherPrefSet didn't like value 0x%04x (i = %d): %s\n",
 	    	   suite, i, SECU_Strerror(err));
 	    exit(2);
 	}
     }
 }
 
+typedef struct
+{
+   PRBool shouldPause; /* PR_TRUE if we should use asynchronous peer cert 
+                        * authentication */
+   PRBool isPaused;    /* PR_TRUE if libssl is waiting for us to validate the
+                        * peer's certificate and restart the handshake. */
+   void * dbHandle;    /* Certificate database handle to use while
+                        * authenticating the peer's certificate. */
+} ServerCertAuth;
+
 /*
  * Callback is called when incoming certificate is not valid.
  * Returns SECSuccess to accept the cert anyway, SECFailure to reject.
  */
 static SECStatus 
 ownBadCertHandler(void * arg, PRFileDesc * socket)
 {
     PRErrorCode err = PR_GetError();
     /* can log invalid cert here */
     fprintf(stderr, "Bad server certificate: %d, %s\n", err, 
             SECU_Strerror(err));
     return SECSuccess;	/* override, say it's OK. */
 }
 
+static SECStatus 
+ownAuthCertificate(void *arg, PRFileDesc *fd, PRBool checkSig,
+                       PRBool isServer)
+{
+    ServerCertAuth * serverCertAuth = (ServerCertAuth *) arg;
+
+    FPRINTF(stderr, "using asynchronous certificate validation\n", progName);
+
+    PORT_Assert(serverCertAuth->shouldPause);
+    PORT_Assert(!serverCertAuth->isPaused);
+    serverCertAuth->isPaused = PR_TRUE;
+    return SECWouldBlock;
+}
+
 SECStatus
 own_GetClientAuthData(void *                       arg,
                       PRFileDesc *                 socket,
                       struct CERTDistNamesStr *    caNames,
                       struct CERTCertificateStr ** pRetCert,
                       struct SECKEYPrivateKeyStr **pRetKey)
 {
     if (verbose > 1) {
@@ -493,21 +519,47 @@ separateReqHeader(const PRFileDesc* outF
     } else if (((c) >= 'a') && ((c) <= 'f')) { \
 	i = (c) - 'a' + 10; \
     } else if (((c) >= 'A') && ((c) <= 'F')) { \
 	i = (c) - 'A' + 10; \
     } else { \
 	Usage(progName); \
     }
 
+static SECStatus
+restartHandshakeAfterServerCertIfNeeded(PRFileDesc * fd,
+                                        ServerCertAuth * serverCertAuth,
+                                        PRBool override)
+{
+    SECStatus rv;
+    
+    if (!serverCertAuth->isPaused)
+	return SECSuccess;
+    
+    FPRINTF(stderr, "%s: handshake was paused by auth certificate hook\n",
+            progName);
+
+    serverCertAuth->isPaused = PR_FALSE;
+    rv = SSL_AuthCertificate(serverCertAuth->dbHandle, fd, PR_TRUE, PR_FALSE);
+    if (rv != SECSuccess && override) {
+        rv = ownBadCertHandler(NULL, fd);
+    }
+    if (rv != SECSuccess) {
+	return rv;
+    }
+    
+    rv = SSL_RestartHandshakeAfterAuthCertificate(fd);
+
+    return rv;
+}
+    
 int main(int argc, char **argv)
 {
     PRFileDesc *       s;
     PRFileDesc *       std_out;
-    CERTCertDBHandle * handle;
     char *             host	=  NULL;
     char *             certDir  =  NULL;
     char *             nickname =  NULL;
     char *             cipherString = NULL;
     char *             tmp;
     int                multiplier = 0;
     SECStatus          rv;
     PRStatus           status;
@@ -525,51 +577,58 @@ int main(int argc, char **argv)
     int                enableFalseStart = 0;
     PRSocketOptionData opt;
     PRNetAddr          addr;
     PRPollDesc         pollset[2];
     PRBool             pingServerFirst = PR_FALSE;
     PRBool             clientSpeaksFirst = PR_FALSE;
     PRBool             wrStarted = PR_FALSE;
     PRBool             skipProtoHeader = PR_FALSE;
+    ServerCertAuth     serverCertAuth;
     int                headerSeparatorPtrnId = 0;
     int                error = 0;
     PRUint16           portno = 443;
     char *             hs1SniHostName = NULL;
     char *             hs2SniHostName = NULL;
     PLOptState *optstate;
     PLOptStatus optstatus;
     PRStatus prStatus;
 
+    serverCertAuth.shouldPause = PR_TRUE;
+    serverCertAuth.isPaused = PR_FALSE;
+    serverCertAuth.dbHandle = NULL;
+
     progName = strrchr(argv[0], '/');
     if (!progName)
 	progName = strrchr(argv[0], '\\');
     progName = progName ? progName+1 : argv[0];
 
     tmp = PR_GetEnv("NSS_DEBUG_TIMEOUT");
     if (tmp && tmp[0]) {
        int sec = PORT_Atoi(tmp);
        if (sec > 0) {
            maxInterval = PR_SecondsToInterval(sec);
        }
     }
 
     optstate = PL_CreateOptState(argc, argv,
-                                 "23BSTW:a:c:d:fgh:m:n:op:qr:suvw:xz");
+                                 "23BOSTW:a:c:d:fgh:m:n:op:qr:suvw:xz");
     while ((optstatus = PL_GetNextOpt(optstate)) == PL_OPT_OK) {
 	switch (optstate->option) {
 	  case '?':
 	  default : Usage(progName); 			break;
 
           case '2': disableSSL2 = 1; 			break;
 
           case '3': disableSSL3 = 1; 			break;
 
           case 'B': bypassPKCS11 = 1; 			break;
 
+          case 'O': serverCertAuth.shouldPause = PR_FALSE; break;
+
           case 'S': skipProtoHeader = PR_TRUE;                 break;
 
           case 'T': disableTLS  = 1; 			break;
 
           case 'a': if (!hs1SniHostName) {
                         hs1SniHostName = PORT_Strdup(optstate->value);
                     } else if (!hs2SniHostName) {
                         hs2SniHostName =  PORT_Strdup(optstate->value);
@@ -645,24 +704,18 @@ int main(int argc, char **argv)
     } else {
 	char *certDirTmp = certDir;
 	certDir = SECU_ConfigDirectory(certDirTmp);
 	PORT_Free(certDirTmp);
     }
     rv = NSS_Init(certDir);
     if (rv != SECSuccess) {
 	SECU_PrintError(progName, "unable to open cert database");
-#if 0
-    rv = CERT_OpenVolatileCertDB(handle);
-	CERT_SetDefaultCertDB(handle);
-#else
 	return 1;
-#endif
     }
-    handle = CERT_GetDefaultCertDB();
 
     /* set the policy bits true for all the cipher suites. */
     if (useExportPolicy)
 	NSS_SetExportPolicy();
     else
 	NSS_SetDomesticPolicy();
 
     /* all the SSL2 and SSL3 cipher suites are enabled by default. */
@@ -871,17 +924,23 @@ int main(int argc, char **argv)
     rv = SSL_OptionSet(s, SSL_ENABLE_FALSE_START, enableFalseStart);
     if (rv != SECSuccess) {
 	SECU_PrintError(progName, "error enabling false start");
 	return 1;
     }
 
     SSL_SetPKCS11PinArg(s, &pwdata);
 
-    SSL_AuthCertificateHook(s, SSL_AuthCertificate, (void *)handle);
+    serverCertAuth.dbHandle = CERT_GetDefaultCertDB();
+
+    if (serverCertAuth.shouldPause) {
+	SSL_AuthCertificateHook(s, ownAuthCertificate, &serverCertAuth);
+    } else {
+	SSL_AuthCertificateHook(s, SSL_AuthCertificate, serverCertAuth.dbHandle);
+    }
     if (override) {
 	SSL_BadCertHook(s, ownBadCertHandler, NULL);
     }
     SSL_GetClientAuthDataHook(s, own_GetClientAuthData, (void *)nickname);
     SSL_HandshakeCallback(s, handshakeCallback, hs2SniHostName);
     if (hs1SniHostName) {
         SSL_SetURL(s, hs1SniHostName);
     } else {
@@ -979,16 +1038,24 @@ int main(int argc, char **argv)
     ** socket, read data from socket and write to stdout.
     */
     FPRINTF(stderr, "%s: ready...\n", progName);
 
     while (pollset[SSOCK_FD].in_flags | pollset[STDIN_FD].in_flags) {
 	char buf[4000];	/* buffer for stdin */
 	int nb;		/* num bytes read from stdin. */
 
+	rv = restartHandshakeAfterServerCertIfNeeded(s, &serverCertAuth,
+						     override);
+	if (rv != SECSuccess) {
+	    error = 254; /* 254 (usually) means "handshake failed" */
+	    SECU_PrintError(progName, "authentication of server cert failed");
+	    goto done;
+	}
+	        
 	pollset[SSOCK_FD].out_flags = 0;
 	pollset[STDIN_FD].out_flags = 0;
 
 	FPRINTF(stderr, "%s: about to call PR_Poll !\n", progName);
 	filesReady = PR_Poll(pollset, npds, PR_INTERVAL_NO_TIMEOUT);
 	if (filesReady < 0) {
 	    SECU_PrintError(progName, "select failed");
 	    error = 1;
@@ -1037,16 +1104,25 @@ int main(int argc, char **argv)
 			    goto done;
 			}
 			cc = 0;
 		    }
 		    bufp += cc;
 		    nb   -= cc;
 		    if (nb <= 0) 
 		    	break;
+
+		    rv = restartHandshakeAfterServerCertIfNeeded(s,
+				&serverCertAuth, override);
+		    if (rv != SECSuccess) {
+			error = 254; /* 254 (usually) means "handshake failed" */
+			SECU_PrintError(progName, "authentication of server cert failed");
+			goto done;
+		    }
+
 		    pollset[SSOCK_FD].in_flags = PR_POLL_WRITE | PR_POLL_EXCEPT;
 		    pollset[SSOCK_FD].out_flags = 0;
 		    FPRINTF(stderr,
 		            "%s: about to call PR_Poll on writable socket !\n", 
 			    progName);
 		    cc = PR_Poll(pollset, 1, PR_INTERVAL_NO_TIMEOUT);
 		    FPRINTF(stderr,
 		            "%s: PR_Poll returned with writable socket !\n", 
Index: mozilla/security/nss/tests/ssl/ssl.sh
===================================================================
RCS file: /cvsroot/mozilla/security/nss/tests/ssl/ssl.sh,v
retrieving revision 1.106
diff -u -8 -p -r1.106 ssl.sh
--- mozilla/security/nss/tests/ssl/ssl.sh	29 Jan 2010 22:36:25 -0000	1.106
+++ mozilla/security/nss/tests/ssl/ssl.sh	16 Nov 2011 08:24:14 -0000
@@ -303,16 +303,26 @@ ssl_cov()
                
   exec < ${SSLCOV}
   while read ectype tls param testname
   do
       echo "${testname}" | grep "EXPORT" > /dev/null 
       EXP=$?
       echo "${testname}" | grep "SSL2" > /dev/null
       SSL2=$?
+
+      if [ "${SSL2}" -eq 0 ] ; then
+          # We cannot use asynchronous cert verification with SSL2
+          SSL2_FLAGS=-O
+      else
+          # Do not enable SSL2 for non-SSL2-specific tests. SSL2 is disabled by
+          # default in libssl but it is enabled by default in tstclnt; we want
+          # to test the libssl default whenever possible.
+          SSL2_FLAGS=-2
+      fi
       
       if [ "$NORM_EXT" = "Extended Test" -a "${SSL2}" -eq 0 ] ; then
           echo "$SCRIPTNAME: skipping  $testname for $NORM_EXT"
       elif [ "$ectype" = "ECC" -a -z "$NSS_ENABLE_ECC" ] ; then
           echo "$SCRIPTNAME: skipping  $testname (ECC only)"
       elif [ "$SERVER_MODE" = "fips" -o "$CLIENT_MODE" = "fips" ] && [ "$SSL2" -eq 0 -o "$EXP" -eq 0 ] ; then
           echo "$SCRIPTNAME: skipping  $testname (non-FIPS only)"
       elif [ "`echo $ectype | cut -b 1`" != "#" ] ; then
@@ -345,21 +355,21 @@ ssl_cov()
               is_selfserv_alive
             else
               kill_selfserv
               start_selfserv
               mixed=0
             fi
           fi
 
-          echo "tstclnt -p ${PORT} -h ${HOSTADDR} -c ${param} ${TLS_FLAG} ${CLIENT_OPTIONS} \\"
+          echo "tstclnt -p ${PORT} -h ${HOSTADDR} -c ${param} ${TLS_FLAG} ${SSL2_FLAGS} ${CLIENT_OPTIONS} \\"
           echo "        -f -d ${P_R_CLIENTDIR} -v -w nss < ${REQUEST_FILE}"
 
           rm ${TMP}/$HOST.tmp.$$ 2>/dev/null
-          ${PROFTOOL} ${BINDIR}/tstclnt -p ${PORT} -h ${HOSTADDR} -c ${param} ${TLS_FLAG} ${CLIENT_OPTIONS} -f \
+          ${PROFTOOL} ${BINDIR}/tstclnt -p ${PORT} -h ${HOSTADDR} -c ${param} ${TLS_FLAG} ${SSL2_FLAGS} ${CLIENT_OPTIONS} -f \
                   -d ${P_R_CLIENTDIR} -v -w nss < ${REQUEST_FILE} \
                   >${TMP}/$HOST.tmp.$$  2>&1
           ret=$?
           cat ${TMP}/$HOST.tmp.$$ 
           rm ${TMP}/$HOST.tmp.$$ 2>/dev/null
           html_msg $ret 0 "${testname}" \
                    "produced a returncode of $ret, expected is 0"
       fi