rustls/msgs/
fragmenter.rs1use crate::msgs::enums::{ContentType, ProtocolVersion};
2use crate::msgs::message::{BorrowMessage, Message, MessagePayload};
3use std::collections::VecDeque;
4
5pub const MAX_FRAGMENT_LEN: usize = 16384;
6pub const PACKET_OVERHEAD: usize = 1 + 2 + 2;
7
8pub struct MessageFragmenter {
9 max_frag: usize,
10}
11
12impl MessageFragmenter {
13 pub fn new(max_fragment_len: usize) -> MessageFragmenter {
18 debug_assert!(max_fragment_len <= MAX_FRAGMENT_LEN);
19 MessageFragmenter {
20 max_frag: max_fragment_len,
21 }
22 }
23
24 pub fn fragment(&self, msg: Message, out: &mut VecDeque<Message>) {
29 if msg.payload.length() <= self.max_frag {
31 out.push_back(msg.into_opaque());
32 return;
33 }
34
35 let typ = msg.typ;
36 let version = msg.version;
37 let payload = msg.take_payload();
38
39 for chunk in payload.chunks(self.max_frag) {
40 let m = Message {
41 typ,
42 version,
43 payload: MessagePayload::new_opaque(chunk.to_vec()),
44 };
45 out.push_back(m);
46 }
47 }
48
49 pub fn fragment_borrow<'a>(
52 &self,
53 typ: ContentType,
54 version: ProtocolVersion,
55 payload: &'a [u8],
56 out: &mut VecDeque<BorrowMessage<'a>>,
57 ) {
58 for chunk in payload.chunks(self.max_frag) {
59 let cm = BorrowMessage {
60 typ,
61 version,
62 payload: chunk,
63 };
64 out.push_back(cm);
65 }
66 }
67}
68
69#[cfg(test)]
70mod tests {
71 use super::{MessageFragmenter, PACKET_OVERHEAD};
72 use crate::msgs::codec::Codec;
73 use crate::msgs::enums::{ContentType, ProtocolVersion};
74 use crate::msgs::message::{Message, MessagePayload};
75 use std::collections::VecDeque;
76
77 fn msg_eq(
78 mm: Option<Message>,
79 total_len: usize,
80 typ: &ContentType,
81 version: &ProtocolVersion,
82 bytes: &[u8],
83 ) {
84 let mut m = mm.unwrap();
85
86 let mut buf = Vec::new();
87 m.encode(&mut buf);
88
89 assert_eq!(&m.typ, typ);
90 assert_eq!(&m.version, version);
91 assert_eq!(m.take_opaque_payload().unwrap().0, bytes.to_vec());
92
93 assert_eq!(total_len, buf.len());
94 }
95
96 #[test]
97 fn smoke() {
98 let typ = ContentType::Handshake;
99 let version = ProtocolVersion::TLSv1_2;
100 let m = Message {
101 typ,
102 version,
103 payload: MessagePayload::new_opaque(b"\x01\x02\x03\x04\x05\x06\x07\x08".to_vec()),
104 };
105
106 let frag = MessageFragmenter::new(3);
107 let mut q = VecDeque::new();
108 frag.fragment(m, &mut q);
109 msg_eq(
110 q.pop_front(),
111 PACKET_OVERHEAD + 3,
112 &typ,
113 &version,
114 b"\x01\x02\x03",
115 );
116 msg_eq(
117 q.pop_front(),
118 PACKET_OVERHEAD + 3,
119 &typ,
120 &version,
121 b"\x04\x05\x06",
122 );
123 msg_eq(
124 q.pop_front(),
125 PACKET_OVERHEAD + 2,
126 &typ,
127 &version,
128 b"\x07\x08",
129 );
130 assert_eq!(q.len(), 0);
131 }
132
133 #[test]
134 fn non_fragment() {
135 let m = Message {
136 typ: ContentType::Handshake,
137 version: ProtocolVersion::TLSv1_2,
138 payload: MessagePayload::new_opaque(b"\x01\x02\x03\x04\x05\x06\x07\x08".to_vec()),
139 };
140
141 let frag = MessageFragmenter::new(8);
142 let mut q = VecDeque::new();
143 frag.fragment(m, &mut q);
144 msg_eq(
145 q.pop_front(),
146 PACKET_OVERHEAD + 8,
147 &ContentType::Handshake,
148 &ProtocolVersion::TLSv1_2,
149 b"\x01\x02\x03\x04\x05\x06\x07\x08",
150 );
151 assert_eq!(q.len(), 0);
152 }
153}