]> granicus.if.org Git - esp-idf/commitdiff
mbedtls hardware RSA: Combine methods for calculating M' & r inverse
authorAngus Gratton <angus@espressif.com>
Fri, 18 Nov 2016 02:44:37 +0000 (13:44 +1100)
committerAngus Gratton <angus@espressif.com>
Fri, 18 Nov 2016 03:10:20 +0000 (14:10 +1100)
Remove redundant gcd calculation, use consistent terminology.
Also remove leftover debugging code

components/mbedtls/port/esp_bignum.c

index 401d3fc24bb8e762fb6b4c016e50db8c00e48119..076234b55bcaa99f3654852e2f82abd5ec3c58ae 100644 (file)
@@ -39,46 +39,10 @@ static const char *TAG = "bignum";
 
 #if defined(MBEDTLS_MPI_MUL_MPI_ALT) || defined(MBEDTLS_MPI_EXP_MOD_ALT)
 
-/* Constants from mbedTLS bignum.c */
-#define ciL    (sizeof(mbedtls_mpi_uint))         /* chars in limb  */
-#define biL    (ciL << 3)               /* bits  in limb  */
-
 static _lock_t mpi_lock;
 
-/* Temporary debugging function to print an MPI number to
-   stdout. Happens to be in a format compatible with Python.
-*/
-void mbedtls_mpi_printf(const char *name, const mbedtls_mpi *X)
-{
-    static char buf[1024];
-    size_t n;
-    memset(buf, 0, sizeof(buf));
-    mbedtls_mpi_write_string(X, 16, buf, sizeof(buf)-1, &n);
-    if(n) {
-        ESP_LOGI(TAG, "%s = 0x%s", name, buf);
-    } else {
-        ESP_LOGI(TAG, "TOOLONG");
-    }
-}
-
-/* Temporary debug function to dump a memory block's contents to stdout
-   TODO remove
- */
-static void __attribute__((unused)) dump_memory_block(const char *label, uint32_t addr)
-{
-    printf("Dumping %s @ %08x\n", label, addr);
-    for(int i = 0; i < (4096 / 8); i += 4) {
-        if(i % 32 == 0) {
-            printf("\n %04x:", i);
-        }
-        printf("%08x ", REG_READ(addr + i));
-    }
-    printf("Done\n");
-}
-
 /* At the moment these hardware locking functions aren't exposed publically
-   for MPI. If you want to use the ROM bigint functions and co-exist with mbedTLS,
-   please raise a feature request.
+   for MPI. If you want to use the ROM bigint functions and co-exist with mbedTLS, please raise a feature request.
 */
 static void esp_mpi_acquire_hardware( void )
 {
@@ -116,6 +80,14 @@ static inline size_t hardware_words_needed(const mbedtls_mpi *mpi)
     return res;
 }
 
+/* Convert number of bits to number of words, rounded up to nearest
+   512 bit (16 word) block count.
+*/
+static inline size_t bits_to_hardware_words(size_t num_bits)
+{
+    return ((num_bits + 511) / 512) * 16;
+}
+
 /* Copy mbedTLS MPI bignum 'mpi' to hardware memory block at 'mem_base'.
 
    If num_words is higher than the number of words in the bignum then
@@ -156,79 +128,62 @@ static inline int mem_block_to_mpi(mbedtls_mpi *x, uint32_t mem_base, int num_wo
     return ret;
 }
 
-/* Given a & b, determine u & v such that
-
-   gcd(a,b) = d = au - bv
-
-   This is suitable for calculating values for montgomery multiplication:
-
-   gcd(R, M) = R * Rinv - M * Mprime = 1
-
-   Conditions which must be true:
-   - argument 'a' (R) is a power of 2.
-   - argument 'b' (M) is odd.
 
-   Underlying algorithm comes from:
-   http://www.hackersdelight.org/hdcodetxt/mont64.c.txt
-   http://www.ucl.ac.uk/~ucahcjm/combopt/ext_gcd_python_programs.pdf
+/**
+ *
+ * There is a need for the value of integer N' such that B^-1(B-1)-N^-1N'=1,
+ * where B^-1(B-1) mod N=1. Actually, only the least significant part of
+ * N' is needed, hence the definition N0'=N' mod b. We reproduce below the
+ * simple algorithm from an article by Dusse and Kaliski to efficiently
+ * find N0' from N0 and b
  */
