rustls/msgs/
hsjoiner.rs

1use std::collections::VecDeque;
2
3use crate::msgs::codec;
4use crate::msgs::enums::{ContentType, ProtocolVersion};
5use crate::msgs::handshake::HandshakeMessagePayload;
6use crate::msgs::message::{Message, MessagePayload};
7
8const HEADER_SIZE: usize = 1 + 3;
9
10/// This works to reconstruct TLS handshake messages
11/// from individual TLS messages.  It's guaranteed that
12/// TLS messages output from this layer contain precisely
13/// one handshake payload.
14pub struct HandshakeJoiner {
15    /// Completed handshake frames for output.
16    pub frames: VecDeque<Message>,
17
18    /// The message payload we're currently accumulating.
19    buf: Vec<u8>,
20}
21
22impl Default for HandshakeJoiner {
23    fn default() -> Self {
24        Self::new()
25    }
26}
27
28impl HandshakeJoiner {
29    /// Make a new HandshakeJoiner.
30    pub fn new() -> HandshakeJoiner {
31        HandshakeJoiner {
32            frames: VecDeque::new(),
33            buf: Vec::new(),
34        }
35    }
36
37    /// Do we want to process this message?
38    pub fn want_message(&self, msg: &Message) -> bool {
39        msg.is_content_type(ContentType::Handshake)
40    }
41
42    /// Do we have any buffered data?
43    pub fn is_empty(&self) -> bool {
44        self.buf.is_empty()
45    }
46
47    /// Take the message, and join/split it as needed.
48    /// Return the number of new messages added to the
49    /// output deque as a result of this message.
50    ///
51    /// Returns None if msg or a preceding message was corrupt.
52    /// You cannot recover from this situation.  Otherwise returns
53    /// a count of how many messages we queued.
54    pub fn take_message(&mut self, mut msg: Message) -> Option<usize> {
55        // Input must be opaque, otherwise we might have already
56        // lost information!
57        let payload = msg.take_opaque_payload().unwrap();
58
59        self.buf
60            .extend_from_slice(&payload.0[..]);
61
62        let mut count = 0;
63        while self.buf_contains_message() {
64            if !self.deframe_one(msg.version) {
65                return None;
66            }
67
68            count += 1;
69        }
70
71        Some(count)
72    }
73
74    /// Does our `buf` contain a full handshake payload?  It does if it is big
75    /// enough to contain a header, and that header has a length which falls
76    /// within `buf`.
77    fn buf_contains_message(&self) -> bool {
78        self.buf.len() >= HEADER_SIZE
79            && self.buf.len()
80                >= (codec::u24::decode(&self.buf[1..4])
81                    .unwrap()
82                    .0 as usize)
83                    + HEADER_SIZE
84    }
85
86    /// Take a TLS handshake payload off the front of `buf`, and put it onto
87    /// the back of our `frames` deque inside a normal `Message`.
88    ///
89    /// Returns false if the stream is desynchronised beyond repair.
90    fn deframe_one(&mut self, version: ProtocolVersion) -> bool {
91        let used = {
92            let mut rd = codec::Reader::init(&self.buf);
93            let payload = HandshakeMessagePayload::read_version(&mut rd, version);
94
95            if payload.is_none() {
96                return false;
97            }
98
99            let m = Message {
100                typ: ContentType::Handshake,
101                version,
102                payload: MessagePayload::Handshake(payload.unwrap()),
103            };
104
105            self.frames.push_back(m);
106            rd.used()
107        };
108        self.buf = self.buf.split_off(used);
109        true
110    }
111}
112
113#[cfg(test)]
114mod tests {
115    use super::HandshakeJoiner;
116    use crate::msgs::base::Payload;
117    use crate::msgs::enums::{ContentType, HandshakeType, ProtocolVersion};
118    use crate::msgs::handshake::{HandshakeMessagePayload, HandshakePayload};
119    use crate::msgs::message::{Message, MessagePayload};
120
121    #[test]
122    fn want() {
123        let hj = HandshakeJoiner::new();
124        assert_eq!(hj.is_empty(), true);
125
126        let wanted = Message {
127            typ: ContentType::Handshake,
128            version: ProtocolVersion::TLSv1_2,
129            payload: MessagePayload::new_opaque(b"hello world".to_vec()),
130        };
131
132        let unwanted = Message {
133            typ: ContentType::Alert,
134            version: ProtocolVersion::TLSv1_2,
135            payload: MessagePayload::new_opaque(b"ponytown".to_vec()),
136        };
137
138        assert_eq!(hj.want_message(&wanted), true);
139        assert_eq!(hj.want_message(&unwanted), false);
140    }
141
142    fn pop_eq(expect: &Message, hj: &mut HandshakeJoiner) {
143        let got = hj.frames.pop_front().unwrap();
144        assert_eq!(got.typ, expect.typ);
145        assert_eq!(got.version, expect.version);
146
147        let (mut left, mut right) = (Vec::new(), Vec::new());
148        got.payload.encode(&mut left);
149        expect.payload.encode(&mut right);
150
151        assert_eq!(left, right);
152    }
153
154    #[test]
155    fn split() {
156        // Check we split two handshake messages within one PDU.
157        let mut hj = HandshakeJoiner::new();
158
159        // two HelloRequests
160        let msg = Message {
161            typ: ContentType::Handshake,
162            version: ProtocolVersion::TLSv1_2,
163            payload: MessagePayload::new_opaque(b"\x00\x00\x00\x00\x00\x00\x00\x00".to_vec()),
164        };
165
166        assert_eq!(hj.want_message(&msg), true);
167        assert_eq!(hj.take_message(msg), Some(2));
168        assert_eq!(hj.is_empty(), true);
169
170        let expect = Message {
171            typ: ContentType::Handshake,
172            version: ProtocolVersion::TLSv1_2,
173            payload: MessagePayload::Handshake(HandshakeMessagePayload {
174                typ: HandshakeType::HelloRequest,
175                payload: HandshakePayload::HelloRequest,
176            }),
177        };
178
179        pop_eq(&expect, &mut hj);
180        pop_eq(&expect, &mut hj);
181    }
182
183    #[test]
184    fn broken() {
185        // Check obvious crap payloads are reported as errors, not panics.
186        let mut hj = HandshakeJoiner::new();
187
188        // short ClientHello
189        let msg = Message {
190            typ: ContentType::Handshake,
191            version: ProtocolVersion::TLSv1_2,
192            payload: MessagePayload::new_opaque(b"\x01\x00\x00\x02\xff\xff".to_vec()),
193        };
194
195        assert_eq!(hj.want_message(&msg), true);
196        assert_eq!(hj.take_message(msg), None);
197    }
198
199    #[test]
200    fn join() {
201        // Check we join one handshake message split over two PDUs.
202        let mut hj = HandshakeJoiner::new();
203        assert_eq!(hj.is_empty(), true);
204
205        // Introduce Finished of 16 bytes, providing 4.
206        let mut msg = Message {
207            typ: ContentType::Handshake,
208            version: ProtocolVersion::TLSv1_2,
209            payload: MessagePayload::new_opaque(b"\x14\x00\x00\x10\x00\x01\x02\x03\x04".to_vec()),
210        };
211
212        assert_eq!(hj.want_message(&msg), true);
213        assert_eq!(hj.take_message(msg), Some(0));
214        assert_eq!(hj.is_empty(), false);
215
216        // 11 more bytes.
217        msg = Message {
218            typ: ContentType::Handshake,
219            version: ProtocolVersion::TLSv1_2,
220            payload: MessagePayload::new_opaque(
221                b"\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e".to_vec(),
222            ),
223        };
224
225        assert_eq!(hj.want_message(&msg), true);
226        assert_eq!(hj.take_message(msg), Some(0));
227        assert_eq!(hj.is_empty(), false);
228
229        // Final 1 byte.
230        msg = Message {
231            typ: ContentType::Handshake,
232            version: ProtocolVersion::TLSv1_2,
233            payload: MessagePayload::new_opaque(b"\x0f".to_vec()),
234        };
235
236        assert_eq!(hj.want_message(&msg), true);
237        assert_eq!(hj.take_message(msg), Some(1));
238        assert_eq!(hj.is_empty(), true);
239
240        let payload = b"\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f".to_vec();
241        let expect = Message {
242            typ: ContentType::Handshake,
243            version: ProtocolVersion::TLSv1_2,
244            payload: MessagePayload::Handshake(HandshakeMessagePayload {
245                typ: HandshakeType::Finished,
246                payload: HandshakePayload::Finished(Payload::new(payload)),
247            }),
248        };
249
250        pop_eq(&expect, &mut hj);
251    }
252
253    #[test]
254    fn test_rejoins_then_rejects_giant_certs() {
255        let mut hj = HandshakeJoiner::new();
256        let msg = Message {
257            typ: ContentType::Handshake,
258            version: ProtocolVersion::TLSv1_2,
259            payload: MessagePayload::new_opaque(
260                b"\x0b\x01\x00\x04\x01\x00\x01\x00\xff\xfe".to_vec(),
261            ),
262        };
263
264        assert_eq!(hj.want_message(&msg), true);
265        assert_eq!(hj.take_message(msg), Some(0));
266        assert_eq!(hj.is_empty(), false);
267
268        for _i in 0..8191 {
269            let msg = Message {
270                typ: ContentType::Handshake,
271                version: ProtocolVersion::TLSv1_2,
272                payload: MessagePayload::new_opaque(b"\x01\x02\x03\x04\x05\x06\x07\x08".to_vec()),
273            };
274
275            assert_eq!(hj.want_message(&msg), true);
276            assert_eq!(hj.take_message(msg), Some(0));
277            assert_eq!(hj.is_empty(), false);
278        }
279
280        // final 6 bytes
281        let msg = Message {
282            typ: ContentType::Handshake,
283            version: ProtocolVersion::TLSv1_2,
284            payload: MessagePayload::new_opaque(b"\x01\x02\x03\x04\x05\x06".to_vec()),
285        };
286
287        assert_eq!(hj.want_message(&msg), true);
288        assert_eq!(hj.take_message(msg), None);
289    }
290}