tower/limit/concurrency/
service.rs1use super::future::ResponseFuture;
2use tokio::sync::{OwnedSemaphorePermit, Semaphore};
3use tokio_util::sync::PollSemaphore;
4use tower_service::Service;
5
6use futures_core::ready;
7use std::{
8 sync::Arc,
9 task::{Context, Poll},
10};
11
12#[derive(Debug)]
15pub struct ConcurrencyLimit<T> {
16 inner: T,
17 semaphore: PollSemaphore,
18 permit: Option<OwnedSemaphorePermit>,
24}
25
26impl<T> ConcurrencyLimit<T> {
27 pub fn new(inner: T, max: usize) -> Self {
29 Self::with_semaphore(inner, Arc::new(Semaphore::new(max)))
30 }
31
32 pub fn with_semaphore(inner: T, semaphore: Arc<Semaphore>) -> Self {
34 ConcurrencyLimit {
35 inner,
36 semaphore: PollSemaphore::new(semaphore),
37 permit: None,
38 }
39 }
40
41 pub fn get_ref(&self) -> &T {
43 &self.inner
44 }
45
46 pub fn get_mut(&mut self) -> &mut T {
48 &mut self.inner
49 }
50
51 pub fn into_inner(self) -> T {
53 self.inner
54 }
55}
56
57impl<S, Request> Service<Request> for ConcurrencyLimit<S>
58where
59 S: Service<Request>,
60{
61 type Response = S::Response;
62 type Error = S::Error;
63 type Future = ResponseFuture<S::Future>;
64
65 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
66 if self.permit.is_none() {
69 self.permit = ready!(self.semaphore.poll_acquire(cx));
70 debug_assert!(
71 self.permit.is_some(),
72 "ConcurrencyLimit semaphore is never closed, so `poll_acquire` \
73 should never fail",
74 );
75 }
76
77 self.inner.poll_ready(cx)
80 }
81
82 fn call(&mut self, request: Request) -> Self::Future {
83 let permit = self
85 .permit
86 .take()
87 .expect("max requests in-flight; poll_ready must be called first");
88
89 let future = self.inner.call(request);
91
92 ResponseFuture::new(future, permit)
93 }
94}
95
96impl<T: Clone> Clone for ConcurrencyLimit<T> {
97 fn clone(&self) -> Self {
98 Self {
102 inner: self.inner.clone(),
103 semaphore: self.semaphore.clone(),
104 permit: None,
105 }
106 }
107}
108
109#[cfg(feature = "load")]
110#[cfg_attr(docsrs, doc(cfg(feature = "load")))]
111impl<S> crate::load::Load for ConcurrencyLimit<S>
112where
113 S: crate::load::Load,
114{
115 type Metric = S::Metric;
116 fn load(&self) -> Self::Metric {
117 self.inner.load()
118 }
119}