rustls/msgs/
fragmenter.rs

1use crate::msgs::enums::{ContentType, ProtocolVersion};
2use crate::msgs::message::{BorrowMessage, Message, MessagePayload};
3use std::collections::VecDeque;
4
5pub const MAX_FRAGMENT_LEN: usize = 16384;
6pub const PACKET_OVERHEAD: usize = 1 + 2 + 2;
7
8pub struct MessageFragmenter {
9    max_frag: usize,
10}
11
12impl MessageFragmenter {
13    /// Make a new fragmenter.  `max_fragment_len` is the maximum
14    /// fragment size that will be produced -- this does not
15    /// include overhead (so a `max_fragment_len` of 5 will produce
16    /// 10 byte packets).
17    pub fn new(max_fragment_len: usize) -> MessageFragmenter {
18        debug_assert!(max_fragment_len <= MAX_FRAGMENT_LEN);
19        MessageFragmenter {
20            max_frag: max_fragment_len,
21        }
22    }
23
24    /// Take the Message `msg` and re-fragment it into new
25    /// messages whose fragment is no more than max_frag.
26    /// The new messages are appended to the `out` deque.
27    /// Payloads are copied.
28    pub fn fragment(&self, msg: Message, out: &mut VecDeque<Message>) {
29        // Non-fragment path
30        if msg.payload.length() <= self.max_frag {
31            out.push_back(msg.into_opaque());
32            return;
33        }
34
35        let typ = msg.typ;
36        let version = msg.version;
37        let payload = msg.take_payload();
38
39        for chunk in payload.chunks(self.max_frag) {
40            let m = Message {
41                typ,
42                version,
43                payload: MessagePayload::new_opaque(chunk.to_vec()),
44            };
45            out.push_back(m);
46        }
47    }
48
49    /// Enqueue borrowed fragments of (version, typ, payload) which
50    /// are no longer than max_frag onto the `out` deque.
51    pub fn fragment_borrow<'a>(
52        &self,
53        typ: ContentType,
54        version: ProtocolVersion,
55        payload: &'a [u8],
56        out: &mut VecDeque<BorrowMessage<'a>>,
57    ) {
58        for chunk in payload.chunks(self.max_frag) {
59            let cm = BorrowMessage {
60                typ,
61                version,
62                payload: chunk,
63            };
64            out.push_back(cm);
65        }
66    }
67}
68
69#[cfg(test)]
70mod tests {
71    use super::{MessageFragmenter, PACKET_OVERHEAD};
72    use crate::msgs::codec::Codec;
73    use crate::msgs::enums::{ContentType, ProtocolVersion};
74    use crate::msgs::message::{Message, MessagePayload};
75    use std::collections::VecDeque;
76
77    fn msg_eq(
78        mm: Option<Message>,
79        total_len: usize,
80        typ: &ContentType,
81        version: &ProtocolVersion,
82        bytes: &[u8],
83    ) {
84        let mut m = mm.unwrap();
85
86        let mut buf = Vec::new();
87        m.encode(&mut buf);
88
89        assert_eq!(&m.typ, typ);
90        assert_eq!(&m.version, version);
91        assert_eq!(m.take_opaque_payload().unwrap().0, bytes.to_vec());
92
93        assert_eq!(total_len, buf.len());
94    }
95
96    #[test]
97    fn smoke() {
98        let typ = ContentType::Handshake;
99        let version = ProtocolVersion::TLSv1_2;
100        let m = Message {
101            typ,
102            version,
103            payload: MessagePayload::new_opaque(b"\x01\x02\x03\x04\x05\x06\x07\x08".to_vec()),
104        };
105
106        let frag = MessageFragmenter::new(3);
107        let mut q = VecDeque::new();
108        frag.fragment(m, &mut q);
109        msg_eq(
110            q.pop_front(),
111            PACKET_OVERHEAD + 3,
112            &typ,
113            &version,
114            b"\x01\x02\x03",
115        );
116        msg_eq(
117            q.pop_front(),
118            PACKET_OVERHEAD + 3,
119            &typ,
120            &version,
121            b"\x04\x05\x06",
122        );
123        msg_eq(
124            q.pop_front(),
125            PACKET_OVERHEAD + 2,
126            &typ,
127            &version,
128            b"\x07\x08",
129        );
130        assert_eq!(q.len(), 0);
131    }
132
133    #[test]
134    fn non_fragment() {
135        let m = Message {
136            typ: ContentType::Handshake,
137            version: ProtocolVersion::TLSv1_2,
138            payload: MessagePayload::new_opaque(b"\x01\x02\x03\x04\x05\x06\x07\x08".to_vec()),
139        };
140
141        let frag = MessageFragmenter::new(8);
142        let mut q = VecDeque::new();
143        frag.fragment(m, &mut q);
144        msg_eq(
145            q.pop_front(),
146            PACKET_OVERHEAD + 8,
147            &ContentType::Handshake,
148            &ProtocolVersion::TLSv1_2,
149            b"\x01\x02\x03\x04\x05\x06\x07\x08",
150        );
151        assert_eq!(q.len(), 0);
152    }
153}