tower/limit/rate/
service.rs1use super::Rate;
2use futures_core::ready;
3use std::{
4 future::Future,
5 pin::Pin,
6 task::{Context, Poll},
7};
8use tokio::time::{Instant, Sleep};
9use tower_service::Service;
10
11#[derive(Debug)]
14pub struct RateLimit<T> {
15 inner: T,
16 rate: Rate,
17 state: State,
18 sleep: Pin<Box<Sleep>>,
19}
20
21#[derive(Debug)]
22enum State {
23 Limited,
25 Ready { until: Instant, rem: u64 },
26}
27
28impl<T> RateLimit<T> {
29 pub fn new(inner: T, rate: Rate) -> Self {
31 let until = Instant::now();
32 let state = State::Ready {
33 until,
34 rem: rate.num(),
35 };
36
37 RateLimit {
38 inner,
39 rate,
40 state,
41 sleep: Box::pin(tokio::time::sleep_until(until)),
45 }
46 }
47
48 pub fn get_ref(&self) -> &T {
50 &self.inner
51 }
52
53 pub fn get_mut(&mut self) -> &mut T {
55 &mut self.inner
56 }
57
58 pub fn into_inner(self) -> T {
60 self.inner
61 }
62}
63
64impl<S, Request> Service<Request> for RateLimit<S>
65where
66 S: Service<Request>,
67{
68 type Response = S::Response;
69 type Error = S::Error;
70 type Future = S::Future;
71
72 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
73 match self.state {
74 State::Ready { .. } => return Poll::Ready(ready!(self.inner.poll_ready(cx))),
75 State::Limited => {
76 if Pin::new(&mut self.sleep).poll(cx).is_pending() {
77 tracing::trace!("rate limit exceeded; sleeping.");
78 return Poll::Pending;
79 }
80 }
81 }
82
83 self.state = State::Ready {
84 until: Instant::now() + self.rate.per(),
85 rem: self.rate.num(),
86 };
87
88 Poll::Ready(ready!(self.inner.poll_ready(cx)))
89 }
90
91 fn call(&mut self, request: Request) -> Self::Future {
92 match self.state {
93 State::Ready { mut until, mut rem } => {
94 let now = Instant::now();
95
96 if now >= until {
98 until = now + self.rate.per();
99 rem = self.rate.num();
100 }
101
102 if rem > 1 {
103 rem -= 1;
104 self.state = State::Ready { until, rem };
105 } else {
106 self.sleep.as_mut().reset(until);
110 self.state = State::Limited;
111 }
112
113 self.inner.call(request)
115 }
116 State::Limited => panic!("service not ready; poll_ready must be called first"),
117 }
118 }
119}
120
121#[cfg(feature = "load")]
122#[cfg_attr(docsrs, doc(cfg(feature = "load")))]
123impl<S> crate::load::Load for RateLimit<S>
124where
125 S: crate::load::Load,
126{
127 type Metric = S::Metric;
128 fn load(&self) -> Self::Metric {
129 self.inner.load()
130 }
131}