-static void extended_binary_gcd(const mbedtls_mpi *a, const mbedtls_mpi *b,
-                                   mbedtls_mpi *u, mbedtls_mpi *v)
+static mbedtls_mpi_uint modular_inverse(const mbedtls_mpi *M)
 {
-    mbedtls_mpi a_, ta;
-
-    /* These checks degrade performance, TODO remove them... */
-    assert(b->p[0] & 1);
-    assert(mbedtls_mpi_bitlen(a) == mbedtls_mpi_lsb(a)+1);
-    assert(mbedtls_mpi_cmp_mpi(a, b) > 0);
-
-    mbedtls_mpi_lset(u, 1);
-    mbedtls_mpi_lset(v, 0);
+    int i;
+    uint64_t t = 1;
+    uint64_t two_2_i_minus_1 = 2;   /* 2^(i-1) */
+    uint64_t two_2_i = 4;           /* 2^i */
+    uint64_t N = M->p[0];
 
-    /* 'a' needs to be half its real value for this algorithm
-       TODO see if we can halve the number in the caller to avoid
-       allocating a bignum here.
-     */
-    mbedtls_mpi_init(&a_);
-    mbedtls_mpi_copy(&a_, a);
-    mbedtls_mpi_shift_r(&a_, 1);
+    for (i = 2; i <= 32; i++) {
+        if ((mbedtls_mpi_uint) N * t % two_2_i >= two_2_i_minus_1) {
+            t += two_2_i_minus_1;
+        }
 
-    mbedtls_mpi_init(&ta);
-    mbedtls_mpi_copy(&ta, &a_);
+        two_2_i_minus_1 <<= 1;
+        two_2_i <<= 1;
+    }
 
-    //mbedtls_mpi_printf("a", &a_);
-    //mbedtls_mpi_printf("b", b);
+    return (mbedtls_mpi_uint)(UINT32_MAX - t + 1);
+}
 
-    /* Loop invariant:
-      2*ta = u*2*a - v*b.
+/* Calculate Rinv = RR^2 mod M, where:
+ *
+ *  R = b^n where b = 2^32, n=num_words,
+ *  R = 2^N (where N=num_bits)
+ *  RR = R^2 = 2^(2*N) (where N=num_bits=num_words*32)
+ *
+ * This calculation is computationally expensive (mbedtls_mpi_mod_mpi)
+ * so caller should cache the result where possible.
+ *
+ * DO NOT call this function while holding esp_mpi_acquire_hardware().
+ *
+ */
+static int calculate_rinv(mbedtls_mpi *Rinv, const mbedtls_mpi *M, int num_words)
+{
+    int ret;
+    size_t num_bits = num_words * 32;
+    mbedtls_mpi RR;
+    mbedtls_mpi_init(&RR);
+    MBEDTLS_MPI_CHK(mbedtls_mpi_set_bit(&RR, num_bits * 2, 1));
+    MBEDTLS_MPI_CHK(mbedtls_mpi_mod_mpi(Rinv, &RR, M));
 
-      Loop until ta == 0
-    */
-    while (mbedtls_mpi_cmp_int(&ta, 0) != 0) {
-        //mbedtls_mpi_printf("ta", &ta);
-        //mbedtls_mpi_printf("u", u);
-        //mbedtls_mpi_printf("v", v);
-        //printf("2*ta == u*2*a - v*b\n");
-
-        mbedtls_mpi_shift_r(&ta, 1);
-        if (mbedtls_mpi_get_bit(u, 0) == 0) {
-            // Remove common factor of 2 in u & v
-            mbedtls_mpi_shift_r(u, 1);
-            mbedtls_mpi_shift_r(v, 1);
-        }
-        else {
-            /* u = (u + b) >> 1 */
-            mbedtls_mpi_add_mpi(u, u, b);
-            mbedtls_mpi_shift_r(u, 1);
-            /* v = (v - a) >> 1 */
-            mbedtls_mpi_shift_r(v, 1);
-            mbedtls_mpi_add_mpi(v, v, &a_);
-        }
-    }
-    mbedtls_mpi_free(&ta);
-    mbedtls_mpi_free(&a_);
+ cleanup:
+    mbedtls_mpi_free(&RR);
+    return ret;
 }
 
