Bug 1631597 - Constant-time GCD and modular inversion r=rrelyea,kjacobs NSS_3_53_BRANCH
authorSohaib ul Hassan <sohaibulhassan@tuni.fi>
Tue, 16 Jun 2020 15:40:57 -0700
branchNSS_3_53_BRANCH
changeset 15674 c5c89b18053aad6147f82abecc568653b78095b4
parent 15639 5c1dff547a19533fdb4b0a98453cd2c784c0ece6
child 15675 fca7a9ba4da2735a3d844aac4411cd5074d456f7
push id3775
push userjjones@mozilla.com
push dateTue, 16 Jun 2020 23:52:22 +0000
reviewersrrelyea, kjacobs
bugs1631597
Bug 1631597 - Constant-time GCD and modular inversion r=rrelyea,kjacobs The implementation is based on the work by Bernstein and Yang (https://eprint.iacr.org/2019/266) "Fast constant-time gcd computation and modular inversion". It fixes the old mp_gcd and s_mp_invmod_odd_m functions. The patch also fix mpl_significant_bits s_mp_div_2d and s_mp_mul_2d by having less control flow to reduce side-channel leaks. Co Author : Billy Bob Brumley Differential Revision: https://phabricator.services.mozilla.com/D78668
lib/freebl/mpi/mpi.c
lib/freebl/mpi/mpi.h
lib/freebl/mpi/mplogic.c
--- a/lib/freebl/mpi/mpi.c
+++ b/lib/freebl/mpi/mpi.c
@@ -3,16 +3,17 @@
  *
  *  Arbitrary precision integer arithmetic library
  *
  * This Source Code Form is subject to the terms of the Mozilla Public
  * License, v. 2.0. If a copy of the MPL was not distributed with this
  * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
 
 #include "mpi-priv.h"
+#include "mplogic.h"
 #if defined(OSF1)
 #include <c_asm.h>
 #endif
 
 #if defined(__arm__) && \
     ((defined(__thumb__) && !defined(__thumb2__)) || defined(__ARM_ARCH_3__))
 /* 16-bit thumb or ARM v3 doesn't work inlined assember version */
 #undef MP_ASSEMBLY_MULTIPLY
@@ -1683,108 +1684,122 @@ mp_iseven(const mp_int *a)
 /* }}} */
 
 /*------------------------------------------------------------------------*/
 /* {{{ Number theoretic functions */
 
 /* {{{ mp_gcd(a, b, c) */
 
 /*
-  Like the old mp_gcd() function, except computes the GCD using the
-  binary algorithm due to Josef Stein in 1961 (via Knuth).
+  Computes the GCD using the constant-time algorithm
+  by Bernstein and Yang (https://eprint.iacr.org/2019/266)
+  "Fast constant-time gcd computation and modular inversion"
  */
 mp_err
 mp_gcd(mp_int *a, mp_int *b, mp_int *c)
 {
     mp_err res;
-    mp_int u, v, t;
-    mp_size k = 0;
+    mp_digit cond = 0, mask = 0;
+    mp_int g, temp, f;
+    int i, j, m, bit = 1, delta = 1, shifts = 0, last = -1;
+    mp_size top, flen, glen;
+    mp_int *clear[3];
 
     ARGCHK(a != NULL && b != NULL && c != NULL, MP_BADARG);
-
-    if (mp_cmp_z(a) == MP_EQ && mp_cmp_z(b) == MP_EQ)
-        return MP_RANGE;
+    /*
+    Early exit if either of the inputs is zero.
+    Caller is responsible for the proper handling of inputs.
+    */
     if (mp_cmp_z(a) == MP_EQ) {
-        return mp_copy(b, c);
+        res = mp_copy(b, c);
+        SIGN(c) = ZPOS;
+        return res;
     } else if (mp_cmp_z(b) == MP_EQ) {
-        return mp_copy(a, c);
-    }
-
-    if ((res = mp_init(&t)) != MP_OKAY)
+        res = mp_copy(a, c);
+        SIGN(c) = ZPOS;
         return res;
-    if ((res = mp_init_copy(&u, a)) != MP_OKAY)
-        goto U;
-    if ((res = mp_init_copy(&v, b)) != MP_OKAY)
-        goto V;
-
-    SIGN(&u) = ZPOS;
-    SIGN(&v) = ZPOS;
-
-    /* Divide out common factors of 2 until at least 1 of a, b is even */
-    while (mp_iseven(&u) && mp_iseven(&v)) {
-        s_mp_div_2(&u);
-        s_mp_div_2(&v);
-        ++k;
+    }
+
+    MP_CHECKOK(mp_init(&temp));
+    clear[++last] = &temp;
+    MP_CHECKOK(mp_init_copy(&g, a));
+    clear[++last] = &g;
+    MP_CHECKOK(mp_init_copy(&f, b));
+    clear[++last] = &f;
+
+    /*
+    For even case compute the number of
+    shared powers of 2 in f and g.
+    */
+    for (i = 0; i < USED(&f) && i < USED(&g); i++) {
+        mask = ~(DIGIT(&f, i) | DIGIT(&g, i));
+        for (j = 0; j < MP_DIGIT_BIT; j++) {
+            bit &= mask;
+            shifts += bit;
+            mask >>= 1;
+        }
     }
-
-    /* Initialize t */
-    if (mp_isodd(&u)) {
-        if ((res = mp_copy(&v, &t)) != MP_OKAY)
-            goto CLEANUP;
-
-        /* t = -v */
-        if (SIGN(&v) == ZPOS)
-            SIGN(&t) = NEG;
-        else
-            SIGN(&t) = ZPOS;
-
-    } else {
-        if ((res = mp_copy(&u, &t)) != MP_OKAY)
-            goto CLEANUP;
+    /* Reduce to the odd case by removing the powers of 2. */
+    s_mp_div_2d(&f, shifts);
+    s_mp_div_2d(&g, shifts);
+
+    /* Allocate to the size of largest mp_int. */
+    top = (mp_size)1 + ((USED(&f) >= USED(&g)) ? USED(&f) : USED(&g));
+    MP_CHECKOK(s_mp_grow(&f, top));
+    MP_CHECKOK(s_mp_grow(&g, top));
+    MP_CHECKOK(s_mp_grow(&temp, top));
+
+    /* Make sure f contains the odd value. */
+    MP_CHECKOK(mp_cswap((~DIGIT(&f, 0) & 1), &f, &g, top));
+
+    /* Upper bound for the total iterations. */
+    flen = mpl_significant_bits(&f);
+    glen = mpl_significant_bits(&g);
+    m = 4 + 3 * ((flen >= glen) ? flen : glen);
+
+#if defined(_MSC_VER)
+#pragma warning(push)
+#pragma warning(disable : 4146) // Thanks MSVC, we know what we're negating an unsigned mp_digit
+#endif
+
+    for (i = 0; i < m; i++) {
+        /* Step 1: conditional swap. */
+        /* Set cond if delta > 0 and g is odd. */
+        cond = (-delta >> (8 * sizeof(delta) - 1)) & DIGIT(&g, 0) & 1;
+        /* If cond is set replace (delta,f) with (-delta,-f). */
+        delta = (-cond & -delta) | ((cond - 1) & delta);
+        SIGN(&f) ^= cond;
+        /* If cond is set swap f with g. */
+        MP_CHECKOK(mp_cswap(cond, &f, &g, top));
+
+        /* Step 2: elemination. */
+        /* Update delta. */
+        delta++;
+        /* If g is odd, right shift (g+f) else right shift g. */
+        MP_CHECKOK(mp_add(&g, &f, &temp));
+        MP_CHECKOK(mp_cswap((DIGIT(&g, 0) & 1), &g, &temp, top));
+        s_mp_div_2(&g);
     }
 
-    for (;;) {
-        while (mp_iseven(&t)) {
-            s_mp_div_2(&t);
-        }
-
-        if (mp_cmp_z(&t) == MP_GT) {
-            if ((res = mp_copy(&t, &u)) != MP_OKAY)
-                goto CLEANUP;
-
-        } else {
-            if ((res = mp_copy(&t, &v)) != MP_OKAY)
-                goto CLEANUP;
-
-            /* v = -t */
-            if (SIGN(&t) == ZPOS)
-                SIGN(&v) = NEG;
-            else
-                SIGN(&v) = ZPOS;
-        }
-
-        if ((res = mp_sub(&u, &v, &t)) != MP_OKAY)
-            goto CLEANUP;
-
-        if (s_mp_cmp_d(&t, 0) == MP_EQ)
-            break;
-    }
-
-    s_mp_2expt(&v, k);       /* v = 2^k   */
-    res = mp_mul(&u, &v, c); /* c = u * v */
+#if defined(_MSC_VER)
+#pragma warning(pop)
+#endif
+
+    /* GCD is in f, take the absolute value. */
+    SIGN(&f) = ZPOS;
+
+    /* Add back the removed powers of 2. */
+    MP_CHECKOK(s_mp_mul_2d(&f, shifts));
+
+    MP_CHECKOK(mp_copy(&f, c));
 
 CLEANUP:
-    mp_clear(&v);
-V:
-    mp_clear(&u);
-U:
-    mp_clear(&t);
-
+    while (last >= 0)
+        mp_clear(clear[last--]);
     return res;
-
 } /* end mp_gcd() */
 
 /* }}} */
 
 /* {{{ mp_lcm(a, b, c) */
 
 /* We compute the least common multiple using the rule:
 
@@ -2126,52 +2141,124 @@ s_mp_fixup_reciprocal(const mp_int *c, c
     s_mp_clamp(x);
     s_mp_div_2d(x, k_orig);
     res = MP_OKAY;
 
 CLEANUP:
     return res;
 }
 
-/* compute mod inverse using Schroeppel's method, only if m is odd */
+/*
+  Computes the modular inverse using the constant-time algorithm
+  by Bernstein and Yang (https://eprint.iacr.org/2019/266)
+  "Fast constant-time gcd computation and modular inversion"
+ */
 mp_err
 s_mp_invmod_odd_m(const mp_int *a, const mp_int *m, mp_int *c)
 {
-    int k;
     mp_err res;
-    mp_int x;
+    mp_digit cond = 0;
+    mp_int g, f, v, r, temp;
+    int i, its, delta = 1, last = -1;
+    mp_size top, flen, glen;
+    mp_int *clear[6];
 
     ARGCHK(a != NULL && m != NULL && c != NULL, MP_BADARG);
-
-    if (mp_cmp_z(a) == 0 || mp_cmp_z(m) == 0)
+    /* Check for invalid inputs. */
+    if (mp_cmp_z(a) == MP_EQ || mp_cmp_d(m, 2) == MP_LT)
         return MP_RANGE;
-    if (mp_iseven(m))
+
+    if (a == m || mp_iseven(m))
         return MP_UNDEF;
 
-    MP_DIGITS(&x) = 0;
-
-    if (a == c) {
-        if ((res = mp_init_copy(&x, a)) != MP_OKAY)
-            return res;
-        if (a == m)
-            m = &x;
-        a = &x;
-    } else if (m == c) {
-        if ((res = mp_init_copy(&x, m)) != MP_OKAY)
-            return res;
-        m = &x;
-    } else {
-        MP_DIGITS(&x) = 0;
+    MP_CHECKOK(mp_init(&temp));
+    clear[++last] = &temp;
+    MP_CHECKOK(mp_init(&v));
+    clear[++last] = &v;
+    MP_CHECKOK(mp_init(&r));
+    clear[++last] = &r;
+    MP_CHECKOK(mp_init_copy(&g, a));
+    clear[++last] = &g;
+    MP_CHECKOK(mp_init_copy(&f, m));
+    clear[++last] = &f;
+
+    mp_set(&v, 0);
+    mp_set(&r, 1);
+
+    /* Allocate to the size of largest mp_int. */
+    top = (mp_size)1 + ((USED(&f) >= USED(&g)) ? USED(&f) : USED(&g));
+    MP_CHECKOK(s_mp_grow(&f, top));
+    MP_CHECKOK(s_mp_grow(&g, top));
+    MP_CHECKOK(s_mp_grow(&temp, top));
+    MP_CHECKOK(s_mp_grow(&v, top));
+    MP_CHECKOK(s_mp_grow(&r, top));
+
+    /* Upper bound for the total iterations. */
+    flen = mpl_significant_bits(&f);
+    glen = mpl_significant_bits(&g);
+    its = 4 + 3 * ((flen >= glen) ? flen : glen);
+
+#if defined(_MSC_VER)
+#pragma warning(push)
+#pragma warning(disable : 4146) // Thanks MSVC, we know what we're negating an unsigned mp_digit
+#endif
+
+    for (i = 0; i < its; i++) {
+        /* Step 1: conditional swap. */
+        /* Set cond if delta > 0 and g is odd. */
+        cond = (-delta >> (8 * sizeof(delta) - 1)) & DIGIT(&g, 0) & 1;
+        /* If cond is set replace (delta,f,v) with (-delta,-f,-v). */
+        delta = (-cond & -delta) | ((cond - 1) & delta);
+        SIGN(&f) ^= cond;
+        SIGN(&v) ^= cond;
+        /* If cond is set swap (f,v) with (g,r). */
+        MP_CHECKOK(mp_cswap(cond, &f, &g, top));
+        MP_CHECKOK(mp_cswap(cond, &v, &r, top));
+
+        /* Step 2: elemination. */
+        /* Update delta */
+        delta++;
+        /* If g is odd replace r with (r+v). */
+        MP_CHECKOK(mp_add(&r, &v, &temp));
+        MP_CHECKOK(mp_cswap((DIGIT(&g, 0) & 1), &r, &temp, top));
+        /* If g is odd, right shift (g+f) else right shift g. */
+        MP_CHECKOK(mp_add(&g, &f, &temp));
+        MP_CHECKOK(mp_cswap((DIGIT(&g, 0) & 1), &g, &temp, top));
+        s_mp_div_2(&g);
+        /*
+        If r is even, right shift it.
+        If r is odd, right shift (r+m) which is even because m is odd.
+        We want the result modulo m so adding in multiples of m here vanish.
+        */
+        MP_CHECKOK(mp_add(&r, m, &temp));
+        MP_CHECKOK(mp_cswap((DIGIT(&r, 0) & 1), &r, &temp, top));
+        s_mp_div_2(&r);
     }
 
-    MP_CHECKOK(s_mp_almost_inverse(a, m, c));
-    k = res;
-    MP_CHECKOK(s_mp_fixup_reciprocal(c, m, k, c));
+#if defined(_MSC_VER)
+#pragma warning(pop)
+#endif
+
+    /* We have the inverse in v, propagate sign from f. */
+    SIGN(&v) ^= SIGN(&f);
+    /* GCD is in f, take the absolute value. */
+    SIGN(&f) = ZPOS;
+
+    /* If gcd != 1, not invertible. */
+    if (mp_cmp_d(&f, 1) != MP_EQ) {
+        res = MP_UNDEF;
+        goto CLEANUP;
+    }
+
+    /* Return inverse modulo m. */
+    MP_CHECKOK(mp_mod(&v, m, c));
+
 CLEANUP:
-    mp_clear(&x);
+    while (last >= 0)
+        mp_clear(clear[last--]);
     return res;
 }
 
 /* Known good algorithm for computing modular inverse.  But slow. */
 mp_err
 mp_invmod_xgcd(const mp_int *a, const mp_int *m, mp_int *c)
 {
     mp_int g, x;
@@ -2213,23 +2300,34 @@ s_mp_invmod_2d(const mp_int *a, mp_size 
     mp_size ix = k + 4;
     mp_int t0, t1, val, tmp, two2k;
 
     static const mp_digit d2 = 2;
     static const mp_int two = { MP_ZPOS, 1, 1, (mp_digit *)&d2 };
 
     if (mp_iseven(a))
         return MP_UNDEF;
+
+#if defined(_MSC_VER)
+#pragma warning(push)
+#pragma warning(disable : 4146) // Thanks MSVC, we know what we're negating an unsigned mp_digit
+#endif
     if (k <= MP_DIGIT_BIT) {
         mp_digit i = s_mp_invmod_radix(MP_DIGIT(a, 0));
+        /* propagate the sign from mp_int */
+        i = (i ^ -(mp_digit)SIGN(a)) + (mp_digit)SIGN(a);
         if (k < MP_DIGIT_BIT)
             i &= ((mp_digit)1 << k) - (mp_digit)1;
         mp_set(c, i);
         return MP_OKAY;
     }
+#if defined(_MSC_VER)
+#pragma warning(pop)
+#endif
+
     MP_DIGITS(&t0) = 0;
     MP_DIGITS(&t1) = 0;
     MP_DIGITS(&val) = 0;
     MP_DIGITS(&tmp) = 0;
     MP_DIGITS(&two2k) = 0;
     MP_CHECKOK(mp_init_copy(&val, a));
     s_mp_mod_2d(&val, k);
     MP_CHECKOK(mp_init_copy(&t0, &val));
@@ -2826,16 +2924,18 @@ s_mp_free(void *ptr)
 /* Remove leading zeroes from the given value                             */
 void
 s_mp_clamp(mp_int *mp)
 {
     mp_size used = MP_USED(mp);
     while (used > 1 && DIGIT(mp, used - 1) == 0)
         --used;
     MP_USED(mp) = used;
+    if (used == 1 && DIGIT(mp, 0) == 0)
+        MP_SIGN(mp) = ZPOS;
 } /* end s_mp_clamp() */
 
 /* }}} */
 
 /* {{{ s_mp_exch(a, b) */
 
 /* Exchange the data for a and b; (b, a) = (a, b)                         */
 void
@@ -2903,47 +3003,46 @@ s_mp_lshd(mp_int *mp, mp_size p)
 /*
   Multiply the integer by 2^d, where d is a number of bits.  This
   amounts to a bitwise shift of the value.
  */
 mp_err
 s_mp_mul_2d(mp_int *mp, mp_digit d)
 {
     mp_err res;
-    mp_digit dshift, bshift;
-    mp_digit mask;
+    mp_digit dshift, rshift, mask, x, prev = 0;
+    mp_digit *pa = NULL;
+    int i;
 
     ARGCHK(mp != NULL, MP_BADARG);
 
     dshift = d / MP_DIGIT_BIT;
-    bshift = d % MP_DIGIT_BIT;
+    d %= MP_DIGIT_BIT;
+    /* mp_digit >> rshift is undefined behavior for rshift >= MP_DIGIT_BIT */
+    /* mod and corresponding mask logic avoid that when d = 0 */
+    rshift = MP_DIGIT_BIT - d;
+    rshift %= MP_DIGIT_BIT;
+    /* mask = (2**d - 1) * 2**(w-d) mod 2**w */
+    mask = (DIGIT_MAX << rshift) + 1;
+    mask &= DIGIT_MAX - 1;
     /* bits to be shifted out of the top word */
-    if (bshift) {
-        mask = (mp_digit)~0 << (MP_DIGIT_BIT - bshift);
-        mask &= MP_DIGIT(mp, MP_USED(mp) - 1);
-    } else {
-        mask = 0;
-    }
-
-    if (MP_OKAY != (res = s_mp_pad(mp, MP_USED(mp) + dshift + (mask != 0))))
+    x = MP_DIGIT(mp, MP_USED(mp) - 1) & mask;
+
+    if (MP_OKAY != (res = s_mp_pad(mp, MP_USED(mp) + dshift + (x != 0))))
         return res;
 
     if (dshift && MP_OKAY != (res = s_mp_lshd(mp, dshift)))
         return res;
 
-    if (bshift) {
-        mp_digit *pa = MP_DIGITS(mp);
-        mp_digit *alim = pa + MP_USED(mp);
-        mp_digit prev = 0;
-
-        for (pa += dshift; pa < alim;) {
-            mp_digit x = *pa;
-            *pa++ = (x << bshift) | prev;
-            prev = x >> (DIGIT_BIT - bshift);
-        }
+    pa = MP_DIGITS(mp) + dshift;
+
+    for (i = MP_USED(mp) - dshift; i > 0; i--) {
+        x = *pa;
+        *pa++ = (x << d) | prev;
+        prev = (x & mask) >> rshift;
     }
 
     s_mp_clamp(mp);
     return MP_OKAY;
 } /* end s_mp_mul_2d() */
 
 /* {{{ s_mp_rshd(mp, p) */
 
@@ -3072,28 +3171,30 @@ s_mp_mod_2d(mp_int *mp, mp_digit d)
   Divide the integer by 2^d, where d is a number of bits.  This
   amounts to a bitwise shift of the value, and does not require the
   full division code (used in Barrett reduction, see below)
  */
 void
 s_mp_div_2d(mp_int *mp, mp_digit d)
 {
     int ix;
-    mp_digit save, next, mask;
+    mp_digit save, next, mask, lshift;
 
     s_mp_rshd(mp, d / DIGIT_BIT);
     d %= DIGIT_BIT;
-    if (d) {
-        mask = ((mp_digit)1 << d) - 1;
-        save = 0;
-        for (ix = USED(mp) - 1; ix >= 0; ix--) {
-            next = DIGIT(mp, ix) & mask;
-            DIGIT(mp, ix) = (DIGIT(mp, ix) >> d) | (save << (DIGIT_BIT - d));
-            save = next;
-        }
+    /* mp_digit << lshift is undefined behavior for lshift >= MP_DIGIT_BIT */
+    /* mod and corresponding mask logic avoid that when d = 0 */
+    lshift = DIGIT_BIT - d;
+    lshift %= DIGIT_BIT;
+    mask = ((mp_digit)1 << d) - 1;
+    save = 0;
+    for (ix = USED(mp) - 1; ix >= 0; ix--) {
+        next = DIGIT(mp, ix) & mask;
+        DIGIT(mp, ix) = (save << lshift) | (DIGIT(mp, ix) >> d);
+        save = next;
     }
     s_mp_clamp(mp);
 
 } /* end s_mp_div_2d() */
 
 /* }}} */
 
 /* {{{ s_mp_norm(a, b, *d) */
@@ -4836,10 +4937,49 @@ mp_to_fixlen_octets(const mp_int *mp, un
         for (jx = MP_DIGIT_SIZE - 1; jx >= 0; jx--) {
             *str++ = d >> (jx * CHAR_BIT);
         }
     }
     return MP_OKAY;
 } /* end mp_to_fixlen_octets() */
 /* }}} */
 
+/* {{{ mp_cswap(condition, a, b, numdigits) */
+/* performs a conditional swap between mp_int. */
+mp_err
+mp_cswap(mp_digit condition, mp_int *a, mp_int *b, mp_size numdigits)
+{
+    mp_digit x;
+    unsigned int i;
+    mp_err res = 0;
+
+    /* if pointers are equal return */
+    if (a == b)
+        return res;
+
+    if (MP_ALLOC(a) < numdigits || MP_ALLOC(b) < numdigits) {
+        MP_CHECKOK(s_mp_grow(a, numdigits));
+        MP_CHECKOK(s_mp_grow(b, numdigits));
+    }
+
+    condition = ((~condition & ((condition - 1))) >> (MP_DIGIT_BIT - 1)) - 1;
+
+    x = (USED(a) ^ USED(b)) & condition;
+    USED(a) ^= x;
+    USED(b) ^= x;
+
+    x = (SIGN(a) ^ SIGN(b)) & condition;
+    SIGN(a) ^= x;
+    SIGN(b) ^= x;
+
+    for (i = 0; i < numdigits; i++) {
+        x = (DIGIT(a, i) ^ DIGIT(b, i)) & condition;
+        DIGIT(a, i) ^= x;
+        DIGIT(b, i) ^= x;
+    }
+
+CLEANUP:
+    return res;
+} /* end mp_cswap() */
+/* }}} */
+
 /*------------------------------------------------------------------------*/
 /* HERE THERE BE DRAGONS                                                  */
--- a/lib/freebl/mpi/mpi.h
+++ b/lib/freebl/mpi/mpi.h
@@ -262,16 +262,17 @@ mp_err mp_to_unsigned_octets(const mp_in
 mp_err mp_to_signed_octets(const mp_int *mp, unsigned char *str, mp_size maxlen);
 mp_err mp_to_fixlen_octets(const mp_int *mp, unsigned char *str, mp_size len);
 
 /* Miscellaneous */
 mp_size mp_trailing_zeros(const mp_int *mp);
 void freebl_cpuid(unsigned long op, unsigned long *eax,
                   unsigned long *ebx, unsigned long *ecx,
                   unsigned long *edx);
+mp_err mp_cswap(mp_digit condition, mp_int *a, mp_int *b, mp_size numdigits);
 
 #define MP_CHECKOK(x)          \
     if (MP_OKAY > (res = (x))) \
     goto CLEANUP
 #define MP_CHECKERR(x)         \
     if (MP_OKAY > (res = (x))) \
     goto CLEANUP
 
--- a/lib/freebl/mpi/mplogic.c
+++ b/lib/freebl/mpi/mplogic.c
@@ -402,40 +402,59 @@ mpl_get_bits(const mp_int *a, mp_size ls
         (lsWndx + 1 >= MP_USED(a))) {
         mask &= (digit[0] >> rshift);
     } else {
         mask &= ((digit[0] >> rshift) | (digit[1] << (MP_DIGIT_BIT - rshift)));
     }
     return (mp_err)mask;
 }
 
+#define LZCNTLOOP(i)                               \
+    do {                                           \
+        x = d >> (i);                              \
+        mask = (0 - x);                            \
+        mask = (0 - (mask >> (MP_DIGIT_BIT - 1))); \
+        bits += (i)&mask;                          \
+        d ^= (x ^ d) & mask;                       \
+    } while (0)
+
 /*
   mpl_significant_bits
-  returns number of significnant bits in abs(a).
+  returns number of significant bits in abs(a).
+  In other words: floor(lg(abs(a))) + 1.
   returns 1 if value is zero.
  */
 mp_size
 mpl_significant_bits(const mp_int *a)
 {
-    mp_size bits = 0;
+    /*
+      start bits at 1.
+      lg(0) = 0 => bits = 1 by function semantics.
+      below does a binary search for the _position_ of the top bit set,
+      which is floor(lg(abs(a))) for a != 0.
+     */
+    mp_size bits = 1;
     int ix;
 
     ARGCHK(a != NULL, MP_BADARG);
 
     for (ix = MP_USED(a); ix > 0;) {
-        mp_digit d;
-        d = MP_DIGIT(a, --ix);
-        if (d) {
-            while (d) {
-                ++bits;
-                d >>= 1;
-            }
-            break;
-        }
+        mp_digit d, x, mask;
+        if ((d = MP_DIGIT(a, --ix)) == 0)
+            continue;
+#if !defined(MP_USE_UINT_DIGIT)
+        LZCNTLOOP(32);
+#endif
+        LZCNTLOOP(16);
+        LZCNTLOOP(8);
+        LZCNTLOOP(4);
+        LZCNTLOOP(2);
+        LZCNTLOOP(1);
+        break;
     }
     bits += ix * MP_DIGIT_BIT;
-    if (!bits)
-        bits = 1;
     return bits;
 }
 
+#undef LZCNTLOOP
+
 /*------------------------------------------------------------------------*/
 /* HERE THERE BE DRAGONS                                                  */