tower/buffer/service.rs
1use super::{
2 future::ResponseFuture,
3 message::Message,
4 worker::{Handle, Worker},
5};
6
7use futures_core::ready;
8use std::sync::Arc;
9use std::task::{Context, Poll};
10use tokio::sync::{mpsc, oneshot, OwnedSemaphorePermit, Semaphore};
11use tokio_util::sync::PollSemaphore;
12use tower_service::Service;
13
14/// Adds an mpsc buffer in front of an inner service.
15///
16/// See the module documentation for more details.
17#[derive(Debug)]
18pub struct Buffer<T, Request>
19where
20 T: Service<Request>,
21{
22 // Note: this actually _is_ bounded, but rather than using Tokio's bounded
23 // channel, we use Tokio's semaphore separately to implement the bound.
24 tx: mpsc::UnboundedSender<Message<Request, T::Future>>,
25 // When the buffer's channel is full, we want to exert backpressure in
26 // `poll_ready`, so that callers such as load balancers could choose to call
27 // another service rather than waiting for buffer capacity.
28 //
29 // Unfortunately, this can't be done easily using Tokio's bounded MPSC
30 // channel, because it doesn't expose a polling-based interface, only an
31 // `async fn ready`, which borrows the sender. Therefore, we implement our
32 // own bounded MPSC on top of the unbounded channel, using a semaphore to
33 // limit how many items are in the channel.
34 semaphore: PollSemaphore,
35 // The current semaphore permit, if one has been acquired.
36 //
37 // This is acquired in `poll_ready` and taken in `call`.
38 permit: Option<OwnedSemaphorePermit>,
39 handle: Handle,
40}
41
42impl<T, Request> Buffer<T, Request>
43where
44 T: Service<Request>,
45 T::Error: Into<crate::BoxError>,
46{
47 /// Creates a new [`Buffer`] wrapping `service`.
48 ///
49 /// `bound` gives the maximal number of requests that can be queued for the service before
50 /// backpressure is applied to callers.
51 ///
52 /// The default Tokio executor is used to run the given service, which means that this method
53 /// must be called while on the Tokio runtime.
54 ///
55 /// # A note on choosing a `bound`
56 ///
57 /// When [`Buffer`]'s implementation of [`poll_ready`] returns [`Poll::Ready`], it reserves a
58 /// slot in the channel for the forthcoming [`call`]. However, if this call doesn't arrive,
59 /// this reserved slot may be held up for a long time. As a result, it's advisable to set
60 /// `bound` to be at least the maximum number of concurrent requests the [`Buffer`] will see.
61 /// If you do not, all the slots in the buffer may be held up by futures that have just called
62 /// [`poll_ready`] but will not issue a [`call`], which prevents other senders from issuing new
63 /// requests.
64 ///
65 /// [`Poll::Ready`]: std::task::Poll::Ready
66 /// [`call`]: crate::Service::call
67 /// [`poll_ready`]: crate::Service::poll_ready
68 pub fn new(service: T, bound: usize) -> Self
69 where
70 T: Send + 'static,
71 T::Future: Send,
72 T::Error: Send + Sync,
73 Request: Send + 'static,
74 {
75 let (service, worker) = Self::pair(service, bound);
76 tokio::spawn(worker);
77 service
78 }
79
80 /// Creates a new [`Buffer`] wrapping `service`, but returns the background worker.
81 ///
82 /// This is useful if you do not want to spawn directly onto the tokio runtime
83 /// but instead want to use your own executor. This will return the [`Buffer`] and
84 /// the background `Worker` that you can then spawn.
85 pub fn pair(service: T, bound: usize) -> (Buffer<T, Request>, Worker<T, Request>)
86 where
87 T: Send + 'static,
88 T::Error: Send + Sync,
89 Request: Send + 'static,
90 {
91 let (tx, rx) = mpsc::unbounded_channel();
92 let semaphore = Arc::new(Semaphore::new(bound));
93 let (handle, worker) = Worker::new(service, rx, &semaphore);
94 let buffer = Buffer {
95 tx,
96 handle,
97 semaphore: PollSemaphore::new(semaphore),
98 permit: None,
99 };
100 (buffer, worker)
101 }
102
103 fn get_worker_error(&self) -> crate::BoxError {
104 self.handle.get_error_on_closed()
105 }
106}
107
108impl<T, Request> Service<Request> for Buffer<T, Request>
109where
110 T: Service<Request>,
111 T::Error: Into<crate::BoxError>,
112{
113 type Response = T::Response;
114 type Error = crate::BoxError;
115 type Future = ResponseFuture<T::Future>;
116
117 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
118 // First, check if the worker is still alive.
119 if self.tx.is_closed() {
120 // If the inner service has errored, then we error here.
121 return Poll::Ready(Err(self.get_worker_error()));
122 }
123
124 // Then, check if we've already acquired a permit.
125 if self.permit.is_some() {
126 // We've already reserved capacity to send a request. We're ready!
127 return Poll::Ready(Ok(()));
128 }
129
130 // Finally, if we haven't already acquired a permit, poll the semaphore
131 // to acquire one. If we acquire a permit, then there's enough buffer
132 // capacity to send a new request. Otherwise, we need to wait for
133 // capacity.
134 let permit =
135 ready!(self.semaphore.poll_acquire(cx)).ok_or_else(|| self.get_worker_error())?;
136 self.permit = Some(permit);
137
138 Poll::Ready(Ok(()))
139 }
140
141 fn call(&mut self, request: Request) -> Self::Future {
142 tracing::trace!("sending request to buffer worker");
143 let _permit = self
144 .permit
145 .take()
146 .expect("buffer full; poll_ready must be called first");
147
148 // get the current Span so that we can explicitly propagate it to the worker
149 // if we didn't do this, events on the worker related to this span wouldn't be counted
150 // towards that span since the worker would have no way of entering it.
151 let span = tracing::Span::current();
152
153 // If we've made it here, then a semaphore permit has already been
154 // acquired, so we can freely allocate a oneshot.
155 let (tx, rx) = oneshot::channel();
156
157 match self.tx.send(Message {
158 request,
159 span,
160 tx,
161 _permit,
162 }) {
163 Err(_) => ResponseFuture::failed(self.get_worker_error()),
164 Ok(_) => ResponseFuture::new(rx),
165 }
166 }
167}
168
169impl<T, Request> Clone for Buffer<T, Request>
170where
171 T: Service<Request>,
172{
173 fn clone(&self) -> Self {
174 Self {
175 tx: self.tx.clone(),
176 handle: self.handle.clone(),
177 semaphore: self.semaphore.clone(),
178 // The new clone hasn't acquired a permit yet. It will when it's
179 // next polled ready.
180 permit: None,
181 }
182 }
183}