tower/buffer/
service.rs

1use super::{
2    future::ResponseFuture,
3    message::Message,
4    worker::{Handle, Worker},
5};
6
7use futures_core::ready;
8use std::sync::Arc;
9use std::task::{Context, Poll};
10use tokio::sync::{mpsc, oneshot, OwnedSemaphorePermit, Semaphore};
11use tokio_util::sync::PollSemaphore;
12use tower_service::Service;
13
14/// Adds an mpsc buffer in front of an inner service.
15///
16/// See the module documentation for more details.
17#[derive(Debug)]
18pub struct Buffer<T, Request>
19where
20    T: Service<Request>,
21{
22    // Note: this actually _is_ bounded, but rather than using Tokio's bounded
23    // channel, we use Tokio's semaphore separately to implement the bound.
24    tx: mpsc::UnboundedSender<Message<Request, T::Future>>,
25    // When the buffer's channel is full, we want to exert backpressure in
26    // `poll_ready`, so that callers such as load balancers could choose to call
27    // another service rather than waiting for buffer capacity.
28    //
29    // Unfortunately, this can't be done easily using Tokio's bounded MPSC
30    // channel, because it doesn't expose a polling-based interface, only an
31    // `async fn ready`, which borrows the sender. Therefore, we implement our
32    // own bounded MPSC on top of the unbounded channel, using a semaphore to
33    // limit how many items are in the channel.
34    semaphore: PollSemaphore,
35    // The current semaphore permit, if one has been acquired.
36    //
37    // This is acquired in `poll_ready` and taken in `call`.
38    permit: Option<OwnedSemaphorePermit>,
39    handle: Handle,
40}
41
42impl<T, Request> Buffer<T, Request>
43where
44    T: Service<Request>,
45    T::Error: Into<crate::BoxError>,
46{
47    /// Creates a new [`Buffer`] wrapping `service`.
48    ///
49    /// `bound` gives the maximal number of requests that can be queued for the service before
50    /// backpressure is applied to callers.
51    ///
52    /// The default Tokio executor is used to run the given service, which means that this method
53    /// must be called while on the Tokio runtime.
54    ///
55    /// # A note on choosing a `bound`
56    ///
57    /// When [`Buffer`]'s implementation of [`poll_ready`] returns [`Poll::Ready`], it reserves a
58    /// slot in the channel for the forthcoming [`call`]. However, if this call doesn't arrive,
59    /// this reserved slot may be held up for a long time. As a result, it's advisable to set
60    /// `bound` to be at least the maximum number of concurrent requests the [`Buffer`] will see.
61    /// If you do not, all the slots in the buffer may be held up by futures that have just called
62    /// [`poll_ready`] but will not issue a [`call`], which prevents other senders from issuing new
63    /// requests.
64    ///
65    /// [`Poll::Ready`]: std::task::Poll::Ready
66    /// [`call`]: crate::Service::call
67    /// [`poll_ready`]: crate::Service::poll_ready
68    pub fn new(service: T, bound: usize) -> Self
69    where
70        T: Send + 'static,
71        T::Future: Send,
72        T::Error: Send + Sync,
73        Request: Send + 'static,
74    {
75        let (service, worker) = Self::pair(service, bound);
76        tokio::spawn(worker);
77        service
78    }
79
80    /// Creates a new [`Buffer`] wrapping `service`, but returns the background worker.
81    ///
82    /// This is useful if you do not want to spawn directly onto the tokio runtime
83    /// but instead want to use your own executor. This will return the [`Buffer`] and
84    /// the background `Worker` that you can then spawn.
85    pub fn pair(service: T, bound: usize) -> (Buffer<T, Request>, Worker<T, Request>)
86    where
87        T: Send + 'static,
88        T::Error: Send + Sync,
89        Request: Send + 'static,
90    {
91        let (tx, rx) = mpsc::unbounded_channel();
92        let semaphore = Arc::new(Semaphore::new(bound));
93        let (handle, worker) = Worker::new(service, rx, &semaphore);
94        let buffer = Buffer {
95            tx,
96            handle,
97            semaphore: PollSemaphore::new(semaphore),
98            permit: None,
99        };
100        (buffer, worker)
101    }
102
103    fn get_worker_error(&self) -> crate::BoxError {
104        self.handle.get_error_on_closed()
105    }
106}
107
108impl<T, Request> Service<Request> for Buffer<T, Request>
109where
110    T: Service<Request>,
111    T::Error: Into<crate::BoxError>,
112{
113    type Response = T::Response;
114    type Error = crate::BoxError;
115    type Future = ResponseFuture<T::Future>;
116
117    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
118        // First, check if the worker is still alive.
119        if self.tx.is_closed() {
120            // If the inner service has errored, then we error here.
121            return Poll::Ready(Err(self.get_worker_error()));
122        }
123
124        // Then, check if we've already acquired a permit.
125        if self.permit.is_some() {
126            // We've already reserved capacity to send a request. We're ready!
127            return Poll::Ready(Ok(()));
128        }
129
130        // Finally, if we haven't already acquired a permit, poll the semaphore
131        // to acquire one. If we acquire a permit, then there's enough buffer
132        // capacity to send a new request. Otherwise, we need to wait for
133        // capacity.
134        let permit =
135            ready!(self.semaphore.poll_acquire(cx)).ok_or_else(|| self.get_worker_error())?;
136        self.permit = Some(permit);
137
138        Poll::Ready(Ok(()))
139    }
140
141    fn call(&mut self, request: Request) -> Self::Future {
142        tracing::trace!("sending request to buffer worker");
143        let _permit = self
144            .permit
145            .take()
146            .expect("buffer full; poll_ready must be called first");
147
148        // get the current Span so that we can explicitly propagate it to the worker
149        // if we didn't do this, events on the worker related to this span wouldn't be counted
150        // towards that span since the worker would have no way of entering it.
151        let span = tracing::Span::current();
152
153        // If we've made it here, then a semaphore permit has already been
154        // acquired, so we can freely allocate a oneshot.
155        let (tx, rx) = oneshot::channel();
156
157        match self.tx.send(Message {
158            request,
159            span,
160            tx,
161            _permit,
162        }) {
163            Err(_) => ResponseFuture::failed(self.get_worker_error()),
164            Ok(_) => ResponseFuture::new(rx),
165        }
166    }
167}
168
169impl<T, Request> Clone for Buffer<T, Request>
170where
171    T: Service<Request>,
172{
173    fn clone(&self) -> Self {
174        Self {
175            tx: self.tx.clone(),
176            handle: self.handle.clone(),
177            semaphore: self.semaphore.clone(),
178            // The new clone hasn't acquired a permit yet. It will when it's
179            // next polled ready.
180            permit: None,
181        }
182    }
183}