1use crate::{
40 arithmetic::montgomery::*,
41 bits, bssl, c, error,
42 limb::{self, Limb, LimbMask, LIMB_BITS, LIMB_BYTES},
43};
44use alloc::{borrow::ToOwned as _, boxed::Box, vec, vec::Vec};
45use core::{
46 marker::PhantomData,
47 ops::{Deref, DerefMut},
48};
49
50pub unsafe trait Prime {}
51
52struct Width<M> {
53 num_limbs: usize,
54
55 m: PhantomData<M>,
57}
58
59struct BoxedLimbs<M> {
61 limbs: Box<[Limb]>,
62
63 m: PhantomData<M>,
65}
66
67impl<M> Deref for BoxedLimbs<M> {
68 type Target = [Limb];
69 #[inline]
70 fn deref(&self) -> &Self::Target {
71 &self.limbs
72 }
73}
74
75impl<M> DerefMut for BoxedLimbs<M> {
76 #[inline]
77 fn deref_mut(&mut self) -> &mut Self::Target {
78 &mut self.limbs
79 }
80}
81
82impl<M> Clone for BoxedLimbs<M> {
85 fn clone(&self) -> Self {
86 Self {
87 limbs: self.limbs.clone(),
88 m: self.m,
89 }
90 }
91}
92
93impl<M> BoxedLimbs<M> {
94 fn positive_minimal_width_from_be_bytes(
95 input: untrusted::Input,
96 ) -> Result<Self, error::KeyRejected> {
97 if untrusted::Reader::new(input).peek(0) {
100 return Err(error::KeyRejected::invalid_encoding());
101 }
102 let num_limbs = (input.len() + LIMB_BYTES - 1) / LIMB_BYTES;
103 let mut r = Self::zero(Width {
104 num_limbs,
105 m: PhantomData,
106 });
107 limb::parse_big_endian_and_pad_consttime(input, &mut r)
108 .map_err(|error::Unspecified| error::KeyRejected::unexpected_error())?;
109 Ok(r)
110 }
111
112 fn minimal_width_from_unpadded(limbs: &[Limb]) -> Self {
113 debug_assert_ne!(limbs.last(), Some(&0));
114 Self {
115 limbs: limbs.to_owned().into_boxed_slice(),
116 m: PhantomData,
117 }
118 }
119
120 fn from_be_bytes_padded_less_than(
121 input: untrusted::Input,
122 m: &Modulus<M>,
123 ) -> Result<Self, error::Unspecified> {
124 let mut r = Self::zero(m.width());
125 limb::parse_big_endian_and_pad_consttime(input, &mut r)?;
126 if limb::limbs_less_than_limbs_consttime(&r, &m.limbs) != LimbMask::True {
127 return Err(error::Unspecified);
128 }
129 Ok(r)
130 }
131
132 #[inline]
133 fn is_zero(&self) -> bool {
134 limb::limbs_are_zero_constant_time(&self.limbs) == LimbMask::True
135 }
136
137 fn zero(width: Width<M>) -> Self {
138 Self {
139 limbs: vec![0; width.num_limbs].into_boxed_slice(),
140 m: PhantomData,
141 }
142 }
143
144 fn width(&self) -> Width<M> {
145 Width {
146 num_limbs: self.limbs.len(),
147 m: PhantomData,
148 }
149 }
150}
151
152pub unsafe trait SmallerModulus<L> {}
155
156pub unsafe trait SlightlySmallerModulus<L>: SmallerModulus<L> {}
160
161pub unsafe trait NotMuchSmallerModulus<L>: SmallerModulus<L> {}
165
166pub unsafe trait PublicModulus {}
167
168pub const MODULUS_MIN_LIMBS: usize = 4;
174
175pub const MODULUS_MAX_LIMBS: usize = 8192 / LIMB_BITS;
176
177pub struct Modulus<M> {
182 limbs: BoxedLimbs<M>, n0: N0,
220
221 oneRR: One<M, RR>,
222}
223
224impl<M: PublicModulus> core::fmt::Debug for Modulus<M> {
225 fn fmt(&self, fmt: &mut ::core::fmt::Formatter) -> Result<(), ::core::fmt::Error> {
226 fmt.debug_struct("Modulus")
227 .finish()
229 }
230}
231
232impl<M> Modulus<M> {
233 pub fn from_be_bytes_with_bit_length(
234 input: untrusted::Input,
235 ) -> Result<(Self, bits::BitLength), error::KeyRejected> {
236 let limbs = BoxedLimbs::positive_minimal_width_from_be_bytes(input)?;
237 Self::from_boxed_limbs(limbs)
238 }
239
240 pub fn from_nonnegative_with_bit_length(
241 n: Nonnegative,
242 ) -> Result<(Self, bits::BitLength), error::KeyRejected> {
243 let limbs = BoxedLimbs {
244 limbs: n.limbs.into_boxed_slice(),
245 m: PhantomData,
246 };
247 Self::from_boxed_limbs(limbs)
248 }
249
250 fn from_boxed_limbs(n: BoxedLimbs<M>) -> Result<(Self, bits::BitLength), error::KeyRejected> {
251 if n.len() > MODULUS_MAX_LIMBS {
252 return Err(error::KeyRejected::too_large());
253 }
254 if n.len() < MODULUS_MIN_LIMBS {
255 return Err(error::KeyRejected::unexpected_error());
256 }
257 if limb::limbs_are_even_constant_time(&n) != LimbMask::False {
258 return Err(error::KeyRejected::invalid_component());
259 }
260 if limb::limbs_less_than_limb_constant_time(&n, 3) != LimbMask::False {
261 return Err(error::KeyRejected::unexpected_error());
262 }
263
264 #[allow(clippy::useless_conversion)]
267 let n0 = {
268 extern "C" {
269 fn GFp_bn_neg_inv_mod_r_u64(n: u64) -> u64;
270 }
271
272 let mut n_mod_r: u64 = u64::from(n[0]);
274
275 if N0_LIMBS_USED == 2 {
276 debug_assert_eq!(LIMB_BITS, 32);
279 n_mod_r |= u64::from(n[1]) << 32;
280 }
281 N0::from(unsafe { GFp_bn_neg_inv_mod_r_u64(n_mod_r) })
282 };
283
284 let bits = limb::limbs_minimal_bits(&n.limbs);
285 let oneRR = {
286 let partial = PartialModulus {
287 limbs: &n.limbs,
288 n0: n0.clone(),
289 m: PhantomData,
290 };
291
292 One::newRR(&partial, bits)
293 };
294
295 Ok((
296 Self {
297 limbs: n,
298 n0,
299 oneRR,
300 },
301 bits,
302 ))
303 }
304
305 #[inline]
306 fn width(&self) -> Width<M> {
307 self.limbs.width()
308 }
309
310 fn zero<E>(&self) -> Elem<M, E> {
311 Elem {
312 limbs: BoxedLimbs::zero(self.width()),
313 encoding: PhantomData,
314 }
315 }
316
317 fn one(&self) -> Elem<M, Unencoded> {
319 let mut r = self.zero();
320 r.limbs[0] = 1;
321 r
322 }
323
324 pub fn oneRR(&self) -> &One<M, RR> {
325 &self.oneRR
326 }
327
328 pub fn to_elem<L>(&self, l: &Modulus<L>) -> Elem<L, Unencoded>
329 where
330 M: SmallerModulus<L>,
331 {
332 assert_eq!(self.width().num_limbs, l.width().num_limbs);
334 let limbs = self.limbs.clone();
335 Elem {
336 limbs: BoxedLimbs {
337 limbs: limbs.limbs,
338 m: PhantomData,
339 },
340 encoding: PhantomData,
341 }
342 }
343
344 fn as_partial(&self) -> PartialModulus<M> {
345 PartialModulus {
346 limbs: &self.limbs,
347 n0: self.n0.clone(),
348 m: PhantomData,
349 }
350 }
351}
352
353struct PartialModulus<'a, M> {
354 limbs: &'a [Limb],
355 n0: N0,
356 m: PhantomData<M>,
357}
358
359impl<M> PartialModulus<'_, M> {
360 fn zero(&self) -> Elem<M, R> {
362 let width = Width {
363 num_limbs: self.limbs.len(),
364 m: PhantomData,
365 };
366 Elem {
367 limbs: BoxedLimbs::zero(width),
368 encoding: PhantomData,
369 }
370 }
371}
372
373pub struct Elem<M, E = Unencoded> {
379 limbs: BoxedLimbs<M>,
380
381 encoding: PhantomData<E>,
384}
385
386impl<M, E> Clone for Elem<M, E> {
389 fn clone(&self) -> Self {
390 Self {
391 limbs: self.limbs.clone(),
392 encoding: self.encoding,
393 }
394 }
395}
396
397impl<M, E> Elem<M, E> {
398 #[inline]
399 pub fn is_zero(&self) -> bool {
400 self.limbs.is_zero()
401 }
402}
403
404impl<M, E: ReductionEncoding> Elem<M, E> {
405 fn decode_once(self, m: &Modulus<M>) -> Elem<M, <E as ReductionEncoding>::Output> {
406 let mut limbs = self.limbs;
411 let num_limbs = m.width().num_limbs;
412 let mut one = [0; MODULUS_MAX_LIMBS];
413 one[0] = 1;
414 let one = &one[..num_limbs]; limbs_mont_mul(&mut limbs, &one, &m.limbs, &m.n0);
416 Elem {
417 limbs,
418 encoding: PhantomData,
419 }
420 }
421}
422
423impl<M> Elem<M, R> {
424 #[inline]
425 pub fn into_unencoded(self, m: &Modulus<M>) -> Elem<M, Unencoded> {
426 self.decode_once(m)
427 }
428}
429
430impl<M> Elem<M, Unencoded> {
431 pub fn from_be_bytes_padded(
432 input: untrusted::Input,
433 m: &Modulus<M>,
434 ) -> Result<Self, error::Unspecified> {
435 Ok(Elem {
436 limbs: BoxedLimbs::from_be_bytes_padded_less_than(input, m)?,
437 encoding: PhantomData,
438 })
439 }
440
441 #[inline]
442 pub fn fill_be_bytes(&self, out: &mut [u8]) {
443 limb::big_endian_from_limbs(&self.limbs, out)
445 }
446
447 pub fn into_modulus<MM>(self) -> Result<Modulus<MM>, error::KeyRejected> {
448 let (m, _bits) =
449 Modulus::from_boxed_limbs(BoxedLimbs::minimal_width_from_unpadded(&self.limbs))?;
450 Ok(m)
451 }
452
453 fn is_one(&self) -> bool {
454 limb::limbs_equal_limb_constant_time(&self.limbs, 1) == LimbMask::True
455 }
456}
457
458pub fn elem_mul<M, AF, BF>(
459 a: &Elem<M, AF>,
460 b: Elem<M, BF>,
461 m: &Modulus<M>,
462) -> Elem<M, <(AF, BF) as ProductEncoding>::Output>
463where
464 (AF, BF): ProductEncoding,
465{
466 elem_mul_(a, b, &m.as_partial())
467}
468
469fn elem_mul_<M, AF, BF>(
470 a: &Elem<M, AF>,
471 mut b: Elem<M, BF>,
472 m: &PartialModulus<M>,
473) -> Elem<M, <(AF, BF) as ProductEncoding>::Output>
474where
475 (AF, BF): ProductEncoding,
476{
477 limbs_mont_mul(&mut b.limbs, &a.limbs, &m.limbs, &m.n0);
478 Elem {
479 limbs: b.limbs,
480 encoding: PhantomData,
481 }
482}
483
484fn elem_mul_by_2<M, AF>(a: &mut Elem<M, AF>, m: &PartialModulus<M>) {
485 extern "C" {
486 fn LIMBS_shl_mod(r: *mut Limb, a: *const Limb, m: *const Limb, num_limbs: c::size_t);
487 }
488 unsafe {
489 LIMBS_shl_mod(
490 a.limbs.as_mut_ptr(),
491 a.limbs.as_ptr(),
492 m.limbs.as_ptr(),
493 m.limbs.len(),
494 );
495 }
496}
497
498pub fn elem_reduced_once<Larger, Smaller: SlightlySmallerModulus<Larger>>(
499 a: &Elem<Larger, Unencoded>,
500 m: &Modulus<Smaller>,
501) -> Elem<Smaller, Unencoded> {
502 let mut r = a.limbs.clone();
503 assert!(r.len() <= m.limbs.len());
504 limb::limbs_reduce_once_constant_time(&mut r, &m.limbs);
505 Elem {
506 limbs: BoxedLimbs {
507 limbs: r.limbs,
508 m: PhantomData,
509 },
510 encoding: PhantomData,
511 }
512}
513
514#[inline]
515pub fn elem_reduced<Larger, Smaller: NotMuchSmallerModulus<Larger>>(
516 a: &Elem<Larger, Unencoded>,
517 m: &Modulus<Smaller>,
518) -> Elem<Smaller, RInverse> {
519 let mut tmp = [0; MODULUS_MAX_LIMBS];
520 let tmp = &mut tmp[..a.limbs.len()];
521 tmp.copy_from_slice(&a.limbs);
522
523 let mut r = m.zero();
524 limbs_from_mont_in_place(&mut r.limbs, tmp, &m.limbs, &m.n0);
525 r
526}
527
528fn elem_squared<M, E>(
529 mut a: Elem<M, E>,
530 m: &PartialModulus<M>,
531) -> Elem<M, <(E, E) as ProductEncoding>::Output>
532where
533 (E, E): ProductEncoding,
534{
535 limbs_mont_square(&mut a.limbs, &m.limbs, &m.n0);
536 Elem {
537 limbs: a.limbs,
538 encoding: PhantomData,
539 }
540}
541
542pub fn elem_widen<Larger, Smaller: SmallerModulus<Larger>>(
543 a: Elem<Smaller, Unencoded>,
544 m: &Modulus<Larger>,
545) -> Elem<Larger, Unencoded> {
546 let mut r = m.zero();
547 r.limbs[..a.limbs.len()].copy_from_slice(&a.limbs);
548 r
549}
550
551pub fn elem_add<M, E>(mut a: Elem<M, E>, b: Elem<M, E>, m: &Modulus<M>) -> Elem<M, E> {
553 extern "C" {
554 fn LIMBS_add_mod(
556 r: *mut Limb,
557 a: *const Limb,
558 b: *const Limb,
559 m: *const Limb,
560 num_limbs: c::size_t,
561 );
562 }
563 unsafe {
564 LIMBS_add_mod(
565 a.limbs.as_mut_ptr(),
566 a.limbs.as_ptr(),
567 b.limbs.as_ptr(),
568 m.limbs.as_ptr(),
569 m.limbs.len(),
570 )
571 }
572 a
573}
574
575pub fn elem_sub<M, E>(mut a: Elem<M, E>, b: &Elem<M, E>, m: &Modulus<M>) -> Elem<M, E> {
577 extern "C" {
578 fn LIMBS_sub_mod(
580 r: *mut Limb,
581 a: *const Limb,
582 b: *const Limb,
583 m: *const Limb,
584 num_limbs: c::size_t,
585 );
586 }
587 unsafe {
588 LIMBS_sub_mod(
589 a.limbs.as_mut_ptr(),
590 a.limbs.as_ptr(),
591 b.limbs.as_ptr(),
592 m.limbs.as_ptr(),
593 m.limbs.len(),
594 );
595 }
596 a
597}
598
599pub struct One<M, E>(Elem<M, E>);
601
602impl<M> One<M, RR> {
603 fn newRR(m: &PartialModulus<M>, m_bits: bits::BitLength) -> Self {
611 let m_bits = m_bits.as_usize_bits();
612 let r = (m_bits + (LIMB_BITS - 1)) / LIMB_BITS * LIMB_BITS;
613
614 let bit = m_bits - 1;
616 let mut base = m.zero();
617 base.limbs[bit / LIMB_BITS] = 1 << (bit % LIMB_BITS);
618
619 let lg_base = 2usize; debug_assert_eq!(lg_base.count_ones(), 1); let shifts = r - bit + lg_base;
636 let exponent = (r / lg_base) as u64;
637 for _ in 0..shifts {
638 elem_mul_by_2(&mut base, m)
639 }
640 let RR = elem_exp_vartime_(base, exponent, m);
641
642 Self(Elem {
643 limbs: RR.limbs,
644 encoding: PhantomData, })
646 }
647}
648
649impl<M, E> AsRef<Elem<M, E>> for One<M, E> {
650 fn as_ref(&self) -> &Elem<M, E> {
651 &self.0
652 }
653}
654
655#[derive(Clone, Copy, Debug)]
658pub struct PublicExponent(u64);
659
660impl PublicExponent {
661 pub fn from_be_bytes(
662 input: untrusted::Input,
663 min_value: u64,
664 ) -> Result<Self, error::KeyRejected> {
665 if input.len() > 5 {
666 return Err(error::KeyRejected::too_large());
667 }
668 let value = input.read_all(error::KeyRejected::invalid_encoding(), |input| {
669 if input.peek(0) {
672 return Err(error::KeyRejected::invalid_encoding());
673 }
674 let mut value = 0u64;
675 loop {
676 let byte = input
677 .read_byte()
678 .map_err(|untrusted::EndOfInput| error::KeyRejected::invalid_encoding())?;
679 value = (value << 8) | u64::from(byte);
680 if input.at_end() {
681 return Ok(value);
682 }
683 }
684 })?;
685
686 if value & 1 != 1 {
691 return Err(error::KeyRejected::invalid_component());
692 }
693 debug_assert!(min_value & 1 == 1);
694 debug_assert!(min_value <= PUBLIC_EXPONENT_MAX_VALUE);
695 if min_value < 3 {
696 return Err(error::KeyRejected::invalid_component());
697 }
698 if value < min_value {
699 return Err(error::KeyRejected::too_small());
700 }
701 if value > PUBLIC_EXPONENT_MAX_VALUE {
702 return Err(error::KeyRejected::too_large());
703 }
704
705 Ok(Self(value))
706 }
707}
708
709const PUBLIC_EXPONENT_MAX_VALUE: u64 = (1u64 << 33) - 1;
721
722pub fn elem_exp_vartime<M>(
726 base: Elem<M, Unencoded>,
727 PublicExponent(exponent): PublicExponent,
728 m: &Modulus<M>,
729) -> Elem<M, R> {
730 let base = elem_mul(m.oneRR().as_ref(), base, &m);
731 elem_exp_vartime_(base, exponent, &m.as_partial())
732}
733
734fn elem_exp_vartime_<M>(base: Elem<M, R>, exponent: u64, m: &PartialModulus<M>) -> Elem<M, R> {
736 assert!(exponent >= 1);
756 assert!(exponent <= PUBLIC_EXPONENT_MAX_VALUE);
757 let mut acc = base.clone();
758 let mut bit = 1 << (64 - 1 - exponent.leading_zeros());
759 debug_assert!((exponent & bit) != 0);
760 while bit > 1 {
761 bit >>= 1;
762 acc = elem_squared(acc, m);
763 if (exponent & bit) != 0 {
764 acc = elem_mul_(&base, acc, m);
765 }
766 }
767 acc
768}
769
770pub struct PrivateExponent<M> {
773 limbs: BoxedLimbs<M>,
774}
775
776impl<M> PrivateExponent<M> {
777 pub fn from_be_bytes_padded(
778 input: untrusted::Input,
779 p: &Modulus<M>,
780 ) -> Result<Self, error::Unspecified> {
781 let dP = BoxedLimbs::from_be_bytes_padded_less_than(input, p)?;
782
783 if limb::limbs_are_even_constant_time(&dP) != LimbMask::False {
792 return Err(error::Unspecified);
793 }
794
795 Ok(Self { limbs: dP })
796 }
797}
798
799impl<M: Prime> PrivateExponent<M> {
800 fn for_flt(p: &Modulus<M>) -> Self {
802 let two = elem_add(p.one(), p.one(), p);
803 let p_minus_2 = elem_sub(p.zero(), &two, p);
804 Self {
805 limbs: p_minus_2.limbs,
806 }
807 }
808}
809
810#[cfg(not(target_arch = "x86_64"))]
811pub fn elem_exp_consttime<M>(
812 base: Elem<M, R>,
813 exponent: &PrivateExponent<M>,
814 m: &Modulus<M>,
815) -> Result<Elem<M, Unencoded>, error::Unspecified> {
816 use crate::limb::Window;
817
818 const WINDOW_BITS: usize = 5;
819 const TABLE_ENTRIES: usize = 1 << WINDOW_BITS;
820
821 let num_limbs = m.limbs.len();
822
823 let mut table = vec![0; TABLE_ENTRIES * num_limbs];
824
825 fn gather<M>(table: &[Limb], i: Window, r: &mut Elem<M, R>) {
826 extern "C" {
827 fn LIMBS_select_512_32(
828 r: *mut Limb,
829 table: *const Limb,
830 num_limbs: c::size_t,
831 i: Window,
832 ) -> bssl::Result;
833 }
834 Result::from(unsafe {
835 LIMBS_select_512_32(r.limbs.as_mut_ptr(), table.as_ptr(), r.limbs.len(), i)
836 })
837 .unwrap();
838 }
839
840 fn power<M>(
841 table: &[Limb],
842 i: Window,
843 mut acc: Elem<M, R>,
844 mut tmp: Elem<M, R>,
845 m: &Modulus<M>,
846 ) -> (Elem<M, R>, Elem<M, R>) {
847 for _ in 0..WINDOW_BITS {
848 acc = elem_squared(acc, &m.as_partial());
849 }
850 gather(table, i, &mut tmp);
851 let acc = elem_mul(&tmp, acc, m);
852 (acc, tmp)
853 }
854
855 let tmp = m.one();
856 let tmp = elem_mul(m.oneRR().as_ref(), tmp, m);
857
858 fn entry(table: &[Limb], i: usize, num_limbs: usize) -> &[Limb] {
859 &table[(i * num_limbs)..][..num_limbs]
860 }
861 fn entry_mut(table: &mut [Limb], i: usize, num_limbs: usize) -> &mut [Limb] {
862 &mut table[(i * num_limbs)..][..num_limbs]
863 }
864 let num_limbs = m.limbs.len();
865 entry_mut(&mut table, 0, num_limbs).copy_from_slice(&tmp.limbs);
866 entry_mut(&mut table, 1, num_limbs).copy_from_slice(&base.limbs);
867 for i in 2..TABLE_ENTRIES {
868 let (src1, src2) = if i % 2 == 0 {
869 (i / 2, i / 2)
870 } else {
871 (i - 1, 1)
872 };
873 let (previous, rest) = table.split_at_mut(num_limbs * i);
874 let src1 = entry(previous, src1, num_limbs);
875 let src2 = entry(previous, src2, num_limbs);
876 let dst = entry_mut(rest, 0, num_limbs);
877 limbs_mont_product(dst, src1, src2, &m.limbs, &m.n0);
878 }
879
880 let (r, _) = limb::fold_5_bit_windows(
881 &exponent.limbs,
882 |initial_window| {
883 let mut r = Elem {
884 limbs: base.limbs,
885 encoding: PhantomData,
886 };
887 gather(&table, initial_window, &mut r);
888 (r, tmp)
889 },
890 |(acc, tmp), window| power(&table, window, acc, tmp, m),
891 );
892
893 let r = r.into_unencoded(m);
894
895 Ok(r)
896}
897
898pub fn elem_inverse_consttime<M: Prime>(
900 a: Elem<M, R>,
901 m: &Modulus<M>,
902) -> Result<Elem<M, Unencoded>, error::Unspecified> {
903 elem_exp_consttime(a, &PrivateExponent::for_flt(&m), m)
904}
905
906#[cfg(target_arch = "x86_64")]
907pub fn elem_exp_consttime<M>(
908 base: Elem<M, R>,
909 exponent: &PrivateExponent<M>,
910 m: &Modulus<M>,
911) -> Result<Elem<M, Unencoded>, error::Unspecified> {
912 use crate::limb::Window;
922
923 const WINDOW_BITS: usize = 5;
924 const TABLE_ENTRIES: usize = 1 << WINDOW_BITS;
925
926 let num_limbs = m.limbs.len();
927
928 const ALIGNMENT: usize = 64;
929 assert_eq!(ALIGNMENT % LIMB_BYTES, 0);
930 let mut table = vec![0; ((TABLE_ENTRIES + 3) * num_limbs) + ALIGNMENT];
931 let (table, state) = {
932 let misalignment = (table.as_ptr() as usize) % ALIGNMENT;
933 let table = &mut table[((ALIGNMENT - misalignment) / LIMB_BYTES)..];
934 assert_eq!((table.as_ptr() as usize) % ALIGNMENT, 0);
935 table.split_at_mut(TABLE_ENTRIES * num_limbs)
936 };
937
938 fn entry(table: &[Limb], i: usize, num_limbs: usize) -> &[Limb] {
939 &table[(i * num_limbs)..][..num_limbs]
940 }
941 fn entry_mut(table: &mut [Limb], i: usize, num_limbs: usize) -> &mut [Limb] {
942 &mut table[(i * num_limbs)..][..num_limbs]
943 }
944
945 const ACC: usize = 0; const BASE: usize = ACC + 1; const M: usize = BASE + 1; entry_mut(state, BASE, num_limbs).copy_from_slice(&base.limbs);
950 entry_mut(state, M, num_limbs).copy_from_slice(&m.limbs);
951
952 fn scatter(table: &mut [Limb], state: &[Limb], i: Window, num_limbs: usize) {
953 extern "C" {
954 fn GFp_bn_scatter5(a: *const Limb, a_len: c::size_t, table: *mut Limb, i: Window);
955 }
956 unsafe {
957 GFp_bn_scatter5(
958 entry(state, ACC, num_limbs).as_ptr(),
959 num_limbs,
960 table.as_mut_ptr(),
961 i,
962 )
963 }
964 }
965
966 fn gather(table: &[Limb], state: &mut [Limb], i: Window, num_limbs: usize) {
967 extern "C" {
968 fn GFp_bn_gather5(r: *mut Limb, a_len: c::size_t, table: *const Limb, i: Window);
969 }
970 unsafe {
971 GFp_bn_gather5(
972 entry_mut(state, ACC, num_limbs).as_mut_ptr(),
973 num_limbs,
974 table.as_ptr(),
975 i,
976 )
977 }
978 }
979
980 fn gather_square(table: &[Limb], state: &mut [Limb], n0: &N0, i: Window, num_limbs: usize) {
981 gather(table, state, i, num_limbs);
982 assert_eq!(ACC, 0);
983 let (acc, rest) = state.split_at_mut(num_limbs);
984 let m = entry(rest, M - 1, num_limbs);
985 limbs_mont_square(acc, m, n0);
986 }
987
988 fn gather_mul_base(table: &[Limb], state: &mut [Limb], n0: &N0, i: Window, num_limbs: usize) {
989 extern "C" {
990 fn GFp_bn_mul_mont_gather5(
991 rp: *mut Limb,
992 ap: *const Limb,
993 table: *const Limb,
994 np: *const Limb,
995 n0: &N0,
996 num: c::size_t,
997 power: Window,
998 );
999 }
1000 unsafe {
1001 GFp_bn_mul_mont_gather5(
1002 entry_mut(state, ACC, num_limbs).as_mut_ptr(),
1003 entry(state, BASE, num_limbs).as_ptr(),
1004 table.as_ptr(),
1005 entry(state, M, num_limbs).as_ptr(),
1006 n0,
1007 num_limbs,
1008 i,
1009 );
1010 }
1011 }
1012
1013 fn power(table: &[Limb], state: &mut [Limb], n0: &N0, i: Window, num_limbs: usize) {
1014 extern "C" {
1015 fn GFp_bn_power5(
1016 r: *mut Limb,
1017 a: *const Limb,
1018 table: *const Limb,
1019 n: *const Limb,
1020 n0: &N0,
1021 num: c::size_t,
1022 i: Window,
1023 );
1024 }
1025 unsafe {
1026 GFp_bn_power5(
1027 entry_mut(state, ACC, num_limbs).as_mut_ptr(),
1028 entry_mut(state, ACC, num_limbs).as_mut_ptr(),
1029 table.as_ptr(),
1030 entry(state, M, num_limbs).as_ptr(),
1031 n0,
1032 num_limbs,
1033 i,
1034 );
1035 }
1036 }
1037
1038 {
1040 let acc = entry_mut(state, ACC, num_limbs);
1041 acc[0] = 1;
1042 limbs_mont_mul(acc, &m.oneRR.0.limbs, &m.limbs, &m.n0);
1043 }
1044 scatter(table, state, 0, num_limbs);
1045
1046 entry_mut(state, ACC, num_limbs).copy_from_slice(&base.limbs);
1048 scatter(table, state, 1, num_limbs);
1049
1050 for i in 2..(TABLE_ENTRIES as Window) {
1051 if i % 2 == 0 {
1052 gather_square(table, state, &m.n0, i / 2, num_limbs);
1054 } else {
1055 gather_mul_base(table, state, &m.n0, i - 1, num_limbs)
1056 };
1057 scatter(table, state, i, num_limbs);
1058 }
1059
1060 let state = limb::fold_5_bit_windows(
1061 &exponent.limbs,
1062 |initial_window| {
1063 gather(table, state, initial_window, num_limbs);
1064 state
1065 },
1066 |state, window| {
1067 power(table, state, &m.n0, window, num_limbs);
1068 state
1069 },
1070 );
1071
1072 extern "C" {
1073 fn GFp_bn_from_montgomery(
1074 r: *mut Limb,
1075 a: *const Limb,
1076 not_used: *const Limb,
1077 n: *const Limb,
1078 n0: &N0,
1079 num: c::size_t,
1080 ) -> bssl::Result;
1081 }
1082 Result::from(unsafe {
1083 GFp_bn_from_montgomery(
1084 entry_mut(state, ACC, num_limbs).as_mut_ptr(),
1085 entry(state, ACC, num_limbs).as_ptr(),
1086 core::ptr::null(),
1087 entry(state, M, num_limbs).as_ptr(),
1088 &m.n0,
1089 num_limbs,
1090 )
1091 })?;
1092 let mut r = Elem {
1093 limbs: base.limbs,
1094 encoding: PhantomData,
1095 };
1096 r.limbs.copy_from_slice(entry(state, ACC, num_limbs));
1097 Ok(r)
1098}
1099
1100pub fn verify_inverses_consttime<M>(
1102 a: &Elem<M, R>,
1103 b: Elem<M, Unencoded>,
1104 m: &Modulus<M>,
1105) -> Result<(), error::Unspecified> {
1106 if elem_mul(a, b, m).is_one() {
1107 Ok(())
1108 } else {
1109 Err(error::Unspecified)
1110 }
1111}
1112
1113#[inline]
1114pub fn elem_verify_equal_consttime<M, E>(
1115 a: &Elem<M, E>,
1116 b: &Elem<M, E>,
1117) -> Result<(), error::Unspecified> {
1118 if limb::limbs_equal_limbs_consttime(&a.limbs, &b.limbs) == LimbMask::True {
1119 Ok(())
1120 } else {
1121 Err(error::Unspecified)
1122 }
1123}
1124
1125pub struct Nonnegative {
1127 limbs: Vec<Limb>,
1128}
1129
1130impl Nonnegative {
1131 pub fn from_be_bytes_with_bit_length(
1132 input: untrusted::Input,
1133 ) -> Result<(Self, bits::BitLength), error::Unspecified> {
1134 let mut limbs = vec![0; (input.len() + LIMB_BYTES - 1) / LIMB_BYTES];
1135 limb::parse_big_endian_and_pad_consttime(input, &mut limbs)?;
1137 while limbs.last() == Some(&0) {
1138 let _ = limbs.pop();
1139 }
1140 let r_bits = limb::limbs_minimal_bits(&limbs);
1141 Ok((Self { limbs }, r_bits))
1142 }
1143
1144 #[inline]
1145 pub fn is_odd(&self) -> bool {
1146 limb::limbs_are_even_constant_time(&self.limbs) != LimbMask::True
1147 }
1148
1149 pub fn verify_less_than(&self, other: &Self) -> Result<(), error::Unspecified> {
1150 if !greater_than(other, self) {
1151 return Err(error::Unspecified);
1152 }
1153 Ok(())
1154 }
1155
1156 pub fn to_elem<M>(&self, m: &Modulus<M>) -> Result<Elem<M, Unencoded>, error::Unspecified> {
1157 self.verify_less_than_modulus(&m)?;
1158 let mut r = m.zero();
1159 r.limbs[0..self.limbs.len()].copy_from_slice(&self.limbs);
1160 Ok(r)
1161 }
1162
1163 pub fn verify_less_than_modulus<M>(&self, m: &Modulus<M>) -> Result<(), error::Unspecified> {
1164 if self.limbs.len() > m.limbs.len() {
1165 return Err(error::Unspecified);
1166 }
1167 if self.limbs.len() == m.limbs.len() {
1168 if limb::limbs_less_than_limbs_consttime(&self.limbs, &m.limbs) != LimbMask::True {
1169 return Err(error::Unspecified);
1170 }
1171 }
1172 Ok(())
1173 }
1174}
1175
1176fn greater_than(a: &Nonnegative, b: &Nonnegative) -> bool {
1178 if a.limbs.len() == b.limbs.len() {
1179 limb::limbs_less_than_limbs_vartime(&b.limbs, &a.limbs)
1180 } else {
1181 a.limbs.len() > b.limbs.len()
1182 }
1183}
1184
1185#[derive(Clone)]
1186#[repr(transparent)]
1187struct N0([Limb; 2]);
1188
1189const N0_LIMBS_USED: usize = 64 / LIMB_BITS;
1190
1191impl From<u64> for N0 {
1192 #[inline]
1193 fn from(n0: u64) -> Self {
1194 #[cfg(target_pointer_width = "64")]
1195 {
1196 Self([n0, 0])
1197 }
1198
1199 #[cfg(target_pointer_width = "32")]
1200 {
1201 Self([n0 as Limb, (n0 >> LIMB_BITS) as Limb])
1202 }
1203 }
1204}
1205
1206fn limbs_mont_mul(r: &mut [Limb], a: &[Limb], m: &[Limb], n0: &N0) {
1208 debug_assert_eq!(r.len(), m.len());
1209 debug_assert_eq!(a.len(), m.len());
1210
1211 #[cfg(any(
1212 target_arch = "aarch64",
1213 target_arch = "arm",
1214 target_arch = "x86_64",
1215 target_arch = "x86"
1216 ))]
1217 unsafe {
1218 GFp_bn_mul_mont(
1219 r.as_mut_ptr(),
1220 r.as_ptr(),
1221 a.as_ptr(),
1222 m.as_ptr(),
1223 n0,
1224 r.len(),
1225 )
1226 }
1227
1228 #[cfg(not(any(
1229 target_arch = "aarch64",
1230 target_arch = "arm",
1231 target_arch = "x86_64",
1232 target_arch = "x86"
1233 )))]
1234 {
1235 let mut tmp = [0; 2 * MODULUS_MAX_LIMBS];
1236 let tmp = &mut tmp[..(2 * a.len())];
1237 limbs_mul(tmp, r, a);
1238 limbs_from_mont_in_place(r, tmp, m, n0);
1239 }
1240}
1241
1242fn limbs_from_mont_in_place(r: &mut [Limb], tmp: &mut [Limb], m: &[Limb], n0: &N0) {
1243 extern "C" {
1244 fn GFp_bn_from_montgomery_in_place(
1245 r: *mut Limb,
1246 num_r: c::size_t,
1247 a: *mut Limb,
1248 num_a: c::size_t,
1249 n: *const Limb,
1250 num_n: c::size_t,
1251 n0: &N0,
1252 ) -> bssl::Result;
1253 }
1254 Result::from(unsafe {
1255 GFp_bn_from_montgomery_in_place(
1256 r.as_mut_ptr(),
1257 r.len(),
1258 tmp.as_mut_ptr(),
1259 tmp.len(),
1260 m.as_ptr(),
1261 m.len(),
1262 &n0,
1263 )
1264 })
1265 .unwrap()
1266}
1267
1268#[cfg(not(any(
1269 target_arch = "aarch64",
1270 target_arch = "arm",
1271 target_arch = "x86_64",
1272 target_arch = "x86"
1273)))]
1274fn limbs_mul(r: &mut [Limb], a: &[Limb], b: &[Limb]) {
1275 debug_assert_eq!(r.len(), 2 * a.len());
1276 debug_assert_eq!(a.len(), b.len());
1277 let ab_len = a.len();
1278
1279 crate::polyfill::slice::fill(&mut r[..ab_len], 0);
1280 for (i, &b_limb) in b.iter().enumerate() {
1281 r[ab_len + i] = unsafe {
1282 GFp_limbs_mul_add_limb(
1283 (&mut r[i..][..ab_len]).as_mut_ptr(),
1284 a.as_ptr(),
1285 b_limb,
1286 ab_len,
1287 )
1288 };
1289 }
1290}
1291
1292#[cfg(not(target_arch = "x86_64"))]
1294fn limbs_mont_product(r: &mut [Limb], a: &[Limb], b: &[Limb], m: &[Limb], n0: &N0) {
1295 debug_assert_eq!(r.len(), m.len());
1296 debug_assert_eq!(a.len(), m.len());
1297 debug_assert_eq!(b.len(), m.len());
1298
1299 #[cfg(any(
1300 target_arch = "aarch64",
1301 target_arch = "arm",
1302 target_arch = "x86_64",
1303 target_arch = "x86"
1304 ))]
1305 unsafe {
1306 GFp_bn_mul_mont(
1307 r.as_mut_ptr(),
1308 a.as_ptr(),
1309 b.as_ptr(),
1310 m.as_ptr(),
1311 n0,
1312 r.len(),
1313 )
1314 }
1315
1316 #[cfg(not(any(
1317 target_arch = "aarch64",
1318 target_arch = "arm",
1319 target_arch = "x86_64",
1320 target_arch = "x86"
1321 )))]
1322 {
1323 let mut tmp = [0; 2 * MODULUS_MAX_LIMBS];
1324 let tmp = &mut tmp[..(2 * a.len())];
1325 limbs_mul(tmp, a, b);
1326 limbs_from_mont_in_place(r, tmp, m, n0)
1327 }
1328}
1329
1330fn limbs_mont_square(r: &mut [Limb], m: &[Limb], n0: &N0) {
1332 debug_assert_eq!(r.len(), m.len());
1333 #[cfg(any(
1334 target_arch = "aarch64",
1335 target_arch = "arm",
1336 target_arch = "x86_64",
1337 target_arch = "x86"
1338 ))]
1339 unsafe {
1340 GFp_bn_mul_mont(
1341 r.as_mut_ptr(),
1342 r.as_ptr(),
1343 r.as_ptr(),
1344 m.as_ptr(),
1345 n0,
1346 r.len(),
1347 )
1348 }
1349
1350 #[cfg(not(any(
1351 target_arch = "aarch64",
1352 target_arch = "arm",
1353 target_arch = "x86_64",
1354 target_arch = "x86"
1355 )))]
1356 {
1357 let mut tmp = [0; 2 * MODULUS_MAX_LIMBS];
1358 let tmp = &mut tmp[..(2 * r.len())];
1359 limbs_mul(tmp, r, r);
1360 limbs_from_mont_in_place(r, tmp, m, n0)
1361 }
1362}
1363
1364extern "C" {
1365 #[cfg(any(
1366 target_arch = "aarch64",
1367 target_arch = "arm",
1368 target_arch = "x86_64",
1369 target_arch = "x86"
1370 ))]
1371 fn GFp_bn_mul_mont(
1373 r: *mut Limb,
1374 a: *const Limb,
1375 b: *const Limb,
1376 n: *const Limb,
1377 n0: &N0,
1378 num_limbs: c::size_t,
1379 );
1380
1381 #[cfg(any(
1383 test,
1384 not(any(
1385 target_arch = "aarch64",
1386 target_arch = "arm",
1387 target_arch = "x86_64",
1388 target_arch = "x86"
1389 ))
1390 ))]
1391 #[must_use]
1392 fn GFp_limbs_mul_add_limb(r: *mut Limb, a: *const Limb, b: Limb, num_limbs: c::size_t) -> Limb;
1393}
1394
1395#[cfg(test)]
1396mod tests {
1397 use super::*;
1398 use crate::test;
1399 use alloc::format;
1400
1401 struct M {}
1403
1404 unsafe impl PublicModulus for M {}
1405
1406 #[test]
1407 fn test_elem_exp_consttime() {
1408 test::run(
1409 test_file!("bigint_elem_exp_consttime_tests.txt"),
1410 |section, test_case| {
1411 assert_eq!(section, "");
1412
1413 let m = consume_modulus::<M>(test_case, "M");
1414 let expected_result = consume_elem(test_case, "ModExp", &m);
1415 let base = consume_elem(test_case, "A", &m);
1416 let e = {
1417 let bytes = test_case.consume_bytes("E");
1418 PrivateExponent::from_be_bytes_padded(untrusted::Input::from(&bytes), &m)
1419 .expect("valid exponent")
1420 };
1421 let base = into_encoded(base, &m);
1422 let actual_result = elem_exp_consttime(base, &e, &m).unwrap();
1423 assert_elem_eq(&actual_result, &expected_result);
1424
1425 Ok(())
1426 },
1427 )
1428 }
1429
1430 #[test]
1435 fn test_elem_mul() {
1436 test::run(
1437 test_file!("bigint_elem_mul_tests.txt"),
1438 |section, test_case| {
1439 assert_eq!(section, "");
1440
1441 let m = consume_modulus::<M>(test_case, "M");
1442 let expected_result = consume_elem(test_case, "ModMul", &m);
1443 let a = consume_elem(test_case, "A", &m);
1444 let b = consume_elem(test_case, "B", &m);
1445
1446 let b = into_encoded(b, &m);
1447 let a = into_encoded(a, &m);
1448 let actual_result = elem_mul(&a, b, &m);
1449 let actual_result = actual_result.into_unencoded(&m);
1450 assert_elem_eq(&actual_result, &expected_result);
1451
1452 Ok(())
1453 },
1454 )
1455 }
1456
1457 #[test]
1458 fn test_elem_squared() {
1459 test::run(
1460 test_file!("bigint_elem_squared_tests.txt"),
1461 |section, test_case| {
1462 assert_eq!(section, "");
1463
1464 let m = consume_modulus::<M>(test_case, "M");
1465 let expected_result = consume_elem(test_case, "ModSquare", &m);
1466 let a = consume_elem(test_case, "A", &m);
1467
1468 let a = into_encoded(a, &m);
1469 let actual_result = elem_squared(a, &m.as_partial());
1470 let actual_result = actual_result.into_unencoded(&m);
1471 assert_elem_eq(&actual_result, &expected_result);
1472
1473 Ok(())
1474 },
1475 )
1476 }
1477
1478 #[test]
1479 fn test_elem_reduced() {
1480 test::run(
1481 test_file!("bigint_elem_reduced_tests.txt"),
1482 |section, test_case| {
1483 assert_eq!(section, "");
1484
1485 struct MM {}
1486 unsafe impl SmallerModulus<MM> for M {}
1487 unsafe impl NotMuchSmallerModulus<MM> for M {}
1488
1489 let m = consume_modulus::<M>(test_case, "M");
1490 let expected_result = consume_elem(test_case, "R", &m);
1491 let a =
1492 consume_elem_unchecked::<MM>(test_case, "A", expected_result.limbs.len() * 2);
1493
1494 let actual_result = elem_reduced(&a, &m);
1495 let oneRR = m.oneRR();
1496 let actual_result = elem_mul(oneRR.as_ref(), actual_result, &m);
1497 assert_elem_eq(&actual_result, &expected_result);
1498
1499 Ok(())
1500 },
1501 )
1502 }
1503
1504 #[test]
1505 fn test_elem_reduced_once() {
1506 test::run(
1507 test_file!("bigint_elem_reduced_once_tests.txt"),
1508 |section, test_case| {
1509 assert_eq!(section, "");
1510
1511 struct N {}
1512 struct QQ {}
1513 unsafe impl SmallerModulus<N> for QQ {}
1514 unsafe impl SlightlySmallerModulus<N> for QQ {}
1515
1516 let qq = consume_modulus::<QQ>(test_case, "QQ");
1517 let expected_result = consume_elem::<QQ>(test_case, "R", &qq);
1518 let n = consume_modulus::<N>(test_case, "N");
1519 let a = consume_elem::<N>(test_case, "A", &n);
1520
1521 let actual_result = elem_reduced_once(&a, &qq);
1522 assert_elem_eq(&actual_result, &expected_result);
1523
1524 Ok(())
1525 },
1526 )
1527 }
1528
1529 #[test]
1530 fn test_modulus_debug() {
1531 let (modulus, _) = Modulus::<M>::from_be_bytes_with_bit_length(untrusted::Input::from(
1532 &[0xff; LIMB_BYTES * MODULUS_MIN_LIMBS],
1533 ))
1534 .unwrap();
1535 assert_eq!("Modulus", format!("{:?}", modulus));
1536 }
1537
1538 #[test]
1539 fn test_public_exponent_debug() {
1540 let exponent =
1541 PublicExponent::from_be_bytes(untrusted::Input::from(&[0x1, 0x00, 0x01]), 65537)
1542 .unwrap();
1543 assert_eq!("PublicExponent(65537)", format!("{:?}", exponent));
1544 }
1545
1546 fn consume_elem<M>(
1547 test_case: &mut test::TestCase,
1548 name: &str,
1549 m: &Modulus<M>,
1550 ) -> Elem<M, Unencoded> {
1551 let value = test_case.consume_bytes(name);
1552 Elem::from_be_bytes_padded(untrusted::Input::from(&value), m).unwrap()
1553 }
1554
1555 fn consume_elem_unchecked<M>(
1556 test_case: &mut test::TestCase,
1557 name: &str,
1558 num_limbs: usize,
1559 ) -> Elem<M, Unencoded> {
1560 let value = consume_nonnegative(test_case, name);
1561 let mut limbs = BoxedLimbs::zero(Width {
1562 num_limbs,
1563 m: PhantomData,
1564 });
1565 limbs[0..value.limbs.len()].copy_from_slice(&value.limbs);
1566 Elem {
1567 limbs,
1568 encoding: PhantomData,
1569 }
1570 }
1571
1572 fn consume_modulus<M>(test_case: &mut test::TestCase, name: &str) -> Modulus<M> {
1573 let value = test_case.consume_bytes(name);
1574 let (value, _) =
1575 Modulus::from_be_bytes_with_bit_length(untrusted::Input::from(&value)).unwrap();
1576 value
1577 }
1578
1579 fn consume_nonnegative(test_case: &mut test::TestCase, name: &str) -> Nonnegative {
1580 let bytes = test_case.consume_bytes(name);
1581 let (r, _r_bits) =
1582 Nonnegative::from_be_bytes_with_bit_length(untrusted::Input::from(&bytes)).unwrap();
1583 r
1584 }
1585
1586 fn assert_elem_eq<M, E>(a: &Elem<M, E>, b: &Elem<M, E>) {
1587 if elem_verify_equal_consttime(&a, b).is_err() {
1588 panic!("{:x?} != {:x?}", &*a.limbs, &*b.limbs);
1589 }
1590 }
1591
1592 fn into_encoded<M>(a: Elem<M, Unencoded>, m: &Modulus<M>) -> Elem<M, R> {
1593 elem_mul(m.oneRR().as_ref(), a, m)
1594 }
1595
1596 #[test]
1597 fn test_mul_add_words() {
1599 const ZERO: Limb = 0;
1600 const MAX: Limb = ZERO.wrapping_sub(1);
1601 static TEST_CASES: &[(&[Limb], &[Limb], Limb, Limb, &[Limb])] = &[
1602 (&[0], &[0], 0, 0, &[0]),
1603 (&[MAX], &[0], MAX, 0, &[MAX]),
1604 (&[0], &[MAX], MAX, MAX - 1, &[1]),
1605 (&[MAX], &[MAX], MAX, MAX, &[0]),
1606 (&[0, 0], &[MAX, MAX], MAX, MAX - 1, &[1, MAX]),
1607 (&[1, 0], &[MAX, MAX], MAX, MAX - 1, &[2, MAX]),
1608 (&[MAX, 0], &[MAX, MAX], MAX, MAX, &[0, 0]),
1609 (&[0, 1], &[MAX, MAX], MAX, MAX, &[1, 0]),
1610 (&[MAX, MAX], &[MAX, MAX], MAX, MAX, &[0, MAX]),
1611 ];
1612
1613 for (i, (r_input, a, w, expected_retval, expected_r)) in TEST_CASES.iter().enumerate() {
1614 extern crate std;
1615 let mut r = std::vec::Vec::from(*r_input);
1616 assert_eq!(r.len(), a.len()); let actual_retval =
1618 unsafe { GFp_limbs_mul_add_limb(r.as_mut_ptr(), a.as_ptr(), *w, a.len()) };
1619 assert_eq!(&r, expected_r, "{}: {:x?} != {:x?}", i, &r[..], expected_r);
1620 assert_eq!(
1621 actual_retval, *expected_retval,
1622 "{}: {:x?} != {:x?}",
1623 i, actual_retval, *expected_retval
1624 );
1625 }
1626 }
1627}