+
 /* Execute RSA operation. op_reg specifies which 'START' register
    to write to.
 */
@@ -247,28 +202,45 @@ static inline void execute_op(uint32_t op_reg)
 }
 
 /* Sub-stages of modulo multiplication/exponentiation operations */
-static int modular_op_prepare(const mbedtls_mpi *X, const mbedtls_mpi *M, size_t num_words);
 inline static int modular_multiply_finish(mbedtls_mpi *Z, const mbedtls_mpi *X, const mbedtls_mpi *Y, size_t num_words);
 
 /* Z = (X * Y) mod M
 
    Not an mbedTLS function
- */
+*/
 int esp_mpi_mul_mpi_mod(mbedtls_mpi *Z, const mbedtls_mpi *X, const mbedtls_mpi *Y, const mbedtls_mpi *M)
 {
     int ret;
     size_t num_words = hardware_words_needed(M);
+    mbedtls_mpi Rinv;
+    mbedtls_mpi_uint Mprime;
 
     /* Calculate and load the first stage montgomery multiplication */
-    MBEDTLS_MPI_CHK( modular_op_prepare(X, M, num_words) );
+    mbedtls_mpi_init(&Rinv);
+    MBEDTLS_MPI_CHK(calculate_rinv(&Rinv, M, num_words));
+    Mprime = modular_inverse(M);
+
+    esp_mpi_acquire_hardware();
+
+    /* Load M, X, Rinv, Mprime (Mprime is mod 2^32) */
+    mpi_to_mem_block(RSA_MEM_M_BLOCK_BASE, M, num_words);
+    mpi_to_mem_block(RSA_MEM_X_BLOCK_BASE, X, num_words);
+    mpi_to_mem_block(RSA_MEM_RB_BLOCK_BASE, &Rinv, num_words);
+    REG_WRITE(RSA_M_DASH_REG, (uint32_t)Mprime);
+
+    /* "mode" register loaded with number of 512-bit blocks, minus 1 */
+    REG_WRITE(RSA_MULT_MODE_REG, (num_words / 16) - 1);
 
+    /* Execute first stage montgomery multiplication */
     execute_op(RSA_MULT_START_REG);
 
+    /* execute second stage */
     MBEDTLS_MPI_CHK( modular_multiply_finish(Z, X, Y, num_words) );
 
     esp_mpi_release_hardware();
 
  cleanup:
+    mbedtls_mpi_free(&Rinv);
     return ret;
 }
 
