tonic/transport/server/
incoming.rs

1use super::{Connected, Server};
2use crate::transport::service::ServerIo;
3use futures_core::Stream;
4use futures_util::stream::TryStreamExt;
5use hyper::server::{
6    accept::Accept,
7    conn::{AddrIncoming, AddrStream},
8};
9use std::{
10    net::SocketAddr,
11    pin::Pin,
12    task::{Context, Poll},
13    time::Duration,
14};
15use tokio::io::{AsyncRead, AsyncWrite};
16
17#[cfg(not(feature = "tls"))]
18pub(crate) fn tcp_incoming<IO, IE, L>(
19    incoming: impl Stream<Item = Result<IO, IE>>,
20    _server: Server<L>,
21) -> impl Stream<Item = Result<ServerIo<IO>, crate::Error>>
22where
23    IO: AsyncRead + AsyncWrite + Connected + Unpin + Send + 'static,
24    IE: Into<crate::Error>,
25{
26    async_stream::try_stream! {
27        futures_util::pin_mut!(incoming);
28
29        while let Some(stream) = incoming.try_next().await? {
30            yield ServerIo::new_io(stream);
31        }
32    }
33}
34
35#[cfg(feature = "tls")]
36pub(crate) fn tcp_incoming<IO, IE, L>(
37    incoming: impl Stream<Item = Result<IO, IE>>,
38    server: Server<L>,
39) -> impl Stream<Item = Result<ServerIo<IO>, crate::Error>>
40where
41    IO: AsyncRead + AsyncWrite + Connected + Unpin + Send + 'static,
42    IE: Into<crate::Error>,
43{
44    async_stream::try_stream! {
45        futures_util::pin_mut!(incoming);
46
47        #[cfg(feature = "tls")]
48        let mut tasks = futures_util::stream::futures_unordered::FuturesUnordered::new();
49
50        loop {
51            match select(&mut incoming, &mut tasks).await {
52                SelectOutput::Incoming(stream) => {
53                    if let Some(tls) = &server.tls {
54                        let tls = tls.clone();
55
56                        let accept = tokio::spawn(async move {
57                            let io = tls.accept(stream).await?;
58                            Ok(ServerIo::new_tls_io(io))
59                        });
60
61                        tasks.push(accept);
62                    } else {
63                        yield ServerIo::new_io(stream);
64                    }
65                }
66
67                SelectOutput::Io(io) => {
68                    yield io;
69                }
70
71                SelectOutput::Err(e) => {
72                    tracing::debug!(message = "Accept loop error.", error = %e);
73                }
74
75                SelectOutput::Done => {
76                    break;
77                }
78            }
79        }
80    }
81}
82
83#[cfg(feature = "tls")]
84async fn select<IO, IE>(
85    incoming: &mut (impl Stream<Item = Result<IO, IE>> + Unpin),
86    tasks: &mut futures_util::stream::futures_unordered::FuturesUnordered<
87        tokio::task::JoinHandle<Result<ServerIo<IO>, crate::Error>>,
88    >,
89) -> SelectOutput<IO>
90where
91    IE: Into<crate::Error>,
92{
93    use futures_util::StreamExt;
94
95    if tasks.is_empty() {
96        return match incoming.try_next().await {
97            Ok(Some(stream)) => SelectOutput::Incoming(stream),
98            Ok(None) => SelectOutput::Done,
99            Err(e) => SelectOutput::Err(e.into()),
100        };
101    }
102
103    tokio::select! {
104        stream = incoming.try_next() => {
105            match stream {
106                Ok(Some(stream)) => SelectOutput::Incoming(stream),
107                Ok(None) => SelectOutput::Done,
108                Err(e) => SelectOutput::Err(e.into()),
109            }
110        }
111
112        accept = tasks.next() => {
113            match accept.expect("FuturesUnordered stream should never end") {
114                Ok(Ok(io)) => SelectOutput::Io(io),
115                Ok(Err(e)) => SelectOutput::Err(e),
116                Err(e) => SelectOutput::Err(e.into()),
117            }
118        }
119    }
120}
121
122#[cfg(feature = "tls")]
123enum SelectOutput<A> {
124    Incoming(A),
125    Io(ServerIo<A>),
126    Err(crate::Error),
127    Done,
128}
129
130pub(crate) struct TcpIncoming {
131    inner: AddrIncoming,
132}
133
134impl TcpIncoming {
135    pub(crate) fn new(
136        addr: SocketAddr,
137        nodelay: bool,
138        keepalive: Option<Duration>,
139    ) -> Result<Self, crate::Error> {
140        let mut inner = AddrIncoming::bind(&addr)?;
141        inner.set_nodelay(nodelay);
142        inner.set_keepalive(keepalive);
143        Ok(TcpIncoming { inner })
144    }
145}
146
147impl Stream for TcpIncoming {
148    type Item = Result<AddrStream, std::io::Error>;
149
150    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
151        Pin::new(&mut self.inner).poll_accept(cx)
152    }
153}