Bug 1542077 - Added extra controls and tests to mp_set_int and mp_set_ulong. r=jcj,kjacobs
authorMarcus Burghardt <mburghardt@mozilla.com>
Tue, 13 Aug 2019 21:29:38 +0000
changeset 15251 9bc47e69613e9ee9c8aaaf555150f815fc8161d9
parent 15250 ec113de50cdd1d4ccb160870099165a0f4d2b99e
child 15252 dfd6996fe7425eb0437346d11a01082f16fcfe34
push id3462
push userjjones@mozilla.com
push dateTue, 13 Aug 2019 21:30:32 +0000
reviewersjcj, kjacobs
bugs1542077
Bug 1542077 - Added extra controls and tests to mp_set_int and mp_set_ulong. r=jcj,kjacobs Differential Revision: https://phabricator.services.mozilla.com/D40649
gtests/freebl_gtest/mpi_unittest.cc
lib/freebl/mpi/README
lib/freebl/mpi/mpi.c
lib/freebl/mpi/mpi.h
--- a/gtests/freebl_gtest/mpi_unittest.cc
+++ b/gtests/freebl_gtest/mpi_unittest.cc
@@ -285,9 +285,9 @@ TEST_F(DISABLED_MPITest, MpiCmpConstTest
   }
   printf("time c: %u\n", time_c / runs);
 
   mp_clear(&a);
   mp_clear(&b);
   mp_clear(&c);
 }
 
-}  // nss_test
+}  // namespace nss_test
--- a/lib/freebl/mpi/README
+++ b/lib/freebl/mpi/README
@@ -162,16 +162,17 @@ using the mp_clear() function.  Remember
 create as a local variable in a function must be mp_clear()'d before
 that function exits, or else the memory allocated to that mp_int will
 be orphaned and unrecoverable.
 
 To set an mp_int to a given value, the following functions are given:
 
         mp_set(mp_int *mp, mp_digit d);
         mp_set_int(mp_int *mp, long z);
+        mp_set_ulong(mp_int *mp, unsigned long z);
 
 The mp_set() function sets the mp_int to a single digit value, while
 mp_set_int() sets the mp_int to a signed long integer value.
 
 To set an mp_int to zero, use:
 
         mp_zero(mp_int *mp);
 
