tower/buffer/
worker.rs

1use super::{
2    error::{Closed, ServiceError},
3    message::Message,
4};
5use futures_core::ready;
6use std::sync::{Arc, Mutex, Weak};
7use std::{
8    future::Future,
9    pin::Pin,
10    task::{Context, Poll},
11};
12use tokio::sync::{mpsc, Semaphore};
13use tower_service::Service;
14
15pin_project_lite::pin_project! {
16    /// Task that handles processing the buffer. This type should not be used
17    /// directly, instead `Buffer` requires an `Executor` that can accept this task.
18    ///
19    /// The struct is `pub` in the private module and the type is *not* re-exported
20    /// as part of the public API. This is the "sealed" pattern to include "private"
21    /// types in public traits that are not meant for consumers of the library to
22    /// implement (only call).
23    #[derive(Debug)]
24    pub struct Worker<T, Request>
25    where
26        T: Service<Request>,
27    {
28        current_message: Option<Message<Request, T::Future>>,
29        rx: mpsc::UnboundedReceiver<Message<Request, T::Future>>,
30        service: T,
31        finish: bool,
32        failed: Option<ServiceError>,
33        handle: Handle,
34        close: Option<Weak<Semaphore>>,
35    }
36
37    impl<T: Service<Request>, Request> PinnedDrop for Worker<T, Request>
38    {
39        fn drop(mut this: Pin<&mut Self>) {
40            this.as_mut().close_semaphore();
41        }
42    }
43}
44
45/// Get the error out
46#[derive(Debug)]
47pub(crate) struct Handle {
48    inner: Arc<Mutex<Option<ServiceError>>>,
49}
50
51impl<T, Request> Worker<T, Request>
52where
53    T: Service<Request>,
54{
55    /// Closes the buffer's semaphore if it is still open, waking any pending
56    /// tasks.
57    fn close_semaphore(&mut self) {
58        if let Some(close) = self.close.take().as_ref().and_then(Weak::upgrade) {
59            tracing::debug!("buffer closing; waking pending tasks");
60            close.close();
61        } else {
62            tracing::trace!("buffer already closed");
63        }
64    }
65}
66
67impl<T, Request> Worker<T, Request>
68where
69    T: Service<Request>,
70    T::Error: Into<crate::BoxError>,
71{
72    pub(crate) fn new(
73        service: T,
74        rx: mpsc::UnboundedReceiver<Message<Request, T::Future>>,
75        semaphore: &Arc<Semaphore>,
76    ) -> (Handle, Worker<T, Request>) {
77        let handle = Handle {
78            inner: Arc::new(Mutex::new(None)),
79        };
80
81        let semaphore = Arc::downgrade(semaphore);
82        let worker = Worker {
83            current_message: None,
84            finish: false,
85            failed: None,
86            rx,
87            service,
88            handle: handle.clone(),
89            close: Some(semaphore),
90        };
91
92        (handle, worker)
93    }
94
95    /// Return the next queued Message that hasn't been canceled.
96    ///
97    /// If a `Message` is returned, the `bool` is true if this is the first time we received this
98    /// message, and false otherwise (i.e., we tried to forward it to the backing service before).
99    fn poll_next_msg(
100        &mut self,
101        cx: &mut Context<'_>,
102    ) -> Poll<Option<(Message<Request, T::Future>, bool)>> {
103        if self.finish {
104            // We've already received None and are shutting down
105            return Poll::Ready(None);
106        }
107
108        tracing::trace!("worker polling for next message");
109        if let Some(msg) = self.current_message.take() {
110            // If the oneshot sender is closed, then the receiver is dropped,
111            // and nobody cares about the response. If this is the case, we
112            // should continue to the next request.
113            if !msg.tx.is_closed() {
114                tracing::trace!("resuming buffered request");
115                return Poll::Ready(Some((msg, false)));
116            }
117
118            tracing::trace!("dropping cancelled buffered request");
119        }
120
121        // Get the next request
122        while let Some(msg) = ready!(Pin::new(&mut self.rx).poll_recv(cx)) {
123            if !msg.tx.is_closed() {
124                tracing::trace!("processing new request");
125                return Poll::Ready(Some((msg, true)));
126            }
127            // Otherwise, request is canceled, so pop the next one.
128            tracing::trace!("dropping cancelled request");
129        }
130
131        Poll::Ready(None)
132    }
133
134    fn failed(&mut self, error: crate::BoxError) {
135        // The underlying service failed when we called `poll_ready` on it with the given `error`. We
136        // need to communicate this to all the `Buffer` handles. To do so, we wrap up the error in
137        // an `Arc`, send that `Arc<E>` to all pending requests, and store it so that subsequent
138        // requests will also fail with the same error.
139
140        // Note that we need to handle the case where some handle is concurrently trying to send us
141        // a request. We need to make sure that *either* the send of the request fails *or* it
142        // receives an error on the `oneshot` it constructed. Specifically, we want to avoid the
143        // case where we send errors to all outstanding requests, and *then* the caller sends its
144        // request. We do this by *first* exposing the error, *then* closing the channel used to
145        // send more requests (so the client will see the error when the send fails), and *then*
146        // sending the error to all outstanding requests.
147        let error = ServiceError::new(error);
148
149        let mut inner = self.handle.inner.lock().unwrap();
150
151        if inner.is_some() {
152            // Future::poll was called after we've already errored out!
153            return;
154        }
155
156        *inner = Some(error.clone());
157        drop(inner);
158
159        self.rx.close();
160
161        // By closing the mpsc::Receiver, we know that poll_next_msg will soon return Ready(None),
162        // which will trigger the `self.finish == true` phase. We just need to make sure that any
163        // requests that we receive before we've exhausted the receiver receive the error:
164        self.failed = Some(error);
165    }
166}
167
168impl<T, Request> Future for Worker<T, Request>
169where
170    T: Service<Request>,
171    T::Error: Into<crate::BoxError>,
172{
173    type Output = ();
174
175    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
176        if self.finish {
177            return Poll::Ready(());
178        }
179
180        loop {
181            match ready!(self.poll_next_msg(cx)) {
182                Some((msg, first)) => {
183                    let _guard = msg.span.enter();
184                    if let Some(ref failed) = self.failed {
185                        tracing::trace!("notifying caller about worker failure");
186                        let _ = msg.tx.send(Err(failed.clone()));
187                        continue;
188                    }
189
190                    // Wait for the service to be ready
191                    tracing::trace!(
192                        resumed = !first,
193                        message = "worker received request; waiting for service readiness"
194                    );
195                    match self.service.poll_ready(cx) {
196                        Poll::Ready(Ok(())) => {
197                            tracing::debug!(service.ready = true, message = "processing request");
198                            let response = self.service.call(msg.request);
199
200                            // Send the response future back to the sender.
201                            //
202                            // An error means the request had been canceled in-between
203                            // our calls, the response future will just be dropped.
204                            tracing::trace!("returning response future");
205                            let _ = msg.tx.send(Ok(response));
206                        }
207                        Poll::Pending => {
208                            tracing::trace!(service.ready = false, message = "delay");
209                            // Put out current message back in its slot.
210                            drop(_guard);
211                            self.current_message = Some(msg);
212                            return Poll::Pending;
213                        }
214                        Poll::Ready(Err(e)) => {
215                            let error = e.into();
216                            tracing::debug!({ %error }, "service failed");
217                            drop(_guard);
218                            self.failed(error);
219                            let _ = msg.tx.send(Err(self
220                                .failed
221                                .as_ref()
222                                .expect("Worker::failed did not set self.failed?")
223                                .clone()));
224                            // Wake any tasks waiting on channel capacity.
225                            self.close_semaphore();
226                        }
227                    }
228                }
229                None => {
230                    // No more more requests _ever_.
231                    self.finish = true;
232                    return Poll::Ready(());
233                }
234            }
235        }
236    }
237}
238
239impl Handle {
240    pub(crate) fn get_error_on_closed(&self) -> crate::BoxError {
241        self.inner
242            .lock()
243            .unwrap()
244            .as_ref()
245            .map(|svc_err| svc_err.clone().into())
246            .unwrap_or_else(|| Closed::new().into())
247    }
248}
249
250impl Clone for Handle {
251    fn clone(&self) -> Handle {
252        Handle {
253            inner: self.inner.clone(),
254        }
255    }
256}