rustls/msgs/
deframer.rs

1use std::collections::VecDeque;
2use std::io;
3
4use crate::msgs::codec;
5use crate::msgs::message::{Message, MessageError};
6
7/// This deframer works to reconstruct TLS messages
8/// from arbitrary-sized reads, buffering as necessary.
9/// The input is `read()`, the output is the `frames` deque.
10pub struct MessageDeframer {
11    /// Completed frames for output.
12    pub frames: VecDeque<Message>,
13
14    /// Set to true if the peer is not talking TLS, but some other
15    /// protocol.  The caller should abort the connection, because
16    /// the deframer cannot recover.
17    pub desynced: bool,
18
19    /// A fixed-size buffer containing the currently-accumulating
20    /// TLS message.
21    buf: Box<[u8; Message::MAX_WIRE_SIZE]>,
22
23    /// What size prefix of `buf` is used.
24    used: usize,
25}
26
27enum BufferContents {
28    /// Contains an invalid message as a header.
29    Invalid,
30
31    /// Might contain a valid message if we receive more.
32    /// Perhaps totally empty!
33    Partial,
34
35    /// Contains a valid frame as a prefix.
36    Valid,
37}
38
39impl Default for MessageDeframer {
40    fn default() -> Self {
41        Self::new()
42    }
43}
44
45impl MessageDeframer {
46    pub fn new() -> MessageDeframer {
47        MessageDeframer {
48            frames: VecDeque::new(),
49            desynced: false,
50            buf: Box::new([0u8; Message::MAX_WIRE_SIZE]),
51            used: 0,
52        }
53    }
54
55    /// Read some bytes from `rd`, and add them to our internal
56    /// buffer.  If this means our internal buffer contains
57    /// full messages, decode them all.
58    pub fn read(&mut self, rd: &mut dyn io::Read) -> io::Result<usize> {
59        // Try to do the largest reads possible.  Note that if
60        // we get a message with a length field out of range here,
61        // we do a zero length read.  That looks like an EOF to
62        // the next layer up, which is fine.
63        debug_assert!(self.used <= Message::MAX_WIRE_SIZE);
64        let new_bytes = rd.read(&mut self.buf[self.used..])?;
65
66        self.used += new_bytes;
67
68        loop {
69            match self.try_deframe_one() {
70                BufferContents::Invalid => {
71                    self.desynced = true;
72                    break;
73                }
74                BufferContents::Valid => continue,
75                BufferContents::Partial => break,
76            }
77        }
78
79        Ok(new_bytes)
80    }
81
82    /// Returns true if we have messages for the caller
83    /// to process, either whole messages in our output
84    /// queue or partial messages in our buffer.
85    pub fn has_pending(&self) -> bool {
86        !self.frames.is_empty() || self.used > 0
87    }
88
89    /// Does our `buf` contain a full message?  It does if it is big enough to
90    /// contain a header, and that header has a length which falls within `buf`.
91    /// If so, deframe it and place the message onto the frames output queue.
92    fn try_deframe_one(&mut self) -> BufferContents {
93        // Try to decode a message off the front of buf.
94        let mut rd = codec::Reader::init(&self.buf[..self.used]);
95
96        match Message::read_with_detailed_error(&mut rd) {
97            Ok(m) => {
98                let used = rd.used();
99                self.frames.push_back(m);
100                self.buf_consume(used);
101                BufferContents::Valid
102            }
103            Err(MessageError::TooShortForHeader) | Err(MessageError::TooShortForLength) => {
104                BufferContents::Partial
105            }
106            Err(_) => BufferContents::Invalid,
107        }
108    }
109
110    fn buf_consume(&mut self, taken: usize) {
111        if taken < self.used {
112            /* Before:
113             * +----------+----------+----------+
114             * | taken    | pending  |xxxxxxxxxx|
115             * +----------+----------+----------+
116             * 0          ^ taken    ^ self.used
117             *
118             * After:
119             * +----------+----------+----------+
120             * | pending  |xxxxxxxxxxxxxxxxxxxxx|
121             * +----------+----------+----------+
122             * 0          ^ self.used
123             */
124
125            self.buf
126                .copy_within(taken..self.used, 0);
127            self.used = self.used - taken;
128        } else if taken == self.used {
129            self.used = 0;
130        }
131    }
132}
133
134#[cfg(test)]
135mod tests {
136    use super::MessageDeframer;
137    use crate::msgs;
138    use std::io;
139
140    const FIRST_MESSAGE: &'static [u8] = include_bytes!("../testdata/deframer-test.1.bin");
141    const SECOND_MESSAGE: &'static [u8] = include_bytes!("../testdata/deframer-test.2.bin");
142
143    struct ByteRead<'a> {
144        buf: &'a [u8],
145        offs: usize,
146    }
147
148    impl<'a> ByteRead<'a> {
149        fn new(bytes: &'a [u8]) -> ByteRead {
150            ByteRead {
151                buf: bytes,
152                offs: 0,
153            }
154        }
155    }
156
157    impl<'a> io::Read for ByteRead<'a> {
158        fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
159            let mut len = 0;
160
161            while len < buf.len() && len < self.buf.len() - self.offs {
162                buf[len] = self.buf[self.offs + len];
163                len += 1;
164            }
165
166            self.offs += len;
167
168            Ok(len)
169        }
170    }
171
172    fn input_bytes(d: &mut MessageDeframer, bytes: &[u8]) -> io::Result<usize> {
173        let mut rd = ByteRead::new(bytes);
174        d.read(&mut rd)
175    }
176
177    fn input_bytes_concat(
178        d: &mut MessageDeframer,
179        bytes1: &[u8],
180        bytes2: &[u8],
181    ) -> io::Result<usize> {
182        let mut bytes = vec![0u8; bytes1.len() + bytes2.len()];
183        bytes[..bytes1.len()].clone_from_slice(bytes1);
184        bytes[bytes1.len()..].clone_from_slice(bytes2);
185        let mut rd = ByteRead::new(&bytes);
186        d.read(&mut rd)
187    }
188
189    struct ErrorRead {
190        error: Option<io::Error>,
191    }
192
193    impl ErrorRead {
194        fn new(error: io::Error) -> ErrorRead {
195            ErrorRead { error: Some(error) }
196        }
197    }
198
199    impl io::Read for ErrorRead {
200        fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
201            for (i, b) in buf.iter_mut().enumerate() {
202                *b = i as u8;
203            }
204
205            let error = self.error.take().unwrap();
206            Err(error)
207        }
208    }
209
210    fn input_error(d: &mut MessageDeframer) {
211        let error = io::Error::from(io::ErrorKind::TimedOut);
212        let mut rd = ErrorRead::new(error);
213        d.read(&mut rd)
214            .expect_err("error not propagated");
215    }
216
217    fn input_whole_incremental(d: &mut MessageDeframer, bytes: &[u8]) {
218        let frames_before = d.frames.len();
219
220        for i in 0..bytes.len() {
221            assert_len(1, input_bytes(d, &bytes[i..i + 1]));
222            assert_eq!(d.has_pending(), true);
223
224            if i < bytes.len() - 1 {
225                assert_eq!(frames_before, d.frames.len());
226            }
227        }
228
229        assert_eq!(frames_before + 1, d.frames.len());
230    }
231
232    fn assert_len(want: usize, got: io::Result<usize>) {
233        if let Ok(gotval) = got {
234            assert_eq!(gotval, want);
235        } else {
236            assert!(false, "read failed, expected {:?} bytes", want);
237        }
238    }
239
240    fn pop_first(d: &mut MessageDeframer) {
241        let mut m = d.frames.pop_front().unwrap();
242        m.decode_payload();
243        assert_eq!(m.typ, msgs::enums::ContentType::Handshake);
244    }
245
246    fn pop_second(d: &mut MessageDeframer) {
247        let mut m = d.frames.pop_front().unwrap();
248        m.decode_payload();
249        assert_eq!(m.typ, msgs::enums::ContentType::Alert);
250    }
251
252    #[test]
253    fn check_incremental() {
254        let mut d = MessageDeframer::new();
255        assert_eq!(d.has_pending(), false);
256        input_whole_incremental(&mut d, FIRST_MESSAGE);
257        assert_eq!(d.has_pending(), true);
258        assert_eq!(1, d.frames.len());
259        pop_first(&mut d);
260        assert_eq!(d.has_pending(), false);
261    }
262
263    #[test]
264    fn check_incremental_2() {
265        let mut d = MessageDeframer::new();
266        assert_eq!(d.has_pending(), false);
267        input_whole_incremental(&mut d, FIRST_MESSAGE);
268        assert_eq!(d.has_pending(), true);
269        input_whole_incremental(&mut d, SECOND_MESSAGE);
270        assert_eq!(d.has_pending(), true);
271        assert_eq!(2, d.frames.len());
272        pop_first(&mut d);
273        assert_eq!(d.has_pending(), true);
274        pop_second(&mut d);
275        assert_eq!(d.has_pending(), false);
276    }
277
278    #[test]
279    fn check_whole() {
280        let mut d = MessageDeframer::new();
281        assert_eq!(d.has_pending(), false);
282        assert_len(FIRST_MESSAGE.len(), input_bytes(&mut d, FIRST_MESSAGE));
283        assert_eq!(d.has_pending(), true);
284        assert_eq!(d.frames.len(), 1);
285        pop_first(&mut d);
286        assert_eq!(d.has_pending(), false);
287    }
288
289    #[test]
290    fn check_whole_2() {
291        let mut d = MessageDeframer::new();
292        assert_eq!(d.has_pending(), false);
293        assert_len(FIRST_MESSAGE.len(), input_bytes(&mut d, FIRST_MESSAGE));
294        assert_len(SECOND_MESSAGE.len(), input_bytes(&mut d, SECOND_MESSAGE));
295        assert_eq!(d.frames.len(), 2);
296        pop_first(&mut d);
297        pop_second(&mut d);
298        assert_eq!(d.has_pending(), false);
299    }
300
301    #[test]
302    fn test_two_in_one_read() {
303        let mut d = MessageDeframer::new();
304        assert_eq!(d.has_pending(), false);
305        assert_len(
306            FIRST_MESSAGE.len() + SECOND_MESSAGE.len(),
307            input_bytes_concat(&mut d, FIRST_MESSAGE, SECOND_MESSAGE),
308        );
309        assert_eq!(d.frames.len(), 2);
310        pop_first(&mut d);
311        pop_second(&mut d);
312        assert_eq!(d.has_pending(), false);
313    }
314
315    #[test]
316    fn test_two_in_one_read_shortest_first() {
317        let mut d = MessageDeframer::new();
318        assert_eq!(d.has_pending(), false);
319        assert_len(
320            FIRST_MESSAGE.len() + SECOND_MESSAGE.len(),
321            input_bytes_concat(&mut d, SECOND_MESSAGE, FIRST_MESSAGE),
322        );
323        assert_eq!(d.frames.len(), 2);
324        pop_second(&mut d);
325        pop_first(&mut d);
326        assert_eq!(d.has_pending(), false);
327    }
328
329    #[test]
330    fn test_incremental_with_nonfatal_read_error() {
331        let mut d = MessageDeframer::new();
332        assert_len(3, input_bytes(&mut d, &FIRST_MESSAGE[..3]));
333        input_error(&mut d);
334        assert_len(
335            FIRST_MESSAGE.len() - 3,
336            input_bytes(&mut d, &FIRST_MESSAGE[3..]),
337        );
338        assert_eq!(d.frames.len(), 1);
339        pop_first(&mut d);
340        assert_eq!(d.has_pending(), false);
341    }
342}