ring/arithmetic/
bigint.rs

1// Copyright 2015-2016 Brian Smith.
2//
3// Permission to use, copy, modify, and/or distribute this software for any
4// purpose with or without fee is hereby granted, provided that the above
5// copyright notice and this permission notice appear in all copies.
6//
7// THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHORS DISCLAIM ALL WARRANTIES
8// WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
9// MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY
10// SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
11// WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION
12// OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
13// CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
14
15//! Multi-precision integers.
16//!
17//! # Modular Arithmetic.
18//!
19//! Modular arithmetic is done in finite commutative rings ℤ/mℤ for some
20//! modulus *m*. We work in finite commutative rings instead of finite fields
21//! because the RSA public modulus *n* is not prime, which means ℤ/nℤ contains
22//! nonzero elements that have no multiplicative inverse, so ℤ/nℤ is not a
23//! finite field.
24//!
25//! In some calculations we need to deal with multiple rings at once. For
26//! example, RSA private key operations operate in the rings ℤ/nℤ, ℤ/pℤ, and
27//! ℤ/qℤ. Types and functions dealing with such rings are all parameterized
28//! over a type `M` to ensure that we don't wrongly mix up the math, e.g. by
29//! multiplying an element of ℤ/pℤ by an element of ℤ/qℤ modulo q. This follows
30//! the "unit" pattern described in [Static checking of units in Servo].
31//!
32//! `Elem` also uses the static unit checking pattern to statically track the
33//! Montgomery factors that need to be canceled out in each value using it's
34//! `E` parameter.
35//!
36//! [Static checking of units in Servo]:
37//!     https://blog.mozilla.org/research/2014/06/23/static-checking-of-units-in-servo/
38
39use crate::{
40    arithmetic::montgomery::*,
41    bits, bssl, c, error,
42    limb::{self, Limb, LimbMask, LIMB_BITS, LIMB_BYTES},
43};
44use alloc::{borrow::ToOwned as _, boxed::Box, vec, vec::Vec};
45use core::{
46    marker::PhantomData,
47    ops::{Deref, DerefMut},
48};
49
50pub unsafe trait Prime {}
51
52struct Width<M> {
53    num_limbs: usize,
54
55    /// The modulus *m* that the width originated from.
56    m: PhantomData<M>,
57}
58
59/// All `BoxedLimbs<M>` are stored in the same number of limbs.
60struct BoxedLimbs<M> {
61    limbs: Box<[Limb]>,
62
63    /// The modulus *m* that determines the size of `limbx`.
64    m: PhantomData<M>,
65}
66
67impl<M> Deref for BoxedLimbs<M> {
68    type Target = [Limb];
69    #[inline]
70    fn deref(&self) -> &Self::Target {
71        &self.limbs
72    }
73}
74
75impl<M> DerefMut for BoxedLimbs<M> {
76    #[inline]
77    fn deref_mut(&mut self) -> &mut Self::Target {
78        &mut self.limbs
79    }
80}
81
82// TODO: `derive(Clone)` after https://github.com/rust-lang/rust/issues/26925
83// is resolved or restrict `M: Clone`.
84impl<M> Clone for BoxedLimbs<M> {
85    fn clone(&self) -> Self {
86        Self {
87            limbs: self.limbs.clone(),
88            m: self.m,
89        }
90    }
91}
92
93impl<M> BoxedLimbs<M> {
94    fn positive_minimal_width_from_be_bytes(
95        input: untrusted::Input,
96    ) -> Result<Self, error::KeyRejected> {
97        // Reject leading zeros. Also reject the value zero ([0]) because zero
98        // isn't positive.
99        if untrusted::Reader::new(input).peek(0) {
100            return Err(error::KeyRejected::invalid_encoding());
101        }
102        let num_limbs = (input.len() + LIMB_BYTES - 1) / LIMB_BYTES;
103        let mut r = Self::zero(Width {
104            num_limbs,
105            m: PhantomData,
106        });
107        limb::parse_big_endian_and_pad_consttime(input, &mut r)
108            .map_err(|error::Unspecified| error::KeyRejected::unexpected_error())?;
109        Ok(r)
110    }
111
112    fn minimal_width_from_unpadded(limbs: &[Limb]) -> Self {
113        debug_assert_ne!(limbs.last(), Some(&0));
114        Self {
115            limbs: limbs.to_owned().into_boxed_slice(),
116            m: PhantomData,
117        }
118    }
119
120    fn from_be_bytes_padded_less_than(
121        input: untrusted::Input,
122        m: &Modulus<M>,
123    ) -> Result<Self, error::Unspecified> {
124        let mut r = Self::zero(m.width());
125        limb::parse_big_endian_and_pad_consttime(input, &mut r)?;
126        if limb::limbs_less_than_limbs_consttime(&r, &m.limbs) != LimbMask::True {
127            return Err(error::Unspecified);
128        }
129        Ok(r)
130    }
131
132    #[inline]
133    fn is_zero(&self) -> bool {
134        limb::limbs_are_zero_constant_time(&self.limbs) == LimbMask::True
135    }
136
137    fn zero(width: Width<M>) -> Self {
138        Self {
139            limbs: vec![0; width.num_limbs].into_boxed_slice(),
140            m: PhantomData,
141        }
142    }
143
144    fn width(&self) -> Width<M> {
145        Width {
146            num_limbs: self.limbs.len(),
147            m: PhantomData,
148        }
149    }
150}
151
152/// A modulus *s* that is smaller than another modulus *l* so every element of
153/// ℤ/sℤ is also an element of ℤ/lℤ.
154pub unsafe trait SmallerModulus<L> {}
155
156/// A modulus *s* where s < l < 2*s for the given larger modulus *l*. This is
157/// the precondition for reduction by conditional subtraction,
158/// `elem_reduce_once()`.
159pub unsafe trait SlightlySmallerModulus<L>: SmallerModulus<L> {}
160
161/// A modulus *s* where √l <= s < l for the given larger modulus *l*. This is
162/// the precondition for the more general Montgomery reduction from ℤ/lℤ to
163/// ℤ/sℤ.
164pub unsafe trait NotMuchSmallerModulus<L>: SmallerModulus<L> {}
165
166pub unsafe trait PublicModulus {}
167
168/// The x86 implementation of `GFp_bn_mul_mont`, at least, requires at least 4
169/// limbs. For a long time we have required 4 limbs for all targets, though
170/// this may be unnecessary. TODO: Replace this with
171/// `n.len() < 256 / LIMB_BITS` so that 32-bit and 64-bit platforms behave the
172/// same.
173pub const MODULUS_MIN_LIMBS: usize = 4;
174
175pub const MODULUS_MAX_LIMBS: usize = 8192 / LIMB_BITS;
176
177/// The modulus *m* for a ring ℤ/mℤ, along with the precomputed values needed
178/// for efficient Montgomery multiplication modulo *m*. The value must be odd
179/// and larger than 2. The larger-than-1 requirement is imposed, at least, by
180/// the modular inversion code.
181pub struct Modulus<M> {
182    limbs: BoxedLimbs<M>, // Also `value >= 3`.
183
184    // n0 * N == -1 (mod r).
185    //
186    // r == 2**(N0_LIMBS_USED * LIMB_BITS) and LG_LITTLE_R == lg(r). This
187    // ensures that we can do integer division by |r| by simply ignoring
188    // `N0_LIMBS_USED` limbs. Similarly, we can calculate values modulo `r` by
189    // just looking at the lowest `N0_LIMBS_USED` limbs. This is what makes
190    // Montgomery multiplication efficient.
191    //
192    // As shown in Algorithm 1 of "Fast Prime Field Elliptic Curve Cryptography
193    // with 256 Bit Primes" by Shay Gueron and Vlad Krasnov, in the loop of a
194    // multi-limb Montgomery multiplication of a * b (mod n), given the
195    // unreduced product t == a * b, we repeatedly calculate:
196    //
197    //    t1 := t % r         |t1| is |t|'s lowest limb (see previous paragraph).
198    //    t2 := t1*n0*n
199    //    t3 := t + t2
200    //    t := t3 / r         copy all limbs of |t3| except the lowest to |t|.
201    //
202    // In the last step, it would only make sense to ignore the lowest limb of
203    // |t3| if it were zero. The middle steps ensure that this is the case:
204    //
205    //                            t3 ==  0 (mod r)
206    //                        t + t2 ==  0 (mod r)
207    //                   t + t1*n0*n ==  0 (mod r)
208    //                       t1*n0*n == -t (mod r)
209    //                        t*n0*n == -t (mod r)
210    //                          n0*n == -1 (mod r)
211    //                            n0 == -1/n (mod r)
212    //
213    // Thus, in each iteration of the loop, we multiply by the constant factor
214    // n0, the negative inverse of n (mod r).
215    //
216    // TODO(perf): Not all 32-bit platforms actually make use of n0[1]. For the
217    // ones that don't, we could use a shorter `R` value and use faster `Limb`
218    // calculations instead of double-precision `u64` calculations.
219    n0: N0,
220
221    oneRR: One<M, RR>,
222}
223
224impl<M: PublicModulus> core::fmt::Debug for Modulus<M> {
225    fn fmt(&self, fmt: &mut ::core::fmt::Formatter) -> Result<(), ::core::fmt::Error> {
226        fmt.debug_struct("Modulus")
227            // TODO: Print modulus value.
228            .finish()
229    }
230}
231
232impl<M> Modulus<M> {
233    pub fn from_be_bytes_with_bit_length(
234        input: untrusted::Input,
235    ) -> Result<(Self, bits::BitLength), error::KeyRejected> {
236        let limbs = BoxedLimbs::positive_minimal_width_from_be_bytes(input)?;
237        Self::from_boxed_limbs(limbs)
238    }
239
240    pub fn from_nonnegative_with_bit_length(
241        n: Nonnegative,
242    ) -> Result<(Self, bits::BitLength), error::KeyRejected> {
243        let limbs = BoxedLimbs {
244            limbs: n.limbs.into_boxed_slice(),
245            m: PhantomData,
246        };
247        Self::from_boxed_limbs(limbs)
248    }
249
250    fn from_boxed_limbs(n: BoxedLimbs<M>) -> Result<(Self, bits::BitLength), error::KeyRejected> {
251        if n.len() > MODULUS_MAX_LIMBS {
252            return Err(error::KeyRejected::too_large());
253        }
254        if n.len() < MODULUS_MIN_LIMBS {
255            return Err(error::KeyRejected::unexpected_error());
256        }
257        if limb::limbs_are_even_constant_time(&n) != LimbMask::False {
258            return Err(error::KeyRejected::invalid_component());
259        }
260        if limb::limbs_less_than_limb_constant_time(&n, 3) != LimbMask::False {
261            return Err(error::KeyRejected::unexpected_error());
262        }
263
264        // n_mod_r = n % r. As explained in the documentation for `n0`, this is
265        // done by taking the lowest `N0_LIMBS_USED` limbs of `n`.
266        #[allow(clippy::useless_conversion)]
267        let n0 = {
268            extern "C" {
269                fn GFp_bn_neg_inv_mod_r_u64(n: u64) -> u64;
270            }
271
272            // XXX: u64::from isn't guaranteed to be constant time.
273            let mut n_mod_r: u64 = u64::from(n[0]);
274
275            if N0_LIMBS_USED == 2 {
276                // XXX: If we use `<< LIMB_BITS` here then 64-bit builds
277                // fail to compile because of `deny(exceeding_bitshifts)`.
278                debug_assert_eq!(LIMB_BITS, 32);
279                n_mod_r |= u64::from(n[1]) << 32;
280            }
281            N0::from(unsafe { GFp_bn_neg_inv_mod_r_u64(n_mod_r) })
282        };
283
284        let bits = limb::limbs_minimal_bits(&n.limbs);
285        let oneRR = {
286            let partial = PartialModulus {
287                limbs: &n.limbs,
288                n0: n0.clone(),
289                m: PhantomData,
290            };
291
292            One::newRR(&partial, bits)
293        };
294
295        Ok((
296            Self {
297                limbs: n,
298                n0,
299                oneRR,
300            },
301            bits,
302        ))
303    }
304
305    #[inline]
306    fn width(&self) -> Width<M> {
307        self.limbs.width()
308    }
309
310    fn zero<E>(&self) -> Elem<M, E> {
311        Elem {
312            limbs: BoxedLimbs::zero(self.width()),
313            encoding: PhantomData,
314        }
315    }
316
317    // TODO: Get rid of this
318    fn one(&self) -> Elem<M, Unencoded> {
319        let mut r = self.zero();
320        r.limbs[0] = 1;
321        r
322    }
323
324    pub fn oneRR(&self) -> &One<M, RR> {
325        &self.oneRR
326    }
327
328    pub fn to_elem<L>(&self, l: &Modulus<L>) -> Elem<L, Unencoded>
329    where
330        M: SmallerModulus<L>,
331    {
332        // TODO: Encode this assertion into the `where` above.
333        assert_eq!(self.width().num_limbs, l.width().num_limbs);
334        let limbs = self.limbs.clone();
335        Elem {
336            limbs: BoxedLimbs {
337                limbs: limbs.limbs,
338                m: PhantomData,
339            },
340            encoding: PhantomData,
341        }
342    }
343
344    fn as_partial(&self) -> PartialModulus<M> {
345        PartialModulus {
346            limbs: &self.limbs,
347            n0: self.n0.clone(),
348            m: PhantomData,
349        }
350    }
351}
352
353struct PartialModulus<'a, M> {
354    limbs: &'a [Limb],
355    n0: N0,
356    m: PhantomData<M>,
357}
358
359impl<M> PartialModulus<'_, M> {
360    // TODO: XXX Avoid duplication with `Modulus`.
361    fn zero(&self) -> Elem<M, R> {
362        let width = Width {
363            num_limbs: self.limbs.len(),
364            m: PhantomData,
365        };
366        Elem {
367            limbs: BoxedLimbs::zero(width),
368            encoding: PhantomData,
369        }
370    }
371}
372
373/// Elements of ℤ/mℤ for some modulus *m*.
374//
375// Defaulting `E` to `Unencoded` is a convenience for callers from outside this
376// submodule. However, for maximum clarity, we always explicitly use
377// `Unencoded` within the `bigint` submodule.
378pub struct Elem<M, E = Unencoded> {
379    limbs: BoxedLimbs<M>,
380
381    /// The number of Montgomery factors that need to be canceled out from
382    /// `value` to get the actual value.
383    encoding: PhantomData<E>,
384}
385
386// TODO: `derive(Clone)` after https://github.com/rust-lang/rust/issues/26925
387// is resolved or restrict `M: Clone` and `E: Clone`.
388impl<M, E> Clone for Elem<M, E> {
389    fn clone(&self) -> Self {
390        Self {
391            limbs: self.limbs.clone(),
392            encoding: self.encoding,
393        }
394    }
395}
396
397impl<M, E> Elem<M, E> {
398    #[inline]
399    pub fn is_zero(&self) -> bool {
400        self.limbs.is_zero()
401    }
402}
403
404impl<M, E: ReductionEncoding> Elem<M, E> {
405    fn decode_once(self, m: &Modulus<M>) -> Elem<M, <E as ReductionEncoding>::Output> {
406        // A multiplication isn't required since we're multiplying by the
407        // unencoded value one (1); only a Montgomery reduction is needed.
408        // However the only non-multiplication Montgomery reduction function we
409        // have requires the input to be large, so we avoid using it here.
410        let mut limbs = self.limbs;
411        let num_limbs = m.width().num_limbs;
412        let mut one = [0; MODULUS_MAX_LIMBS];
413        one[0] = 1;
414        let one = &one[..num_limbs]; // assert!(num_limbs <= MODULUS_MAX_LIMBS);
415        limbs_mont_mul(&mut limbs, &one, &m.limbs, &m.n0);
416        Elem {
417            limbs,
418            encoding: PhantomData,
419        }
420    }
421}
422
423impl<M> Elem<M, R> {
424    #[inline]
425    pub fn into_unencoded(self, m: &Modulus<M>) -> Elem<M, Unencoded> {
426        self.decode_once(m)
427    }
428}
429
430impl<M> Elem<M, Unencoded> {
431    pub fn from_be_bytes_padded(
432        input: untrusted::Input,
433        m: &Modulus<M>,
434    ) -> Result<Self, error::Unspecified> {
435        Ok(Elem {
436            limbs: BoxedLimbs::from_be_bytes_padded_less_than(input, m)?,
437            encoding: PhantomData,
438        })
439    }
440
441    #[inline]
442    pub fn fill_be_bytes(&self, out: &mut [u8]) {
443        // See Falko Strenzke, "Manger's Attack revisited", ICICS 2010.
444        limb::big_endian_from_limbs(&self.limbs, out)
445    }
446
447    pub fn into_modulus<MM>(self) -> Result<Modulus<MM>, error::KeyRejected> {
448        let (m, _bits) =
449            Modulus::from_boxed_limbs(BoxedLimbs::minimal_width_from_unpadded(&self.limbs))?;
450        Ok(m)
451    }
452
453    fn is_one(&self) -> bool {
454        limb::limbs_equal_limb_constant_time(&self.limbs, 1) == LimbMask::True
455    }
456}
457
458pub fn elem_mul<M, AF, BF>(
459    a: &Elem<M, AF>,
460    b: Elem<M, BF>,
461    m: &Modulus<M>,
462) -> Elem<M, <(AF, BF) as ProductEncoding>::Output>
463where
464    (AF, BF): ProductEncoding,
465{
466    elem_mul_(a, b, &m.as_partial())
467}
468
469fn elem_mul_<M, AF, BF>(
470    a: &Elem<M, AF>,
471    mut b: Elem<M, BF>,
472    m: &PartialModulus<M>,
473) -> Elem<M, <(AF, BF) as ProductEncoding>::Output>
474where
475    (AF, BF): ProductEncoding,
476{
477    limbs_mont_mul(&mut b.limbs, &a.limbs, &m.limbs, &m.n0);
478    Elem {
479        limbs: b.limbs,
480        encoding: PhantomData,
481    }
482}
483
484fn elem_mul_by_2<M, AF>(a: &mut Elem<M, AF>, m: &PartialModulus<M>) {
485    extern "C" {
486        fn LIMBS_shl_mod(r: *mut Limb, a: *const Limb, m: *const Limb, num_limbs: c::size_t);
487    }
488    unsafe {
489        LIMBS_shl_mod(
490            a.limbs.as_mut_ptr(),
491            a.limbs.as_ptr(),
492            m.limbs.as_ptr(),
493            m.limbs.len(),
494        );
495    }
496}
497
498pub fn elem_reduced_once<Larger, Smaller: SlightlySmallerModulus<Larger>>(
499    a: &Elem<Larger, Unencoded>,
500    m: &Modulus<Smaller>,
501) -> Elem<Smaller, Unencoded> {
502    let mut r = a.limbs.clone();
503    assert!(r.len() <= m.limbs.len());
504    limb::limbs_reduce_once_constant_time(&mut r, &m.limbs);
505    Elem {
506        limbs: BoxedLimbs {
507            limbs: r.limbs,
508            m: PhantomData,
509        },
510        encoding: PhantomData,
511    }
512}
513
514#[inline]
515pub fn elem_reduced<Larger, Smaller: NotMuchSmallerModulus<Larger>>(
516    a: &Elem<Larger, Unencoded>,
517    m: &Modulus<Smaller>,
518) -> Elem<Smaller, RInverse> {
519    let mut tmp = [0; MODULUS_MAX_LIMBS];
520    let tmp = &mut tmp[..a.limbs.len()];
521    tmp.copy_from_slice(&a.limbs);
522
523    let mut r = m.zero();
524    limbs_from_mont_in_place(&mut r.limbs, tmp, &m.limbs, &m.n0);
525    r
526}
527
528fn elem_squared<M, E>(
529    mut a: Elem<M, E>,
530    m: &PartialModulus<M>,
531) -> Elem<M, <(E, E) as ProductEncoding>::Output>
532where
533    (E, E): ProductEncoding,
534{
535    limbs_mont_square(&mut a.limbs, &m.limbs, &m.n0);
536    Elem {
537        limbs: a.limbs,
538        encoding: PhantomData,
539    }
540}
541
542pub fn elem_widen<Larger, Smaller: SmallerModulus<Larger>>(
543    a: Elem<Smaller, Unencoded>,
544    m: &Modulus<Larger>,
545) -> Elem<Larger, Unencoded> {
546    let mut r = m.zero();
547    r.limbs[..a.limbs.len()].copy_from_slice(&a.limbs);
548    r
549}
550
551// TODO: Document why this works for all Montgomery factors.
552pub fn elem_add<M, E>(mut a: Elem<M, E>, b: Elem<M, E>, m: &Modulus<M>) -> Elem<M, E> {
553    extern "C" {
554        // `r` and `a` may alias.
555        fn LIMBS_add_mod(
556            r: *mut Limb,
557            a: *const Limb,
558            b: *const Limb,
559            m: *const Limb,
560            num_limbs: c::size_t,
561        );
562    }
563    unsafe {
564        LIMBS_add_mod(
565            a.limbs.as_mut_ptr(),
566            a.limbs.as_ptr(),
567            b.limbs.as_ptr(),
568            m.limbs.as_ptr(),
569            m.limbs.len(),
570        )
571    }
572    a
573}
574
575// TODO: Document why this works for all Montgomery factors.
576pub fn elem_sub<M, E>(mut a: Elem<M, E>, b: &Elem<M, E>, m: &Modulus<M>) -> Elem<M, E> {
577    extern "C" {
578        // `r` and `a` may alias.
579        fn LIMBS_sub_mod(
580            r: *mut Limb,
581            a: *const Limb,
582            b: *const Limb,
583            m: *const Limb,
584            num_limbs: c::size_t,
585        );
586    }
587    unsafe {
588        LIMBS_sub_mod(
589            a.limbs.as_mut_ptr(),
590            a.limbs.as_ptr(),
591            b.limbs.as_ptr(),
592            m.limbs.as_ptr(),
593            m.limbs.len(),
594        );
595    }
596    a
597}
598
599// The value 1, Montgomery-encoded some number of times.
600pub struct One<M, E>(Elem<M, E>);
601
602impl<M> One<M, RR> {
603    // Returns RR = = R**2 (mod n) where R = 2**r is the smallest power of
604    // 2**LIMB_BITS such that R > m.
605    //
606    // Even though the assembly on some 32-bit platforms works with 64-bit
607    // values, using `LIMB_BITS` here, rather than `N0_LIMBS_USED * LIMB_BITS`,
608    // is correct because R**2 will still be a multiple of the latter as
609    // `N0_LIMBS_USED` is either one or two.
610    fn newRR(m: &PartialModulus<M>, m_bits: bits::BitLength) -> Self {
611        let m_bits = m_bits.as_usize_bits();
612        let r = (m_bits + (LIMB_BITS - 1)) / LIMB_BITS * LIMB_BITS;
613
614        // base = 2**(lg m - 1).
615        let bit = m_bits - 1;
616        let mut base = m.zero();
617        base.limbs[bit / LIMB_BITS] = 1 << (bit % LIMB_BITS);
618
619        // Double `base` so that base == R == 2**r (mod m). For normal moduli
620        // that have the high bit of the highest limb set, this requires one
621        // doubling. Unusual moduli require more doublings but we are less
622        // concerned about the performance of those.
623        //
624        // Then double `base` again so that base == 2*R (mod n), i.e. `2` in
625        // Montgomery form (`elem_exp_vartime_()` requires the base to be in
626        // Montgomery form). Then compute
627        // RR = R**2 == base**r == R**r == (2**r)**r (mod n).
628        //
629        // Take advantage of the fact that `elem_mul_by_2` is faster than
630        // `elem_squared` by replacing some of the early squarings with shifts.
631        // TODO: Benchmark shift vs. squaring performance to determine the
632        // optimal value of `lg_base`.
633        let lg_base = 2usize; // Shifts vs. squaring trade-off.
634        debug_assert_eq!(lg_base.count_ones(), 1); // Must 2**n for n >= 0.
635        let shifts = r - bit + lg_base;
636        let exponent = (r / lg_base) as u64;
637        for _ in 0..shifts {
638            elem_mul_by_2(&mut base, m)
639        }
640        let RR = elem_exp_vartime_(base, exponent, m);
641
642        Self(Elem {
643            limbs: RR.limbs,
644            encoding: PhantomData, // PhantomData<RR>
645        })
646    }
647}
648
649impl<M, E> AsRef<Elem<M, E>> for One<M, E> {
650    fn as_ref(&self) -> &Elem<M, E> {
651        &self.0
652    }
653}
654
655/// A non-secret odd positive value in the range
656/// [3, PUBLIC_EXPONENT_MAX_VALUE].
657#[derive(Clone, Copy, Debug)]
658pub struct PublicExponent(u64);
659
660impl PublicExponent {
661    pub fn from_be_bytes(
662        input: untrusted::Input,
663        min_value: u64,
664    ) -> Result<Self, error::KeyRejected> {
665        if input.len() > 5 {
666            return Err(error::KeyRejected::too_large());
667        }
668        let value = input.read_all(error::KeyRejected::invalid_encoding(), |input| {
669            // The exponent can't be zero and it can't be prefixed with
670            // zero-valued bytes.
671            if input.peek(0) {
672                return Err(error::KeyRejected::invalid_encoding());
673            }
674            let mut value = 0u64;
675            loop {
676                let byte = input
677                    .read_byte()
678                    .map_err(|untrusted::EndOfInput| error::KeyRejected::invalid_encoding())?;
679                value = (value << 8) | u64::from(byte);
680                if input.at_end() {
681                    return Ok(value);
682                }
683            }
684        })?;
685
686        // Step 2 / Step b. NIST SP800-89 defers to FIPS 186-3, which requires
687        // `e >= 65537`. We enforce this when signing, but are more flexible in
688        // verification, for compatibility. Only small public exponents are
689        // supported.
690        if value & 1 != 1 {
691            return Err(error::KeyRejected::invalid_component());
692        }
693        debug_assert!(min_value & 1 == 1);
694        debug_assert!(min_value <= PUBLIC_EXPONENT_MAX_VALUE);
695        if min_value < 3 {
696            return Err(error::KeyRejected::invalid_component());
697        }
698        if value < min_value {
699            return Err(error::KeyRejected::too_small());
700        }
701        if value > PUBLIC_EXPONENT_MAX_VALUE {
702            return Err(error::KeyRejected::too_large());
703        }
704
705        Ok(Self(value))
706    }
707}
708
709// This limit was chosen to bound the performance of the simple
710// exponentiation-by-squaring implementation in `elem_exp_vartime`. In
711// particular, it helps mitigate theoretical resource exhaustion attacks. 33
712// bits was chosen as the limit based on the recommendations in [1] and
713// [2]. Windows CryptoAPI (at least older versions) doesn't support values
714// larger than 32 bits [3], so it is unlikely that exponents larger than 32
715// bits are being used for anything Windows commonly does.
716//
717// [1] https://www.imperialviolet.org/2012/03/16/rsae.html
718// [2] https://www.imperialviolet.org/2012/03/17/rsados.html
719// [3] https://msdn.microsoft.com/en-us/library/aa387685(VS.85).aspx
720const PUBLIC_EXPONENT_MAX_VALUE: u64 = (1u64 << 33) - 1;
721
722/// Calculates base**exponent (mod m).
723// TODO: The test coverage needs to be expanded, e.g. test with the largest
724// accepted exponent and with the most common values of 65537 and 3.
725pub fn elem_exp_vartime<M>(
726    base: Elem<M, Unencoded>,
727    PublicExponent(exponent): PublicExponent,
728    m: &Modulus<M>,
729) -> Elem<M, R> {
730    let base = elem_mul(m.oneRR().as_ref(), base, &m);
731    elem_exp_vartime_(base, exponent, &m.as_partial())
732}
733
734/// Calculates base**exponent (mod m).
735fn elem_exp_vartime_<M>(base: Elem<M, R>, exponent: u64, m: &PartialModulus<M>) -> Elem<M, R> {
736    // Use what [Knuth] calls the "S-and-X binary method", i.e. variable-time
737    // square-and-multiply that scans the exponent from the most significant
738    // bit to the least significant bit (left-to-right). Left-to-right requires
739    // less storage compared to right-to-left scanning, at the cost of needing
740    // to compute `exponent.leading_zeros()`, which we assume to be cheap.
741    //
742    // During RSA public key operations the exponent is almost always either 65537
743    // (0b10000000000000001) or 3 (0b11), both of which have a Hamming weight
744    // of 2. During Montgomery setup the exponent is almost always a power of two,
745    // with Hamming weight 1. As explained in [Knuth], exponentiation by squaring
746    // is the most efficient algorithm when the Hamming weight is 2 or less. It
747    // isn't the most efficient for all other, uncommon, exponent values but any
748    // suboptimality is bounded by `PUBLIC_EXPONENT_MAX_VALUE`.
749    //
750    // This implementation is slightly simplified by taking advantage of the
751    // fact that we require the exponent to be a positive integer.
752    //
753    // [Knuth]: The Art of Computer Programming, Volume 2: Seminumerical
754    //          Algorithms (3rd Edition), Section 4.6.3.
755    assert!(exponent >= 1);
756    assert!(exponent <= PUBLIC_EXPONENT_MAX_VALUE);
757    let mut acc = base.clone();
758    let mut bit = 1 << (64 - 1 - exponent.leading_zeros());
759    debug_assert!((exponent & bit) != 0);
760    while bit > 1 {
761        bit >>= 1;
762        acc = elem_squared(acc, m);
763        if (exponent & bit) != 0 {
764            acc = elem_mul_(&base, acc, m);
765        }
766    }
767    acc
768}
769
770// `M` represents the prime modulus for which the exponent is in the interval
771// [1, `m` - 1).
772pub struct PrivateExponent<M> {
773    limbs: BoxedLimbs<M>,
774}
775
776impl<M> PrivateExponent<M> {
777    pub fn from_be_bytes_padded(
778        input: untrusted::Input,
779        p: &Modulus<M>,
780    ) -> Result<Self, error::Unspecified> {
781        let dP = BoxedLimbs::from_be_bytes_padded_less_than(input, p)?;
782
783        // Proof that `dP < p - 1`:
784        //
785        // If `dP < p` then either `dP == p - 1` or `dP < p - 1`. Since `p` is
786        // odd, `p - 1` is even. `d` is odd, and an odd number modulo an even
787        // number is odd. Therefore `dP` must be odd. But then it cannot be
788        // `p - 1` and so we know `dP < p - 1`.
789        //
790        // Further we know `dP != 0` because `dP` is not even.
791        if limb::limbs_are_even_constant_time(&dP) != LimbMask::False {
792            return Err(error::Unspecified);
793        }
794
795        Ok(Self { limbs: dP })
796    }
797}
798
799impl<M: Prime> PrivateExponent<M> {
800    // Returns `p - 2`.
801    fn for_flt(p: &Modulus<M>) -> Self {
802        let two = elem_add(p.one(), p.one(), p);
803        let p_minus_2 = elem_sub(p.zero(), &two, p);
804        Self {
805            limbs: p_minus_2.limbs,
806        }
807    }
808}
809
810#[cfg(not(target_arch = "x86_64"))]
811pub fn elem_exp_consttime<M>(
812    base: Elem<M, R>,
813    exponent: &PrivateExponent<M>,
814    m: &Modulus<M>,
815) -> Result<Elem<M, Unencoded>, error::Unspecified> {
816    use crate::limb::Window;
817
818    const WINDOW_BITS: usize = 5;
819    const TABLE_ENTRIES: usize = 1 << WINDOW_BITS;
820
821    let num_limbs = m.limbs.len();
822
823    let mut table = vec![0; TABLE_ENTRIES * num_limbs];
824
825    fn gather<M>(table: &[Limb], i: Window, r: &mut Elem<M, R>) {
826        extern "C" {
827            fn LIMBS_select_512_32(
828                r: *mut Limb,
829                table: *const Limb,
830                num_limbs: c::size_t,
831                i: Window,
832            ) -> bssl::Result;
833        }
834        Result::from(unsafe {
835            LIMBS_select_512_32(r.limbs.as_mut_ptr(), table.as_ptr(), r.limbs.len(), i)
836        })
837        .unwrap();
838    }
839
840    fn power<M>(
841        table: &[Limb],
842        i: Window,
843        mut acc: Elem<M, R>,
844        mut tmp: Elem<M, R>,
845        m: &Modulus<M>,
846    ) -> (Elem<M, R>, Elem<M, R>) {
847        for _ in 0..WINDOW_BITS {
848            acc = elem_squared(acc, &m.as_partial());
849        }
850        gather(table, i, &mut tmp);
851        let acc = elem_mul(&tmp, acc, m);
852        (acc, tmp)
853    }
854
855    let tmp = m.one();
856    let tmp = elem_mul(m.oneRR().as_ref(), tmp, m);
857
858    fn entry(table: &[Limb], i: usize, num_limbs: usize) -> &[Limb] {
859        &table[(i * num_limbs)..][..num_limbs]
860    }
861    fn entry_mut(table: &mut [Limb], i: usize, num_limbs: usize) -> &mut [Limb] {
862        &mut table[(i * num_limbs)..][..num_limbs]
863    }
864    let num_limbs = m.limbs.len();
865    entry_mut(&mut table, 0, num_limbs).copy_from_slice(&tmp.limbs);
866    entry_mut(&mut table, 1, num_limbs).copy_from_slice(&base.limbs);
867    for i in 2..TABLE_ENTRIES {
868        let (src1, src2) = if i % 2 == 0 {
869            (i / 2, i / 2)
870        } else {
871            (i - 1, 1)
872        };
873        let (previous, rest) = table.split_at_mut(num_limbs * i);
874        let src1 = entry(previous, src1, num_limbs);
875        let src2 = entry(previous, src2, num_limbs);
876        let dst = entry_mut(rest, 0, num_limbs);
877        limbs_mont_product(dst, src1, src2, &m.limbs, &m.n0);
878    }
879
880    let (r, _) = limb::fold_5_bit_windows(
881        &exponent.limbs,
882        |initial_window| {
883            let mut r = Elem {
884                limbs: base.limbs,
885                encoding: PhantomData,
886            };
887            gather(&table, initial_window, &mut r);
888            (r, tmp)
889        },
890        |(acc, tmp), window| power(&table, window, acc, tmp, m),
891    );
892
893    let r = r.into_unencoded(m);
894
895    Ok(r)
896}
897
898/// Uses Fermat's Little Theorem to calculate modular inverse in constant time.
899pub fn elem_inverse_consttime<M: Prime>(
900    a: Elem<M, R>,
901    m: &Modulus<M>,
902) -> Result<Elem<M, Unencoded>, error::Unspecified> {
903    elem_exp_consttime(a, &PrivateExponent::for_flt(&m), m)
904}
905
906#[cfg(target_arch = "x86_64")]
907pub fn elem_exp_consttime<M>(
908    base: Elem<M, R>,
909    exponent: &PrivateExponent<M>,
910    m: &Modulus<M>,
911) -> Result<Elem<M, Unencoded>, error::Unspecified> {
912    // The x86_64 assembly was written under the assumption that the input data
913    // is aligned to `MOD_EXP_CTIME_MIN_CACHE_LINE_WIDTH` bytes, which was/is
914    // 64 in OpenSSL. Similarly, OpenSSL uses the x86_64 assembly functions by
915    // giving it only inputs `tmp`, `am`, and `np` that immediately follow the
916    // table. The code seems to "work" even when the inputs aren't exactly
917    // like that but the side channel defenses might not be as effective. All
918    // the awkwardness here stems from trying to use the assembly code like
919    // OpenSSL does.
920
921    use crate::limb::Window;
922
923    const WINDOW_BITS: usize = 5;
924    const TABLE_ENTRIES: usize = 1 << WINDOW_BITS;
925
926    let num_limbs = m.limbs.len();
927
928    const ALIGNMENT: usize = 64;
929    assert_eq!(ALIGNMENT % LIMB_BYTES, 0);
930    let mut table = vec![0; ((TABLE_ENTRIES + 3) * num_limbs) + ALIGNMENT];
931    let (table, state) = {
932        let misalignment = (table.as_ptr() as usize) % ALIGNMENT;
933        let table = &mut table[((ALIGNMENT - misalignment) / LIMB_BYTES)..];
934        assert_eq!((table.as_ptr() as usize) % ALIGNMENT, 0);
935        table.split_at_mut(TABLE_ENTRIES * num_limbs)
936    };
937
938    fn entry(table: &[Limb], i: usize, num_limbs: usize) -> &[Limb] {
939        &table[(i * num_limbs)..][..num_limbs]
940    }
941    fn entry_mut(table: &mut [Limb], i: usize, num_limbs: usize) -> &mut [Limb] {
942        &mut table[(i * num_limbs)..][..num_limbs]
943    }
944
945    const ACC: usize = 0; // `tmp` in OpenSSL
946    const BASE: usize = ACC + 1; // `am` in OpenSSL
947    const M: usize = BASE + 1; // `np` in OpenSSL
948
949    entry_mut(state, BASE, num_limbs).copy_from_slice(&base.limbs);
950    entry_mut(state, M, num_limbs).copy_from_slice(&m.limbs);
951
952    fn scatter(table: &mut [Limb], state: &[Limb], i: Window, num_limbs: usize) {
953        extern "C" {
954            fn GFp_bn_scatter5(a: *const Limb, a_len: c::size_t, table: *mut Limb, i: Window);
955        }
956        unsafe {
957            GFp_bn_scatter5(
958                entry(state, ACC, num_limbs).as_ptr(),
959                num_limbs,
960                table.as_mut_ptr(),
961                i,
962            )
963        }
964    }
965
966    fn gather(table: &[Limb], state: &mut [Limb], i: Window, num_limbs: usize) {
967        extern "C" {
968            fn GFp_bn_gather5(r: *mut Limb, a_len: c::size_t, table: *const Limb, i: Window);
969        }
970        unsafe {
971            GFp_bn_gather5(
972                entry_mut(state, ACC, num_limbs).as_mut_ptr(),
973                num_limbs,
974                table.as_ptr(),
975                i,
976            )
977        }
978    }
979
980    fn gather_square(table: &[Limb], state: &mut [Limb], n0: &N0, i: Window, num_limbs: usize) {
981        gather(table, state, i, num_limbs);
982        assert_eq!(ACC, 0);
983        let (acc, rest) = state.split_at_mut(num_limbs);
984        let m = entry(rest, M - 1, num_limbs);
985        limbs_mont_square(acc, m, n0);
986    }
987
988    fn gather_mul_base(table: &[Limb], state: &mut [Limb], n0: &N0, i: Window, num_limbs: usize) {
989        extern "C" {
990            fn GFp_bn_mul_mont_gather5(
991                rp: *mut Limb,
992                ap: *const Limb,
993                table: *const Limb,
994                np: *const Limb,
995                n0: &N0,
996                num: c::size_t,
997                power: Window,
998            );
999        }
1000        unsafe {
1001            GFp_bn_mul_mont_gather5(
1002                entry_mut(state, ACC, num_limbs).as_mut_ptr(),
1003                entry(state, BASE, num_limbs).as_ptr(),
1004                table.as_ptr(),
1005                entry(state, M, num_limbs).as_ptr(),
1006                n0,
1007                num_limbs,
1008                i,
1009            );
1010        }
1011    }
1012
1013    fn power(table: &[Limb], state: &mut [Limb], n0: &N0, i: Window, num_limbs: usize) {
1014        extern "C" {
1015            fn GFp_bn_power5(
1016                r: *mut Limb,
1017                a: *const Limb,
1018                table: *const Limb,
1019                n: *const Limb,
1020                n0: &N0,
1021                num: c::size_t,
1022                i: Window,
1023            );
1024        }
1025        unsafe {
1026            GFp_bn_power5(
1027                entry_mut(state, ACC, num_limbs).as_mut_ptr(),
1028                entry_mut(state, ACC, num_limbs).as_mut_ptr(),
1029                table.as_ptr(),
1030                entry(state, M, num_limbs).as_ptr(),
1031                n0,
1032                num_limbs,
1033                i,
1034            );
1035        }
1036    }
1037
1038    // table[0] = base**0.
1039    {
1040        let acc = entry_mut(state, ACC, num_limbs);
1041        acc[0] = 1;
1042        limbs_mont_mul(acc, &m.oneRR.0.limbs, &m.limbs, &m.n0);
1043    }
1044    scatter(table, state, 0, num_limbs);
1045
1046    // table[1] = base**1.
1047    entry_mut(state, ACC, num_limbs).copy_from_slice(&base.limbs);
1048    scatter(table, state, 1, num_limbs);
1049
1050    for i in 2..(TABLE_ENTRIES as Window) {
1051        if i % 2 == 0 {
1052            // TODO: Optimize this to avoid gathering
1053            gather_square(table, state, &m.n0, i / 2, num_limbs);
1054        } else {
1055            gather_mul_base(table, state, &m.n0, i - 1, num_limbs)
1056        };
1057        scatter(table, state, i, num_limbs);
1058    }
1059
1060    let state = limb::fold_5_bit_windows(
1061        &exponent.limbs,
1062        |initial_window| {
1063            gather(table, state, initial_window, num_limbs);
1064            state
1065        },
1066        |state, window| {
1067            power(table, state, &m.n0, window, num_limbs);
1068            state
1069        },
1070    );
1071
1072    extern "C" {
1073        fn GFp_bn_from_montgomery(
1074            r: *mut Limb,
1075            a: *const Limb,
1076            not_used: *const Limb,
1077            n: *const Limb,
1078            n0: &N0,
1079            num: c::size_t,
1080        ) -> bssl::Result;
1081    }
1082    Result::from(unsafe {
1083        GFp_bn_from_montgomery(
1084            entry_mut(state, ACC, num_limbs).as_mut_ptr(),
1085            entry(state, ACC, num_limbs).as_ptr(),
1086            core::ptr::null(),
1087            entry(state, M, num_limbs).as_ptr(),
1088            &m.n0,
1089            num_limbs,
1090        )
1091    })?;
1092    let mut r = Elem {
1093        limbs: base.limbs,
1094        encoding: PhantomData,
1095    };
1096    r.limbs.copy_from_slice(entry(state, ACC, num_limbs));
1097    Ok(r)
1098}
1099
1100/// Verified a == b**-1 (mod m), i.e. a**-1 == b (mod m).
1101pub fn verify_inverses_consttime<M>(
1102    a: &Elem<M, R>,
1103    b: Elem<M, Unencoded>,
1104    m: &Modulus<M>,
1105) -> Result<(), error::Unspecified> {
1106    if elem_mul(a, b, m).is_one() {
1107        Ok(())
1108    } else {
1109        Err(error::Unspecified)
1110    }
1111}
1112
1113#[inline]
1114pub fn elem_verify_equal_consttime<M, E>(
1115    a: &Elem<M, E>,
1116    b: &Elem<M, E>,
1117) -> Result<(), error::Unspecified> {
1118    if limb::limbs_equal_limbs_consttime(&a.limbs, &b.limbs) == LimbMask::True {
1119        Ok(())
1120    } else {
1121        Err(error::Unspecified)
1122    }
1123}
1124
1125/// Nonnegative integers.
1126pub struct Nonnegative {
1127    limbs: Vec<Limb>,
1128}
1129
1130impl Nonnegative {
1131    pub fn from_be_bytes_with_bit_length(
1132        input: untrusted::Input,
1133    ) -> Result<(Self, bits::BitLength), error::Unspecified> {
1134        let mut limbs = vec![0; (input.len() + LIMB_BYTES - 1) / LIMB_BYTES];
1135        // Rejects empty inputs.
1136        limb::parse_big_endian_and_pad_consttime(input, &mut limbs)?;
1137        while limbs.last() == Some(&0) {
1138            let _ = limbs.pop();
1139        }
1140        let r_bits = limb::limbs_minimal_bits(&limbs);
1141        Ok((Self { limbs }, r_bits))
1142    }
1143
1144    #[inline]
1145    pub fn is_odd(&self) -> bool {
1146        limb::limbs_are_even_constant_time(&self.limbs) != LimbMask::True
1147    }
1148
1149    pub fn verify_less_than(&self, other: &Self) -> Result<(), error::Unspecified> {
1150        if !greater_than(other, self) {
1151            return Err(error::Unspecified);
1152        }
1153        Ok(())
1154    }
1155
1156    pub fn to_elem<M>(&self, m: &Modulus<M>) -> Result<Elem<M, Unencoded>, error::Unspecified> {
1157        self.verify_less_than_modulus(&m)?;
1158        let mut r = m.zero();
1159        r.limbs[0..self.limbs.len()].copy_from_slice(&self.limbs);
1160        Ok(r)
1161    }
1162
1163    pub fn verify_less_than_modulus<M>(&self, m: &Modulus<M>) -> Result<(), error::Unspecified> {
1164        if self.limbs.len() > m.limbs.len() {
1165            return Err(error::Unspecified);
1166        }
1167        if self.limbs.len() == m.limbs.len() {
1168            if limb::limbs_less_than_limbs_consttime(&self.limbs, &m.limbs) != LimbMask::True {
1169                return Err(error::Unspecified);
1170            }
1171        }
1172        Ok(())
1173    }
1174}
1175
1176// Returns a > b.
1177fn greater_than(a: &Nonnegative, b: &Nonnegative) -> bool {
1178    if a.limbs.len() == b.limbs.len() {
1179        limb::limbs_less_than_limbs_vartime(&b.limbs, &a.limbs)
1180    } else {
1181        a.limbs.len() > b.limbs.len()
1182    }
1183}
1184
1185#[derive(Clone)]
1186#[repr(transparent)]
1187struct N0([Limb; 2]);
1188
1189const N0_LIMBS_USED: usize = 64 / LIMB_BITS;
1190
1191impl From<u64> for N0 {
1192    #[inline]
1193    fn from(n0: u64) -> Self {
1194        #[cfg(target_pointer_width = "64")]
1195        {
1196            Self([n0, 0])
1197        }
1198
1199        #[cfg(target_pointer_width = "32")]
1200        {
1201            Self([n0 as Limb, (n0 >> LIMB_BITS) as Limb])
1202        }
1203    }
1204}
1205
1206/// r *= a
1207fn limbs_mont_mul(r: &mut [Limb], a: &[Limb], m: &[Limb], n0: &N0) {
1208    debug_assert_eq!(r.len(), m.len());
1209    debug_assert_eq!(a.len(), m.len());
1210
1211    #[cfg(any(
1212        target_arch = "aarch64",
1213        target_arch = "arm",
1214        target_arch = "x86_64",
1215        target_arch = "x86"
1216    ))]
1217    unsafe {
1218        GFp_bn_mul_mont(
1219            r.as_mut_ptr(),
1220            r.as_ptr(),
1221            a.as_ptr(),
1222            m.as_ptr(),
1223            n0,
1224            r.len(),
1225        )
1226    }
1227
1228    #[cfg(not(any(
1229        target_arch = "aarch64",
1230        target_arch = "arm",
1231        target_arch = "x86_64",
1232        target_arch = "x86"
1233    )))]
1234    {
1235        let mut tmp = [0; 2 * MODULUS_MAX_LIMBS];
1236        let tmp = &mut tmp[..(2 * a.len())];
1237        limbs_mul(tmp, r, a);
1238        limbs_from_mont_in_place(r, tmp, m, n0);
1239    }
1240}
1241
1242fn limbs_from_mont_in_place(r: &mut [Limb], tmp: &mut [Limb], m: &[Limb], n0: &N0) {
1243    extern "C" {
1244        fn GFp_bn_from_montgomery_in_place(
1245            r: *mut Limb,
1246            num_r: c::size_t,
1247            a: *mut Limb,
1248            num_a: c::size_t,
1249            n: *const Limb,
1250            num_n: c::size_t,
1251            n0: &N0,
1252        ) -> bssl::Result;
1253    }
1254    Result::from(unsafe {
1255        GFp_bn_from_montgomery_in_place(
1256            r.as_mut_ptr(),
1257            r.len(),
1258            tmp.as_mut_ptr(),
1259            tmp.len(),
1260            m.as_ptr(),
1261            m.len(),
1262            &n0,
1263        )
1264    })
1265    .unwrap()
1266}
1267
1268#[cfg(not(any(
1269    target_arch = "aarch64",
1270    target_arch = "arm",
1271    target_arch = "x86_64",
1272    target_arch = "x86"
1273)))]
1274fn limbs_mul(r: &mut [Limb], a: &[Limb], b: &[Limb]) {
1275    debug_assert_eq!(r.len(), 2 * a.len());
1276    debug_assert_eq!(a.len(), b.len());
1277    let ab_len = a.len();
1278
1279    crate::polyfill::slice::fill(&mut r[..ab_len], 0);
1280    for (i, &b_limb) in b.iter().enumerate() {
1281        r[ab_len + i] = unsafe {
1282            GFp_limbs_mul_add_limb(
1283                (&mut r[i..][..ab_len]).as_mut_ptr(),
1284                a.as_ptr(),
1285                b_limb,
1286                ab_len,
1287            )
1288        };
1289    }
1290}
1291
1292/// r = a * b
1293#[cfg(not(target_arch = "x86_64"))]
1294fn limbs_mont_product(r: &mut [Limb], a: &[Limb], b: &[Limb], m: &[Limb], n0: &N0) {
1295    debug_assert_eq!(r.len(), m.len());
1296    debug_assert_eq!(a.len(), m.len());
1297    debug_assert_eq!(b.len(), m.len());
1298
1299    #[cfg(any(
1300        target_arch = "aarch64",
1301        target_arch = "arm",
1302        target_arch = "x86_64",
1303        target_arch = "x86"
1304    ))]
1305    unsafe {
1306        GFp_bn_mul_mont(
1307            r.as_mut_ptr(),
1308            a.as_ptr(),
1309            b.as_ptr(),
1310            m.as_ptr(),
1311            n0,
1312            r.len(),
1313        )
1314    }
1315
1316    #[cfg(not(any(
1317        target_arch = "aarch64",
1318        target_arch = "arm",
1319        target_arch = "x86_64",
1320        target_arch = "x86"
1321    )))]
1322    {
1323        let mut tmp = [0; 2 * MODULUS_MAX_LIMBS];
1324        let tmp = &mut tmp[..(2 * a.len())];
1325        limbs_mul(tmp, a, b);
1326        limbs_from_mont_in_place(r, tmp, m, n0)
1327    }
1328}
1329
1330/// r = r**2
1331fn limbs_mont_square(r: &mut [Limb], m: &[Limb], n0: &N0) {
1332    debug_assert_eq!(r.len(), m.len());
1333    #[cfg(any(
1334        target_arch = "aarch64",
1335        target_arch = "arm",
1336        target_arch = "x86_64",
1337        target_arch = "x86"
1338    ))]
1339    unsafe {
1340        GFp_bn_mul_mont(
1341            r.as_mut_ptr(),
1342            r.as_ptr(),
1343            r.as_ptr(),
1344            m.as_ptr(),
1345            n0,
1346            r.len(),
1347        )
1348    }
1349
1350    #[cfg(not(any(
1351        target_arch = "aarch64",
1352        target_arch = "arm",
1353        target_arch = "x86_64",
1354        target_arch = "x86"
1355    )))]
1356    {
1357        let mut tmp = [0; 2 * MODULUS_MAX_LIMBS];
1358        let tmp = &mut tmp[..(2 * r.len())];
1359        limbs_mul(tmp, r, r);
1360        limbs_from_mont_in_place(r, tmp, m, n0)
1361    }
1362}
1363
1364extern "C" {
1365    #[cfg(any(
1366        target_arch = "aarch64",
1367        target_arch = "arm",
1368        target_arch = "x86_64",
1369        target_arch = "x86"
1370    ))]
1371    // `r` and/or 'a' and/or 'b' may alias.
1372    fn GFp_bn_mul_mont(
1373        r: *mut Limb,
1374        a: *const Limb,
1375        b: *const Limb,
1376        n: *const Limb,
1377        n0: &N0,
1378        num_limbs: c::size_t,
1379    );
1380
1381    // `r` must not alias `a`
1382    #[cfg(any(
1383        test,
1384        not(any(
1385            target_arch = "aarch64",
1386            target_arch = "arm",
1387            target_arch = "x86_64",
1388            target_arch = "x86"
1389        ))
1390    ))]
1391    #[must_use]
1392    fn GFp_limbs_mul_add_limb(r: *mut Limb, a: *const Limb, b: Limb, num_limbs: c::size_t) -> Limb;
1393}
1394
1395#[cfg(test)]
1396mod tests {
1397    use super::*;
1398    use crate::test;
1399    use alloc::format;
1400
1401    // Type-level representation of an arbitrary modulus.
1402    struct M {}
1403
1404    unsafe impl PublicModulus for M {}
1405
1406    #[test]
1407    fn test_elem_exp_consttime() {
1408        test::run(
1409            test_file!("bigint_elem_exp_consttime_tests.txt"),
1410            |section, test_case| {
1411                assert_eq!(section, "");
1412
1413                let m = consume_modulus::<M>(test_case, "M");
1414                let expected_result = consume_elem(test_case, "ModExp", &m);
1415                let base = consume_elem(test_case, "A", &m);
1416                let e = {
1417                    let bytes = test_case.consume_bytes("E");
1418                    PrivateExponent::from_be_bytes_padded(untrusted::Input::from(&bytes), &m)
1419                        .expect("valid exponent")
1420                };
1421                let base = into_encoded(base, &m);
1422                let actual_result = elem_exp_consttime(base, &e, &m).unwrap();
1423                assert_elem_eq(&actual_result, &expected_result);
1424
1425                Ok(())
1426            },
1427        )
1428    }
1429
1430    // TODO: fn test_elem_exp_vartime() using
1431    // "src/rsa/bigint_elem_exp_vartime_tests.txt". See that file for details.
1432    // In the meantime, the function is tested indirectly via the RSA
1433    // verification and signing tests.
1434    #[test]
1435    fn test_elem_mul() {
1436        test::run(
1437            test_file!("bigint_elem_mul_tests.txt"),
1438            |section, test_case| {
1439                assert_eq!(section, "");
1440
1441                let m = consume_modulus::<M>(test_case, "M");
1442                let expected_result = consume_elem(test_case, "ModMul", &m);
1443                let a = consume_elem(test_case, "A", &m);
1444                let b = consume_elem(test_case, "B", &m);
1445
1446                let b = into_encoded(b, &m);
1447                let a = into_encoded(a, &m);
1448                let actual_result = elem_mul(&a, b, &m);
1449                let actual_result = actual_result.into_unencoded(&m);
1450                assert_elem_eq(&actual_result, &expected_result);
1451
1452                Ok(())
1453            },
1454        )
1455    }
1456
1457    #[test]
1458    fn test_elem_squared() {
1459        test::run(
1460            test_file!("bigint_elem_squared_tests.txt"),
1461            |section, test_case| {
1462                assert_eq!(section, "");
1463
1464                let m = consume_modulus::<M>(test_case, "M");
1465                let expected_result = consume_elem(test_case, "ModSquare", &m);
1466                let a = consume_elem(test_case, "A", &m);
1467
1468                let a = into_encoded(a, &m);
1469                let actual_result = elem_squared(a, &m.as_partial());
1470                let actual_result = actual_result.into_unencoded(&m);
1471                assert_elem_eq(&actual_result, &expected_result);
1472
1473                Ok(())
1474            },
1475        )
1476    }
1477
1478    #[test]
1479    fn test_elem_reduced() {
1480        test::run(
1481            test_file!("bigint_elem_reduced_tests.txt"),
1482            |section, test_case| {
1483                assert_eq!(section, "");
1484
1485                struct MM {}
1486                unsafe impl SmallerModulus<MM> for M {}
1487                unsafe impl NotMuchSmallerModulus<MM> for M {}
1488
1489                let m = consume_modulus::<M>(test_case, "M");
1490                let expected_result = consume_elem(test_case, "R", &m);
1491                let a =
1492                    consume_elem_unchecked::<MM>(test_case, "A", expected_result.limbs.len() * 2);
1493
1494                let actual_result = elem_reduced(&a, &m);
1495                let oneRR = m.oneRR();
1496                let actual_result = elem_mul(oneRR.as_ref(), actual_result, &m);
1497                assert_elem_eq(&actual_result, &expected_result);
1498
1499                Ok(())
1500            },
1501        )
1502    }
1503
1504    #[test]
1505    fn test_elem_reduced_once() {
1506        test::run(
1507            test_file!("bigint_elem_reduced_once_tests.txt"),
1508            |section, test_case| {
1509                assert_eq!(section, "");
1510
1511                struct N {}
1512                struct QQ {}
1513                unsafe impl SmallerModulus<N> for QQ {}
1514                unsafe impl SlightlySmallerModulus<N> for QQ {}
1515
1516                let qq = consume_modulus::<QQ>(test_case, "QQ");
1517                let expected_result = consume_elem::<QQ>(test_case, "R", &qq);
1518                let n = consume_modulus::<N>(test_case, "N");
1519                let a = consume_elem::<N>(test_case, "A", &n);
1520
1521                let actual_result = elem_reduced_once(&a, &qq);
1522                assert_elem_eq(&actual_result, &expected_result);
1523
1524                Ok(())
1525            },
1526        )
1527    }
1528
1529    #[test]
1530    fn test_modulus_debug() {
1531        let (modulus, _) = Modulus::<M>::from_be_bytes_with_bit_length(untrusted::Input::from(
1532            &[0xff; LIMB_BYTES * MODULUS_MIN_LIMBS],
1533        ))
1534        .unwrap();
1535        assert_eq!("Modulus", format!("{:?}", modulus));
1536    }
1537
1538    #[test]
1539    fn test_public_exponent_debug() {
1540        let exponent =
1541            PublicExponent::from_be_bytes(untrusted::Input::from(&[0x1, 0x00, 0x01]), 65537)
1542                .unwrap();
1543        assert_eq!("PublicExponent(65537)", format!("{:?}", exponent));
1544    }
1545
1546    fn consume_elem<M>(
1547        test_case: &mut test::TestCase,
1548        name: &str,
1549        m: &Modulus<M>,
1550    ) -> Elem<M, Unencoded> {
1551        let value = test_case.consume_bytes(name);
1552        Elem::from_be_bytes_padded(untrusted::Input::from(&value), m).unwrap()
1553    }
1554
1555    fn consume_elem_unchecked<M>(
1556        test_case: &mut test::TestCase,
1557        name: &str,
1558        num_limbs: usize,
1559    ) -> Elem<M, Unencoded> {
1560        let value = consume_nonnegative(test_case, name);
1561        let mut limbs = BoxedLimbs::zero(Width {
1562            num_limbs,
1563            m: PhantomData,
1564        });
1565        limbs[0..value.limbs.len()].copy_from_slice(&value.limbs);
1566        Elem {
1567            limbs,
1568            encoding: PhantomData,
1569        }
1570    }
1571
1572    fn consume_modulus<M>(test_case: &mut test::TestCase, name: &str) -> Modulus<M> {
1573        let value = test_case.consume_bytes(name);
1574        let (value, _) =
1575            Modulus::from_be_bytes_with_bit_length(untrusted::Input::from(&value)).unwrap();
1576        value
1577    }
1578
1579    fn consume_nonnegative(test_case: &mut test::TestCase, name: &str) -> Nonnegative {
1580        let bytes = test_case.consume_bytes(name);
1581        let (r, _r_bits) =
1582            Nonnegative::from_be_bytes_with_bit_length(untrusted::Input::from(&bytes)).unwrap();
1583        r
1584    }
1585
1586    fn assert_elem_eq<M, E>(a: &Elem<M, E>, b: &Elem<M, E>) {
1587        if elem_verify_equal_consttime(&a, b).is_err() {
1588            panic!("{:x?} != {:x?}", &*a.limbs, &*b.limbs);
1589        }
1590    }
1591
1592    fn into_encoded<M>(a: Elem<M, Unencoded>, m: &Modulus<M>) -> Elem<M, R> {
1593        elem_mul(m.oneRR().as_ref(), a, m)
1594    }
1595
1596    #[test]
1597    // TODO: wasm
1598    fn test_mul_add_words() {
1599        const ZERO: Limb = 0;
1600        const MAX: Limb = ZERO.wrapping_sub(1);
1601        static TEST_CASES: &[(&[Limb], &[Limb], Limb, Limb, &[Limb])] = &[
1602            (&[0], &[0], 0, 0, &[0]),
1603            (&[MAX], &[0], MAX, 0, &[MAX]),
1604            (&[0], &[MAX], MAX, MAX - 1, &[1]),
1605            (&[MAX], &[MAX], MAX, MAX, &[0]),
1606            (&[0, 0], &[MAX, MAX], MAX, MAX - 1, &[1, MAX]),
1607            (&[1, 0], &[MAX, MAX], MAX, MAX - 1, &[2, MAX]),
1608            (&[MAX, 0], &[MAX, MAX], MAX, MAX, &[0, 0]),
1609            (&[0, 1], &[MAX, MAX], MAX, MAX, &[1, 0]),
1610            (&[MAX, MAX], &[MAX, MAX], MAX, MAX, &[0, MAX]),
1611        ];
1612
1613        for (i, (r_input, a, w, expected_retval, expected_r)) in TEST_CASES.iter().enumerate() {
1614            extern crate std;
1615            let mut r = std::vec::Vec::from(*r_input);
1616            assert_eq!(r.len(), a.len()); // Sanity check
1617            let actual_retval =
1618                unsafe { GFp_limbs_mul_add_limb(r.as_mut_ptr(), a.as_ptr(), *w, a.len()) };
1619            assert_eq!(&r, expected_r, "{}: {:x?} != {:x?}", i, &r[..], expected_r);
1620            assert_eq!(
1621                actual_retval, *expected_retval,
1622                "{}: {:x?} != {:x?}",
1623                i, actual_retval, *expected_retval
1624            );
1625        }
1626    }
1627}