1#[cfg(feature = "compression")]
2use super::compression::{compress, CompressionEncoding, SingleMessageCompressionOverride};
3use super::{EncodeBuf, Encoder, HEADER_SIZE};
4use crate::{Code, Status};
5use bytes::{BufMut, Bytes, BytesMut};
6use futures_core::{Stream, TryStream};
7use futures_util::{ready, StreamExt, TryStreamExt};
8use http::HeaderMap;
9use http_body::Body;
10use pin_project::pin_project;
11use std::{
12 pin::Pin,
13 task::{Context, Poll},
14};
15
16pub(super) const BUFFER_SIZE: usize = 8 * 1024;
17
18pub(crate) fn encode_server<T, U>(
19 encoder: T,
20 source: U,
21 #[cfg(feature = "compression")] compression_encoding: Option<CompressionEncoding>,
22 #[cfg(feature = "compression")] compression_override: SingleMessageCompressionOverride,
23) -> EncodeBody<impl Stream<Item = Result<Bytes, Status>>>
24where
25 T: Encoder<Error = Status>,
26 U: Stream<Item = Result<T::Item, Status>>,
27{
28 let stream = encode(
29 encoder,
30 source,
31 #[cfg(feature = "compression")]
32 compression_encoding,
33 #[cfg(feature = "compression")]
34 compression_override,
35 )
36 .into_stream();
37
38 EncodeBody::new_server(stream)
39}
40
41pub(crate) fn encode_client<T, U>(
42 encoder: T,
43 source: U,
44 #[cfg(feature = "compression")] compression_encoding: Option<CompressionEncoding>,
45) -> EncodeBody<impl Stream<Item = Result<Bytes, Status>>>
46where
47 T: Encoder<Error = Status>,
48 U: Stream<Item = T::Item>,
49{
50 let stream = encode(
51 encoder,
52 source.map(Ok),
53 #[cfg(feature = "compression")]
54 compression_encoding,
55 #[cfg(feature = "compression")]
56 SingleMessageCompressionOverride::default(),
57 )
58 .into_stream();
59 EncodeBody::new_client(stream)
60}
61
62fn encode<T, U>(
63 mut encoder: T,
64 source: U,
65 #[cfg(feature = "compression")] compression_encoding: Option<CompressionEncoding>,
66 #[cfg(feature = "compression")] compression_override: SingleMessageCompressionOverride,
67) -> impl TryStream<Ok = Bytes, Error = Status>
68where
69 T: Encoder<Error = Status>,
70 U: Stream<Item = Result<T::Item, Status>>,
71{
72 async_stream::stream! {
73 let mut buf = BytesMut::with_capacity(BUFFER_SIZE);
74
75 #[cfg(feature = "compression")]
76 let (compression_enabled_for_stream, mut uncompression_buf) = match compression_encoding {
77 Some(CompressionEncoding::Gzip) => (true, BytesMut::with_capacity(BUFFER_SIZE)),
78 None => (false, BytesMut::new()),
79 };
80
81 #[cfg(feature = "compression")]
82 let compress_item = compression_enabled_for_stream && compression_override == SingleMessageCompressionOverride::Inherit;
83
84 #[cfg(not(feature = "compression"))]
85 let compress_item = false;
86
87 futures_util::pin_mut!(source);
88
89 loop {
90 match source.next().await {
91 Some(Ok(item)) => {
92 buf.reserve(HEADER_SIZE);
93 unsafe {
94 buf.advance_mut(HEADER_SIZE);
95 }
96
97 if compress_item {
98 #[cfg(feature = "compression")]
99 {
100 uncompression_buf.clear();
101
102 encoder.encode(item, &mut EncodeBuf::new(&mut uncompression_buf))
103 .map_err(|err| Status::internal(format!("Error encoding: {}", err)))?;
104
105 let uncompressed_len = uncompression_buf.len();
106
107 compress(
108 compression_encoding.unwrap(),
109 &mut uncompression_buf,
110 &mut buf,
111 uncompressed_len,
112 ).map_err(|err| Status::internal(format!("Error compressing: {}", err)))?;
113 }
114
115 #[cfg(not(feature = "compression"))]
116 unreachable!("compression disabled, should not take this branch");
117 } else {
118 encoder.encode(item, &mut EncodeBuf::new(&mut buf))
119 .map_err(|err| Status::internal(format!("Error encoding: {}", err)))?;
120 }
121
122 let len = buf.len() - HEADER_SIZE;
124 assert!(len <= std::u32::MAX as usize);
125 {
126 let mut buf = &mut buf[..HEADER_SIZE];
127 buf.put_u8(compress_item as u8);
128 buf.put_u32(len as u32);
129 }
130
131 yield Ok(buf.split_to(len + HEADER_SIZE).freeze());
132 },
133 Some(Err(status)) => yield Err(status),
134 None => break,
135 }
136 }
137 }
138}
139
140#[derive(Debug)]
141enum Role {
142 Client,
143 Server,
144}
145
146#[pin_project]
147#[derive(Debug)]
148pub(crate) struct EncodeBody<S> {
149 #[pin]
150 inner: S,
151 error: Option<Status>,
152 role: Role,
153 is_end_stream: bool,
154}
155
156impl<S> EncodeBody<S>
157where
158 S: Stream<Item = Result<Bytes, Status>>,
159{
160 pub(crate) fn new_client(inner: S) -> Self {
161 Self {
162 inner,
163 error: None,
164 role: Role::Client,
165 is_end_stream: false,
166 }
167 }
168
169 pub(crate) fn new_server(inner: S) -> Self {
170 Self {
171 inner,
172 error: None,
173 role: Role::Server,
174 is_end_stream: false,
175 }
176 }
177}
178
179impl<S> Body for EncodeBody<S>
180where
181 S: Stream<Item = Result<Bytes, Status>>,
182{
183 type Data = Bytes;
184 type Error = Status;
185
186 fn is_end_stream(&self) -> bool {
187 self.is_end_stream
188 }
189
190 fn poll_data(
191 self: Pin<&mut Self>,
192 cx: &mut Context<'_>,
193 ) -> Poll<Option<Result<Self::Data, Self::Error>>> {
194 let mut self_proj = self.project();
195 match ready!(self_proj.inner.try_poll_next_unpin(cx)) {
196 Some(Ok(d)) => Some(Ok(d)).into(),
197 Some(Err(status)) => match self_proj.role {
198 Role::Client => Some(Err(status)).into(),
199 Role::Server => {
200 *self_proj.error = Some(status);
201 None.into()
202 }
203 },
204 None => None.into(),
205 }
206 }
207
208 fn poll_trailers(
209 self: Pin<&mut Self>,
210 _cx: &mut Context<'_>,
211 ) -> Poll<Result<Option<HeaderMap>, Status>> {
212 match self.role {
213 Role::Client => Poll::Ready(Ok(None)),
214 Role::Server => {
215 let self_proj = self.project();
216
217 if *self_proj.is_end_stream {
218 return Poll::Ready(Ok(None));
219 }
220
221 let status = if let Some(status) = self_proj.error.take() {
222 *self_proj.is_end_stream = true;
223 status
224 } else {
225 Status::new(Code::Ok, "")
226 };
227
228 Poll::Ready(Ok(Some(status.to_header_map()?)))
229 }
230 }
231 }
232}