1use super::PUBLIC_KEY_PUBLIC_MODULUS_MAX_LEN;
16use crate::{bits, digest, error, io::der};
17
18#[cfg(feature = "alloc")]
19use crate::rand;
20
21pub trait Padding: 'static + Sync + crate::sealed::Sealed + core::fmt::Debug {
23 fn digest_alg(&self) -> &'static digest::Algorithm;
26}
27
28#[cfg(feature = "alloc")]
32pub trait RsaEncoding: Padding {
33 #[doc(hidden)]
34 fn encode(
35 &self,
36 m_hash: &digest::Digest,
37 m_out: &mut [u8],
38 mod_bits: bits::BitLength,
39 rng: &dyn rand::SecureRandom,
40 ) -> Result<(), error::Unspecified>;
41}
42
43pub trait Verification: Padding {
48 fn verify(
49 &self,
50 m_hash: &digest::Digest,
51 m: &mut untrusted::Reader,
52 mod_bits: bits::BitLength,
53 ) -> Result<(), error::Unspecified>;
54}
55
56#[derive(Debug)]
63pub struct PKCS1 {
64 digest_alg: &'static digest::Algorithm,
65 digestinfo_prefix: &'static [u8],
66}
67
68impl crate::sealed::Sealed for PKCS1 {}
69
70impl Padding for PKCS1 {
71 fn digest_alg(&self) -> &'static digest::Algorithm {
72 self.digest_alg
73 }
74}
75
76#[cfg(feature = "alloc")]
77impl RsaEncoding for PKCS1 {
78 fn encode(
79 &self,
80 m_hash: &digest::Digest,
81 m_out: &mut [u8],
82 _mod_bits: bits::BitLength,
83 _rng: &dyn rand::SecureRandom,
84 ) -> Result<(), error::Unspecified> {
85 pkcs1_encode(&self, m_hash, m_out);
86 Ok(())
87 }
88}
89
90impl Verification for PKCS1 {
91 fn verify(
92 &self,
93 m_hash: &digest::Digest,
94 m: &mut untrusted::Reader,
95 mod_bits: bits::BitLength,
96 ) -> Result<(), error::Unspecified> {
97 let mut calculated = [0u8; PUBLIC_KEY_PUBLIC_MODULUS_MAX_LEN];
100 let calculated = &mut calculated[..mod_bits.as_usize_bytes_rounded_up()];
101 pkcs1_encode(&self, m_hash, calculated);
102 if m.read_bytes_to_end() != *calculated {
103 return Err(error::Unspecified);
104 }
105 Ok(())
106 }
107}
108
109fn pkcs1_encode(pkcs1: &PKCS1, m_hash: &digest::Digest, m_out: &mut [u8]) {
114 let em = m_out;
115
116 let digest_len = pkcs1.digestinfo_prefix.len() + pkcs1.digest_alg.output_len;
117
118 assert!(em.len() >= digest_len + 11);
121 let pad_len = em.len() - digest_len - 3;
122 em[0] = 0;
123 em[1] = 1;
124 for i in 0..pad_len {
125 em[2 + i] = 0xff;
126 }
127 em[2 + pad_len] = 0;
128
129 let (digest_prefix, digest_dst) = em[3 + pad_len..].split_at_mut(pkcs1.digestinfo_prefix.len());
130 digest_prefix.copy_from_slice(pkcs1.digestinfo_prefix);
131 digest_dst.copy_from_slice(m_hash.as_ref());
132}
133
134macro_rules! rsa_pkcs1_padding {
135 ( $PADDING_ALGORITHM:ident, $digest_alg:expr, $digestinfo_prefix:expr,
136 $doc_str:expr ) => {
137 #[doc=$doc_str]
138 pub static $PADDING_ALGORITHM: PKCS1 = PKCS1 {
139 digest_alg: $digest_alg,
140 digestinfo_prefix: $digestinfo_prefix,
141 };
142 };
143}
144
145rsa_pkcs1_padding!(
146 RSA_PKCS1_SHA1_FOR_LEGACY_USE_ONLY,
147 &digest::SHA1_FOR_LEGACY_USE_ONLY,
148 &SHA1_PKCS1_DIGESTINFO_PREFIX,
149 "PKCS#1 1.5 padding using SHA-1 for RSA signatures."
150);
151rsa_pkcs1_padding!(
152 RSA_PKCS1_SHA256,
153 &digest::SHA256,
154 &SHA256_PKCS1_DIGESTINFO_PREFIX,
155 "PKCS#1 1.5 padding using SHA-256 for RSA signatures."
156);
157rsa_pkcs1_padding!(
158 RSA_PKCS1_SHA384,
159 &digest::SHA384,
160 &SHA384_PKCS1_DIGESTINFO_PREFIX,
161 "PKCS#1 1.5 padding using SHA-384 for RSA signatures."
162);
163rsa_pkcs1_padding!(
164 RSA_PKCS1_SHA512,
165 &digest::SHA512,
166 &SHA512_PKCS1_DIGESTINFO_PREFIX,
167 "PKCS#1 1.5 padding using SHA-512 for RSA signatures."
168);
169
170macro_rules! pkcs1_digestinfo_prefix {
171 ( $name:ident, $digest_len:expr, $digest_oid_len:expr,
172 [ $( $digest_oid:expr ),* ] ) => {
173 static $name: [u8; 2 + 8 + $digest_oid_len] = [
174 der::Tag::Sequence as u8, 8 + $digest_oid_len + $digest_len,
175 der::Tag::Sequence as u8, 2 + $digest_oid_len + 2,
176 der::Tag::OID as u8, $digest_oid_len, $( $digest_oid ),*,
177 der::Tag::Null as u8, 0,
178 der::Tag::OctetString as u8, $digest_len,
179 ];
180 }
181}
182
183pkcs1_digestinfo_prefix!(
184 SHA1_PKCS1_DIGESTINFO_PREFIX,
185 20,
186 5,
187 [0x2b, 0x0e, 0x03, 0x02, 0x1a]
188);
189
190pkcs1_digestinfo_prefix!(
191 SHA256_PKCS1_DIGESTINFO_PREFIX,
192 32,
193 9,
194 [0x60, 0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x01]
195);
196
197pkcs1_digestinfo_prefix!(
198 SHA384_PKCS1_DIGESTINFO_PREFIX,
199 48,
200 9,
201 [0x60, 0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x02]
202);
203
204pkcs1_digestinfo_prefix!(
205 SHA512_PKCS1_DIGESTINFO_PREFIX,
206 64,
207 9,
208 [0x60, 0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x03]
209);
210
211#[derive(Debug)]
218pub struct PSS {
219 digest_alg: &'static digest::Algorithm,
220}
221
222impl crate::sealed::Sealed for PSS {}
223
224const MAX_SALT_LEN: usize = digest::MAX_OUTPUT_LEN;
227
228impl Padding for PSS {
229 fn digest_alg(&self) -> &'static digest::Algorithm {
230 self.digest_alg
231 }
232}
233
234impl RsaEncoding for PSS {
235 fn encode(
238 &self,
239 m_hash: &digest::Digest,
240 m_out: &mut [u8],
241 mod_bits: bits::BitLength,
242 rng: &dyn rand::SecureRandom,
243 ) -> Result<(), error::Unspecified> {
244 let metrics = PSSMetrics::new(self.digest_alg, mod_bits)?;
245
246 let em = if metrics.top_byte_mask == 0xff {
253 m_out[0] = 0;
254 &mut m_out[1..]
255 } else {
256 m_out
257 };
258 assert_eq!(em.len(), metrics.em_len);
259
260 let mut salt = [0u8; MAX_SALT_LEN];
266 let salt = &mut salt[..metrics.s_len];
267 rng.fill(salt)?;
268
269 let h_hash = pss_digest(self.digest_alg, m_hash, salt);
271
272 let (mut masked_db, digest_terminator) = em.split_at_mut(metrics.db_len);
277 mgf1(self.digest_alg, h_hash.as_ref(), &mut masked_db)?;
278
279 {
280 let masked_db = masked_db.iter_mut();
282 let mut masked_db = masked_db.skip(metrics.ps_len);
285
286 *(masked_db.next().ok_or(error::Unspecified)?) ^= 0x01;
288
289 for (masked_db_b, salt_b) in masked_db.zip(salt) {
291 *masked_db_b ^= *salt_b;
292 }
293 }
294
295 masked_db[0] &= metrics.top_byte_mask;
297
298 digest_terminator[..metrics.h_len].copy_from_slice(h_hash.as_ref());
300 digest_terminator[metrics.h_len] = 0xbc;
301
302 Ok(())
303 }
304}
305
306impl Verification for PSS {
307 fn verify(
310 &self,
311 m_hash: &digest::Digest,
312 m: &mut untrusted::Reader,
313 mod_bits: bits::BitLength,
314 ) -> Result<(), error::Unspecified> {
315 let metrics = PSSMetrics::new(self.digest_alg, mod_bits)?;
316
317 if metrics.top_byte_mask == 0xff {
327 if m.read_byte()? != 0 {
328 return Err(error::Unspecified);
329 }
330 };
331 let em = m;
332
333 let masked_db = em.read_bytes(metrics.db_len)?;
342 let h_hash = em.read_bytes(metrics.h_len)?;
343
344 if em.read_byte()? != 0xbc {
346 return Err(error::Unspecified);
347 }
348
349 let mut db = [0u8; PUBLIC_KEY_PUBLIC_MODULUS_MAX_LEN];
351 let db = &mut db[..metrics.db_len];
352
353 mgf1(self.digest_alg, h_hash.as_slice_less_safe(), db)?;
354
355 masked_db.read_all(error::Unspecified, |masked_bytes| {
356 let b = masked_bytes.read_byte()?;
358 if b & !metrics.top_byte_mask != 0 {
359 return Err(error::Unspecified);
360 }
361 db[0] ^= b;
362
363 for i in 1..db.len() {
365 db[i] ^= masked_bytes.read_byte()?;
366 }
367 Ok(())
368 })?;
369
370 db[0] &= metrics.top_byte_mask;
372
373 let ps_len = metrics.ps_len;
375 for i in 0..ps_len {
376 if db[i] != 0 {
377 return Err(error::Unspecified);
378 }
379 }
380 if db[metrics.ps_len] != 1 {
381 return Err(error::Unspecified);
382 }
383
384 let salt = &db[(db.len() - metrics.s_len)..];
386
387 let h_prime = pss_digest(self.digest_alg, m_hash, salt);
389
390 if h_hash != *h_prime.as_ref() {
392 return Err(error::Unspecified);
393 }
394
395 Ok(())
396 }
397}
398
399struct PSSMetrics {
400 #[cfg_attr(not(feature = "alloc"), allow(dead_code))]
401 em_len: usize,
402 db_len: usize,
403 ps_len: usize,
404 s_len: usize,
405 h_len: usize,
406 top_byte_mask: u8,
407}
408
409impl PSSMetrics {
410 fn new(
411 digest_alg: &'static digest::Algorithm,
412 mod_bits: bits::BitLength,
413 ) -> Result<PSSMetrics, error::Unspecified> {
414 let em_bits = mod_bits.try_sub_1()?;
415 let em_len = em_bits.as_usize_bytes_rounded_up();
416 let leading_zero_bits = (8 * em_len) - em_bits.as_usize_bits();
417 debug_assert!(leading_zero_bits < 8);
418 let top_byte_mask = 0xffu8 >> leading_zero_bits;
419
420 let h_len = digest_alg.output_len;
421
422 let s_len = h_len;
424
425 let db_len = em_len.checked_sub(1 + s_len).ok_or(error::Unspecified)?;
433 let ps_len = db_len.checked_sub(h_len + 1).ok_or(error::Unspecified)?;
434
435 debug_assert!(em_bits.as_usize_bits() >= (8 * h_len) + (8 * s_len) + 9);
436
437 Ok(PSSMetrics {
438 em_len,
439 db_len,
440 ps_len,
441 s_len,
442 h_len,
443 top_byte_mask,
444 })
445 }
446}
447
448fn mgf1(
451 digest_alg: &'static digest::Algorithm,
452 seed: &[u8],
453 mask: &mut [u8],
454) -> Result<(), error::Unspecified> {
455 let digest_len = digest_alg.output_len;
456
457 let ctr_max = (mask.len() - 1) / digest_len;
459 assert!(ctr_max <= u32::max_value() as usize);
460 for (i, mask_chunk) in mask.chunks_mut(digest_len).enumerate() {
461 let mut ctx = digest::Context::new(digest_alg);
462 ctx.update(seed);
463 ctx.update(&u32::to_be_bytes(i as u32));
464 let digest = ctx.finish();
465 let mask_chunk_len = mask_chunk.len();
466 mask_chunk.copy_from_slice(&digest.as_ref()[..mask_chunk_len]);
467 }
468
469 Ok(())
470}
471
472fn pss_digest(
473 digest_alg: &'static digest::Algorithm,
474 m_hash: &digest::Digest,
475 salt: &[u8],
476) -> digest::Digest {
477 const PREFIX_ZEROS: [u8; 8] = [0u8; 8];
479
480 let mut ctx = digest::Context::new(digest_alg);
482 ctx.update(&PREFIX_ZEROS);
483 ctx.update(m_hash.as_ref());
484 ctx.update(salt);
485 ctx.finish()
486}
487
488macro_rules! rsa_pss_padding {
489 ( $PADDING_ALGORITHM:ident, $digest_alg:expr, $doc_str:expr ) => {
490 #[doc=$doc_str]
491 pub static $PADDING_ALGORITHM: PSS = PSS {
492 digest_alg: $digest_alg,
493 };
494 };
495}
496
497rsa_pss_padding!(
498 RSA_PSS_SHA256,
499 &digest::SHA256,
500 "RSA PSS padding using SHA-256 for RSA signatures.\n\nSee
501 \"`RSA_PSS_*` Details\" in `ring::signature`'s module-level
502 documentation for more details."
503);
504rsa_pss_padding!(
505 RSA_PSS_SHA384,
506 &digest::SHA384,
507 "RSA PSS padding using SHA-384 for RSA signatures.\n\nSee
508 \"`RSA_PSS_*` Details\" in `ring::signature`'s module-level
509 documentation for more details."
510);
511rsa_pss_padding!(
512 RSA_PSS_SHA512,
513 &digest::SHA512,
514 "RSA PSS padding using SHA-512 for RSA signatures.\n\nSee
515 \"`RSA_PSS_*` Details\" in `ring::signature`'s module-level
516 documentation for more details."
517);
518
519#[cfg(test)]
520mod test {
521 use super::*;
522 use crate::{digest, error, test};
523 use alloc::vec;
524
525 #[test]
526 fn test_pss_padding_verify() {
527 test::run(
528 test_file!("rsa_pss_padding_tests.txt"),
529 |section, test_case| {
530 assert_eq!(section, "");
531
532 let digest_name = test_case.consume_string("Digest");
533 let alg = match digest_name.as_ref() {
534 "SHA256" => &RSA_PSS_SHA256,
535 "SHA384" => &RSA_PSS_SHA384,
536 "SHA512" => &RSA_PSS_SHA512,
537 _ => panic!("Unsupported digest: {}", digest_name),
538 };
539
540 let msg = test_case.consume_bytes("Msg");
541 let msg = untrusted::Input::from(&msg);
542 let m_hash = digest::digest(alg.digest_alg(), msg.as_slice_less_safe());
543
544 let encoded = test_case.consume_bytes("EM");
545 let encoded = untrusted::Input::from(&encoded);
546
547 let _ = test_case.consume_bytes("Salt");
549
550 let bit_len = test_case.consume_usize_bits("Len");
551 let is_valid = test_case.consume_string("Result") == "P";
552
553 let actual_result =
554 encoded.read_all(error::Unspecified, |m| alg.verify(&m_hash, m, bit_len));
555 assert_eq!(actual_result.is_ok(), is_valid);
556
557 Ok(())
558 },
559 );
560 }
561
562 #[cfg(feature = "alloc")]
564 #[test]
565 fn test_pss_padding_encode() {
566 test::run(
567 test_file!("rsa_pss_padding_tests.txt"),
568 |section, test_case| {
569 assert_eq!(section, "");
570
571 let digest_name = test_case.consume_string("Digest");
572 let alg = match digest_name.as_ref() {
573 "SHA256" => &RSA_PSS_SHA256,
574 "SHA384" => &RSA_PSS_SHA384,
575 "SHA512" => &RSA_PSS_SHA512,
576 _ => panic!("Unsupported digest: {}", digest_name),
577 };
578
579 let msg = test_case.consume_bytes("Msg");
580 let salt = test_case.consume_bytes("Salt");
581 let encoded = test_case.consume_bytes("EM");
582 let bit_len = test_case.consume_usize_bits("Len");
583 let expected_result = test_case.consume_string("Result");
584
585 if expected_result != "P" {
587 return Ok(());
588 }
589
590 let rng = test::rand::FixedSliceRandom { bytes: &salt };
591
592 let mut m_out = vec![0u8; bit_len.as_usize_bytes_rounded_up()];
593 let digest = digest::digest(alg.digest_alg(), &msg);
594 alg.encode(&digest, &mut m_out, bit_len, &rng).unwrap();
595 assert_eq!(m_out, encoded);
596
597 Ok(())
598 },
599 );
600 }
601}