--- a/lib/freebl/mpi/mpi.c
+++ b/lib/freebl/mpi/mpi.c
@@ -339,16 +339,18 @@ mp_set(mp_int *mp, mp_digit d)
 /* {{{ mp_set_int(mp, z) */
 
 mp_err
 mp_set_int(mp_int *mp, long z)
 {
     unsigned long v = labs(z);
     mp_err res;
 
+    ARGCHK(mp != NULL, MP_BADARG);
+
     /* https://bugzilla.mozilla.org/show_bug.cgi?id=1509432 */
     if ((res = mp_set_ulong(mp, v)) != MP_OKAY) { /* avoids duplicated code */
         return res;
     }
 
     if (z < 0) {
         SIGN(mp) = NEG;
     }
@@ -1422,17 +1424,17 @@ mp_sqrmod(const mp_int *a, const mp_int 
 mp_err
 s_mp_exptmod(const mp_int *a, const mp_int *b, const mp_int *m, mp_int *c)
 {
     mp_int s, x, mu;
     mp_err res;
     mp_digit d;
     unsigned int dig, bit;
 
-    ARGCHK(a != NULL && b != NULL && c != NULL, MP_BADARG);
+    ARGCHK(a != NULL && b != NULL && c != NULL && m != NULL, MP_BADARG);
 
     if (mp_cmp_z(b) < 0 || mp_cmp_z(m) <= 0)
         return MP_RANGE;
 
     if ((res = mp_init(&s)) != MP_OKAY)
         return res;
     if ((res = mp_init_copy(&x, a)) != MP_OKAY ||
         (res = mp_mod(&x, m, &x)) != MP_OKAY)
@@ -1509,17 +1511,17 @@ X:
 /* {{{ mp_exptmod_d(a, d, m, c) */
 
 mp_err
 mp_exptmod_d(const mp_int *a, mp_digit d, const mp_int *m, mp_int *c)
 {
     mp_int s, x;
     mp_err res;
 
-    ARGCHK(a != NULL && c != NULL, MP_BADARG);
+    ARGCHK(a != NULL && c != NULL && m != NULL, MP_BADARG);
 
     if ((res = mp_init(&s)) != MP_OKAY)
         return res;
     if ((res = mp_init_copy(&x, a)) != MP_OKAY)
         goto X;
 
     mp_set(&s, 1);
 
@@ -1562,16 +1564,18 @@ X:
   mp_cmp_z(a)
 
   Compare a <=> 0.  Returns <0 if a<0, 0 if a=0, >0 if a>0.
  */
 
 int
 mp_cmp_z(const mp_int *a)
 {
+    ARGMPCHK(a != NULL);
+
     if (SIGN(a) == NEG)
         return MP_LT;
     else if (USED(a) == 1 && DIGIT(a, 0) == 0)
         return MP_EQ;
     else
         return MP_GT;
 
 } /* end mp_cmp_z() */
@@ -1652,17 +1656,17 @@ mp_cmp_mag(const mp_int *a, const mp_int
 /*
   mp_isodd(a)
 
   Returns a true (non-zero) value if a is odd, false (zero) otherwise.
  */
 int
 mp_isodd(const mp_int *a)
 {
-    ARGCHK(a != NULL, 0);
+    ARGMPCHK(a != NULL);
 
     return (int)(DIGIT(a, 0) & 1);
 
 } /* end mp_isodd() */
 
 /* }}} */
 
 /* {{{ mp_iseven(a) */
@@ -1996,17 +2000,17 @@ mp_trailing_zeros(const mp_int *mp)
 */
 mp_err
 s_mp_almost_inverse(const mp_int *a, const mp_int *p, mp_int *c)
 {
     mp_err res;
     mp_err k = 0;
     mp_int d, f, g;
 
-    ARGCHK(a && p && c, MP_BADARG);
+    ARGCHK(a != NULL && p != NULL && c != NULL, MP_BADARG);
 
     MP_DIGITS(&d) = 0;
     MP_DIGITS(&f) = 0;
     MP_DIGITS(&g) = 0;
     MP_CHECKOK(mp_init(&d));
     MP_CHECKOK(mp_init_copy(&f, a)); /* f = a */
     MP_CHECKOK(mp_init_copy(&g, p)); /* g = p */
 
@@ -2130,17 +2134,17 @@ CLEANUP:
 /* compute mod inverse using Schroeppel's method, only if m is odd */
 mp_err
 s_mp_invmod_odd_m(const mp_int *a, const mp_int *m, mp_int *c)
 {
     int k;
     mp_err res;
     mp_int x;
 
-    ARGCHK(a && m && c, MP_BADARG);
+    ARGCHK(a != NULL && m != NULL && c != NULL, MP_BADARG);
 
     if (mp_cmp_z(a) == 0 || mp_cmp_z(m) == 0)
         return MP_RANGE;
     if (mp_iseven(m))
         return MP_UNDEF;
 
     MP_DIGITS(&x) = 0;
 
@@ -2168,17 +2172,17 @@ CLEANUP:
 
 /* Known good algorithm for computing modular inverse.  But slow. */
 mp_err
 mp_invmod_xgcd(const mp_int *a, const mp_int *m, mp_int *c)
 {
     mp_int g, x;
     mp_err res;
 
-    ARGCHK(a && m && c, MP_BADARG);
+    ARGCHK(a != NULL && m != NULL && c != NULL, MP_BADARG);
 
     if (mp_cmp_z(a) == 0 || mp_cmp_z(m) == 0)
         return MP_RANGE;
 
     MP_DIGITS(&g) = 0;
     MP_DIGITS(&x) = 0;
     MP_CHECKOK(mp_init(&x));
     MP_CHECKOK(mp_init(&g));
@@ -2264,16 +2268,18 @@ mp_err
 s_mp_invmod_even_m(const mp_int *a, const mp_int *m, mp_int *c)
 {
     mp_err res;
     mp_size k;
     mp_int oddFactor, evenFactor; /* factors of the modulus */
     mp_int oddPart, evenPart;     /* parts to combine via CRT. */
     mp_int C2, tmp1, tmp2;
 
+    ARGCHK(a != NULL && m != NULL && c != NULL, MP_BADARG);
+
     /*static const mp_digit d1 = 1; */
     /*static const mp_int one = { MP_ZPOS, 1, 1, (mp_digit *)&d1 }; */
 
     if ((res = s_mp_ispow2(m)) >= 0) {
         k = res;
         return s_mp_invmod_2d(a, k, c);
     }
     MP_DIGITS(&oddFactor) = 0;
@@ -2342,18 +2348,17 @@ CLEANUP:
   Compute c = a^-1 (mod m), if there is an inverse for a (mod m).
   This is equivalent to the question of whether (a, m) = 1.  If not,
   MP_UNDEF is returned, and there is no inverse.
  */
 
 mp_err
 mp_invmod(const mp_int *a, const mp_int *m, mp_int *c)
 {
-
-    ARGCHK(a && m && c, MP_BADARG);
+    ARGCHK(a != NULL && m != NULL && c != NULL, MP_BADARG);
 
     if (mp_cmp_z(a) == 0 || mp_cmp_z(m) == 0)
         return MP_RANGE;
 
     if (mp_isodd(m)) {
         return s_mp_invmod_odd_m(a, m, c);
     }
     if (mp_iseven(a))
@@ -2710,16 +2715,18 @@ mp_strerror(mp_err ec)
 /* {{{ Memory management */
 
 /* {{{ s_mp_grow(mp, min) */
 
 /* Make sure there are at least 'min' digits allocated to mp              */
 mp_err
 s_mp_grow(mp_int *mp, mp_size min)
 {
+    ARGCHK(mp != NULL, MP_BADARG);
+
     if (min > ALLOC(mp)) {
         mp_digit *tmp;
 
         /* Set min to next nearest default precision block size */
         min = MP_ROUNDUP(min, s_mp_defprec);
 
         if ((tmp = s_mp_alloc(min, sizeof(mp_digit))) == NULL)
             return MP_MEM;
@@ -2739,16 +2746,18 @@ s_mp_grow(mp_int *mp, mp_size min)
 /* }}} */
 
 /* {{{ s_mp_pad(mp, min) */
 
 /* Make sure the used size of mp is at least 'min', growing if needed     */
 mp_err
 s_mp_pad(mp_int *mp, mp_size min)
 {
+    ARGCHK(mp != NULL, MP_BADARG);
+
     if (min > USED(mp)) {
         mp_err res;
 
         /* Make sure there is room to increase precision  */
         if (min > ALLOC(mp)) {
             if ((res = s_mp_grow(mp, min)) != MP_OKAY)
                 return res;
         } else {
@@ -2858,16 +2867,18 @@ s_mp_exch(mp_int *a, mp_int *b)
  */
 
 mp_err
 s_mp_lshd(mp_int *mp, mp_size p)
 {
     mp_err res;
     unsigned int ix;
 
+    ARGCHK(mp != NULL, MP_BADARG);
+
     if (p == 0)
         return MP_OKAY;
 
     if (MP_USED(mp) == 1 && MP_DIGIT(mp, 0) == 0)
         return MP_OKAY;
 
     if ((res = s_mp_pad(mp, USED(mp) + p)) != MP_OKAY)
         return res;
@@ -2990,16 +3001,18 @@ s_mp_div_2(mp_int *mp)
 
 mp_err
 s_mp_mul_2(mp_int *mp)
 {
     mp_digit *pd;
     unsigned int ix, used;
     mp_digit kin = 0;
 
+    ARGCHK(mp != NULL, MP_BADARG);
+
     /* Shift digits leftward by 1 bit */
     used = MP_USED(mp);
     pd = MP_DIGITS(mp);
     for (ix = 0; ix < used; ix++) {
         mp_digit d = *pd;
         *pd++ = (d << 1) | kin;
         kin = (d >> (DIGIT_BIT - 1));
     }
@@ -3099,16 +3112,18 @@ s_mp_div_2d(mp_int *mp, mp_digit d)
 mp_err
 s_mp_norm(mp_int *a, mp_int *b, mp_digit *pd)
 {
     mp_digit d;
     mp_digit mask;
     mp_digit b_msd;
     mp_err res = MP_OKAY;
 
+    ARGCHK(a != NULL && b != NULL && pd != NULL, MP_BADARG);
+
     d = 0;
     mask = DIGIT_MAX & ~(DIGIT_MAX >> 1); /* mask is msb of digit */
     b_msd = DIGIT(b, USED(b) - 1);
     while (!(b_msd & mask)) {
         b_msd <<= 1;
         ++d;
     }
 
@@ -4363,16 +4378,18 @@ CLEANUP:
 /* {{{ Primitive comparisons */
 
 /* {{{ s_mp_cmp(a, b) */
 
 /* Compare |a| <=> |b|, return 0 if equal, <0 if a<b, >0 if a>b           */
 int
 s_mp_cmp(const mp_int *a, const mp_int *b)
 {
+    ARGMPCHK(a != NULL && b != NULL);
+
     mp_size used_a = MP_USED(a);
     {
         mp_size used_b = MP_USED(b);
 
         if (used_a > used_b)
             goto IS_GT;
         if (used_a < used_b)
             goto IS_LT;
@@ -4414,16 +4431,18 @@ IS_GT:
 /* }}} */
 
 /* {{{ s_mp_cmp_d(a, d) */
 
 /* Compare |a| <=> d, return 0 if equal, <0 if a<d, >0 if a>d             */
 int
 s_mp_cmp_d(const mp_int *a, mp_digit d)
 {
+    ARGMPCHK(a != NULL);
+
     if (USED(a) > 1)
         return MP_GT;
 
     if (DIGIT(a, 0) < d)
         return MP_LT;
     else if (DIGIT(a, 0) > d)
         return MP_GT;
     else
@@ -4440,16 +4459,18 @@ s_mp_cmp_d(const mp_int *a, mp_digit d)
   k such that v = 2^k, i.e. lg(v).
  */
 int
 s_mp_ispow2(const mp_int *v)
 {
     mp_digit d;
     int extra = 0, ix;
 
+    ARGMPCHK(v != NULL);
+
     ix = MP_USED(v) - 1;
     d = MP_DIGIT(v, ix); /* most significant digit of v */
 
     extra = s_mp_ispow2d(d);
     if (extra < 0 || ix == 0)
         return extra;
 
     while (--ix >= 0) {
@@ -4767,20 +4788,17 @@ mp_to_signed_octets(const mp_int *mp, un
 /* output a buffer of big endian octets exactly as long as requested.
    constant time on the value of mp. */
 mp_err
 mp_to_fixlen_octets(const mp_int *mp, unsigned char *str, mp_size length)
 {
     int ix, jx;
     unsigned int bytes;
 
-    ARGCHK(mp != NULL, MP_BADARG);
-    ARGCHK(str != NULL, MP_BADARG);
-    ARGCHK(!SIGN(mp), MP_BADARG);
-    ARGCHK(length > 0, MP_BADARG);
+    ARGCHK(mp != NULL && str != NULL && !SIGN(mp) && length > 0, MP_BADARG);
 
     /* Constant time on the value of mp.  Don't use mp_unsigned_octet_size. */
     bytes = USED(mp) * MP_DIGIT_SIZE;
 
     /* If the output is shorter than the native size of mp, then check that any
      * bytes not written have zero values.  This check isn't constant time on
      * the assumption that timing-sensitive callers can guarantee that mp fits
      * in the allocated space. */
--- a/lib/freebl/mpi/mpi.h
+++ b/lib/freebl/mpi/mpi.h
@@ -283,28 +283,37 @@ void freebl_cpuid(unsigned long op, unsi
 #define RADIX MP_RADIX
 #define MAX_RADIX MP_MAX_RADIX
 #define SIGN(MP) MP_SIGN(MP)
 #define USED(MP) MP_USED(MP)
 #define ALLOC(MP) MP_ALLOC(MP)
 #define DIGITS(MP) MP_DIGITS(MP)
 #define DIGIT(MP, N) MP_DIGIT(MP, N)
 
+/* Functions which return an mp_err value will NULL-check their arguments via
+ * ARGCHK(condition, return), where the caller is responsible for checking the
+ * mp_err return code. For functions that return an integer type, the caller 
+ * has no way to tell if the value is an error code or a legitimate value. 
+ * Therefore, ARGMPCHK(condition) will trigger an assertion failure on debug
+ * builds, but no-op in optimized builds. */
 #if MP_ARGCHK == 1
+#define ARGMPCHK(X) /* */
 #define ARGCHK(X, Y)    \
     {                   \
         if (!(X)) {     \
             return (Y); \
         }               \
     }
 #elif MP_ARGCHK == 2
 #include <assert.h>
+#define ARGMPCHK(X) assert(X)
 #define ARGCHK(X, Y) assert(X)
 #else
-#define ARGCHK(X, Y) /*  */
+#define ARGMPCHK(X)  /* */
+#define ARGCHK(X, Y) /* */
 #endif
 
 #ifdef CT_VERIF
 void mp_taint(mp_int *mp);
 void mp_untaint(mp_int *mp);
 #endif
 
 SEC_END_PROTOS