]> granicus.if.org Git - esp-idf/commitdiff
mbedtls hardware RSA: Fix "mbedtls_mpi_exp_mod" hardware calculations
authorDong Heng <dongheng@espressif.com>
Wed, 16 Nov 2016 12:37:51 +0000 (20:37 +0800)
committerAngus Gratton <angus@espressif.com>
Fri, 18 Nov 2016 03:09:59 +0000 (14:09 +1100)
components/mbedtls/port/esp_bignum.c

index 0a835c9e8d315442b7237255550f51738343f751..401d3fc24bb8e762fb6b4c016e50db8c00e48119 100644 (file)
@@ -53,12 +53,11 @@ void mbedtls_mpi_printf(const char *name, const mbedtls_mpi *X)
     static char buf[1024];
     size_t n;
     memset(buf, 0, sizeof(buf));
-    printf("%s = 0x", name);
     mbedtls_mpi_write_string(X, 16, buf, sizeof(buf)-1, &n);
     if(n) {
-        puts(buf);
+        ESP_LOGI(TAG, "%s = 0x%s", name, buf);
     } else {
-        puts("TOOLONG");
+        ESP_LOGI(TAG, "TOOLONG");
     }
 }
 
@@ -278,6 +277,7 @@ int esp_mpi_mul_mpi_mod(mbedtls_mpi *Z, const mbedtls_mpi *X, const mbedtls_mpi
 /*
  * Sliding-window exponentiation: Z = X^Y mod M  (HAC 14.85)
  */
+ #if 0
 int mbedtls_mpi_exp_mod( mbedtls_mpi* Z, const mbedtls_mpi* X, const mbedtls_mpi* Y, const mbedtls_mpi* M, mbedtls_mpi* _RR )
 {
     int ret;
@@ -336,6 +336,155 @@ int mbedtls_mpi_exp_mod( mbedtls_mpi* Z, const mbedtls_mpi* X, const mbedtls_mpi
     return ret;
 }
 
+#else
+
+/**
+ * There is a need for the value of integer N' such that B^-1(B-1)-N^-1N'=1, 
+ * where B^-1(B-1) mod N=1. Actually, only the least significant part of 
+ * N' is needed, hence the definition N0'=N' mod b. We reproduce below the 
+ * simple algorithm from an article by Dusse and Kaliski to efficiently 
+ * find N0' from N0 and b 
+ */
+static mbedtls_mpi_uint modular_inverse(const mbedtls_mpi *M)
+{
+    int i;
+    uint64_t t = 1;
+    uint64_t two_2_i_minus_1 = 2;   /* 2^(i-1) */
+    uint64_t two_2_i = 4;           /* 2^i */
+    uint64_t N = M->p[0];
+
+    for (i = 2; i <= 32; i++) {
+        if ((mbedtls_mpi_uint) N * t % two_2_i >= two_2_i_minus_1) {
+            t += two_2_i_minus_1;
+        }
+
+        two_2_i_minus_1 <<= 1;
+        two_2_i <<= 1;
+    }
+
+    return (mbedtls_mpi_uint)(UINT32_MAX - t + 1);
+}
+
+static int bignum_param_init(const mbedtls_mpi *M, mbedtls_mpi *_RR, mbedtls_mpi *r, mbedtls_mpi_uint *Mi, size_t num_words)
+{
+    int ret = 0;
+    size_t num_bits;
+    mbedtls_mpi RR;
+
+    /* Calculate number of bits */
+    num_bits = num_words * 32;
+    ESP_LOGI(TAG, "num_bits = %d\n", num_bits);
+
+    /* 
+     *  R = b^n where b = 2^32, n=num_words,
+     *  R = 2^N (where N=num_bits)
+     *  RR(R^2) = 2^(2*N) (where N=num_bits)
+     *
+     *  r = RR(R^2) mod M
+     *
+     *  Get the RR(RR == r) value from up level if RR and RR->p is not NULL
+     */
+    ESP_LOGI(TAG, "r = RR(R^2) mod M\n");
+    if (_RR == NULL || _RR->p == NULL) {
+        ESP_LOGI(TAG, "RR(R^2) = 2^(2*N) (where N=num_bits)\n");
+        mbedtls_mpi_init(&RR);
+        MBEDTLS_MPI_CHK(mbedtls_mpi_set_bit(&RR, num_bits * 2, 1));
+        mbedtls_mpi_printf("RR", &RR);
+        
+        MBEDTLS_MPI_CHK(mbedtls_mpi_mod_mpi(r, &RR, M));
+        
+        if (_RR != NULL)
+            memcpy(_RR, r, sizeof( mbedtls_mpi ) );
+    } else {
+        memcpy(r, _RR, sizeof( mbedtls_mpi ) );
+    }
+    mbedtls_mpi_printf("r", r);
+
+    *Mi = modular_inverse(M);
+
+cleanup:
+    mbedtls_mpi_free(&RR);
+
+    return ret;
+}
+
+static void bignum_param_deinit(mbedtls_mpi *_RR, mbedtls_mpi *r)
+{
+    if (_RR == NULL || _RR->p == NULL)
+        mbedtls_mpi_free(r);
+}
+
+/*
+ * Sliding-window exponentiation: Z = X^Y mod M  (HAC 14.85)
+ */
+int mbedtls_mpi_exp_mod( mbedtls_mpi* Z, const mbedtls_mpi* X, const mbedtls_mpi* Y, const mbedtls_mpi* M, mbedtls_mpi* _RR )
+{
+    int ret = 0;
+    size_t z_words = hardware_words_needed(Z);
+    size_t x_words = hardware_words_needed(X);
+    size_t y_words = hardware_words_needed(Y);
+    size_t m_words = hardware_words_needed(M);
+    size_t num_words;
+
+    mbedtls_mpi r;
+    mbedtls_mpi_uint Mi = 0;
+
+    /* "all numbers must be the same length", so choose longest number
+       as cardinal length of operation...
+    */
+    num_words = z_words;
+    if (x_words > num_words) {
+        num_words = x_words;
+    }
+    if (y_words > num_words) {
+        num_words = y_words;
+    }
+    if (m_words > num_words) {
+        num_words = m_words;
+    }
+    ESP_LOGI(TAG, "num_words = %d  # %d, %d, %d\n", num_words, x_words, y_words, m_words);
+
+    if (num_words * 32 > 4096)
+        return MBEDTLS_ERR_MPI_NOT_ACCEPTABLE;
+    
+    mbedtls_mpi_init(&r);
+    ret = bignum_param_init(M, _RR, &r, &Mi, num_words);
+    if (ret != 0) {
+        return ret;
+    }
+
+    mbedtls_mpi_printf("X",X);
+    mbedtls_mpi_printf("Y",Y);
+
+    esp_mpi_acquire_hardware();
+
+    /* "mode" register loaded with number of 512-bit blocks, minus 1 */
+    REG_WRITE(RSA_MODEXP_MODE_REG, (num_words / 16) - 1);
+
+    /* Load M, X, Rinv, M-prime (M-prime is mod 2^32) */
+    mpi_to_mem_block(RSA_MEM_X_BLOCK_BASE, X, num_words);
+    mpi_to_mem_block(RSA_MEM_Y_BLOCK_BASE, Y, num_words);
+    mpi_to_mem_block(RSA_MEM_M_BLOCK_BASE, M, num_words);
+    mpi_to_mem_block(RSA_MEM_RB_BLOCK_BASE, &r, num_words);
+    REG_WRITE(RSA_M_DASH_REG, Mi);
+
+    execute_op(RSA_START_MODEXP_REG);
+
+    ret = mem_block_to_mpi(Z, RSA_MEM_Z_BLOCK_BASE, num_words);
+
+    esp_mpi_release_hardware();
+
+    mbedtls_mpi_printf("Z",Z);
+    ESP_LOGI(TAG, "print (Z == (X ** Y) %% M)\n");
+
+    bignum_param_deinit(_RR, &r);
+
+    return ret;
+}
+
+
+#endif
+
 #endif /* MBEDTLS_MPI_EXP_MOD_ALT */
 
 
@@ -385,7 +534,7 @@ static int modular_op_prepare(const mbedtls_mpi *X, const mbedtls_mpi *M, size_t
     /* Block of debugging data, output suitable to paste into Python
        TODO remove
     */
-    mbedtls_mpi_printf("R", &RR);
+    mbedtls_mpi_printf("RR", &RR);
     mbedtls_mpi_printf("M", M);
     mbedtls_mpi_printf("Rinv", &Rinv);
     mbedtls_mpi_printf("Mprime", &Mprime);
@@ -463,6 +612,7 @@ int mbedtls_mpi_mul_mpi( mbedtls_mpi *Z, const mbedtls_mpi *X, const mbedtls_mpi
        multiplication doesn't have the same restriction, so result is simply the
        number of bits in X plus number of bits in in Y.)
     */
+    //ESP_LOGE(TAG, "INFO: %d bit result (%d bits * %d bits)\n", words_z * 32, mbedtls_mpi_bitlen(X), mbedtls_mpi_bitlen(Y));
     if (words_mult * 32 > 2048) {
         /* Calculate new length of Z */
         words_z = words_x + words_y;