1use std::collections::VecDeque;
2
3use crate::msgs::codec;
4use crate::msgs::enums::{ContentType, ProtocolVersion};
5use crate::msgs::handshake::HandshakeMessagePayload;
6use crate::msgs::message::{Message, MessagePayload};
7
8const HEADER_SIZE: usize = 1 + 3;
9
10pub struct HandshakeJoiner {
15 pub frames: VecDeque<Message>,
17
18 buf: Vec<u8>,
20}
21
22impl Default for HandshakeJoiner {
23 fn default() -> Self {
24 Self::new()
25 }
26}
27
28impl HandshakeJoiner {
29 pub fn new() -> HandshakeJoiner {
31 HandshakeJoiner {
32 frames: VecDeque::new(),
33 buf: Vec::new(),
34 }
35 }
36
37 pub fn want_message(&self, msg: &Message) -> bool {
39 msg.is_content_type(ContentType::Handshake)
40 }
41
42 pub fn is_empty(&self) -> bool {
44 self.buf.is_empty()
45 }
46
47 pub fn take_message(&mut self, mut msg: Message) -> Option<usize> {
55 let payload = msg.take_opaque_payload().unwrap();
58
59 self.buf
60 .extend_from_slice(&payload.0[..]);
61
62 let mut count = 0;
63 while self.buf_contains_message() {
64 if !self.deframe_one(msg.version) {
65 return None;
66 }
67
68 count += 1;
69 }
70
71 Some(count)
72 }
73
74 fn buf_contains_message(&self) -> bool {
78 self.buf.len() >= HEADER_SIZE
79 && self.buf.len()
80 >= (codec::u24::decode(&self.buf[1..4])
81 .unwrap()
82 .0 as usize)
83 + HEADER_SIZE
84 }
85
86 fn deframe_one(&mut self, version: ProtocolVersion) -> bool {
91 let used = {
92 let mut rd = codec::Reader::init(&self.buf);
93 let payload = HandshakeMessagePayload::read_version(&mut rd, version);
94
95 if payload.is_none() {
96 return false;
97 }
98
99 let m = Message {
100 typ: ContentType::Handshake,
101 version,
102 payload: MessagePayload::Handshake(payload.unwrap()),
103 };
104
105 self.frames.push_back(m);
106 rd.used()
107 };
108 self.buf = self.buf.split_off(used);
109 true
110 }
111}
112
113#[cfg(test)]
114mod tests {
115 use super::HandshakeJoiner;
116 use crate::msgs::base::Payload;
117 use crate::msgs::enums::{ContentType, HandshakeType, ProtocolVersion};
118 use crate::msgs::handshake::{HandshakeMessagePayload, HandshakePayload};
119 use crate::msgs::message::{Message, MessagePayload};
120
121 #[test]
122 fn want() {
123 let hj = HandshakeJoiner::new();
124 assert_eq!(hj.is_empty(), true);
125
126 let wanted = Message {
127 typ: ContentType::Handshake,
128 version: ProtocolVersion::TLSv1_2,
129 payload: MessagePayload::new_opaque(b"hello world".to_vec()),
130 };
131
132 let unwanted = Message {
133 typ: ContentType::Alert,
134 version: ProtocolVersion::TLSv1_2,
135 payload: MessagePayload::new_opaque(b"ponytown".to_vec()),
136 };
137
138 assert_eq!(hj.want_message(&wanted), true);
139 assert_eq!(hj.want_message(&unwanted), false);
140 }
141
142 fn pop_eq(expect: &Message, hj: &mut HandshakeJoiner) {
143 let got = hj.frames.pop_front().unwrap();
144 assert_eq!(got.typ, expect.typ);
145 assert_eq!(got.version, expect.version);
146
147 let (mut left, mut right) = (Vec::new(), Vec::new());
148 got.payload.encode(&mut left);
149 expect.payload.encode(&mut right);
150
151 assert_eq!(left, right);
152 }
153
154 #[test]
155 fn split() {
156 let mut hj = HandshakeJoiner::new();
158
159 let msg = Message {
161 typ: ContentType::Handshake,
162 version: ProtocolVersion::TLSv1_2,
163 payload: MessagePayload::new_opaque(b"\x00\x00\x00\x00\x00\x00\x00\x00".to_vec()),
164 };
165
166 assert_eq!(hj.want_message(&msg), true);
167 assert_eq!(hj.take_message(msg), Some(2));
168 assert_eq!(hj.is_empty(), true);
169
170 let expect = Message {
171 typ: ContentType::Handshake,
172 version: ProtocolVersion::TLSv1_2,
173 payload: MessagePayload::Handshake(HandshakeMessagePayload {
174 typ: HandshakeType::HelloRequest,
175 payload: HandshakePayload::HelloRequest,
176 }),
177 };
178
179 pop_eq(&expect, &mut hj);
180 pop_eq(&expect, &mut hj);
181 }
182
183 #[test]
184 fn broken() {
185 let mut hj = HandshakeJoiner::new();
187
188 let msg = Message {
190 typ: ContentType::Handshake,
191 version: ProtocolVersion::TLSv1_2,
192 payload: MessagePayload::new_opaque(b"\x01\x00\x00\x02\xff\xff".to_vec()),
193 };
194
195 assert_eq!(hj.want_message(&msg), true);
196 assert_eq!(hj.take_message(msg), None);
197 }
198
199 #[test]
200 fn join() {
201 let mut hj = HandshakeJoiner::new();
203 assert_eq!(hj.is_empty(), true);
204
205 let mut msg = Message {
207 typ: ContentType::Handshake,
208 version: ProtocolVersion::TLSv1_2,
209 payload: MessagePayload::new_opaque(b"\x14\x00\x00\x10\x00\x01\x02\x03\x04".to_vec()),
210 };
211
212 assert_eq!(hj.want_message(&msg), true);
213 assert_eq!(hj.take_message(msg), Some(0));
214 assert_eq!(hj.is_empty(), false);
215
216 msg = Message {
218 typ: ContentType::Handshake,
219 version: ProtocolVersion::TLSv1_2,
220 payload: MessagePayload::new_opaque(
221 b"\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e".to_vec(),
222 ),
223 };
224
225 assert_eq!(hj.want_message(&msg), true);
226 assert_eq!(hj.take_message(msg), Some(0));
227 assert_eq!(hj.is_empty(), false);
228
229 msg = Message {
231 typ: ContentType::Handshake,
232 version: ProtocolVersion::TLSv1_2,
233 payload: MessagePayload::new_opaque(b"\x0f".to_vec()),
234 };
235
236 assert_eq!(hj.want_message(&msg), true);
237 assert_eq!(hj.take_message(msg), Some(1));
238 assert_eq!(hj.is_empty(), true);
239
240 let payload = b"\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f".to_vec();
241 let expect = Message {
242 typ: ContentType::Handshake,
243 version: ProtocolVersion::TLSv1_2,
244 payload: MessagePayload::Handshake(HandshakeMessagePayload {
245 typ: HandshakeType::Finished,
246 payload: HandshakePayload::Finished(Payload::new(payload)),
247 }),
248 };
249
250 pop_eq(&expect, &mut hj);
251 }
252
253 #[test]
254 fn test_rejoins_then_rejects_giant_certs() {
255 let mut hj = HandshakeJoiner::new();
256 let msg = Message {
257 typ: ContentType::Handshake,
258 version: ProtocolVersion::TLSv1_2,
259 payload: MessagePayload::new_opaque(
260 b"\x0b\x01\x00\x04\x01\x00\x01\x00\xff\xfe".to_vec(),
261 ),
262 };
263
264 assert_eq!(hj.want_message(&msg), true);
265 assert_eq!(hj.take_message(msg), Some(0));
266 assert_eq!(hj.is_empty(), false);
267
268 for _i in 0..8191 {
269 let msg = Message {
270 typ: ContentType::Handshake,
271 version: ProtocolVersion::TLSv1_2,
272 payload: MessagePayload::new_opaque(b"\x01\x02\x03\x04\x05\x06\x07\x08".to_vec()),
273 };
274
275 assert_eq!(hj.want_message(&msg), true);
276 assert_eq!(hj.take_message(msg), Some(0));
277 assert_eq!(hj.is_empty(), false);
278 }
279
280 let msg = Message {
282 typ: ContentType::Handshake,
283 version: ProtocolVersion::TLSv1_2,
284 payload: MessagePayload::new_opaque(b"\x01\x02\x03\x04\x05\x06".to_vec()),
285 };
286
287 assert_eq!(hj.want_message(&msg), true);
288 assert_eq!(hj.take_message(msg), None);
289 }
290}