h2/codec/
framed_write.rs

1use crate::codec::UserError;
2use crate::codec::UserError::*;
3use crate::frame::{self, Frame, FrameSize};
4use crate::hpack;
5
6use bytes::{Buf, BufMut, BytesMut};
7use std::pin::Pin;
8use std::task::{Context, Poll};
9use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
10use tokio_util::io::poll_write_buf;
11
12use std::io::{self, Cursor};
13
14// A macro to get around a method needing to borrow &mut self
15macro_rules! limited_write_buf {
16    ($self:expr) => {{
17        let limit = $self.max_frame_size() + frame::HEADER_LEN;
18        $self.buf.get_mut().limit(limit)
19    }};
20}
21
22#[derive(Debug)]
23pub struct FramedWrite<T, B> {
24    /// Upstream `AsyncWrite`
25    inner: T,
26
27    encoder: Encoder<B>,
28}
29
30#[derive(Debug)]
31struct Encoder<B> {
32    /// HPACK encoder
33    hpack: hpack::Encoder,
34
35    /// Write buffer
36    ///
37    /// TODO: Should this be a ring buffer?
38    buf: Cursor<BytesMut>,
39
40    /// Next frame to encode
41    next: Option<Next<B>>,
42
43    /// Last data frame
44    last_data_frame: Option<frame::Data<B>>,
45
46    /// Max frame size, this is specified by the peer
47    max_frame_size: FrameSize,
48
49    /// Chain payloads bigger than this.
50    chain_threshold: usize,
51
52    /// Min buffer required to attempt to write a frame
53    min_buffer_capacity: usize,
54}
55
56#[derive(Debug)]
57enum Next<B> {
58    Data(frame::Data<B>),
59    Continuation(frame::Continuation),
60}
61
62/// Initialize the connection with this amount of write buffer.
63///
64/// The minimum MAX_FRAME_SIZE is 16kb, so always be able to send a HEADERS
65/// frame that big.
66const DEFAULT_BUFFER_CAPACITY: usize = 16 * 1_024;
67
68/// Chain payloads bigger than this when vectored I/O is enabled. The remote
69/// will never advertise a max frame size less than this (well, the spec says
70/// the max frame size can't be less than 16kb, so not even close).
71const CHAIN_THRESHOLD: usize = 256;
72
73/// Chain payloads bigger than this when vectored I/O is **not** enabled.
74/// A larger value in this scenario will reduce the number of small and
75/// fragmented data being sent, and hereby improve the throughput.
76const CHAIN_THRESHOLD_WITHOUT_VECTORED_IO: usize = 1024;
77
78// TODO: Make generic
79impl<T, B> FramedWrite<T, B>
80where
81    T: AsyncWrite + Unpin,
82    B: Buf,
83{
84    pub fn new(inner: T) -> FramedWrite<T, B> {
85        let chain_threshold = if inner.is_write_vectored() {
86            CHAIN_THRESHOLD
87        } else {
88            CHAIN_THRESHOLD_WITHOUT_VECTORED_IO
89        };
90        FramedWrite {
91            inner,
92            encoder: Encoder {
93                hpack: hpack::Encoder::default(),
94                buf: Cursor::new(BytesMut::with_capacity(DEFAULT_BUFFER_CAPACITY)),
95                next: None,
96                last_data_frame: None,
97                max_frame_size: frame::DEFAULT_MAX_FRAME_SIZE,
98                chain_threshold,
99                min_buffer_capacity: chain_threshold + frame::HEADER_LEN,
100            },
101        }
102    }
103
104    /// Returns `Ready` when `send` is able to accept a frame
105    ///
106    /// Calling this function may result in the current contents of the buffer
107    /// to be flushed to `T`.
108    pub fn poll_ready(&mut self, cx: &mut Context) -> Poll<io::Result<()>> {
109        if !self.encoder.has_capacity() {
110            // Try flushing
111            ready!(self.flush(cx))?;
112
113            if !self.encoder.has_capacity() {
114                return Poll::Pending;
115            }
116        }
117
118        Poll::Ready(Ok(()))
119    }
120
121    /// Buffer a frame.
122    ///
123    /// `poll_ready` must be called first to ensure that a frame may be
124    /// accepted.
125    pub fn buffer(&mut self, item: Frame<B>) -> Result<(), UserError> {
126        self.encoder.buffer(item)
127    }
128
129    /// Flush buffered data to the wire
130    pub fn flush(&mut self, cx: &mut Context) -> Poll<io::Result<()>> {
131        let span = tracing::trace_span!("FramedWrite::flush");
132        let _e = span.enter();
133
134        loop {
135            while !self.encoder.is_empty() {
136                let n = match self.encoder.next {
137                    Some(Next::Data(ref mut frame)) => {
138                        tracing::trace!(queued_data_frame = true);
139                        let mut buf = (&mut self.encoder.buf).chain(frame.payload_mut());
140                        ready!(poll_write_buf(Pin::new(&mut self.inner), cx, &mut buf))?
141                    }
142                    _ => {
143                        tracing::trace!(queued_data_frame = false);
144                        ready!(poll_write_buf(
145                            Pin::new(&mut self.inner),
146                            cx,
147                            &mut self.encoder.buf
148                        ))?
149                    }
150                };
151                if n == 0 {
152                    return Poll::Ready(Err(io::Error::new(
153                        io::ErrorKind::WriteZero,
154                        "failed to write frame to socket",
155                    )));
156                }
157            }
158
159            match self.encoder.unset_frame() {
160                ControlFlow::Continue => (),
161                ControlFlow::Break => break,
162            }
163        }
164
165        tracing::trace!("flushing buffer");
166        // Flush the upstream
167        ready!(Pin::new(&mut self.inner).poll_flush(cx))?;
168
169        Poll::Ready(Ok(()))
170    }
171
172    /// Close the codec
173    pub fn shutdown(&mut self, cx: &mut Context) -> Poll<io::Result<()>> {
174        ready!(self.flush(cx))?;
175        Pin::new(&mut self.inner).poll_shutdown(cx)
176    }
177}
178
179#[must_use]
180enum ControlFlow {
181    Continue,
182    Break,
183}
184
185impl<B> Encoder<B>
186where
187    B: Buf,
188{
189    fn unset_frame(&mut self) -> ControlFlow {
190        // Clear internal buffer
191        self.buf.set_position(0);
192        self.buf.get_mut().clear();
193
194        // The data frame has been written, so unset it
195        match self.next.take() {
196            Some(Next::Data(frame)) => {
197                self.last_data_frame = Some(frame);
198                debug_assert!(self.is_empty());
199                ControlFlow::Break
200            }
201            Some(Next::Continuation(frame)) => {
202                // Buffer the continuation frame, then try to write again
203                let mut buf = limited_write_buf!(self);
204                if let Some(continuation) = frame.encode(&mut buf) {
205                    self.next = Some(Next::Continuation(continuation));
206                }
207                ControlFlow::Continue
208            }
209            None => ControlFlow::Break,
210        }
211    }
212
213    fn buffer(&mut self, item: Frame<B>) -> Result<(), UserError> {
214        // Ensure that we have enough capacity to accept the write.
215        assert!(self.has_capacity());
216        let span = tracing::trace_span!("FramedWrite::buffer", frame = ?item);
217        let _e = span.enter();
218
219        tracing::debug!(frame = ?item, "send");
220
221        match item {
222            Frame::Data(mut v) => {
223                // Ensure that the payload is not greater than the max frame.
224                let len = v.payload().remaining();
225
226                if len > self.max_frame_size() {
227                    return Err(PayloadTooBig);
228                }
229
230                if len >= self.chain_threshold {
231                    let head = v.head();
232
233                    // Encode the frame head to the buffer
234                    head.encode(len, self.buf.get_mut());
235
236                    if self.buf.get_ref().remaining() < self.chain_threshold {
237                        let extra_bytes = self.chain_threshold - self.buf.remaining();
238                        self.buf.get_mut().put(v.payload_mut().take(extra_bytes));
239                    }
240
241                    // Save the data frame
242                    self.next = Some(Next::Data(v));
243                } else {
244                    v.encode_chunk(self.buf.get_mut());
245
246                    // The chunk has been fully encoded, so there is no need to
247                    // keep it around
248                    assert_eq!(v.payload().remaining(), 0, "chunk not fully encoded");
249
250                    // Save off the last frame...
251                    self.last_data_frame = Some(v);
252                }
253            }
254            Frame::Headers(v) => {
255                let mut buf = limited_write_buf!(self);
256                if let Some(continuation) = v.encode(&mut self.hpack, &mut buf) {
257                    self.next = Some(Next::Continuation(continuation));
258                }
259            }
260            Frame::PushPromise(v) => {
261                let mut buf = limited_write_buf!(self);
262                if let Some(continuation) = v.encode(&mut self.hpack, &mut buf) {
263                    self.next = Some(Next::Continuation(continuation));
264                }
265            }
266            Frame::Settings(v) => {
267                v.encode(self.buf.get_mut());
268                tracing::trace!(rem = self.buf.remaining(), "encoded settings");
269            }
270            Frame::GoAway(v) => {
271                v.encode(self.buf.get_mut());
272                tracing::trace!(rem = self.buf.remaining(), "encoded go_away");
273            }
274            Frame::Ping(v) => {
275                v.encode(self.buf.get_mut());
276                tracing::trace!(rem = self.buf.remaining(), "encoded ping");
277            }
278            Frame::WindowUpdate(v) => {
279                v.encode(self.buf.get_mut());
280                tracing::trace!(rem = self.buf.remaining(), "encoded window_update");
281            }
282
283            Frame::Priority(_) => {
284                /*
285                v.encode(self.buf.get_mut());
286                tracing::trace!("encoded priority; rem={:?}", self.buf.remaining());
287                */
288                unimplemented!();
289            }
290            Frame::Reset(v) => {
291                v.encode(self.buf.get_mut());
292                tracing::trace!(rem = self.buf.remaining(), "encoded reset");
293            }
294        }
295
296        Ok(())
297    }
298
299    fn has_capacity(&self) -> bool {
300        self.next.is_none()
301            && (self.buf.get_ref().capacity() - self.buf.get_ref().len()
302                >= self.min_buffer_capacity)
303    }
304
305    fn is_empty(&self) -> bool {
306        match self.next {
307            Some(Next::Data(ref frame)) => !frame.payload().has_remaining(),
308            _ => !self.buf.has_remaining(),
309        }
310    }
311}
312
313impl<B> Encoder<B> {
314    fn max_frame_size(&self) -> usize {
315        self.max_frame_size as usize
316    }
317}
318
319impl<T, B> FramedWrite<T, B> {
320    /// Returns the max frame size that can be sent
321    pub fn max_frame_size(&self) -> usize {
322        self.encoder.max_frame_size()
323    }
324
325    /// Set the peer's max frame size.
326    pub fn set_max_frame_size(&mut self, val: usize) {
327        assert!(val <= frame::MAX_MAX_FRAME_SIZE as usize);
328        self.encoder.max_frame_size = val as FrameSize;
329    }
330
331    /// Set the peer's header table size.
332    pub fn set_header_table_size(&mut self, val: usize) {
333        self.encoder.hpack.update_max_size(val);
334    }
335
336    /// Retrieve the last data frame that has been sent
337    pub fn take_last_data_frame(&mut self) -> Option<frame::Data<B>> {
338        self.encoder.last_data_frame.take()
339    }
340
341    pub fn get_mut(&mut self) -> &mut T {
342        &mut self.inner
343    }
344}
345
346impl<T: AsyncRead + Unpin, B> AsyncRead for FramedWrite<T, B> {
347    fn poll_read(
348        mut self: Pin<&mut Self>,
349        cx: &mut Context<'_>,
350        buf: &mut ReadBuf,
351    ) -> Poll<io::Result<()>> {
352        Pin::new(&mut self.inner).poll_read(cx, buf)
353    }
354}
355
356// We never project the Pin to `B`.
357impl<T: Unpin, B> Unpin for FramedWrite<T, B> {}
358
359#[cfg(feature = "unstable")]
360mod unstable {
361    use super::*;
362
363    impl<T, B> FramedWrite<T, B> {
364        pub fn get_ref(&self) -> &T {
365            &self.inner
366        }
367    }
368}