rustls/
check.rs

1use crate::error::TLSError;
2#[cfg(feature = "logging")]
3use crate::log::warn;
4use crate::msgs::enums::{ContentType, HandshakeType};
5use crate::msgs::message::{Message, MessagePayload};
6
7/// For a Message $m, and a HandshakePayload enum member $payload_type,
8/// return Ok(payload) if $m is both a handshake message and one that
9/// has the given $payload_type.  If not, return Err(TLSError) quoting
10/// $handshake_type as the expected handshake type.
11macro_rules! require_handshake_msg(
12  ( $m:expr, $handshake_type:path, $payload_type:path ) => (
13    match $m.payload {
14        MessagePayload::Handshake(ref hsp) => match hsp.payload {
15            $payload_type(ref hm) => Ok(hm),
16            _ => Err(TLSError::InappropriateHandshakeMessage {
17                     expect_types: vec![ $handshake_type ],
18                     got_type: hsp.typ})
19        }
20        _ => Err(TLSError::InappropriateMessage {
21                 expect_types: vec![ ContentType::Handshake ],
22                 got_type: $m.typ})
23    }
24  )
25);
26
27/// Like require_handshake_msg, but moves the payload out of $m.
28macro_rules! require_handshake_msg_mut(
29  ( $m:expr, $handshake_type:path, $payload_type:path ) => (
30    match $m.payload {
31        MessagePayload::Handshake(hsp) => match hsp.payload {
32            $payload_type(hm) => Ok(hm),
33            _ => Err(TLSError::InappropriateHandshakeMessage {
34                     expect_types: vec![ $handshake_type ],
35                     got_type: hsp.typ})
36        }
37        _ => Err(TLSError::InappropriateMessage {
38                 expect_types: vec![ ContentType::Handshake ],
39                 got_type: $m.typ})
40    }
41  )
42);
43
44/// Validate the message `m`: return an error if:
45///
46/// - the type of m does not appear in `content_types`.
47/// - if m is a handshake message, the handshake message type does
48///   not appear in `handshake_types`.
49pub fn check_message(
50    m: &Message,
51    content_types: &[ContentType],
52    handshake_types: &[HandshakeType],
53) -> Result<(), TLSError> {
54    if !content_types.contains(&m.typ) {
55        warn!(
56            "Received a {:?} message while expecting {:?}",
57            m.typ, content_types
58        );
59        return Err(TLSError::InappropriateMessage {
60            expect_types: content_types.to_vec(),
61            got_type: m.typ,
62        });
63    }
64
65    if let MessagePayload::Handshake(ref hsp) = m.payload {
66        if !handshake_types.is_empty() && !handshake_types.contains(&hsp.typ) {
67            warn!(
68                "Received a {:?} handshake message while expecting {:?}",
69                hsp.typ, handshake_types
70            );
71            return Err(TLSError::InappropriateHandshakeMessage {
72                expect_types: handshake_types.to_vec(),
73                got_type: hsp.typ,
74            });
75        }
76    }
77
78    Ok(())
79}