1use crate::codec::UserError;
2use crate::codec::UserError::*;
3use crate::frame::{self, Frame, FrameSize};
4use crate::hpack;
5
6use bytes::{Buf, BufMut, BytesMut};
7use std::pin::Pin;
8use std::task::{Context, Poll};
9use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
10use tokio_util::io::poll_write_buf;
11
12use std::io::{self, Cursor};
13
14macro_rules! limited_write_buf {
16 ($self:expr) => {{
17 let limit = $self.max_frame_size() + frame::HEADER_LEN;
18 $self.buf.get_mut().limit(limit)
19 }};
20}
21
22#[derive(Debug)]
23pub struct FramedWrite<T, B> {
24 inner: T,
26
27 encoder: Encoder<B>,
28}
29
30#[derive(Debug)]
31struct Encoder<B> {
32 hpack: hpack::Encoder,
34
35 buf: Cursor<BytesMut>,
39
40 next: Option<Next<B>>,
42
43 last_data_frame: Option<frame::Data<B>>,
45
46 max_frame_size: FrameSize,
48
49 chain_threshold: usize,
51
52 min_buffer_capacity: usize,
54}
55
56#[derive(Debug)]
57enum Next<B> {
58 Data(frame::Data<B>),
59 Continuation(frame::Continuation),
60}
61
62const DEFAULT_BUFFER_CAPACITY: usize = 16 * 1_024;
67
68const CHAIN_THRESHOLD: usize = 256;
72
73const CHAIN_THRESHOLD_WITHOUT_VECTORED_IO: usize = 1024;
77
78impl<T, B> FramedWrite<T, B>
80where
81 T: AsyncWrite + Unpin,
82 B: Buf,
83{
84 pub fn new(inner: T) -> FramedWrite<T, B> {
85 let chain_threshold = if inner.is_write_vectored() {
86 CHAIN_THRESHOLD
87 } else {
88 CHAIN_THRESHOLD_WITHOUT_VECTORED_IO
89 };
90 FramedWrite {
91 inner,
92 encoder: Encoder {
93 hpack: hpack::Encoder::default(),
94 buf: Cursor::new(BytesMut::with_capacity(DEFAULT_BUFFER_CAPACITY)),
95 next: None,
96 last_data_frame: None,
97 max_frame_size: frame::DEFAULT_MAX_FRAME_SIZE,
98 chain_threshold,
99 min_buffer_capacity: chain_threshold + frame::HEADER_LEN,
100 },
101 }
102 }
103
104 pub fn poll_ready(&mut self, cx: &mut Context) -> Poll<io::Result<()>> {
109 if !self.encoder.has_capacity() {
110 ready!(self.flush(cx))?;
112
113 if !self.encoder.has_capacity() {
114 return Poll::Pending;
115 }
116 }
117
118 Poll::Ready(Ok(()))
119 }
120
121 pub fn buffer(&mut self, item: Frame<B>) -> Result<(), UserError> {
126 self.encoder.buffer(item)
127 }
128
129 pub fn flush(&mut self, cx: &mut Context) -> Poll<io::Result<()>> {
131 let span = tracing::trace_span!("FramedWrite::flush");
132 let _e = span.enter();
133
134 loop {
135 while !self.encoder.is_empty() {
136 let n = match self.encoder.next {
137 Some(Next::Data(ref mut frame)) => {
138 tracing::trace!(queued_data_frame = true);
139 let mut buf = (&mut self.encoder.buf).chain(frame.payload_mut());
140 ready!(poll_write_buf(Pin::new(&mut self.inner), cx, &mut buf))?
141 }
142 _ => {
143 tracing::trace!(queued_data_frame = false);
144 ready!(poll_write_buf(
145 Pin::new(&mut self.inner),
146 cx,
147 &mut self.encoder.buf
148 ))?
149 }
150 };
151 if n == 0 {
152 return Poll::Ready(Err(io::Error::new(
153 io::ErrorKind::WriteZero,
154 "failed to write frame to socket",
155 )));
156 }
157 }
158
159 match self.encoder.unset_frame() {
160 ControlFlow::Continue => (),
161 ControlFlow::Break => break,
162 }
163 }
164
165 tracing::trace!("flushing buffer");
166 ready!(Pin::new(&mut self.inner).poll_flush(cx))?;
168
169 Poll::Ready(Ok(()))
170 }
171
172 pub fn shutdown(&mut self, cx: &mut Context) -> Poll<io::Result<()>> {
174 ready!(self.flush(cx))?;
175 Pin::new(&mut self.inner).poll_shutdown(cx)
176 }
177}
178
179#[must_use]
180enum ControlFlow {
181 Continue,
182 Break,
183}
184
185impl<B> Encoder<B>
186where
187 B: Buf,
188{
189 fn unset_frame(&mut self) -> ControlFlow {
190 self.buf.set_position(0);
192 self.buf.get_mut().clear();
193
194 match self.next.take() {
196 Some(Next::Data(frame)) => {
197 self.last_data_frame = Some(frame);
198 debug_assert!(self.is_empty());
199 ControlFlow::Break
200 }
201 Some(Next::Continuation(frame)) => {
202 let mut buf = limited_write_buf!(self);
204 if let Some(continuation) = frame.encode(&mut buf) {
205 self.next = Some(Next::Continuation(continuation));
206 }
207 ControlFlow::Continue
208 }
209 None => ControlFlow::Break,
210 }
211 }
212
213 fn buffer(&mut self, item: Frame<B>) -> Result<(), UserError> {
214 assert!(self.has_capacity());
216 let span = tracing::trace_span!("FramedWrite::buffer", frame = ?item);
217 let _e = span.enter();
218
219 tracing::debug!(frame = ?item, "send");
220
221 match item {
222 Frame::Data(mut v) => {
223 let len = v.payload().remaining();
225
226 if len > self.max_frame_size() {
227 return Err(PayloadTooBig);
228 }
229
230 if len >= self.chain_threshold {
231 let head = v.head();
232
233 head.encode(len, self.buf.get_mut());
235
236 if self.buf.get_ref().remaining() < self.chain_threshold {
237 let extra_bytes = self.chain_threshold - self.buf.remaining();
238 self.buf.get_mut().put(v.payload_mut().take(extra_bytes));
239 }
240
241 self.next = Some(Next::Data(v));
243 } else {
244 v.encode_chunk(self.buf.get_mut());
245
246 assert_eq!(v.payload().remaining(), 0, "chunk not fully encoded");
249
250 self.last_data_frame = Some(v);
252 }
253 }
254 Frame::Headers(v) => {
255 let mut buf = limited_write_buf!(self);
256 if let Some(continuation) = v.encode(&mut self.hpack, &mut buf) {
257 self.next = Some(Next::Continuation(continuation));
258 }
259 }
260 Frame::PushPromise(v) => {
261 let mut buf = limited_write_buf!(self);
262 if let Some(continuation) = v.encode(&mut self.hpack, &mut buf) {
263 self.next = Some(Next::Continuation(continuation));
264 }
265 }
266 Frame::Settings(v) => {
267 v.encode(self.buf.get_mut());
268 tracing::trace!(rem = self.buf.remaining(), "encoded settings");
269 }
270 Frame::GoAway(v) => {
271 v.encode(self.buf.get_mut());
272 tracing::trace!(rem = self.buf.remaining(), "encoded go_away");
273 }
274 Frame::Ping(v) => {
275 v.encode(self.buf.get_mut());
276 tracing::trace!(rem = self.buf.remaining(), "encoded ping");
277 }
278 Frame::WindowUpdate(v) => {
279 v.encode(self.buf.get_mut());
280 tracing::trace!(rem = self.buf.remaining(), "encoded window_update");
281 }
282
283 Frame::Priority(_) => {
284 unimplemented!();
289 }
290 Frame::Reset(v) => {
291 v.encode(self.buf.get_mut());
292 tracing::trace!(rem = self.buf.remaining(), "encoded reset");
293 }
294 }
295
296 Ok(())
297 }
298
299 fn has_capacity(&self) -> bool {
300 self.next.is_none()
301 && (self.buf.get_ref().capacity() - self.buf.get_ref().len()
302 >= self.min_buffer_capacity)
303 }
304
305 fn is_empty(&self) -> bool {
306 match self.next {
307 Some(Next::Data(ref frame)) => !frame.payload().has_remaining(),
308 _ => !self.buf.has_remaining(),
309 }
310 }
311}
312
313impl<B> Encoder<B> {
314 fn max_frame_size(&self) -> usize {
315 self.max_frame_size as usize
316 }
317}
318
319impl<T, B> FramedWrite<T, B> {
320 pub fn max_frame_size(&self) -> usize {
322 self.encoder.max_frame_size()
323 }
324
325 pub fn set_max_frame_size(&mut self, val: usize) {
327 assert!(val <= frame::MAX_MAX_FRAME_SIZE as usize);
328 self.encoder.max_frame_size = val as FrameSize;
329 }
330
331 pub fn set_header_table_size(&mut self, val: usize) {
333 self.encoder.hpack.update_max_size(val);
334 }
335
336 pub fn take_last_data_frame(&mut self) -> Option<frame::Data<B>> {
338 self.encoder.last_data_frame.take()
339 }
340
341 pub fn get_mut(&mut self) -> &mut T {
342 &mut self.inner
343 }
344}
345
346impl<T: AsyncRead + Unpin, B> AsyncRead for FramedWrite<T, B> {
347 fn poll_read(
348 mut self: Pin<&mut Self>,
349 cx: &mut Context<'_>,
350 buf: &mut ReadBuf,
351 ) -> Poll<io::Result<()>> {
352 Pin::new(&mut self.inner).poll_read(cx, buf)
353 }
354}
355
356impl<T: Unpin, B> Unpin for FramedWrite<T, B> {}
358
359#[cfg(feature = "unstable")]
360mod unstable {
361 use super::*;
362
363 impl<T, B> FramedWrite<T, B> {
364 pub fn get_ref(&self) -> &T {
365 &self.inner
366 }
367 }
368}