1use crate::error::TLSError;
2use crate::key_schedule::{derive_traffic_iv, derive_traffic_key};
3use crate::msgs::codec;
4use crate::msgs::codec::Codec;
5use crate::msgs::enums::{ContentType, ProtocolVersion};
6use crate::msgs::fragmenter::MAX_FRAGMENT_LEN;
7use crate::msgs::message::{BorrowMessage, Message, MessagePayload};
8use crate::session::SessionSecrets;
9use crate::suites::SupportedCipherSuite;
10use ring::{aead, hkdf};
11use std::io::Write;
12
13pub trait MessageDecrypter: Send + Sync {
15 fn decrypt(&self, m: Message, seq: u64) -> Result<Message, TLSError>;
16}
17
18pub trait MessageEncrypter: Send + Sync {
20 fn encrypt(&self, m: BorrowMessage, seq: u64) -> Result<Message, TLSError>;
21}
22
23impl dyn MessageEncrypter {
24 pub fn invalid() -> Box<dyn MessageEncrypter> {
25 Box::new(InvalidMessageEncrypter {})
26 }
27}
28
29impl dyn MessageDecrypter {
30 pub fn invalid() -> Box<dyn MessageDecrypter> {
31 Box::new(InvalidMessageDecrypter {})
32 }
33}
34
35pub type MessageCipherPair = (Box<dyn MessageDecrypter>, Box<dyn MessageEncrypter>);
36
37const TLS12_AAD_SIZE: usize = 8 + 1 + 2 + 2;
38fn make_tls12_aad(
39 seq: u64,
40 typ: ContentType,
41 vers: ProtocolVersion,
42 len: usize,
43) -> ring::aead::Aad<[u8; TLS12_AAD_SIZE]> {
44 let mut out = [0; TLS12_AAD_SIZE];
45 codec::put_u64(seq, &mut out[0..]);
46 out[8] = typ.get_u8();
47 codec::put_u16(vers.get_u16(), &mut out[9..]);
48 codec::put_u16(len as u16, &mut out[11..]);
49 ring::aead::Aad::from(out)
50}
51
52fn make_tls12_gcm_nonce(write_iv: &[u8], explicit: &[u8]) -> Iv {
53 debug_assert_eq!(write_iv.len(), 4);
54 debug_assert_eq!(explicit.len(), 8);
55
56 let mut iv = Iv(Default::default());
64 iv.0[..4].copy_from_slice(write_iv);
65 iv.0[4..].copy_from_slice(explicit);
66 iv
67}
68
69pub type BuildTLS12Decrypter = fn(&[u8], &[u8]) -> Box<dyn MessageDecrypter>;
70pub type BuildTLS12Encrypter = fn(&[u8], &[u8], &[u8]) -> Box<dyn MessageEncrypter>;
71
72pub fn build_tls12_gcm_128_decrypter(key: &[u8], iv: &[u8]) -> Box<dyn MessageDecrypter> {
73 Box::new(GCMMessageDecrypter::new(&aead::AES_128_GCM, key, iv))
74}
75
76pub fn build_tls12_gcm_128_encrypter(
77 key: &[u8],
78 iv: &[u8],
79 extra: &[u8],
80) -> Box<dyn MessageEncrypter> {
81 let nonce = make_tls12_gcm_nonce(iv, extra);
82 Box::new(GCMMessageEncrypter::new(&aead::AES_128_GCM, key, nonce))
83}
84
85pub fn build_tls12_gcm_256_decrypter(key: &[u8], iv: &[u8]) -> Box<dyn MessageDecrypter> {
86 Box::new(GCMMessageDecrypter::new(&aead::AES_256_GCM, key, iv))
87}
88
89pub fn build_tls12_gcm_256_encrypter(
90 key: &[u8],
91 iv: &[u8],
92 extra: &[u8],
93) -> Box<dyn MessageEncrypter> {
94 let nonce = make_tls12_gcm_nonce(iv, extra);
95 Box::new(GCMMessageEncrypter::new(&aead::AES_256_GCM, key, nonce))
96}
97
98pub fn build_tls12_chacha_decrypter(key: &[u8], iv: &[u8]) -> Box<dyn MessageDecrypter> {
99 Box::new(ChaCha20Poly1305MessageDecrypter::new(
100 &aead::CHACHA20_POLY1305,
101 key,
102 Iv::copy(iv),
103 ))
104}
105
106pub fn build_tls12_chacha_encrypter(key: &[u8], iv: &[u8], _: &[u8]) -> Box<dyn MessageEncrypter> {
107 Box::new(ChaCha20Poly1305MessageEncrypter::new(
108 &aead::CHACHA20_POLY1305,
109 key,
110 Iv::copy(iv),
111 ))
112}
113
114pub fn new_tls12(
117 scs: &'static SupportedCipherSuite,
118 secrets: &SessionSecrets,
119) -> MessageCipherPair {
120 let key_block = secrets.make_key_block(scs.key_block_len());
123
124 let mut offs = 0;
125 let client_write_key = &key_block[offs..offs + scs.enc_key_len];
126 offs += scs.enc_key_len;
127 let server_write_key = &key_block[offs..offs + scs.enc_key_len];
128 offs += scs.enc_key_len;
129 let client_write_iv = &key_block[offs..offs + scs.fixed_iv_len];
130 offs += scs.fixed_iv_len;
131 let server_write_iv = &key_block[offs..offs + scs.fixed_iv_len];
132 offs += scs.fixed_iv_len;
133
134 let (write_key, write_iv) = if secrets.randoms.we_are_client {
135 (client_write_key, client_write_iv)
136 } else {
137 (server_write_key, server_write_iv)
138 };
139
140 let (read_key, read_iv) = if secrets.randoms.we_are_client {
141 (server_write_key, server_write_iv)
142 } else {
143 (client_write_key, client_write_iv)
144 };
145
146 (
147 scs.build_tls12_decrypter.unwrap()(read_key, read_iv),
148 scs.build_tls12_encrypter.unwrap()(write_key, write_iv, &key_block[offs..]),
149 )
150}
151
152pub fn new_tls13_read(
153 scs: &'static SupportedCipherSuite,
154 secret: &hkdf::Prk,
155) -> Box<dyn MessageDecrypter> {
156 let key = derive_traffic_key(secret, scs.aead_algorithm);
157 let iv = derive_traffic_iv(secret);
158
159 Box::new(TLS13MessageDecrypter::new(key, iv))
160}
161
162pub fn new_tls13_write(
163 scs: &'static SupportedCipherSuite,
164 secret: &hkdf::Prk,
165) -> Box<dyn MessageEncrypter> {
166 let key = derive_traffic_key(secret, scs.aead_algorithm);
167 let iv = derive_traffic_iv(secret);
168
169 Box::new(TLS13MessageEncrypter::new(key, iv))
170}
171
172pub struct GCMMessageEncrypter {
174 enc_key: aead::LessSafeKey,
175 iv: Iv,
176}
177
178pub struct GCMMessageDecrypter {
180 dec_key: aead::LessSafeKey,
181 dec_salt: [u8; 4],
182}
183
184const GCM_EXPLICIT_NONCE_LEN: usize = 8;
185const GCM_OVERHEAD: usize = GCM_EXPLICIT_NONCE_LEN + 16;
186
187impl MessageDecrypter for GCMMessageDecrypter {
188 fn decrypt(&self, mut msg: Message, seq: u64) -> Result<Message, TLSError> {
189 let payload = msg
190 .take_opaque_payload()
191 .ok_or(TLSError::DecryptError)?;
192 let mut buf = payload.0;
193
194 if buf.len() < GCM_OVERHEAD {
195 return Err(TLSError::DecryptError);
196 }
197
198 let nonce = {
199 let mut nonce = [0u8; 12];
200 nonce
201 .as_mut()
202 .write_all(&self.dec_salt)
203 .unwrap();
204 nonce[4..]
205 .as_mut()
206 .write_all(&buf[..8])
207 .unwrap();
208 aead::Nonce::assume_unique_for_key(nonce)
209 };
210
211 let aad = make_tls12_aad(seq, msg.typ, msg.version, buf.len() - GCM_OVERHEAD);
212
213 let plain_len = self
214 .dec_key
215 .open_within(nonce, aad, &mut buf, GCM_EXPLICIT_NONCE_LEN..)
216 .map_err(|_| TLSError::DecryptError)?
217 .len();
218
219 if plain_len > MAX_FRAGMENT_LEN {
220 return Err(TLSError::PeerSentOversizedRecord);
221 }
222
223 buf.truncate(plain_len);
224
225 Ok(Message {
226 typ: msg.typ,
227 version: msg.version,
228 payload: MessagePayload::new_opaque(buf),
229 })
230 }
231}
232
233impl MessageEncrypter for GCMMessageEncrypter {
234 fn encrypt(&self, msg: BorrowMessage, seq: u64) -> Result<Message, TLSError> {
235 let nonce = make_tls13_nonce(&self.iv, seq);
236 let aad = make_tls12_aad(seq, msg.typ, msg.version, msg.payload.len());
237
238 let total_len = msg.payload.len() + self.enc_key.algorithm().tag_len();
239 let mut payload = Vec::with_capacity(GCM_EXPLICIT_NONCE_LEN + total_len);
240 payload.extend_from_slice(&nonce.as_ref()[4..]);
241 payload.extend_from_slice(&msg.payload);
242
243 self.enc_key
244 .seal_in_place_separate_tag(nonce, aad, &mut payload[GCM_EXPLICIT_NONCE_LEN..])
245 .map(|tag| payload.extend(tag.as_ref()))
246 .map_err(|_| TLSError::General("encrypt failed".to_string()))?;
247
248 Ok(Message {
249 typ: msg.typ,
250 version: msg.version,
251 payload: MessagePayload::new_opaque(payload),
252 })
253 }
254}
255
256impl GCMMessageEncrypter {
257 fn new(alg: &'static aead::Algorithm, enc_key: &[u8], iv: Iv) -> GCMMessageEncrypter {
258 let key = aead::UnboundKey::new(alg, enc_key).unwrap();
259 GCMMessageEncrypter {
260 enc_key: aead::LessSafeKey::new(key),
261 iv,
262 }
263 }
264}
265
266impl GCMMessageDecrypter {
267 fn new(alg: &'static aead::Algorithm, dec_key: &[u8], dec_iv: &[u8]) -> GCMMessageDecrypter {
268 let key = aead::UnboundKey::new(alg, dec_key).unwrap();
269 let mut ret = GCMMessageDecrypter {
270 dec_key: aead::LessSafeKey::new(key),
271 dec_salt: [0u8; 4],
272 };
273
274 debug_assert_eq!(dec_iv.len(), 4);
275 ret.dec_salt
276 .as_mut()
277 .write_all(dec_iv)
278 .unwrap();
279 ret
280 }
281}
282
283pub(crate) struct Iv([u8; ring::aead::NONCE_LEN]);
285
286impl Iv {
287 pub(crate) fn new(value: [u8; ring::aead::NONCE_LEN]) -> Self {
288 Self(value)
289 }
290
291 fn copy(value: &[u8]) -> Self {
292 debug_assert_eq!(value.len(), ring::aead::NONCE_LEN);
293 let mut iv = Iv::new(Default::default());
294 iv.0.copy_from_slice(value);
295 iv
296 }
297
298 #[cfg(test)]
299 pub(crate) fn value(&self) -> &[u8; 12] {
300 &self.0
301 }
302}
303
304pub(crate) struct IvLen;
305
306impl hkdf::KeyType for IvLen {
307 fn len(&self) -> usize {
308 aead::NONCE_LEN
309 }
310}
311
312impl From<hkdf::Okm<'_, IvLen>> for Iv {
313 fn from(okm: hkdf::Okm<IvLen>) -> Self {
314 let mut r = Iv(Default::default());
315 okm.fill(&mut r.0[..]).unwrap();
316 r
317 }
318}
319
320struct TLS13MessageEncrypter {
321 enc_key: aead::LessSafeKey,
322 iv: Iv,
323}
324
325struct TLS13MessageDecrypter {
326 dec_key: aead::LessSafeKey,
327 iv: Iv,
328}
329
330fn unpad_tls13(v: &mut Vec<u8>) -> ContentType {
331 loop {
332 match v.pop() {
333 Some(0) => {}
334
335 Some(content_type) => return ContentType::read_bytes(&[content_type]).unwrap(),
336
337 None => return ContentType::Unknown(0),
338 }
339 }
340}
341
342fn make_tls13_nonce(iv: &Iv, seq: u64) -> ring::aead::Nonce {
343 let mut nonce = [0u8; ring::aead::NONCE_LEN];
344 codec::put_u64(seq, &mut nonce[4..]);
345
346 nonce
347 .iter_mut()
348 .zip(iv.0.iter())
349 .for_each(|(nonce, iv)| {
350 *nonce ^= *iv;
351 });
352
353 aead::Nonce::assume_unique_for_key(nonce)
354}
355
356fn make_tls13_aad(len: usize) -> ring::aead::Aad<[u8; 1 + 2 + 2]> {
357 ring::aead::Aad::from([
358 0x17, 0x3, 0x3, (len >> 8) as u8,
362 len as u8,
363 ])
364}
365
366impl MessageEncrypter for TLS13MessageEncrypter {
367 fn encrypt(&self, msg: BorrowMessage, seq: u64) -> Result<Message, TLSError> {
368 let total_len = msg.payload.len() + 1 + self.enc_key.algorithm().tag_len();
369 let mut buf = Vec::with_capacity(total_len);
370 buf.extend_from_slice(&msg.payload);
371 msg.typ.encode(&mut buf);
372
373 let nonce = make_tls13_nonce(&self.iv, seq);
374 let aad = make_tls13_aad(total_len);
375
376 self.enc_key
377 .seal_in_place_append_tag(nonce, aad, &mut buf)
378 .map_err(|_| TLSError::General("encrypt failed".to_string()))?;
379
380 Ok(Message {
381 typ: ContentType::ApplicationData,
382 version: ProtocolVersion::TLSv1_2,
383 payload: MessagePayload::new_opaque(buf),
384 })
385 }
386}
387
388impl MessageDecrypter for TLS13MessageDecrypter {
389 fn decrypt(&self, mut msg: Message, seq: u64) -> Result<Message, TLSError> {
390 let payload = msg
391 .take_opaque_payload()
392 .ok_or(TLSError::DecryptError)?;
393 let mut buf = payload.0;
394
395 if buf.len() < self.dec_key.algorithm().tag_len() {
396 return Err(TLSError::DecryptError);
397 }
398
399 let nonce = make_tls13_nonce(&self.iv, seq);
400 let aad = make_tls13_aad(buf.len());
401 let plain_len = self
402 .dec_key
403 .open_in_place(nonce, aad, &mut buf)
404 .map_err(|_| TLSError::DecryptError)?
405 .len();
406
407 buf.truncate(plain_len);
408
409 if buf.len() > MAX_FRAGMENT_LEN + 1 {
410 return Err(TLSError::PeerSentOversizedRecord);
411 }
412
413 let content_type = unpad_tls13(&mut buf);
414 if content_type == ContentType::Unknown(0) {
415 let msg = "peer sent bad TLSInnerPlaintext".to_string();
416 return Err(TLSError::PeerMisbehavedError(msg));
417 }
418
419 if buf.len() > MAX_FRAGMENT_LEN {
420 return Err(TLSError::PeerSentOversizedRecord);
421 }
422
423 Ok(Message {
424 typ: content_type,
425 version: ProtocolVersion::TLSv1_3,
426 payload: MessagePayload::new_opaque(buf),
427 })
428 }
429}
430
431impl TLS13MessageEncrypter {
432 fn new(key: aead::UnboundKey, enc_iv: Iv) -> TLS13MessageEncrypter {
433 TLS13MessageEncrypter {
434 enc_key: aead::LessSafeKey::new(key),
435 iv: enc_iv,
436 }
437 }
438}
439
440impl TLS13MessageDecrypter {
441 fn new(key: aead::UnboundKey, dec_iv: Iv) -> TLS13MessageDecrypter {
442 TLS13MessageDecrypter {
443 dec_key: aead::LessSafeKey::new(key),
444 iv: dec_iv,
445 }
446 }
447}
448
449pub struct ChaCha20Poly1305MessageEncrypter {
453 enc_key: aead::LessSafeKey,
454 enc_offset: Iv,
455}
456
457pub struct ChaCha20Poly1305MessageDecrypter {
461 dec_key: aead::LessSafeKey,
462 dec_offset: Iv,
463}
464
465impl ChaCha20Poly1305MessageEncrypter {
466 fn new(
467 alg: &'static aead::Algorithm,
468 enc_key: &[u8],
469 enc_iv: Iv,
470 ) -> ChaCha20Poly1305MessageEncrypter {
471 let key = aead::UnboundKey::new(alg, enc_key).unwrap();
472 ChaCha20Poly1305MessageEncrypter {
473 enc_key: aead::LessSafeKey::new(key),
474 enc_offset: enc_iv,
475 }
476 }
477}
478
479impl ChaCha20Poly1305MessageDecrypter {
480 fn new(
481 alg: &'static aead::Algorithm,
482 dec_key: &[u8],
483 dec_iv: Iv,
484 ) -> ChaCha20Poly1305MessageDecrypter {
485 let key = aead::UnboundKey::new(alg, dec_key).unwrap();
486 ChaCha20Poly1305MessageDecrypter {
487 dec_key: aead::LessSafeKey::new(key),
488 dec_offset: dec_iv,
489 }
490 }
491}
492
493const CHACHAPOLY1305_OVERHEAD: usize = 16;
494
495impl MessageDecrypter for ChaCha20Poly1305MessageDecrypter {
496 fn decrypt(&self, mut msg: Message, seq: u64) -> Result<Message, TLSError> {
497 let payload = msg
498 .take_opaque_payload()
499 .ok_or(TLSError::DecryptError)?;
500 let mut buf = payload.0;
501
502 if buf.len() < CHACHAPOLY1305_OVERHEAD {
503 return Err(TLSError::DecryptError);
504 }
505
506 let nonce = make_tls13_nonce(&self.dec_offset, seq);
507 let aad = make_tls12_aad(
508 seq,
509 msg.typ,
510 msg.version,
511 buf.len() - CHACHAPOLY1305_OVERHEAD,
512 );
513
514 let plain_len = self
515 .dec_key
516 .open_in_place(nonce, aad, &mut buf)
517 .map_err(|_| TLSError::DecryptError)?
518 .len();
519
520 if plain_len > MAX_FRAGMENT_LEN {
521 return Err(TLSError::PeerSentOversizedRecord);
522 }
523
524 buf.truncate(plain_len);
525
526 Ok(Message {
527 typ: msg.typ,
528 version: msg.version,
529 payload: MessagePayload::new_opaque(buf),
530 })
531 }
532}
533
534impl MessageEncrypter for ChaCha20Poly1305MessageEncrypter {
535 fn encrypt(&self, msg: BorrowMessage, seq: u64) -> Result<Message, TLSError> {
536 let nonce = make_tls13_nonce(&self.enc_offset, seq);
537 let aad = make_tls12_aad(seq, msg.typ, msg.version, msg.payload.len());
538
539 let total_len = msg.payload.len() + self.enc_key.algorithm().tag_len();
540 let mut buf = Vec::with_capacity(total_len);
541 buf.extend_from_slice(&msg.payload);
542
543 self.enc_key
544 .seal_in_place_append_tag(nonce, aad, &mut buf)
545 .map_err(|_| TLSError::General("encrypt failed".to_string()))?;
546
547 Ok(Message {
548 typ: msg.typ,
549 version: msg.version,
550 payload: MessagePayload::new_opaque(buf),
551 })
552 }
553}
554
555pub struct InvalidMessageEncrypter {}
557
558impl MessageEncrypter for InvalidMessageEncrypter {
559 fn encrypt(&self, _m: BorrowMessage, _seq: u64) -> Result<Message, TLSError> {
560 Err(TLSError::General("encrypt not yet available".to_string()))
561 }
562}
563
564pub struct InvalidMessageDecrypter {}
566
567impl MessageDecrypter for InvalidMessageDecrypter {
568 fn decrypt(&self, _m: Message, _seq: u64) -> Result<Message, TLSError> {
569 Err(TLSError::DecryptError)
570 }
571}