1use std::collections::VecDeque;
2use std::io;
3
4use crate::msgs::codec;
5use crate::msgs::message::{Message, MessageError};
6
7pub struct MessageDeframer {
11 pub frames: VecDeque<Message>,
13
14 pub desynced: bool,
18
19 buf: Box<[u8; Message::MAX_WIRE_SIZE]>,
22
23 used: usize,
25}
26
27enum BufferContents {
28 Invalid,
30
31 Partial,
34
35 Valid,
37}
38
39impl Default for MessageDeframer {
40 fn default() -> Self {
41 Self::new()
42 }
43}
44
45impl MessageDeframer {
46 pub fn new() -> MessageDeframer {
47 MessageDeframer {
48 frames: VecDeque::new(),
49 desynced: false,
50 buf: Box::new([0u8; Message::MAX_WIRE_SIZE]),
51 used: 0,
52 }
53 }
54
55 pub fn read(&mut self, rd: &mut dyn io::Read) -> io::Result<usize> {
59 debug_assert!(self.used <= Message::MAX_WIRE_SIZE);
64 let new_bytes = rd.read(&mut self.buf[self.used..])?;
65
66 self.used += new_bytes;
67
68 loop {
69 match self.try_deframe_one() {
70 BufferContents::Invalid => {
71 self.desynced = true;
72 break;
73 }
74 BufferContents::Valid => continue,
75 BufferContents::Partial => break,
76 }
77 }
78
79 Ok(new_bytes)
80 }
81
82 pub fn has_pending(&self) -> bool {
86 !self.frames.is_empty() || self.used > 0
87 }
88
89 fn try_deframe_one(&mut self) -> BufferContents {
93 let mut rd = codec::Reader::init(&self.buf[..self.used]);
95
96 match Message::read_with_detailed_error(&mut rd) {
97 Ok(m) => {
98 let used = rd.used();
99 self.frames.push_back(m);
100 self.buf_consume(used);
101 BufferContents::Valid
102 }
103 Err(MessageError::TooShortForHeader) | Err(MessageError::TooShortForLength) => {
104 BufferContents::Partial
105 }
106 Err(_) => BufferContents::Invalid,
107 }
108 }
109
110 fn buf_consume(&mut self, taken: usize) {
111 if taken < self.used {
112 self.buf
126 .copy_within(taken..self.used, 0);
127 self.used = self.used - taken;
128 } else if taken == self.used {
129 self.used = 0;
130 }
131 }
132}
133
134#[cfg(test)]
135mod tests {
136 use super::MessageDeframer;
137 use crate::msgs;
138 use std::io;
139
140 const FIRST_MESSAGE: &'static [u8] = include_bytes!("../testdata/deframer-test.1.bin");
141 const SECOND_MESSAGE: &'static [u8] = include_bytes!("../testdata/deframer-test.2.bin");
142
143 struct ByteRead<'a> {
144 buf: &'a [u8],
145 offs: usize,
146 }
147
148 impl<'a> ByteRead<'a> {
149 fn new(bytes: &'a [u8]) -> ByteRead {
150 ByteRead {
151 buf: bytes,
152 offs: 0,
153 }
154 }
155 }
156
157 impl<'a> io::Read for ByteRead<'a> {
158 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
159 let mut len = 0;
160
161 while len < buf.len() && len < self.buf.len() - self.offs {
162 buf[len] = self.buf[self.offs + len];
163 len += 1;
164 }
165
166 self.offs += len;
167
168 Ok(len)
169 }
170 }
171
172 fn input_bytes(d: &mut MessageDeframer, bytes: &[u8]) -> io::Result<usize> {
173 let mut rd = ByteRead::new(bytes);
174 d.read(&mut rd)
175 }
176
177 fn input_bytes_concat(
178 d: &mut MessageDeframer,
179 bytes1: &[u8],
180 bytes2: &[u8],
181 ) -> io::Result<usize> {
182 let mut bytes = vec![0u8; bytes1.len() + bytes2.len()];
183 bytes[..bytes1.len()].clone_from_slice(bytes1);
184 bytes[bytes1.len()..].clone_from_slice(bytes2);
185 let mut rd = ByteRead::new(&bytes);
186 d.read(&mut rd)
187 }
188
189 struct ErrorRead {
190 error: Option<io::Error>,
191 }
192
193 impl ErrorRead {
194 fn new(error: io::Error) -> ErrorRead {
195 ErrorRead { error: Some(error) }
196 }
197 }
198
199 impl io::Read for ErrorRead {
200 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
201 for (i, b) in buf.iter_mut().enumerate() {
202 *b = i as u8;
203 }
204
205 let error = self.error.take().unwrap();
206 Err(error)
207 }
208 }
209
210 fn input_error(d: &mut MessageDeframer) {
211 let error = io::Error::from(io::ErrorKind::TimedOut);
212 let mut rd = ErrorRead::new(error);
213 d.read(&mut rd)
214 .expect_err("error not propagated");
215 }
216
217 fn input_whole_incremental(d: &mut MessageDeframer, bytes: &[u8]) {
218 let frames_before = d.frames.len();
219
220 for i in 0..bytes.len() {
221 assert_len(1, input_bytes(d, &bytes[i..i + 1]));
222 assert_eq!(d.has_pending(), true);
223
224 if i < bytes.len() - 1 {
225 assert_eq!(frames_before, d.frames.len());
226 }
227 }
228
229 assert_eq!(frames_before + 1, d.frames.len());
230 }
231
232 fn assert_len(want: usize, got: io::Result<usize>) {
233 if let Ok(gotval) = got {
234 assert_eq!(gotval, want);
235 } else {
236 assert!(false, "read failed, expected {:?} bytes", want);
237 }
238 }
239
240 fn pop_first(d: &mut MessageDeframer) {
241 let mut m = d.frames.pop_front().unwrap();
242 m.decode_payload();
243 assert_eq!(m.typ, msgs::enums::ContentType::Handshake);
244 }
245
246 fn pop_second(d: &mut MessageDeframer) {
247 let mut m = d.frames.pop_front().unwrap();
248 m.decode_payload();
249 assert_eq!(m.typ, msgs::enums::ContentType::Alert);
250 }
251
252 #[test]
253 fn check_incremental() {
254 let mut d = MessageDeframer::new();
255 assert_eq!(d.has_pending(), false);
256 input_whole_incremental(&mut d, FIRST_MESSAGE);
257 assert_eq!(d.has_pending(), true);
258 assert_eq!(1, d.frames.len());
259 pop_first(&mut d);
260 assert_eq!(d.has_pending(), false);
261 }
262
263 #[test]
264 fn check_incremental_2() {
265 let mut d = MessageDeframer::new();
266 assert_eq!(d.has_pending(), false);
267 input_whole_incremental(&mut d, FIRST_MESSAGE);
268 assert_eq!(d.has_pending(), true);
269 input_whole_incremental(&mut d, SECOND_MESSAGE);
270 assert_eq!(d.has_pending(), true);
271 assert_eq!(2, d.frames.len());
272 pop_first(&mut d);
273 assert_eq!(d.has_pending(), true);
274 pop_second(&mut d);
275 assert_eq!(d.has_pending(), false);
276 }
277
278 #[test]
279 fn check_whole() {
280 let mut d = MessageDeframer::new();
281 assert_eq!(d.has_pending(), false);
282 assert_len(FIRST_MESSAGE.len(), input_bytes(&mut d, FIRST_MESSAGE));
283 assert_eq!(d.has_pending(), true);
284 assert_eq!(d.frames.len(), 1);
285 pop_first(&mut d);
286 assert_eq!(d.has_pending(), false);
287 }
288
289 #[test]
290 fn check_whole_2() {
291 let mut d = MessageDeframer::new();
292 assert_eq!(d.has_pending(), false);
293 assert_len(FIRST_MESSAGE.len(), input_bytes(&mut d, FIRST_MESSAGE));
294 assert_len(SECOND_MESSAGE.len(), input_bytes(&mut d, SECOND_MESSAGE));
295 assert_eq!(d.frames.len(), 2);
296 pop_first(&mut d);
297 pop_second(&mut d);
298 assert_eq!(d.has_pending(), false);
299 }
300
301 #[test]
302 fn test_two_in_one_read() {
303 let mut d = MessageDeframer::new();
304 assert_eq!(d.has_pending(), false);
305 assert_len(
306 FIRST_MESSAGE.len() + SECOND_MESSAGE.len(),
307 input_bytes_concat(&mut d, FIRST_MESSAGE, SECOND_MESSAGE),
308 );
309 assert_eq!(d.frames.len(), 2);
310 pop_first(&mut d);
311 pop_second(&mut d);
312 assert_eq!(d.has_pending(), false);
313 }
314
315 #[test]
316 fn test_two_in_one_read_shortest_first() {
317 let mut d = MessageDeframer::new();
318 assert_eq!(d.has_pending(), false);
319 assert_len(
320 FIRST_MESSAGE.len() + SECOND_MESSAGE.len(),
321 input_bytes_concat(&mut d, SECOND_MESSAGE, FIRST_MESSAGE),
322 );
323 assert_eq!(d.frames.len(), 2);
324 pop_second(&mut d);
325 pop_first(&mut d);
326 assert_eq!(d.has_pending(), false);
327 }
328
329 #[test]
330 fn test_incremental_with_nonfatal_read_error() {
331 let mut d = MessageDeframer::new();
332 assert_len(3, input_bytes(&mut d, &FIRST_MESSAGE[..3]));
333 input_error(&mut d);
334 assert_len(
335 FIRST_MESSAGE.len() - 3,
336 input_bytes(&mut d, &FIRST_MESSAGE[3..]),
337 );
338 assert_eq!(d.frames.len(), 1);
339 pop_first(&mut d);
340 assert_eq!(d.has_pending(), false);
341 }
342}