_lock_release(&mpi_lock);
}
+/* Given a & b, determine u & v such that
+
+ gcd(a,b) = d = au + bv
+
+ Underlying algorithm comes from:
+ http://www.ucl.ac.uk/~ucahcjm/combopt/ext_gcd_python_programs.pdf
+ http://www.hackersdelight.org/hdcodetxt/mont64.c.txt
+ */
+static void extended_binary_gcd(const mbedtls_mpi *a, const mbedtls_mpi *b,
+ mbedtls_mpi *u, mbedtls_mpi *v)
+{
+ mbedtls_mpi ta, tb;
+
+ mbedtls_mpi_init(&ta);
+ mbedtls_mpi_copy(&ta, a);
+ mbedtls_mpi_init(&tb);
+ mbedtls_mpi_copy(&tb, b);
+
+ mbedtls_mpi_lset(u, 1);
+ mbedtls_mpi_lset(v, 0);
+
+ /* Loop invariant:
+ ta = u*2*a - v*b. */
+ while (mbedtls_mpi_cmp_int(&ta, 0) != 0) {
+ mbedtls_mpi_shift_r(&ta, 1);
+ if (mbedtls_mpi_get_bit(u, 0) == 0) {
+ // Remove common factor of 2 in u & v
+ mbedtls_mpi_shift_r(u, 1);
+ mbedtls_mpi_shift_r(v, 1);
+ }
+ else {
+ /* u = (u + b) >> 1 */
+ mbedtls_mpi_add_mpi(u, u, b);
+ mbedtls_mpi_shift_r(u, 1);
+ /* v = (v >> 1) + a */
+ mbedtls_mpi_shift_r(v, 1);
+ mbedtls_mpi_add_mpi(v, v, a);
+ }
+ }
+ mbedtls_mpi_free(&ta);
+ mbedtls_mpi_free(&tb);
+
+ /* u = u * 2, so 1 = u*a - v*b */
+ mbedtls_mpi_shift_l(u, 1);
+}
+
+/* inner part of MPI modular multiply, after Rinv & Mprime are calculated */
+static int mpi_mul_mpi_mod_inner(mbedtls_mpi *X, const mbedtls_mpi *A, const mbedtls_mpi *B, const mbedtls_mpi *M, mbedtls_mpi *Rinv, uint32_t Mprime, size_t num_words)
+{
+ int ret;
+ mbedtls_mpi TA, TB;
+ size_t num_bits = num_words * 32;
+
+ mbedtls_mpi_grow(Rinv, num_words);
+
+ /* TODO: fill memory blocks directly so this isn't needed */
+ mbedtls_mpi_init(&TA);
+ mbedtls_mpi_copy(&TA, A);
+ mbedtls_mpi_grow(&TA, num_words);
+ A = &TA;
+ mbedtls_mpi_init(&TB);
+ mbedtls_mpi_copy(&TB, B);
+ mbedtls_mpi_grow(&TB, num_words);
+ B = &TB;
+
+ esp_mpi_acquire_hardware();
+
+ if(ets_bigint_mod_mult_prepare(A->p, B->p, M->p, Mprime,
+ Rinv->p, num_bits, false)) {
+ mbedtls_mpi_grow(X, num_words);
+ ets_bigint_wait_finish();
+ if(ets_bigint_mod_mult_getz(M->p, X->p, num_bits)) {
+ X->s = A->s * B->s;
+ ret = 0;
+ } else {
+ printf("ets_bigint_mod_mult_getz failed\n");
+ ret = MBEDTLS_ERR_MPI_BAD_INPUT_DATA;
+ }
+ } else {
+ printf("ets_bigint_mod_mult_prepare failed\n");
+ ret = MBEDTLS_ERR_MPI_BAD_INPUT_DATA;
+ }
+ esp_mpi_release_hardware();
+
+ /* unclear why this is necessary, but the result seems
+ to come back rotated 32 bits to the right... */
+ uint32_t last_word = X->p[num_words-1];
+ X->p[num_words-1] = 0;
+ mbedtls_mpi_shift_l(X, 32);
+ X->p[0] = last_word;
+
+ mbedtls_mpi_free(&TA);
+ mbedtls_mpi_free(&TB);
+
+ return ret;
+}
+
+/* X = (A * B) mod M
+
+ Not an mbedTLS function
+
+ num_bits guaranteed to be a multiple of 512 already.
+
+ TODO: ensure M is odd
+ */
+int esp_mpi_mul_mpi_mod(mbedtls_mpi *X, const mbedtls_mpi *A, const mbedtls_mpi *B, const mbedtls_mpi *M, size_t num_bits)
+{
+ int ret = 0;
+ mbedtls_mpi RR, Rinv, Mprime;
+ uint32_t Mprime_int;
+ size_t num_words = num_bits / 32;
+
+ /* Rinv & Mprime are calculated via extended binary gcd
+ algorithm, see references on extended_binary_gcd above.
+ */
+ mbedtls_mpi_init(&Rinv);
+ mbedtls_mpi_init(&RR);
+ mbedtls_mpi_set_bit(&RR, num_bits+32, 1);
+ mbedtls_mpi_init(&Mprime);
+ extended_binary_gcd(&RR, M, &Rinv, &Mprime);
+
+ /* M' is mod 2^32 */
+ Mprime_int = Mprime.p[0];
+
+ ret = mpi_mul_mpi_mod_inner(X, A, B, M, &Rinv, Mprime_int, num_words);
+
+ mbedtls_mpi_free(&RR);
+ mbedtls_mpi_free(&Mprime);
+ mbedtls_mpi_free(&Rinv);
+
+ return ret;
+}
+
+
/*
* Helper for mbedtls_mpi multiplication
* copied/trimmed from mbedtls bignum.c
return res;
}
+
+/* Special-case multiply, where we use hardware montgomery mod
+ multiplication to solve the case where A or B are >2048 bits so
+ can't do standard multiplication.
+
+ the modulus here is chosen with M=(2^num_bits-1)
+ to guarantee the output isn't actually modulo anything. This means
+ we don't need to calculate M' and Rinv, they are predictable
+ as follows:
+ M' = 1
+ Rinv = (1 << (num_bits - 32)
+
+ (See RSA Accelerator section in Technical Reference for derivation
+ of M', Rinv)
+*/
+static int esp_mpi_mult_mpi_failover_mod_mult(mbedtls_mpi *X, const mbedtls_mpi *A, const mbedtls_mpi *B, size_t num_words)
+ {
+ mbedtls_mpi M, Rinv;
+ int ret;
+ size_t mprime;
+ size_t num_bits = num_words * 32;
+
+ mbedtls_mpi_init(&M);
+ mbedtls_mpi_init(&Rinv);
+
+ /* TODO: it may be faster to just use 4096-bit arithmetic every time,
+ and make these constants rather than runtime derived
+ derived. */
+ /* M = (2^num_words)-1 */
+ mbedtls_mpi_grow(&M, num_words);
+ for(int i = 0; i < num_words*32; i++) {
+ mbedtls_mpi_set_bit(&M, i, 1);
+ }
+
+ /* Rinv = (2^num_words-32) */
+ mbedtls_mpi_grow(&Rinv, num_words);
+ mbedtls_mpi_set_bit(&Rinv, num_bits - 32, 1);
+
+ mprime = 1;
+
+ ret = mpi_mul_mpi_mod_inner(X, A, B, &M, &Rinv, mprime, num_words);
+
+ mbedtls_mpi_free(&M);
+ mbedtls_mpi_free(&Rinv);
+ return ret;
+ }
+
int mbedtls_mpi_mul_mpi( mbedtls_mpi *X, const mbedtls_mpi *A, const mbedtls_mpi *B )
{
int ret = -1;
words_a = hardware_words_needed(A);
words_b = hardware_words_needed(B);
+ words_mult = (words_a > words_b ? words_a : words_b);
+
/* Take a copy of A if either X == A OR if A isn't long enough
to hold the number of words needed for hardware.
RAM. But we need to reimplement ets_bigint_mult_prepare() in
software for this.
*/
- if( X == A || A->n < words_a) {
+ if( X == A || A->n < words_mult) {
MBEDTLS_MPI_CHK( mbedtls_mpi_copy( &TA, A ) );
- MBEDTLS_MPI_CHK( mbedtls_mpi_grow( &TA, words_a) );
+ MBEDTLS_MPI_CHK( mbedtls_mpi_grow( &TA, words_mult) );
A = &TA;
}
/* Same for B */
- if( X == B || B->n < words_b ) {
+ if( X == B || B->n < words_mult ) {
MBEDTLS_MPI_CHK( mbedtls_mpi_copy( &TB, B ) );
- MBEDTLS_MPI_CHK( mbedtls_mpi_grow( &TB, words_b) );
+ MBEDTLS_MPI_CHK( mbedtls_mpi_grow( &TB, words_mult) );
B = &TB;
}
/* Result X has to have room for double the larger operand */
- words_mult = (words_a > words_b ? words_a : words_b);
words_x = words_mult * 2;
MBEDTLS_MPI_CHK( mbedtls_mpi_grow( X, words_x ) );
/* TODO: check if lset here is necessary, hardware should zero */
MBEDTLS_MPI_CHK( mbedtls_mpi_lset( X, 0 ) );
- esp_mpi_acquire_hardware();
+ /* If either operand is over 2048 bits, we can't use the standard hardware multiplier
+ (it assumes result is double longest operand, and result is max 4096 bits.)
+ However, we can fail over to mod_mult for up to 4096 bits.
+ */
if(words_mult * 32 > 2048) {
- printf("WARNING: %d bit operands (%d bits * %d bits) too large for hardware unit\n", words_mult * 32, mbedtls_mpi_bitlen(A), mbedtls_mpi_bitlen(B));
- }
-
- if (ets_bigint_mult_prepare(A->p, B->p, words_mult * 32)) {
- ets_bigint_wait_finish();
- /* NB: argument to bigint_mult_getz is length of inputs, double this number (words_x) is
- copied to output X->p.
+ /* TODO: check if there's an overflow condition if words_a & words_b are both
+ the bit lengths of the operands, result could be 1 bit longer
*/
- if (ets_bigint_mult_getz(X->p, words_mult * 32) == true) {
- ret = 0;
- } else {
- printf("ets_bigint_mult_getz failed\n");
- }
- } else{
- printf("Baseline multiplication failed\n");
- }
- esp_mpi_release_hardware();
-
- X->s = A->s * B->s;
+ if((words_a + words_b) * 32 > 4096) {
+ printf("ERROR: %d bit operands (%d bits * %d bits) too large for hardware unit\n", words_mult * 32, mbedtls_mpi_bitlen(A), mbedtls_mpi_bitlen(B));
+ ret = MBEDTLS_ERR_MPI_NOT_ACCEPTABLE;
+ }
+ else {
+ ret = esp_mpi_mult_mpi_failover_mod_mult(X, A, B, words_a + words_b);
+ }
+ }
+ else {
+
+ /* normal mpi multiplication */
+ esp_mpi_acquire_hardware();
+ if (ets_bigint_mult_prepare(A->p, B->p, words_mult * 32)) {
+ ets_bigint_wait_finish();
+ /* NB: argument to bigint_mult_getz is length of inputs, double this number (words_x) is
+ copied to output X->p.
+ */
+ if (ets_bigint_mult_getz(X->p, words_mult * 32) == true) {
+ X->s = A->s * B->s;
+ ret = 0;
+ } else {
+ printf("ets_bigint_mult_getz failed\n");
+ ret = MBEDTLS_ERR_MPI_NOT_ACCEPTABLE;
+ }
+ } else{
+ printf("Baseline multiplication failed\n");
+ ret = MBEDTLS_ERR_MPI_NOT_ACCEPTABLE;
+ }
+ esp_mpi_release_hardware();
+ }
cleanup:
mbedtls_mpi_free( &TB ); mbedtls_mpi_free( &TA );