tokio_io_timeout/
lib.rs

1//! Tokio wrappers which apply timeouts to IO operations.
2//!
3//! These timeouts are analogous to the read and write timeouts on traditional blocking sockets. A timeout countdown is
4//! initiated when a read/write operation returns [`Poll::Pending`]. If a read/write does not return successfully before
5//! the countdown expires, an [`io::Error`] with a kind of [`TimedOut`](io::ErrorKind::TimedOut) is returned.
6#![doc(html_root_url = "https://docs.rs/tokio-io-timeout/1")]
7#![warn(missing_docs)]
8
9use pin_project_lite::pin_project;
10use std::future::Future;
11use std::io;
12use std::io::SeekFrom;
13use std::pin::Pin;
14use std::task::{Context, Poll};
15use std::time::Duration;
16use tokio::io::{AsyncRead, AsyncSeek, AsyncWrite, ReadBuf};
17use tokio::time::{sleep_until, Instant, Sleep};
18
19pin_project! {
20    #[derive(Debug)]
21    struct TimeoutState {
22        timeout: Option<Duration>,
23        #[pin]
24        cur: Sleep,
25        active: bool,
26    }
27}
28
29impl TimeoutState {
30    #[inline]
31    fn new() -> TimeoutState {
32        TimeoutState {
33            timeout: None,
34            cur: sleep_until(Instant::now()),
35            active: false,
36        }
37    }
38
39    #[inline]
40    fn timeout(&self) -> Option<Duration> {
41        self.timeout
42    }
43
44    #[inline]
45    fn set_timeout(&mut self, timeout: Option<Duration>) {
46        // since this takes &mut self, we can't yet be active
47        self.timeout = timeout;
48    }
49
50    #[inline]
51    fn set_timeout_pinned(mut self: Pin<&mut Self>, timeout: Option<Duration>) {
52        *self.as_mut().project().timeout = timeout;
53        self.reset();
54    }
55
56    #[inline]
57    fn reset(self: Pin<&mut Self>) {
58        let this = self.project();
59
60        if *this.active {
61            *this.active = false;
62            this.cur.reset(Instant::now());
63        }
64    }
65
66    #[inline]
67    fn poll_check(self: Pin<&mut Self>, cx: &mut Context<'_>) -> io::Result<()> {
68        let mut this = self.project();
69
70        let timeout = match this.timeout {
71            Some(timeout) => *timeout,
72            None => return Ok(()),
73        };
74
75        if !*this.active {
76            this.cur.as_mut().reset(Instant::now() + timeout);
77            *this.active = true;
78        }
79
80        match this.cur.poll(cx) {
81            Poll::Ready(()) => Err(io::Error::from(io::ErrorKind::TimedOut)),
82            Poll::Pending => Ok(()),
83        }
84    }
85}
86
87pin_project! {
88    /// An `AsyncRead`er which applies a timeout to read operations.
89    #[derive(Debug)]
90    pub struct TimeoutReader<R> {
91        #[pin]
92        reader: R,
93        #[pin]
94        state: TimeoutState,
95    }
96}
97
98impl<R> TimeoutReader<R>
99where
100    R: AsyncRead,
101{
102    /// Returns a new `TimeoutReader` wrapping the specified reader.
103    ///
104    /// There is initially no timeout.
105    pub fn new(reader: R) -> TimeoutReader<R> {
106        TimeoutReader {
107            reader,
108            state: TimeoutState::new(),
109        }
110    }
111
112    /// Returns the current read timeout.
113    pub fn timeout(&self) -> Option<Duration> {
114        self.state.timeout()
115    }
116
117    /// Sets the read timeout.
118    ///
119    /// This can only be used before the reader is pinned; use [`set_timeout_pinned`](Self::set_timeout_pinned)
120    /// otherwise.
121    pub fn set_timeout(&mut self, timeout: Option<Duration>) {
122        self.state.set_timeout(timeout);
123    }
124
125    /// Sets the read timeout.
126    ///
127    /// This will reset any pending timeout. Use [`set_timeout`](Self::set_timeout) instead if the reader is not yet
128    /// pinned.
129    pub fn set_timeout_pinned(self: Pin<&mut Self>, timeout: Option<Duration>) {
130        self.project().state.set_timeout_pinned(timeout);
131    }
132
133    /// Returns a shared reference to the inner reader.
134    pub fn get_ref(&self) -> &R {
135        &self.reader
136    }
137
138    /// Returns a mutable reference to the inner reader.
139    pub fn get_mut(&mut self) -> &mut R {
140        &mut self.reader
141    }
142
143    /// Returns a pinned mutable reference to the inner reader.
144    pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut R> {
145        self.project().reader
146    }
147
148    /// Consumes the `TimeoutReader`, returning the inner reader.
149    pub fn into_inner(self) -> R {
150        self.reader
151    }
152}
153
154impl<R> AsyncRead for TimeoutReader<R>
155where
156    R: AsyncRead,
157{
158    fn poll_read(
159        self: Pin<&mut Self>,
160        cx: &mut Context<'_>,
161        buf: &mut ReadBuf<'_>,
162    ) -> Poll<Result<(), io::Error>> {
163        let this = self.project();
164        let r = this.reader.poll_read(cx, buf);
165        match r {
166            Poll::Pending => this.state.poll_check(cx)?,
167            _ => this.state.reset(),
168        }
169        r
170    }
171}
172
173impl<R> AsyncWrite for TimeoutReader<R>
174where
175    R: AsyncWrite,
176{
177    fn poll_write(
178        self: Pin<&mut Self>,
179        cx: &mut Context,
180        buf: &[u8],
181    ) -> Poll<Result<usize, io::Error>> {
182        self.project().reader.poll_write(cx, buf)
183    }
184
185    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
186        self.project().reader.poll_flush(cx)
187    }
188
189    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
190        self.project().reader.poll_shutdown(cx)
191    }
192
193    fn poll_write_vectored(
194        self: Pin<&mut Self>,
195        cx: &mut Context<'_>,
196        bufs: &[io::IoSlice<'_>],
197    ) -> Poll<io::Result<usize>> {
198        self.project().reader.poll_write_vectored(cx, bufs)
199    }
200
201    fn is_write_vectored(&self) -> bool {
202        self.reader.is_write_vectored()
203    }
204}
205
206impl<R> AsyncSeek for TimeoutReader<R>
207where
208    R: AsyncSeek,
209{
210    fn start_seek(self: Pin<&mut Self>, position: SeekFrom) -> io::Result<()> {
211        self.project().reader.start_seek(position)
212    }
213    fn poll_complete(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> {
214        self.project().reader.poll_complete(cx)
215    }
216}
217
218pin_project! {
219    /// An `AsyncWrite`er which applies a timeout to write operations.
220    #[derive(Debug)]
221    pub struct TimeoutWriter<W> {
222        #[pin]
223        writer: W,
224        #[pin]
225        state: TimeoutState,
226    }
227}
228
229impl<W> TimeoutWriter<W>
230where
231    W: AsyncWrite,
232{
233    /// Returns a new `TimeoutReader` wrapping the specified reader.
234    ///
235    /// There is initially no timeout.
236    pub fn new(writer: W) -> TimeoutWriter<W> {
237        TimeoutWriter {
238            writer,
239            state: TimeoutState::new(),
240        }
241    }
242
243    /// Returns the current write timeout.
244    pub fn timeout(&self) -> Option<Duration> {
245        self.state.timeout()
246    }
247
248    /// Sets the write timeout.
249    ///
250    /// This can only be used before the writer is pinned; use [`set_timeout_pinned`](Self::set_timeout_pinned)
251    /// otherwise.
252    pub fn set_timeout(&mut self, timeout: Option<Duration>) {
253        self.state.set_timeout(timeout);
254    }
255
256    /// Sets the write timeout.
257    ///
258    /// This will reset any pending timeout. Use [`set_timeout`](Self::set_timeout) instead if the reader is not yet
259    /// pinned.
260    pub fn set_timeout_pinned(self: Pin<&mut Self>, timeout: Option<Duration>) {
261        self.project().state.set_timeout_pinned(timeout);
262    }
263
264    /// Returns a shared reference to the inner writer.
265    pub fn get_ref(&self) -> &W {
266        &self.writer
267    }
268
269    /// Returns a mutable reference to the inner writer.
270    pub fn get_mut(&mut self) -> &mut W {
271        &mut self.writer
272    }
273
274    /// Returns a pinned mutable reference to the inner writer.
275    pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut W> {
276        self.project().writer
277    }
278
279    /// Consumes the `TimeoutWriter`, returning the inner writer.
280    pub fn into_inner(self) -> W {
281        self.writer
282    }
283}
284
285impl<W> AsyncWrite for TimeoutWriter<W>
286where
287    W: AsyncWrite,
288{
289    fn poll_write(
290        self: Pin<&mut Self>,
291        cx: &mut Context,
292        buf: &[u8],
293    ) -> Poll<Result<usize, io::Error>> {
294        let this = self.project();
295        let r = this.writer.poll_write(cx, buf);
296        match r {
297            Poll::Pending => this.state.poll_check(cx)?,
298            _ => this.state.reset(),
299        }
300        r
301    }
302
303    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
304        let this = self.project();
305        let r = this.writer.poll_flush(cx);
306        match r {
307            Poll::Pending => this.state.poll_check(cx)?,
308            _ => this.state.reset(),
309        }
310        r
311    }
312
313    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
314        let this = self.project();
315        let r = this.writer.poll_shutdown(cx);
316        match r {
317            Poll::Pending => this.state.poll_check(cx)?,
318            _ => this.state.reset(),
319        }
320        r
321    }
322
323    fn poll_write_vectored(
324        self: Pin<&mut Self>,
325        cx: &mut Context<'_>,
326        bufs: &[io::IoSlice<'_>],
327    ) -> Poll<io::Result<usize>> {
328        let this = self.project();
329        let r = this.writer.poll_write_vectored(cx, bufs);
330        match r {
331            Poll::Pending => this.state.poll_check(cx)?,
332            _ => this.state.reset(),
333        }
334        r
335    }
336
337    fn is_write_vectored(&self) -> bool {
338        self.writer.is_write_vectored()
339    }
340}
341
342impl<W> AsyncRead for TimeoutWriter<W>
343where
344    W: AsyncRead,
345{
346    fn poll_read(
347        self: Pin<&mut Self>,
348        cx: &mut Context<'_>,
349        buf: &mut ReadBuf<'_>,
350    ) -> Poll<Result<(), io::Error>> {
351        self.project().writer.poll_read(cx, buf)
352    }
353}
354
355impl<W> AsyncSeek for TimeoutWriter<W>
356where
357    W: AsyncSeek,
358{
359    fn start_seek(self: Pin<&mut Self>, position: SeekFrom) -> io::Result<()> {
360        self.project().writer.start_seek(position)
361    }
362    fn poll_complete(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> {
363        self.project().writer.poll_complete(cx)
364    }
365}
366
367pin_project! {
368    /// A stream which applies read and write timeouts to an inner stream.
369    #[derive(Debug)]
370    pub struct TimeoutStream<S> {
371        #[pin]
372        stream: TimeoutReader<TimeoutWriter<S>>
373    }
374}
375
376impl<S> TimeoutStream<S>
377where
378    S: AsyncRead + AsyncWrite,
379{
380    /// Returns a new `TimeoutStream` wrapping the specified stream.
381    ///
382    /// There is initially no read or write timeout.
383    pub fn new(stream: S) -> TimeoutStream<S> {
384        let writer = TimeoutWriter::new(stream);
385        let stream = TimeoutReader::new(writer);
386        TimeoutStream { stream }
387    }
388
389    /// Returns the current read timeout.
390    pub fn read_timeout(&self) -> Option<Duration> {
391        self.stream.timeout()
392    }
393
394    /// Sets the read timeout.
395    ///
396    /// This can only be used before the stream is pinned; use
397    /// [`set_read_timeout_pinned`](Self::set_read_timeout_pinned) otherwise.
398    pub fn set_read_timeout(&mut self, timeout: Option<Duration>) {
399        self.stream.set_timeout(timeout)
400    }
401
402    /// Sets the read timeout.
403    ///
404    /// This will reset any pending read timeout. Use [`set_read_timeout`](Self::set_read_timeout) instead if the stream
405    /// has not yet been pinned.
406    pub fn set_read_timeout_pinned(self: Pin<&mut Self>, timeout: Option<Duration>) {
407        self.project().stream.set_timeout_pinned(timeout)
408    }
409
410    /// Returns the current write timeout.
411    pub fn write_timeout(&self) -> Option<Duration> {
412        self.stream.get_ref().timeout()
413    }
414
415    /// Sets the write timeout.
416    ///
417    /// This can only be used before the stream is pinned; use
418    /// [`set_write_timeout_pinned`](Self::set_write_timeout_pinned) otherwise.
419    pub fn set_write_timeout(&mut self, timeout: Option<Duration>) {
420        self.stream.get_mut().set_timeout(timeout)
421    }
422
423    /// Sets the write timeout.
424    ///
425    /// This will reset any pending write timeout. Use [`set_write_timeout`](Self::set_write_timeout) instead if the
426    /// stream has not yet been pinned.
427    pub fn set_write_timeout_pinned(self: Pin<&mut Self>, timeout: Option<Duration>) {
428        self.project()
429            .stream
430            .get_pin_mut()
431            .set_timeout_pinned(timeout)
432    }
433
434    /// Returns a shared reference to the inner stream.
435    pub fn get_ref(&self) -> &S {
436        self.stream.get_ref().get_ref()
437    }
438
439    /// Returns a mutable reference to the inner stream.
440    pub fn get_mut(&mut self) -> &mut S {
441        self.stream.get_mut().get_mut()
442    }
443
444    /// Returns a pinned mutable reference to the inner stream.
445    pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut S> {
446        self.project().stream.get_pin_mut().get_pin_mut()
447    }
448
449    /// Consumes the stream, returning the inner stream.
450    pub fn into_inner(self) -> S {
451        self.stream.into_inner().into_inner()
452    }
453}
454
455impl<S> AsyncRead for TimeoutStream<S>
456where
457    S: AsyncRead + AsyncWrite,
458{
459    fn poll_read(
460        self: Pin<&mut Self>,
461        cx: &mut Context<'_>,
462        buf: &mut ReadBuf<'_>,
463    ) -> Poll<Result<(), io::Error>> {
464        self.project().stream.poll_read(cx, buf)
465    }
466}
467
468impl<S> AsyncWrite for TimeoutStream<S>
469where
470    S: AsyncRead + AsyncWrite,
471{
472    fn poll_write(
473        self: Pin<&mut Self>,
474        cx: &mut Context,
475        buf: &[u8],
476    ) -> Poll<Result<usize, io::Error>> {
477        self.project().stream.poll_write(cx, buf)
478    }
479
480    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
481        self.project().stream.poll_flush(cx)
482    }
483
484    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
485        self.project().stream.poll_shutdown(cx)
486    }
487
488    fn poll_write_vectored(
489        self: Pin<&mut Self>,
490        cx: &mut Context<'_>,
491        bufs: &[io::IoSlice<'_>],
492    ) -> Poll<io::Result<usize>> {
493        self.project().stream.poll_write_vectored(cx, bufs)
494    }
495
496    fn is_write_vectored(&self) -> bool {
497        self.stream.is_write_vectored()
498    }
499}
500
501#[cfg(test)]
502mod test {
503    use super::*;
504    use std::io::Write;
505    use std::net::TcpListener;
506    use std::thread;
507    use tokio::io::{AsyncReadExt, AsyncWriteExt};
508    use tokio::net::TcpStream;
509    use tokio::pin;
510
511    pin_project! {
512        struct DelayStream {
513            #[pin]
514            sleep: Sleep,
515        }
516    }
517
518    impl DelayStream {
519        fn new(until: Instant) -> Self {
520            DelayStream {
521                sleep: sleep_until(until),
522            }
523        }
524    }
525
526    impl AsyncRead for DelayStream {
527        fn poll_read(
528            self: Pin<&mut Self>,
529            cx: &mut Context,
530            _buf: &mut ReadBuf,
531        ) -> Poll<Result<(), io::Error>> {
532            match self.project().sleep.poll(cx) {
533                Poll::Ready(()) => Poll::Ready(Ok(())),
534                Poll::Pending => Poll::Pending,
535            }
536        }
537    }
538
539    impl AsyncWrite for DelayStream {
540        fn poll_write(
541            self: Pin<&mut Self>,
542            cx: &mut Context,
543            buf: &[u8],
544        ) -> Poll<Result<usize, io::Error>> {
545            match self.project().sleep.poll(cx) {
546                Poll::Ready(()) => Poll::Ready(Ok(buf.len())),
547                Poll::Pending => Poll::Pending,
548            }
549        }
550
551        fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<Result<(), io::Error>> {
552            Poll::Ready(Ok(()))
553        }
554
555        fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<Result<(), io::Error>> {
556            Poll::Ready(Ok(()))
557        }
558    }
559
560    #[tokio::test]
561    async fn read_timeout() {
562        let reader = DelayStream::new(Instant::now() + Duration::from_millis(500));
563        let mut reader = TimeoutReader::new(reader);
564        reader.set_timeout(Some(Duration::from_millis(100)));
565        pin!(reader);
566
567        let r = reader.read(&mut [0]).await;
568        assert_eq!(r.err().unwrap().kind(), io::ErrorKind::TimedOut);
569    }
570
571    #[tokio::test]
572    async fn read_ok() {
573        let reader = DelayStream::new(Instant::now() + Duration::from_millis(100));
574        let mut reader = TimeoutReader::new(reader);
575        reader.set_timeout(Some(Duration::from_millis(500)));
576        pin!(reader);
577
578        reader.read(&mut [0]).await.unwrap();
579    }
580
581    #[tokio::test]
582    async fn write_timeout() {
583        let writer = DelayStream::new(Instant::now() + Duration::from_millis(500));
584        let mut writer = TimeoutWriter::new(writer);
585        writer.set_timeout(Some(Duration::from_millis(100)));
586        pin!(writer);
587
588        let r = writer.write(&[0]).await;
589        assert_eq!(r.err().unwrap().kind(), io::ErrorKind::TimedOut);
590    }
591
592    #[tokio::test]
593    async fn write_ok() {
594        let writer = DelayStream::new(Instant::now() + Duration::from_millis(100));
595        let mut writer = TimeoutWriter::new(writer);
596        writer.set_timeout(Some(Duration::from_millis(500)));
597        pin!(writer);
598
599        writer.write(&[0]).await.unwrap();
600    }
601
602    #[tokio::test]
603    async fn tcp_read() {
604        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
605        let addr = listener.local_addr().unwrap();
606
607        thread::spawn(move || {
608            let mut socket = listener.accept().unwrap().0;
609            thread::sleep(Duration::from_millis(10));
610            socket.write_all(b"f").unwrap();
611            thread::sleep(Duration::from_millis(500));
612            let _ = socket.write_all(b"f"); // this may hit an eof
613        });
614
615        let s = TcpStream::connect(&addr).await.unwrap();
616        let mut s = TimeoutStream::new(s);
617        s.set_read_timeout(Some(Duration::from_millis(100)));
618        pin!(s);
619        s.read(&mut [0]).await.unwrap();
620        let r = s.read(&mut [0]).await;
621
622        match r {
623            Ok(_) => panic!("unexpected success"),
624            Err(ref e) if e.kind() == io::ErrorKind::TimedOut => (),
625            Err(e) => panic!("{:?}", e),
626        }
627    }
628}