From 6b687b43f48252649af40f26ccca1e1599b65296 Mon Sep 17 00:00:00 2001 From: Dong Heng Date: Wed, 16 Nov 2016 20:37:51 +0800 Subject: [PATCH] mbedtls hardware RSA: Fix "mbedtls_mpi_exp_mod" hardware calculations --- components/mbedtls/port/esp_bignum.c | 158 ++++++++++++++++++++++++++- 1 file changed, 154 insertions(+), 4 deletions(-) diff --git a/components/mbedtls/port/esp_bignum.c b/components/mbedtls/port/esp_bignum.c index 0a835c9e8d..401d3fc24b 100644 --- a/components/mbedtls/port/esp_bignum.c +++ b/components/mbedtls/port/esp_bignum.c @@ -53,12 +53,11 @@ void mbedtls_mpi_printf(const char *name, const mbedtls_mpi *X) static char buf[1024]; size_t n; memset(buf, 0, sizeof(buf)); - printf("%s = 0x", name); mbedtls_mpi_write_string(X, 16, buf, sizeof(buf)-1, &n); if(n) { - puts(buf); + ESP_LOGI(TAG, "%s = 0x%s", name, buf); } else { - puts("TOOLONG"); + ESP_LOGI(TAG, "TOOLONG"); } } @@ -278,6 +277,7 @@ int esp_mpi_mul_mpi_mod(mbedtls_mpi *Z, const mbedtls_mpi *X, const mbedtls_mpi /* * Sliding-window exponentiation: Z = X^Y mod M (HAC 14.85) */ + #if 0 int mbedtls_mpi_exp_mod( mbedtls_mpi* Z, const mbedtls_mpi* X, const mbedtls_mpi* Y, const mbedtls_mpi* M, mbedtls_mpi* _RR ) { int ret; @@ -336,6 +336,155 @@ int mbedtls_mpi_exp_mod( mbedtls_mpi* Z, const mbedtls_mpi* X, const mbedtls_mpi return ret; } +#else + +/** + * There is a need for the value of integer N' such that B^-1(B-1)-N^-1N'=1, + * where B^-1(B-1) mod N=1. Actually, only the least significant part of + * N' is needed, hence the definition N0'=N' mod b. We reproduce below the + * simple algorithm from an article by Dusse and Kaliski to efficiently + * find N0' from N0 and b + */ +static mbedtls_mpi_uint modular_inverse(const mbedtls_mpi *M) +{ + int i; + uint64_t t = 1; + uint64_t two_2_i_minus_1 = 2; /* 2^(i-1) */ + uint64_t two_2_i = 4; /* 2^i */ + uint64_t N = M->p[0]; + + for (i = 2; i <= 32; i++) { + if ((mbedtls_mpi_uint) N * t % two_2_i >= two_2_i_minus_1) { + t += two_2_i_minus_1; + } + + two_2_i_minus_1 <<= 1; + two_2_i <<= 1; + } + + return (mbedtls_mpi_uint)(UINT32_MAX - t + 1); +} + +static int bignum_param_init(const mbedtls_mpi *M, mbedtls_mpi *_RR, mbedtls_mpi *r, mbedtls_mpi_uint *Mi, size_t num_words) +{ + int ret = 0; + size_t num_bits; + mbedtls_mpi RR; + + /* Calculate number of bits */ + num_bits = num_words * 32; + ESP_LOGI(TAG, "num_bits = %d\n", num_bits); + + /* + * R = b^n where b = 2^32, n=num_words, + * R = 2^N (where N=num_bits) + * RR(R^2) = 2^(2*N) (where N=num_bits) + * + * r = RR(R^2) mod M + * + * Get the RR(RR == r) value from up level if RR and RR->p is not NULL + */ + ESP_LOGI(TAG, "r = RR(R^2) mod M\n"); + if (_RR == NULL || _RR->p == NULL) { + ESP_LOGI(TAG, "RR(R^2) = 2^(2*N) (where N=num_bits)\n"); + mbedtls_mpi_init(&RR); + MBEDTLS_MPI_CHK(mbedtls_mpi_set_bit(&RR, num_bits * 2, 1)); + mbedtls_mpi_printf("RR", &RR); + + MBEDTLS_MPI_CHK(mbedtls_mpi_mod_mpi(r, &RR, M)); + + if (_RR != NULL) + memcpy(_RR, r, sizeof( mbedtls_mpi ) ); + } else { + memcpy(r, _RR, sizeof( mbedtls_mpi ) ); + } + mbedtls_mpi_printf("r", r); + + *Mi = modular_inverse(M); + +cleanup: + mbedtls_mpi_free(&RR); + + return ret; +} + +static void bignum_param_deinit(mbedtls_mpi *_RR, mbedtls_mpi *r) +{ + if (_RR == NULL || _RR->p == NULL) + mbedtls_mpi_free(r); +} + +/* + * Sliding-window exponentiation: Z = X^Y mod M (HAC 14.85) + */ +int mbedtls_mpi_exp_mod( mbedtls_mpi* Z, const mbedtls_mpi* X, const mbedtls_mpi* Y, const mbedtls_mpi* M, mbedtls_mpi* _RR ) +{ + int ret = 0; + size_t z_words = hardware_words_needed(Z); + size_t x_words = hardware_words_needed(X); + size_t y_words = hardware_words_needed(Y); + size_t m_words = hardware_words_needed(M); + size_t num_words; + + mbedtls_mpi r; + mbedtls_mpi_uint Mi = 0; + + /* "all numbers must be the same length", so choose longest number + as cardinal length of operation... + */ + num_words = z_words; + if (x_words > num_words) { + num_words = x_words; + } + if (y_words > num_words) { + num_words = y_words; + } + if (m_words > num_words) { + num_words = m_words; + } + ESP_LOGI(TAG, "num_words = %d # %d, %d, %d\n", num_words, x_words, y_words, m_words); + + if (num_words * 32 > 4096) + return MBEDTLS_ERR_MPI_NOT_ACCEPTABLE; + + mbedtls_mpi_init(&r); + ret = bignum_param_init(M, _RR, &r, &Mi, num_words); + if (ret != 0) { + return ret; + } + + mbedtls_mpi_printf("X",X); + mbedtls_mpi_printf("Y",Y); + + esp_mpi_acquire_hardware(); + + /* "mode" register loaded with number of 512-bit blocks, minus 1 */ + REG_WRITE(RSA_MODEXP_MODE_REG, (num_words / 16) - 1); + + /* Load M, X, Rinv, M-prime (M-prime is mod 2^32) */ + mpi_to_mem_block(RSA_MEM_X_BLOCK_BASE, X, num_words); + mpi_to_mem_block(RSA_MEM_Y_BLOCK_BASE, Y, num_words); + mpi_to_mem_block(RSA_MEM_M_BLOCK_BASE, M, num_words); + mpi_to_mem_block(RSA_MEM_RB_BLOCK_BASE, &r, num_words); + REG_WRITE(RSA_M_DASH_REG, Mi); + + execute_op(RSA_START_MODEXP_REG); + + ret = mem_block_to_mpi(Z, RSA_MEM_Z_BLOCK_BASE, num_words); + + esp_mpi_release_hardware(); + + mbedtls_mpi_printf("Z",Z); + ESP_LOGI(TAG, "print (Z == (X ** Y) %% M)\n"); + + bignum_param_deinit(_RR, &r); + + return ret; +} + + +#endif + #endif /* MBEDTLS_MPI_EXP_MOD_ALT */ @@ -385,7 +534,7 @@ static int modular_op_prepare(const mbedtls_mpi *X, const mbedtls_mpi *M, size_t /* Block of debugging data, output suitable to paste into Python TODO remove */ - mbedtls_mpi_printf("R", &RR); + mbedtls_mpi_printf("RR", &RR); mbedtls_mpi_printf("M", M); mbedtls_mpi_printf("Rinv", &Rinv); mbedtls_mpi_printf("Mprime", &Mprime); @@ -463,6 +612,7 @@ int mbedtls_mpi_mul_mpi( mbedtls_mpi *Z, const mbedtls_mpi *X, const mbedtls_mpi multiplication doesn't have the same restriction, so result is simply the number of bits in X plus number of bits in in Y.) */ + //ESP_LOGE(TAG, "INFO: %d bit result (%d bits * %d bits)\n", words_z * 32, mbedtls_mpi_bitlen(X), mbedtls_mpi_bitlen(Y)); if (words_mult * 32 > 2048) { /* Calculate new length of Z */ words_z = words_x + words_y; -- 2.40.0