@@ -276,20 +248,24 @@ int esp_mpi_mul_mpi_mod(mbedtls_mpi *Z, const mbedtls_mpi *X, const mbedtls_mpi
 
 /*
  * Sliding-window exponentiation: Z = X^Y mod M  (HAC 14.85)
+ *
+ * _Rinv is optional pre-calculated version of Rinv (via calculate_rinv()).
+ *
+ * (See RSA Accelerator section in Technical Reference for more about Mprime, Rinv)
+ *
  */
- #if 0
-int mbedtls_mpi_exp_mod( mbedtls_mpi* Z, const mbedtls_mpi* X, const mbedtls_mpi* Y, const mbedtls_mpi* M, mbedtls_mpi* _RR )
+int mbedtls_mpi_exp_mod( mbedtls_mpi* Z, const mbedtls_mpi* X, const mbedtls_mpi* Y, const mbedtls_mpi* M, mbedtls_mpi* _Rinv )
 {
-    int ret;
+    int ret = 0;
     size_t z_words = hardware_words_needed(Z);
     size_t x_words = hardware_words_needed(X);
     size_t y_words = hardware_words_needed(Y);
     size_t m_words = hardware_words_needed(M);
     size_t num_words;
 
-    mbedtls_mpi_printf("X",X);
-    mbedtls_mpi_printf("Y",Y);
-    mbedtls_mpi_printf("M",M);
+    mbedtls_mpi Rinv_new; /* used if _Rinv == NULL */
+    mbedtls_mpi *Rinv;    /* points to _Rinv (if not NULL) othwerwise &RR_new */
+    mbedtls_mpi_uint Mprime;
 
     /* "all numbers must be the same length", so choose longest number
        as cardinal length of operation...
@@ -304,157 +280,24 @@ int mbedtls_mpi_exp_mod( mbedtls_mpi* Z, const mbedtls_mpi* X, const mbedtls_mpi
     if (m_words > num_words) {
         num_words = m_words;
     }
-    printf("num_words = %d  # %d, %d, %d\n", num_words, x_words, y_words, m_words);
-
-    /* TODO: _RR parameter currently ignored */
-
-    ret = modular_op_prepare(X, M, num_words);
-    if (ret != 0) {
-        return ret;
-    }
-
-    mpi_to_mem_block(RSA_MEM_Y_BLOCK_BASE, Y, num_words);
-
-    //dump_memory_block("X_BLOCK", RSA_MEM_X_BLOCK_BASE);
-    //dump_memory_block("Y_BLOCK", RSA_MEM_Y_BLOCK_BASE);
-    //dump_memory_block("M_BLOCK", RSA_MEM_M_BLOCK_BASE);
-
-    REG_WRITE(RSA_MODEXP_MODE_REG, (num_words / 16) - 1);
-
-    execute_op(RSA_START_MODEXP_REG);
-
-    //dump_memory_block("Z_BLOCK", RSA_MEM_Z_BLOCK_BASE);
-
-    /* TODO: only need to read m_words not num_words, provided result is correct... */
-    ret = mem_block_to_mpi(Z, RSA_MEM_Z_BLOCK_BASE, num_words);
-
-    esp_mpi_release_hardware();
-
-    mbedtls_mpi_printf("Z",Z);
-    printf("print (Z == (X ** Y) %% M)\n");
-
-    return ret;
-}
-
-#else
-
-/**
- * There is a need for the value of integer N' such that B^-1(B-1)-N^-1N'=1, 
- * where B^-1(B-1) mod N=1. Actually, only the least significant part of 
- * N' is needed, hence the definition N0'=N' mod b. We reproduce below the 
- * simple algorithm from an article by Dusse and Kaliski to efficiently 
- * find N0' from N0 and b 
- */
-static mbedtls_mpi_uint modular_inverse(const mbedtls_mpi *M)
-{
-    int i;
-    uint64_t t = 1;
-    uint64_t two_2_i_minus_1 = 2;   /* 2^(i-1) */
-    uint64_t two_2_i = 4;           /* 2^i */
-    uint64_t N = M->p[0];
-
-    for (i = 2; i <= 32; i++) {
-        if ((mbedtls_mpi_uint) N * t % two_2_i >= two_2_i_minus_1) {
-            t += two_2_i_minus_1;
-        }
 
-        two_2_i_minus_1 <<= 1;
-        two_2_i <<= 1;
+    if (num_words * 32 > 4096) {
+        return MBEDTLS_ERR_MPI_NOT_ACCEPTABLE;
     }
 
-    return (mbedtls_mpi_uint)(UINT32_MAX - t + 1);
-}
-
-static int bignum_param_init(const mbedtls_mpi *M, mbedtls_mpi *_RR, mbedtls_mpi *r, mbedtls_mpi_uint *Mi, size_t num_words)
-{
-    int ret = 0;
-    size_t num_bits;
-    mbedtls_mpi RR;
-
-    /* Calculate number of bits */
-    num_bits = num_words * 32;
-    ESP_LOGI(TAG, "num_bits = %d\n", num_bits);
-
-    /* 
-     *  R = b^n where b = 2^32, n=num_words,
-     *  R = 2^N (where N=num_bits)
-     *  RR(R^2) = 2^(2*N) (where N=num_bits)
-     *
-     *  r = RR(R^2) mod M
-     *
-     *  Get the RR(RR == r) value from up level if RR and RR->p is not NULL
-     */
-    ESP_LOGI(TAG, "r = RR(R^2) mod M\n");
-    if (_RR == NULL || _RR->p == NULL) {
-        ESP_LOGI(TAG, "RR(R^2) = 2^(2*N) (where N=num_bits)\n");
-        mbedtls_mpi_init(&RR);
-        MBEDTLS_MPI_CHK(mbedtls_mpi_set_bit(&RR, num_bits * 2, 1));
-        mbedtls_mpi_printf("RR", &RR);
-        
-        MBEDTLS_MPI_CHK(mbedtls_mpi_mod_mpi(r, &RR, M));
-        
-        if (_RR != NULL)
-            memcpy(_RR, r, sizeof( mbedtls_mpi ) );
+    /* Determine RR pointer, either _RR for cached value
+       or local RR_new */
+    if (_Rinv == NULL) {
+        mbedtls_mpi_init(&Rinv_new);
+        Rinv = &Rinv_new;
     } else {
-        memcpy(r, _RR, sizeof( mbedtls_mpi ) );
+        Rinv = _Rinv;
     }
-    mbedtls_mpi_printf("r", r);
-
-    *Mi = modular_inverse(M);
-
-cleanup:
-    mbedtls_mpi_free(&RR);
-
-    return ret;
-}
-
-static void bignum_param_deinit(mbedtls_mpi *_RR, mbedtls_mpi *r)
-{
-    if (_RR == NULL || _RR->p == NULL)
-        mbedtls_mpi_free(r);
-}
-
-/*
- * Sliding-window exponentiation: Z = X^Y mod M  (HAC 14.85)
- */
-int mbedtls_mpi_exp_mod( mbedtls_mpi* Z, const mbedtls_mpi* X, const mbedtls_mpi* Y, const mbedtls_mpi* M, mbedtls_mpi* _RR )
-{
-    int ret = 0;
-    size_t z_words = hardware_words_needed(Z);
-    size_t x_words = hardware_words_needed(X);
-    size_t y_words = hardware_words_needed(Y);
-    size_t m_words = hardware_words_needed(M);
-    size_t num_words;
-
-    mbedtls_mpi r;
-    mbedtls_mpi_uint Mi = 0;
-
-    /* "all numbers must be the same length", so choose longest number
-       as cardinal length of operation...
-    */
-    num_words = z_words;
-    if (x_words > num_words) {
-        num_words = x_words;
-    }
-    if (y_words > num_words) {
-        num_words = y_words;
+    if (Rinv->p == NULL) {
+        MBEDTLS_MPI_CHK(calculate_rinv(Rinv, M, num_words));
     }
-    if (m_words > num_words) {
-        num_words = m_words;
-    }
-    ESP_LOGI(TAG, "num_words = %d  # %d, %d, %d\n", num_words, x_words, y_words, m_words);
 
-    if (num_words * 32 > 4096)
-        return MBEDTLS_ERR_MPI_NOT_ACCEPTABLE;
-    
-    mbedtls_mpi_init(&r);
-    ret = bignum_param_init(M, _RR, &r, &Mi, num_words);
-    if (ret != 0) {
-        return ret;
-    }
-
-    mbedtls_mpi_printf("X",X);
-    mbedtls_mpi_printf("Y",Y);
+    Mprime = modular_inverse(M);
 
     esp_mpi_acquire_hardware();
 
@@ -465,8 +308,8 @@ int mbedtls_mpi_exp_mod( mbedtls_mpi* Z, const mbedtls_mpi* X, const mbedtls_mpi
     mpi_to_mem_block(RSA_MEM_X_BLOCK_BASE, X, num_words);
     mpi_to_mem_block(RSA_MEM_Y_BLOCK_BASE, Y, num_words);
     mpi_to_mem_block(RSA_MEM_M_BLOCK_BASE, M, num_words);
-    mpi_to_mem_block(RSA_MEM_RB_BLOCK_BASE, &r, num_words);
-    REG_WRITE(RSA_M_DASH_REG, Mi);
+    mpi_to_mem_block(RSA_MEM_RB_BLOCK_BASE, Rinv, num_words);
+    REG_WRITE(RSA_M_DASH_REG, Mprime);
 
     execute_op(RSA_START_MODEXP_REG);
 
@@ -474,91 +317,16 @@ int mbedtls_mpi_exp_mod( mbedtls_mpi* Z, const mbedtls_mpi* X, const mbedtls_mpi
 
     esp_mpi_release_hardware();
 
-    mbedtls_mpi_printf("Z",Z);
-    ESP_LOGI(TAG, "print (Z == (X ** Y) %% M)\n");
-
-    bignum_param_deinit(_RR, &r);
+ cleanup:
+    if (_Rinv == NULL) {
+        mbedtls_mpi_free(&Rinv_new);
+    }
 
     return ret;
 }
 
-
-#endif
-
 #endif /* MBEDTLS_MPI_EXP_MOD_ALT */
 
-
-/* The common parts of modulo multiplication and modular sliding
- *  window exponentiation:
- *
- * @param X first multiplication factor and/or base of exponent.
- * @param M modulo value for result
- * @param num_words size of modulo operation, in words (limbs).
- *        Should already be rounded up to a multiple of 16 words (512 bits) & range checked.
- *
- * Steps:
- * Calculate Rinv & Mprime based on M & num_words
- * Load all coefficients to memory
- * Set mode register
- *
- * @note This function calls esp_mpi_acquire_hardware. If successful,
- * returns 0 and it becomes the callers responsibility to call
- * esp_mpi_release_hardware(). If failure is returned, the caller does
- * not need to call esp_mpi_release_hardware().
- */
-static int modular_op_prepare(const mbedtls_mpi *X, const mbedtls_mpi *M, size_t num_words)
-{
-    int ret = 0;
-    mbedtls_mpi RR, Rinv, Mprime;
-    size_t num_bits;
-
-    /* Calculate number of bits */
-    num_bits = num_words * 32;
-
-    if(num_bits > 4096) {
-        return MBEDTLS_ERR_MPI_NOT_ACCEPTABLE;
-    }
-
-    /* Rinv & Mprime are calculated via extended binary gcd
-       algorithm, see references on extended_binary_gcd() above.
-    */
-    mbedtls_mpi_init(&Rinv);
-    mbedtls_mpi_init(&RR);
-    mbedtls_mpi_init(&Mprime);
-
-    mbedtls_mpi_set_bit(&RR, num_bits, 1); /* R = b^n where b = 2^32, n=num_words,
-                                              ie R = 2^N (where N=num_bits) */
-    /* calculate Rinv & Mprime */
-    extended_binary_gcd(&RR, M, &Rinv, &Mprime);
-
-    /* Block of debugging data, output suitable to paste into Python
-       TODO remove
-    */
-    mbedtls_mpi_printf("RR", &RR);
-    mbedtls_mpi_printf("M", M);
-    mbedtls_mpi_printf("Rinv", &Rinv);
-    mbedtls_mpi_printf("Mprime", &Mprime);
-    printf("print (R * Rinv - M * Mprime == 1)\n");
-    printf("print (Rinv == (R * R) %% M)\n");
-
-    esp_mpi_acquire_hardware();
-
-    /* Load M, X, Rinv, M-prime (M-prime is mod 2^32) */
-    mpi_to_mem_block(RSA_MEM_M_BLOCK_BASE, M, num_words);
-    mpi_to_mem_block(RSA_MEM_X_BLOCK_BASE, X, num_words);
-    mpi_to_mem_block(RSA_MEM_RB_BLOCK_BASE, &Rinv, num_words);
-    REG_WRITE(RSA_M_DASH_REG, Mprime.p[0]);
-
-    /* "mode" register loaded with number of 512-bit blocks, minus 1 */
-    REG_WRITE(RSA_MULT_MODE_REG, (num_words / 16) - 1);
-
-    mbedtls_mpi_free(&Rinv);
-    mbedtls_mpi_free(&RR);
-    mbedtls_mpi_free(&Mprime);
-
-    return ret;
-}
-
 /* Second & final step of a modular multiply - load second multiplication
  * factor Y, run the multiply, read back the result into Z.
  *
@@ -594,17 +362,43 @@ static int mpi_mult_mpi_failover_mod_mult(mbedtls_mpi *Z, const mbedtls_mpi *X,
 int mbedtls_mpi_mul_mpi( mbedtls_mpi *Z, const mbedtls_mpi *X, const mbedtls_mpi *Y )
 {
     int ret;
-    size_t words_x, words_y, words_mult, words_z;
+    size_t bits_x, bits_y, words_x, words_y, words_mult, words_z;
 
     /* Count words needed for X & Y in hardware */
-    words_x = hardware_words_needed(X);
-    words_y = hardware_words_needed(Y);
+    bits_x = mbedtls_mpi_bitlen(X);
+    bits_y = mbedtls_mpi_bitlen(Y);
+    /* Convert bit counts to words, rounded up to 512-bit
+       (16 word) blocks */
+    words_x = bits_to_hardware_words(bits_x);
+    words_y = bits_to_hardware_words(bits_y);
+
+    /* Short-circuit eval if either argument is 0 or 1.
+
+       This is needed as the mpi modular division
+       argument will sometimes call in here when one
+       argument is too large for the hardware unit, but the other
+       argument is zero or one.
+
+       This leaks some timing information, although overall there is a
+       lot less timing variation than a software MPI approach.
+    */
+    if (bits_x == 0 || bits_y == 0) {
+        mbedtls_mpi_lset(Z, 0);
+        return 0;
+    }
+    if (bits_x == 1) {
+        return mbedtls_mpi_copy(Z, Y);
+    }
+    if (bits_y == 1) {
+        return mbedtls_mpi_copy(Z, X);
+    }
 
     words_mult = (words_x > words_y ? words_x : words_y);
 
     /* Result Z has to have room for double the larger factor */
     words_z = words_mult * 2;
 
+
     /* If either factor is over 2048 bits, we can't use the standard hardware multiplier
        (it assumes result is double longest factor, and result is max 4096 bits.)
 
@@ -612,12 +406,11 @@ int mbedtls_mpi_mul_mpi( mbedtls_mpi *Z, const mbedtls_mpi *X, const mbedtls_mpi
        multiplication doesn't have the same restriction, so result is simply the
        number of bits in X plus number of bits in in Y.)
     */
-    //ESP_LOGE(TAG, "INFO: %d bit result (%d bits * %d bits)\n", words_z * 32, mbedtls_mpi_bitlen(X), mbedtls_mpi_bitlen(Y));
     if (words_mult * 32 > 2048) {
         /* Calculate new length of Z */
-        words_z = words_x + words_y;
+        words_z = bits_to_hardware_words(bits_x + bits_y);
         if (words_z * 32 > 4096) {
-            ESP_LOGE(TAG, "ERROR: %d bit result (%d bits * %d bits) too large for hardware unit\n", words_z * 32, mbedtls_mpi_bitlen(X), mbedtls_mpi_bitlen(Y));
+            ESP_LOGE(TAG, "ERROR: %d bit result %d bits * %d bits too large for hardware unit\n", words_z * 32, bits_x, bits_y);
             return MBEDTLS_ERR_MPI_NOT_ACCEPTABLE;
         }
         else {
@@ -640,7 +433,7 @@ int mbedtls_mpi_mul_mpi( mbedtls_mpi *Z, const mbedtls_mpi *X, const mbedtls_mpi
 
     /* "mode" register loaded with number of 512-bit blocks in result,
        plus 7 (for range 9-12). (this is ((N~ / 32) - 1) + 8))
-     */
+    */
     REG_WRITE(RSA_MULT_MODE_REG, (words_z / 16) + 7);
 
     execute_op(RSA_MULT_START_REG);
@@ -656,54 +449,57 @@ int mbedtls_mpi_mul_mpi( mbedtls_mpi *Z, const mbedtls_mpi *X, const mbedtls_mpi
 }
 
 /* Special-case of mbedtls_mpi_mult_mpi(), where we use hardware montgomery mod
-   multiplication to solve the case where A or B are >2048 bits so
-   can't use the standard multiplication method.
+   multiplication to calculate an mbedtls_mpi_mult_mpi result where either
+   A or B are >2048 bits so can't use the standard multiplication method.
+
+   Result (A bits + B bits) must still be less than 4096 bits.
 
-   This case is simpler than esp_mpi_mul_mpi_mod() as we control the arguments:
+   This case is simpler than the general case modulo multiply of
+   esp_mpi_mul_mpi_mod() because we can control the other arguments:
 
    * Modulus is chosen with M=(2^num_bits - 1) (ie M=R-1), so output
-     isn't actually modulo anything.
-   * Therefore of of M' and Rinv are predictable as follows:
-      M' = 1
-      Rinv = 1
+   isn't actually modulo anything.
+   * Mprime and Rinv are therefore predictable as follows:
+   Mprime = 1
+   Rinv = 1
 
-   (See RSA Accelerator section in Technical Reference *
-   extended_binary_gcd() function above for more about M', Rinv)
+   (See RSA Accelerator section in Technical Reference for more about Mprime, Rinv)
 */
 static int mpi_mult_mpi_failover_mod_mult(mbedtls_mpi *Z, const mbedtls_mpi *X, const mbedtls_mpi *Y, size_t num_words)
- {
-     int ret = 0;
+{
+    int ret = 0;
 
-     /* Load coefficients to hardware */
-     esp_mpi_acquire_hardware();
+    /* Load coefficients to hardware */
+    esp_mpi_acquire_hardware();
 
-     /* M = 2^num_words - 1, so block is entirely FF */
-     for(int i = 0; i < num_words; i++) {
-         REG_WRITE(RSA_MEM_M_BLOCK_BASE + i * 4, UINT32_MAX);
-     }
-     /* Mprime = 1 */
-     REG_WRITE(RSA_M_DASH_REG, 1);
+    /* M = 2^num_words - 1, so block is entirely FF */
+    for(int i = 0; i < num_words; i++) {
+        REG_WRITE(RSA_MEM_M_BLOCK_BASE + i * 4, UINT32_MAX);
+    }
+    /* Mprime = 1 */
+    REG_WRITE(RSA_M_DASH_REG, 1);
 
-     /* "mode" register loaded with number of 512-bit blocks, minus 1 */
-     REG_WRITE(RSA_MULT_MODE_REG, (num_words / 16) - 1);
+    /* "mode" register loaded with number of 512-bit blocks, minus 1 */
+    REG_WRITE(RSA_MULT_MODE_REG, (num_words / 16) - 1);
 
-     /* Load X */
-     mpi_to_mem_block(RSA_MEM_X_BLOCK_BASE, X, num_words);
+    /* Load X */
+    mpi_to_mem_block(RSA_MEM_X_BLOCK_BASE, X, num_words);
 
-     /* Rinv = 1 */
-     REG_WRITE(RSA_MEM_RB_BLOCK_BASE, 1);
-     for(int i = 1; i < num_words; i++) {
-         REG_WRITE(RSA_MEM_RB_BLOCK_BASE + i * 4, 0);
-     }
+    /* Rinv = 1 */
+    REG_WRITE(RSA_MEM_RB_BLOCK_BASE, 1);
+    for(int i = 1; i < num_words; i++) {
+        REG_WRITE(RSA_MEM_RB_BLOCK_BASE + i * 4, 0);
+    }
 
-     execute_op(RSA_MULT_START_REG);
+    execute_op(RSA_MULT_START_REG);
 
-     MBEDTLS_MPI_CHK( modular_multiply_finish(Z, X, Y, num_words) );
+    /* finish the modular multiplication */
+    MBEDTLS_MPI_CHK( modular_multiply_finish(Z, X, Y, num_words) );
 
-     esp_mpi_release_hardware();
+    esp_mpi_release_hardware();
 
  cleanup:
-     return ret;
+    return ret;
 }
 
 #endif /* MBEDTLS_MPI_MUL_MPI_ALT */