1use crate::cipher;
2use crate::error::TLSError;
3use crate::key;
4#[cfg(feature = "logging")]
5use crate::log::{debug, error, warn};
6use crate::msgs::base::Payload;
7use crate::msgs::codec::Codec;
8use crate::msgs::deframer::MessageDeframer;
9use crate::msgs::enums::{AlertDescription, AlertLevel, ContentType, ProtocolVersion};
10use crate::msgs::fragmenter::{MessageFragmenter, MAX_FRAGMENT_LEN};
11use crate::msgs::hsjoiner::HandshakeJoiner;
12use crate::msgs::message::{BorrowMessage, Message, MessagePayload};
13use crate::prf;
14use crate::quic;
15use crate::rand;
16use crate::record_layer;
17use crate::suites::SupportedCipherSuite;
18use crate::vecbuf::ChunkVecBuffer;
19use ring;
20use std::io::{Read, Write};
21
22use std::collections::VecDeque;
23use std::io;
24
25pub trait Session: quic::QuicExt + Read + Write + Send + Sync {
27 fn read_tls(&mut self, rd: &mut dyn Read) -> Result<usize, io::Error>;
41
42 fn write_tls(&mut self, wr: &mut dyn Write) -> Result<usize, io::Error>;
54
55 fn process_new_packets(&mut self) -> Result<(), TLSError>;
63
64 fn wants_read(&self) -> bool;
67
68 fn wants_write(&self) -> bool;
71
72 fn is_handshaking(&self) -> bool;
76
77 fn set_buffer_limit(&mut self, limit: usize);
84
85 fn send_close_notify(&mut self);
89
90 fn get_peer_certificates(&self) -> Option<Vec<key::Certificate>>;
104
105 fn get_alpn_protocol(&self) -> Option<&[u8]>;
111
112 fn get_protocol_version(&self) -> Option<ProtocolVersion>;
116
117 fn export_keying_material(
131 &self,
132 output: &mut [u8],
133 label: &[u8],
134 context: Option<&[u8]>,
135 ) -> Result<(), TLSError>;
136
137 fn get_negotiated_ciphersuite(&self) -> Option<&'static SupportedCipherSuite>;
141
142 fn complete_io<T>(&mut self, io: &mut T) -> Result<(usize, usize), io::Error>
165 where
166 Self: Sized,
167 T: Read + Write,
168 {
169 let until_handshaked = self.is_handshaking();
170 let mut eof = false;
171 let mut wrlen = 0;
172 let mut rdlen = 0;
173
174 loop {
175 while self.wants_write() {
176 wrlen += self.write_tls(io)?;
177 }
178
179 if !until_handshaked && wrlen > 0 {
180 return Ok((rdlen, wrlen));
181 }
182
183 if !eof && self.wants_read() {
184 match self.read_tls(io)? {
185 0 => eof = true,
186 n => rdlen += n,
187 }
188 }
189
190 match self.process_new_packets() {
191 Ok(_) => {}
192 Err(e) => {
193 let _ignored = self.write_tls(io);
197
198 return Err(io::Error::new(io::ErrorKind::InvalidData, e));
199 }
200 };
201
202 match (eof, until_handshaked, self.is_handshaking()) {
203 (_, true, false) => return Ok((rdlen, wrlen)),
204 (_, false, _) => return Ok((rdlen, wrlen)),
205 (true, true, true) => return Err(io::Error::from(io::ErrorKind::UnexpectedEof)),
206 (..) => {}
207 }
208 }
209 }
210}
211
212#[derive(Copy, Clone, Eq, PartialEq)]
213pub enum Protocol {
214 Tls13,
215 #[cfg(feature = "quic")]
216 Quic,
217}
218
219#[derive(Clone, Debug)]
220pub struct SessionRandoms {
221 pub we_are_client: bool,
222 pub client: [u8; 32],
223 pub server: [u8; 32],
224}
225
226static TLS12_DOWNGRADE_SENTINEL: &[u8] = &[0x44, 0x4f, 0x57, 0x4e, 0x47, 0x52, 0x44, 0x01];
227
228impl SessionRandoms {
229 pub fn for_server() -> SessionRandoms {
230 let mut ret = SessionRandoms {
231 we_are_client: false,
232 client: [0u8; 32],
233 server: [0u8; 32],
234 };
235
236 rand::fill_random(&mut ret.server);
237 ret
238 }
239
240 pub fn for_client() -> SessionRandoms {
241 let mut ret = SessionRandoms {
242 we_are_client: true,
243 client: [0u8; 32],
244 server: [0u8; 32],
245 };
246
247 rand::fill_random(&mut ret.client);
248 ret
249 }
250
251 pub fn set_tls12_downgrade_marker(&mut self) {
252 assert!(!self.we_are_client);
253 self.server[24..]
254 .as_mut()
255 .write_all(TLS12_DOWNGRADE_SENTINEL)
256 .unwrap();
257 }
258
259 pub fn has_tls12_downgrade_marker(&mut self) -> bool {
260 assert!(self.we_are_client);
261 &self.server[24..] == TLS12_DOWNGRADE_SENTINEL
264 }
265}
266
267fn join_randoms(first: &[u8], second: &[u8]) -> [u8; 64] {
268 let mut randoms = [0u8; 64];
269 randoms
270 .as_mut()
271 .write_all(first)
272 .unwrap();
273 randoms[32..]
274 .as_mut()
275 .write_all(second)
276 .unwrap();
277 randoms
278}
279
280pub struct SessionSecrets {
282 pub randoms: SessionRandoms,
283 hash: &'static ring::digest::Algorithm,
284 pub master_secret: [u8; 48],
285}
286
287impl SessionSecrets {
288 pub fn new(
289 randoms: &SessionRandoms,
290 hashalg: &'static ring::digest::Algorithm,
291 pms: &[u8],
292 ) -> SessionSecrets {
293 let mut ret = SessionSecrets {
294 randoms: randoms.clone(),
295 hash: hashalg,
296 master_secret: [0u8; 48],
297 };
298
299 let randoms = join_randoms(&ret.randoms.client, &ret.randoms.server);
300 prf::prf(
301 &mut ret.master_secret,
302 ret.hash,
303 pms,
304 b"master secret",
305 &randoms,
306 );
307 ret
308 }
309
310 pub fn new_ems(
311 randoms: &SessionRandoms,
312 hs_hash: &[u8],
313 hashalg: &'static ring::digest::Algorithm,
314 pms: &[u8],
315 ) -> SessionSecrets {
316 let mut ret = SessionSecrets {
317 randoms: randoms.clone(),
318 hash: hashalg,
319 master_secret: [0u8; 48],
320 };
321
322 prf::prf(
323 &mut ret.master_secret,
324 ret.hash,
325 pms,
326 b"extended master secret",
327 hs_hash,
328 );
329 ret
330 }
331
332 pub fn new_resume(
333 randoms: &SessionRandoms,
334 hashalg: &'static ring::digest::Algorithm,
335 master_secret: &[u8],
336 ) -> SessionSecrets {
337 let mut ret = SessionSecrets {
338 randoms: randoms.clone(),
339 hash: hashalg,
340 master_secret: [0u8; 48],
341 };
342 ret.master_secret
343 .as_mut()
344 .write_all(master_secret)
345 .unwrap();
346 ret
347 }
348
349 pub fn make_key_block(&self, len: usize) -> Vec<u8> {
350 let mut out = Vec::new();
351 out.resize(len, 0u8);
352
353 let randoms = join_randoms(&self.randoms.server, &self.randoms.client);
356 prf::prf(
357 &mut out,
358 self.hash,
359 &self.master_secret,
360 b"key expansion",
361 &randoms,
362 );
363
364 out
365 }
366
367 pub fn get_master_secret(&self) -> Vec<u8> {
368 let mut ret = Vec::new();
369 ret.extend_from_slice(&self.master_secret);
370 ret
371 }
372
373 pub fn make_verify_data(&self, handshake_hash: &[u8], label: &[u8]) -> Vec<u8> {
374 let mut out = Vec::new();
375 out.resize(12, 0u8);
376
377 prf::prf(
378 &mut out,
379 self.hash,
380 &self.master_secret,
381 label,
382 handshake_hash,
383 );
384 out
385 }
386
387 pub fn client_verify_data(&self, handshake_hash: &[u8]) -> Vec<u8> {
388 self.make_verify_data(handshake_hash, b"client finished")
389 }
390
391 pub fn server_verify_data(&self, handshake_hash: &[u8]) -> Vec<u8> {
392 self.make_verify_data(handshake_hash, b"server finished")
393 }
394
395 pub fn export_keying_material(&self, output: &mut [u8], label: &[u8], context: Option<&[u8]>) {
396 let mut randoms = Vec::new();
397 randoms.extend_from_slice(&self.randoms.client);
398 randoms.extend_from_slice(&self.randoms.server);
399 if let Some(context) = context {
400 assert!(context.len() <= 0xffff);
401 (context.len() as u16).encode(&mut randoms);
402 randoms.extend_from_slice(context);
403 }
404
405 prf::prf(output, self.hash, &self.master_secret, label, &randoms)
406 }
407}
408
409enum Limit {
412 Yes,
413 No,
414}
415
416pub enum MiddleboxCCS {
419 Process,
421
422 Drop,
424}
425
426pub struct SessionCommon {
427 pub negotiated_version: Option<ProtocolVersion>,
428 pub is_client: bool,
429 pub record_layer: record_layer::RecordLayer,
430 suite: Option<&'static SupportedCipherSuite>,
431 peer_eof: bool,
432 pub traffic: bool,
433 pub early_traffic: bool,
434 sent_fatal_alert: bool,
435 received_middlebox_ccs: bool,
436 pub message_deframer: MessageDeframer,
437 pub handshake_joiner: HandshakeJoiner,
438 pub message_fragmenter: MessageFragmenter,
439 received_plaintext: ChunkVecBuffer,
440 sendable_plaintext: ChunkVecBuffer,
441 pub sendable_tls: ChunkVecBuffer,
442 pub protocol: Protocol,
444 #[cfg(feature = "quic")]
445 pub(crate) quic: Quic,
446}
447
448impl SessionCommon {
449 pub fn new(mtu: Option<usize>, client: bool) -> SessionCommon {
450 SessionCommon {
451 negotiated_version: None,
452 is_client: client,
453 record_layer: record_layer::RecordLayer::new(),
454 suite: None,
455 peer_eof: false,
456 traffic: false,
457 early_traffic: false,
458 sent_fatal_alert: false,
459 received_middlebox_ccs: false,
460 message_deframer: MessageDeframer::new(),
461 handshake_joiner: HandshakeJoiner::new(),
462 message_fragmenter: MessageFragmenter::new(mtu.unwrap_or(MAX_FRAGMENT_LEN)),
463 received_plaintext: ChunkVecBuffer::new(),
464 sendable_plaintext: ChunkVecBuffer::new(),
465 sendable_tls: ChunkVecBuffer::new(),
466 protocol: Protocol::Tls13,
467 #[cfg(feature = "quic")]
468 quic: Quic::new(),
469 }
470 }
471
472 pub fn is_tls13(&self) -> bool {
473 match self.negotiated_version {
474 Some(ProtocolVersion::TLSv1_3) => true,
475 _ => false,
476 }
477 }
478
479 pub fn get_suite(&self) -> Option<&'static SupportedCipherSuite> {
480 self.suite
481 }
482
483 pub fn get_suite_assert(&self) -> &'static SupportedCipherSuite {
484 self.suite.as_ref().unwrap()
485 }
486
487 pub fn set_suite(&mut self, suite: &'static SupportedCipherSuite) -> bool {
488 match self.suite {
489 None => {
490 self.suite = Some(suite);
491 true
492 }
493 Some(s) if s == suite => {
494 self.suite = Some(suite);
495 true
496 }
497 _ => false,
498 }
499 }
500
501 pub fn filter_tls13_ccs(&mut self, msg: &Message) -> Result<MiddleboxCCS, TLSError> {
502 if !self.is_tls13() || !msg.is_content_type(ContentType::ChangeCipherSpec) || self.traffic {
508 return Ok(MiddleboxCCS::Process);
509 }
510
511 if self.received_middlebox_ccs {
512 Err(TLSError::PeerMisbehavedError(
513 "illegal middlebox CCS received".into(),
514 ))
515 } else {
516 self.received_middlebox_ccs = true;
517 Ok(MiddleboxCCS::Drop)
518 }
519 }
520
521 pub fn decrypt_incoming(&mut self, encr: Message) -> Result<Message, TLSError> {
522 if self
523 .record_layer
524 .wants_close_before_decrypt()
525 {
526 self.send_close_notify();
527 }
528
529 let rc = self.record_layer.decrypt_incoming(encr);
530 if let Err(TLSError::PeerSentOversizedRecord) = rc {
531 self.send_fatal_alert(AlertDescription::RecordOverflow);
532 }
533 rc
534 }
535
536 pub fn has_readable_plaintext(&self) -> bool {
537 !self.received_plaintext.is_empty()
538 }
539
540 pub fn set_buffer_limit(&mut self, limit: usize) {
541 self.sendable_plaintext.set_limit(limit);
542 self.sendable_tls.set_limit(limit);
543 }
544
545 pub fn process_alert(&mut self, msg: Message) -> Result<(), TLSError> {
546 if let MessagePayload::Alert(ref alert) = msg.payload {
547 if let AlertLevel::Unknown(_) = alert.level {
549 self.send_fatal_alert(AlertDescription::IllegalParameter);
550 }
551
552 if alert.description == AlertDescription::CloseNotify {
555 self.peer_eof = true;
556 return Ok(());
557 }
558
559 if alert.level == AlertLevel::Warning {
562 if self.is_tls13() && alert.description != AlertDescription::UserCanceled {
563 self.send_fatal_alert(AlertDescription::DecodeError);
564 } else {
565 warn!("TLS alert warning received: {:#?}", msg);
566 return Ok(());
567 }
568 }
569
570 error!("TLS alert received: {:#?}", msg);
571 Err(TLSError::AlertReceived(alert.description))
572 } else {
573 Err(TLSError::CorruptMessagePayload(ContentType::Alert))
574 }
575 }
576
577 pub fn send_msg_encrypt(&mut self, m: Message) {
580 let mut plain_messages = VecDeque::new();
581 self.message_fragmenter
582 .fragment(m, &mut plain_messages);
583
584 for m in plain_messages {
585 self.send_single_fragment(m.to_borrowed());
586 }
587 }
588
589 fn send_appdata_encrypt(&mut self, payload: &[u8], limit: Limit) -> usize {
591 let len = match limit {
596 Limit::Yes => self
597 .sendable_tls
598 .apply_limit(payload.len()),
599 Limit::No => payload.len(),
600 };
601
602 let mut plain_messages = VecDeque::new();
603 self.message_fragmenter.fragment_borrow(
604 ContentType::ApplicationData,
605 ProtocolVersion::TLSv1_2,
606 &payload[..len],
607 &mut plain_messages,
608 );
609
610 for m in plain_messages {
611 self.send_single_fragment(m);
612 }
613
614 len
615 }
616
617 fn send_single_fragment(&mut self, m: BorrowMessage) {
618 if self
621 .record_layer
622 .wants_close_before_encrypt()
623 {
624 self.send_close_notify();
625 }
626
627 if self.record_layer.encrypt_exhausted() {
630 return;
631 }
632
633 let em = self.record_layer.encrypt_outgoing(m);
634 self.queue_tls_message(em);
635 }
636
637 pub fn connection_at_eof(&self) -> bool {
641 self.peer_eof && !self.message_deframer.has_pending()
642 }
643
644 pub fn read_tls(&mut self, rd: &mut dyn Read) -> io::Result<usize> {
648 self.message_deframer.read(rd)
649 }
650
651 pub fn write_tls(&mut self, wr: &mut dyn Write) -> io::Result<usize> {
652 self.sendable_tls.write_to(wr)
653 }
654
655 pub fn send_some_plaintext(&mut self, data: &[u8]) -> usize {
661 self.send_plain(data, Limit::Yes)
662 }
663
664 pub fn send_early_plaintext(&mut self, data: &[u8]) -> usize {
665 debug_assert!(self.early_traffic);
666 debug_assert!(self.record_layer.is_encrypting());
667
668 if data.is_empty() {
669 return 0;
671 }
672
673 self.send_appdata_encrypt(data, Limit::Yes)
674 }
675
676 fn send_plain(&mut self, data: &[u8], limit: Limit) -> usize {
682 if !self.traffic {
683 let len = match limit {
686 Limit::Yes => self
687 .sendable_plaintext
688 .append_limited_copy(data),
689 Limit::No => self
690 .sendable_plaintext
691 .append(data.to_vec()),
692 };
693 return len;
694 }
695
696 debug_assert!(self.record_layer.is_encrypting());
697
698 if data.is_empty() {
699 return 0;
701 }
702
703 self.send_appdata_encrypt(data, limit)
704 }
705
706 pub fn start_traffic(&mut self) {
707 self.traffic = true;
708 self.flush_plaintext();
709 }
710
711 pub fn flush_plaintext(&mut self) {
714 if !self.traffic {
715 return;
716 }
717
718 while !self.sendable_plaintext.is_empty() {
719 let buf = self.sendable_plaintext.take_one();
720 self.send_plain(&buf, Limit::No);
721 }
722 }
723
724 fn queue_tls_message(&mut self, m: Message) {
726 self.sendable_tls
727 .append(m.get_encoding());
728 }
729
730 pub fn send_msg(&mut self, m: Message, must_encrypt: bool) {
732 #[cfg(feature = "quic")]
733 {
734 if let Protocol::Quic = self.protocol {
735 if let MessagePayload::Alert(alert) = m.payload {
736 self.quic.alert = Some(alert.description);
737 } else {
738 debug_assert!(
739 if let MessagePayload::Handshake(_) = m.payload {
740 true
741 } else {
742 false
743 },
744 "QUIC uses TLS for the cryptographic handshake only"
745 );
746 let mut bytes = Vec::new();
747 m.payload.encode(&mut bytes);
748 self.quic
749 .hs_queue
750 .push_back((must_encrypt, bytes));
751 }
752 return;
753 }
754 }
755 if !must_encrypt {
756 let mut to_send = VecDeque::new();
757 self.message_fragmenter
758 .fragment(m, &mut to_send);
759 for mm in to_send {
760 self.queue_tls_message(mm);
761 }
762 } else {
763 self.send_msg_encrypt(m);
764 }
765 }
766
767 pub fn take_received_plaintext(&mut self, bytes: Payload) {
768 self.received_plaintext.append(bytes.0);
769 }
770
771 pub fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
772 let len = self.received_plaintext.read(buf)?;
773
774 if len == 0 && self.connection_at_eof() && self.received_plaintext.is_empty() {
775 return Err(io::Error::new(
776 io::ErrorKind::ConnectionAborted,
777 "CloseNotify alert received",
778 ));
779 }
780
781 Ok(len)
782 }
783
784 pub fn start_encryption_tls12(&mut self, secrets: &SessionSecrets) {
785 let (dec, enc) = cipher::new_tls12(self.get_suite_assert(), secrets);
786 self.record_layer
787 .prepare_message_encrypter(enc);
788 self.record_layer
789 .prepare_message_decrypter(dec);
790 }
791
792 pub fn send_warning_alert(&mut self, desc: AlertDescription) {
793 warn!("Sending warning alert {:?}", desc);
794 self.send_warning_alert_no_log(desc);
795 }
796
797 pub fn send_fatal_alert(&mut self, desc: AlertDescription) {
798 warn!("Sending fatal alert {:?}", desc);
799 debug_assert!(!self.sent_fatal_alert);
800 let m = Message::build_alert(AlertLevel::Fatal, desc);
801 self.send_msg(m, self.record_layer.is_encrypting());
802 self.sent_fatal_alert = true;
803 }
804
805 pub fn send_close_notify(&mut self) {
806 debug!("Sending warning alert {:?}", AlertDescription::CloseNotify);
807 self.send_warning_alert_no_log(AlertDescription::CloseNotify);
808 }
809
810 fn send_warning_alert_no_log(&mut self, desc: AlertDescription) {
811 let m = Message::build_alert(AlertLevel::Warning, desc);
812 self.send_msg(m, self.record_layer.is_encrypting());
813 }
814
815 pub fn is_quic(&self) -> bool {
816 #[cfg(feature = "quic")]
817 {
818 self.protocol == Protocol::Quic
819 }
820 #[cfg(not(feature = "quic"))]
821 false
822 }
823}
824
825#[cfg(feature = "quic")]
826pub(crate) struct Quic {
827 pub params: Option<Vec<u8>>,
829 pub alert: Option<AlertDescription>,
830 pub hs_queue: VecDeque<(bool, Vec<u8>)>,
831 pub early_secret: Option<ring::hkdf::Prk>,
832 pub hs_secrets: Option<quic::Secrets>,
833 pub traffic_secrets: Option<quic::Secrets>,
834 pub returned_traffic_keys: bool,
836}
837
838#[cfg(feature = "quic")]
839impl Quic {
840 pub fn new() -> Self {
841 Self {
842 params: None,
843 alert: None,
844 hs_queue: VecDeque::new(),
845 early_secret: None,
846 hs_secrets: None,
847 traffic_secrets: None,
848 returned_traffic_keys: false,
849 }
850 }
851}