tokio_io_timeout/
lib.rs

1//! Tokio wrappers which apply timeouts to IO operations.
2//!
3//! These timeouts are analogous to the read and write timeouts on traditional blocking sockets. A timeout countdown is
4//! initiated when a read/write operation returns [`Poll::Pending`]. If a read/write does not return successfully before
5//! the countdown expires, an [`io::Error`] with a kind of [`TimedOut`](io::ErrorKind::TimedOut) is returned.
6#![doc(html_root_url = "https://docs.rs/tokio-io-timeout/1")]
7#![warn(missing_docs)]
8
9use pin_project_lite::pin_project;
10use std::future::Future;
11use std::io;
12use std::io::SeekFrom;
13use std::pin::Pin;
14use std::task::{Context, Poll};
15use std::time::Duration;
16use tokio::io::{AsyncRead, AsyncSeek, AsyncWrite, ReadBuf};
17use tokio::time::{sleep_until, Instant, Sleep};
18
19pin_project! {
20    #[derive(Debug)]
21    struct TimeoutState {
22        timeout: Option<Duration>,
23        #[pin]
24        cur: Sleep,
25        active: bool,
26    }
27}
28
29impl TimeoutState {
30    #[inline]
31    fn new() -> TimeoutState {
32        TimeoutState {
33            timeout: None,
34            cur: sleep_until(Instant::now()),
35            active: false,
36        }
37    }
38
39    #[inline]
40    fn timeout(&self) -> Option<Duration> {
41        self.timeout
42    }
43
44    #[inline]
45    fn set_timeout(&mut self, timeout: Option<Duration>) {
46        // since this takes &mut self, we can't yet be active
47        self.timeout = timeout;
48    }
49
50    #[inline]
51    fn set_timeout_pinned(mut self: Pin<&mut Self>, timeout: Option<Duration>) {
52        *self.as_mut().project().timeout = timeout;
53        self.reset();
54    }
55
56    #[inline]
57    fn reset(self: Pin<&mut Self>) {
58        let this = self.project();
59
60        if *this.active {
61            *this.active = false;
62            this.cur.reset(Instant::now());
63        }
64    }
65
66    #[inline]
67    fn poll_check(self: Pin<&mut Self>, cx: &mut Context<'_>) -> io::Result<()> {
68        let mut this = self.project();
69
70        let timeout = match this.timeout {
71            Some(timeout) => *timeout,
72            None => return Ok(()),
73        };
74
75        if !*this.active {
76            this.cur.as_mut().reset(Instant::now() + timeout);
77            *this.active = true;
78        }
79
80        match this.cur.poll(cx) {
81            Poll::Ready(()) => {
82                *this.active = false;
83                Err(io::Error::from(io::ErrorKind::TimedOut))
84            }
85            Poll::Pending => Ok(()),
86        }
87    }
88}
89
90pin_project! {
91    /// An `AsyncRead`er which applies a timeout to read operations.
92    #[derive(Debug)]
93    pub struct TimeoutReader<R> {
94        #[pin]
95        reader: R,
96        #[pin]
97        state: TimeoutState,
98    }
99}
100
101impl<R> TimeoutReader<R>
102where
103    R: AsyncRead,
104{
105    /// Returns a new `TimeoutReader` wrapping the specified reader.
106    ///
107    /// There is initially no timeout.
108    pub fn new(reader: R) -> TimeoutReader<R> {
109        TimeoutReader {
110            reader,
111            state: TimeoutState::new(),
112        }
113    }
114
115    /// Returns the current read timeout.
116    pub fn timeout(&self) -> Option<Duration> {
117        self.state.timeout()
118    }
119
120    /// Sets the read timeout.
121    ///
122    /// This can only be used before the reader is pinned; use [`set_timeout_pinned`](Self::set_timeout_pinned)
123    /// otherwise.
124    pub fn set_timeout(&mut self, timeout: Option<Duration>) {
125        self.state.set_timeout(timeout);
126    }
127
128    /// Sets the read timeout.
129    ///
130    /// This will reset any pending timeout. Use [`set_timeout`](Self::set_timeout) instead if the reader is not yet
131    /// pinned.
132    pub fn set_timeout_pinned(self: Pin<&mut Self>, timeout: Option<Duration>) {
133        self.project().state.set_timeout_pinned(timeout);
134    }
135
136    /// Returns a shared reference to the inner reader.
137    pub fn get_ref(&self) -> &R {
138        &self.reader
139    }
140
141    /// Returns a mutable reference to the inner reader.
142    pub fn get_mut(&mut self) -> &mut R {
143        &mut self.reader
144    }
145
146    /// Returns a pinned mutable reference to the inner reader.
147    pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut R> {
148        self.project().reader
149    }
150
151    /// Consumes the `TimeoutReader`, returning the inner reader.
152    pub fn into_inner(self) -> R {
153        self.reader
154    }
155}
156
157impl<R> AsyncRead for TimeoutReader<R>
158where
159    R: AsyncRead,
160{
161    fn poll_read(
162        self: Pin<&mut Self>,
163        cx: &mut Context<'_>,
164        buf: &mut ReadBuf<'_>,
165    ) -> Poll<Result<(), io::Error>> {
166        let this = self.project();
167        let r = this.reader.poll_read(cx, buf);
168        match r {
169            Poll::Pending => this.state.poll_check(cx)?,
170            _ => this.state.reset(),
171        }
172        r
173    }
174}
175
176impl<R> AsyncWrite for TimeoutReader<R>
177where
178    R: AsyncWrite,
179{
180    fn poll_write(
181        self: Pin<&mut Self>,
182        cx: &mut Context,
183        buf: &[u8],
184    ) -> Poll<Result<usize, io::Error>> {
185        self.project().reader.poll_write(cx, buf)
186    }
187
188    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
189        self.project().reader.poll_flush(cx)
190    }
191
192    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
193        self.project().reader.poll_shutdown(cx)
194    }
195
196    fn poll_write_vectored(
197        self: Pin<&mut Self>,
198        cx: &mut Context<'_>,
199        bufs: &[io::IoSlice<'_>],
200    ) -> Poll<io::Result<usize>> {
201        self.project().reader.poll_write_vectored(cx, bufs)
202    }
203
204    fn is_write_vectored(&self) -> bool {
205        self.reader.is_write_vectored()
206    }
207}
208
209impl<R> AsyncSeek for TimeoutReader<R>
210where
211    R: AsyncSeek,
212{
213    fn start_seek(self: Pin<&mut Self>, position: SeekFrom) -> io::Result<()> {
214        self.project().reader.start_seek(position)
215    }
216    fn poll_complete(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> {
217        self.project().reader.poll_complete(cx)
218    }
219}
220
221pin_project! {
222    /// An `AsyncWrite`er which applies a timeout to write operations.
223    #[derive(Debug)]
224    pub struct TimeoutWriter<W> {
225        #[pin]
226        writer: W,
227        #[pin]
228        state: TimeoutState,
229    }
230}
231
232impl<W> TimeoutWriter<W>
233where
234    W: AsyncWrite,
235{
236    /// Returns a new `TimeoutReader` wrapping the specified reader.
237    ///
238    /// There is initially no timeout.
239    pub fn new(writer: W) -> TimeoutWriter<W> {
240        TimeoutWriter {
241            writer,
242            state: TimeoutState::new(),
243        }
244    }
245
246    /// Returns the current write timeout.
247    pub fn timeout(&self) -> Option<Duration> {
248        self.state.timeout()
249    }
250
251    /// Sets the write timeout.
252    ///
253    /// This can only be used before the writer is pinned; use [`set_timeout_pinned`](Self::set_timeout_pinned)
254    /// otherwise.
255    pub fn set_timeout(&mut self, timeout: Option<Duration>) {
256        self.state.set_timeout(timeout);
257    }
258
259    /// Sets the write timeout.
260    ///
261    /// This will reset any pending timeout. Use [`set_timeout`](Self::set_timeout) instead if the reader is not yet
262    /// pinned.
263    pub fn set_timeout_pinned(self: Pin<&mut Self>, timeout: Option<Duration>) {
264        self.project().state.set_timeout_pinned(timeout);
265    }
266
267    /// Returns a shared reference to the inner writer.
268    pub fn get_ref(&self) -> &W {
269        &self.writer
270    }
271
272    /// Returns a mutable reference to the inner writer.
273    pub fn get_mut(&mut self) -> &mut W {
274        &mut self.writer
275    }
276
277    /// Returns a pinned mutable reference to the inner writer.
278    pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut W> {
279        self.project().writer
280    }
281
282    /// Consumes the `TimeoutWriter`, returning the inner writer.
283    pub fn into_inner(self) -> W {
284        self.writer
285    }
286}
287
288impl<W> AsyncWrite for TimeoutWriter<W>
289where
290    W: AsyncWrite,
291{
292    fn poll_write(
293        self: Pin<&mut Self>,
294        cx: &mut Context,
295        buf: &[u8],
296    ) -> Poll<Result<usize, io::Error>> {
297        let this = self.project();
298        let r = this.writer.poll_write(cx, buf);
299        match r {
300            Poll::Pending => this.state.poll_check(cx)?,
301            _ => this.state.reset(),
302        }
303        r
304    }
305
306    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
307        let this = self.project();
308        let r = this.writer.poll_flush(cx);
309        match r {
310            Poll::Pending => this.state.poll_check(cx)?,
311            _ => this.state.reset(),
312        }
313        r
314    }
315
316    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
317        let this = self.project();
318        let r = this.writer.poll_shutdown(cx);
319        match r {
320            Poll::Pending => this.state.poll_check(cx)?,
321            _ => this.state.reset(),
322        }
323        r
324    }
325
326    fn poll_write_vectored(
327        self: Pin<&mut Self>,
328        cx: &mut Context<'_>,
329        bufs: &[io::IoSlice<'_>],
330    ) -> Poll<io::Result<usize>> {
331        let this = self.project();
332        let r = this.writer.poll_write_vectored(cx, bufs);
333        match r {
334            Poll::Pending => this.state.poll_check(cx)?,
335            _ => this.state.reset(),
336        }
337        r
338    }
339
340    fn is_write_vectored(&self) -> bool {
341        self.writer.is_write_vectored()
342    }
343}
344
345impl<W> AsyncRead for TimeoutWriter<W>
346where
347    W: AsyncRead,
348{
349    fn poll_read(
350        self: Pin<&mut Self>,
351        cx: &mut Context<'_>,
352        buf: &mut ReadBuf<'_>,
353    ) -> Poll<Result<(), io::Error>> {
354        self.project().writer.poll_read(cx, buf)
355    }
356}
357
358impl<W> AsyncSeek for TimeoutWriter<W>
359where
360    W: AsyncSeek,
361{
362    fn start_seek(self: Pin<&mut Self>, position: SeekFrom) -> io::Result<()> {
363        self.project().writer.start_seek(position)
364    }
365    fn poll_complete(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> {
366        self.project().writer.poll_complete(cx)
367    }
368}
369
370pin_project! {
371    /// A stream which applies read and write timeouts to an inner stream.
372    #[derive(Debug)]
373    pub struct TimeoutStream<S> {
374        #[pin]
375        stream: TimeoutReader<TimeoutWriter<S>>
376    }
377}
378
379impl<S> TimeoutStream<S>
380where
381    S: AsyncRead + AsyncWrite,
382{
383    /// Returns a new `TimeoutStream` wrapping the specified stream.
384    ///
385    /// There is initially no read or write timeout.
386    pub fn new(stream: S) -> TimeoutStream<S> {
387        let writer = TimeoutWriter::new(stream);
388        let stream = TimeoutReader::new(writer);
389        TimeoutStream { stream }
390    }
391
392    /// Returns the current read timeout.
393    pub fn read_timeout(&self) -> Option<Duration> {
394        self.stream.timeout()
395    }
396
397    /// Sets the read timeout.
398    ///
399    /// This can only be used before the stream is pinned; use
400    /// [`set_read_timeout_pinned`](Self::set_read_timeout_pinned) otherwise.
401    pub fn set_read_timeout(&mut self, timeout: Option<Duration>) {
402        self.stream.set_timeout(timeout)
403    }
404
405    /// Sets the read timeout.
406    ///
407    /// This will reset any pending read timeout. Use [`set_read_timeout`](Self::set_read_timeout) instead if the stream
408    /// has not yet been pinned.
409    pub fn set_read_timeout_pinned(self: Pin<&mut Self>, timeout: Option<Duration>) {
410        self.project().stream.set_timeout_pinned(timeout)
411    }
412
413    /// Returns the current write timeout.
414    pub fn write_timeout(&self) -> Option<Duration> {
415        self.stream.get_ref().timeout()
416    }
417
418    /// Sets the write timeout.
419    ///
420    /// This can only be used before the stream is pinned; use
421    /// [`set_write_timeout_pinned`](Self::set_write_timeout_pinned) otherwise.
422    pub fn set_write_timeout(&mut self, timeout: Option<Duration>) {
423        self.stream.get_mut().set_timeout(timeout)
424    }
425
426    /// Sets the write timeout.
427    ///
428    /// This will reset any pending write timeout. Use [`set_write_timeout`](Self::set_write_timeout) instead if the
429    /// stream has not yet been pinned.
430    pub fn set_write_timeout_pinned(self: Pin<&mut Self>, timeout: Option<Duration>) {
431        self.project()
432            .stream
433            .get_pin_mut()
434            .set_timeout_pinned(timeout)
435    }
436
437    /// Returns a shared reference to the inner stream.
438    pub fn get_ref(&self) -> &S {
439        self.stream.get_ref().get_ref()
440    }
441
442    /// Returns a mutable reference to the inner stream.
443    pub fn get_mut(&mut self) -> &mut S {
444        self.stream.get_mut().get_mut()
445    }
446
447    /// Returns a pinned mutable reference to the inner stream.
448    pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut S> {
449        self.project().stream.get_pin_mut().get_pin_mut()
450    }
451
452    /// Consumes the stream, returning the inner stream.
453    pub fn into_inner(self) -> S {
454        self.stream.into_inner().into_inner()
455    }
456}
457
458impl<S> AsyncRead for TimeoutStream<S>
459where
460    S: AsyncRead + AsyncWrite,
461{
462    fn poll_read(
463        self: Pin<&mut Self>,
464        cx: &mut Context<'_>,
465        buf: &mut ReadBuf<'_>,
466    ) -> Poll<Result<(), io::Error>> {
467        self.project().stream.poll_read(cx, buf)
468    }
469}
470
471impl<S> AsyncWrite for TimeoutStream<S>
472where
473    S: AsyncRead + AsyncWrite,
474{
475    fn poll_write(
476        self: Pin<&mut Self>,
477        cx: &mut Context,
478        buf: &[u8],
479    ) -> Poll<Result<usize, io::Error>> {
480        self.project().stream.poll_write(cx, buf)
481    }
482
483    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
484        self.project().stream.poll_flush(cx)
485    }
486
487    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
488        self.project().stream.poll_shutdown(cx)
489    }
490
491    fn poll_write_vectored(
492        self: Pin<&mut Self>,
493        cx: &mut Context<'_>,
494        bufs: &[io::IoSlice<'_>],
495    ) -> Poll<io::Result<usize>> {
496        self.project().stream.poll_write_vectored(cx, bufs)
497    }
498
499    fn is_write_vectored(&self) -> bool {
500        self.stream.is_write_vectored()
501    }
502}
503
504#[cfg(test)]
505mod test {
506    use super::*;
507    use std::io::Write;
508    use std::net::TcpListener;
509    use std::thread;
510    use tokio::io::{AsyncReadExt, AsyncWriteExt};
511    use tokio::net::TcpStream;
512    use tokio::pin;
513
514    pin_project! {
515        struct DelayStream {
516            #[pin]
517            sleep: Sleep,
518        }
519    }
520
521    impl DelayStream {
522        fn new(until: Instant) -> Self {
523            DelayStream {
524                sleep: sleep_until(until),
525            }
526        }
527    }
528
529    impl AsyncRead for DelayStream {
530        fn poll_read(
531            self: Pin<&mut Self>,
532            cx: &mut Context,
533            _buf: &mut ReadBuf,
534        ) -> Poll<Result<(), io::Error>> {
535            match self.project().sleep.poll(cx) {
536                Poll::Ready(()) => Poll::Ready(Ok(())),
537                Poll::Pending => Poll::Pending,
538            }
539        }
540    }
541
542    impl AsyncWrite for DelayStream {
543        fn poll_write(
544            self: Pin<&mut Self>,
545            cx: &mut Context,
546            buf: &[u8],
547        ) -> Poll<Result<usize, io::Error>> {
548            match self.project().sleep.poll(cx) {
549                Poll::Ready(()) => Poll::Ready(Ok(buf.len())),
550                Poll::Pending => Poll::Pending,
551            }
552        }
553
554        fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<Result<(), io::Error>> {
555            Poll::Ready(Ok(()))
556        }
557
558        fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<Result<(), io::Error>> {
559            Poll::Ready(Ok(()))
560        }
561    }
562
563    #[tokio::test]
564    async fn read_timeout() {
565        let reader = DelayStream::new(Instant::now() + Duration::from_millis(150));
566        let mut reader = TimeoutReader::new(reader);
567        reader.set_timeout(Some(Duration::from_millis(100)));
568        pin!(reader);
569
570        let r = reader.read(&mut [0]).await;
571        assert_eq!(r.err().unwrap().kind(), io::ErrorKind::TimedOut);
572
573        let _ = reader.read(&mut [0]).await.unwrap();
574    }
575
576    #[tokio::test]
577    async fn read_ok() {
578        let reader = DelayStream::new(Instant::now() + Duration::from_millis(100));
579        let mut reader = TimeoutReader::new(reader);
580        reader.set_timeout(Some(Duration::from_millis(500)));
581        pin!(reader);
582
583        let _ = reader.read(&mut [0]).await.unwrap();
584    }
585
586    #[tokio::test]
587    async fn write_timeout() {
588        let writer = DelayStream::new(Instant::now() + Duration::from_millis(150));
589        let mut writer = TimeoutWriter::new(writer);
590        writer.set_timeout(Some(Duration::from_millis(100)));
591        pin!(writer);
592
593        let r = writer.write(&[0]).await;
594        assert_eq!(r.err().unwrap().kind(), io::ErrorKind::TimedOut);
595
596        let _ = writer.write(&[0]).await.unwrap();
597    }
598
599    #[tokio::test]
600    async fn write_ok() {
601        let writer = DelayStream::new(Instant::now() + Duration::from_millis(100));
602        let mut writer = TimeoutWriter::new(writer);
603        writer.set_timeout(Some(Duration::from_millis(500)));
604        pin!(writer);
605
606        let _ = writer.write(&[0]).await.unwrap();
607    }
608
609    #[tokio::test]
610    async fn tcp_read() {
611        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
612        let addr = listener.local_addr().unwrap();
613
614        thread::spawn(move || {
615            let mut socket = listener.accept().unwrap().0;
616            thread::sleep(Duration::from_millis(10));
617            socket.write_all(b"f").unwrap();
618            thread::sleep(Duration::from_millis(500));
619            let _ = socket.write_all(b"f"); // this may hit an eof
620        });
621
622        let s = TcpStream::connect(&addr).await.unwrap();
623        let mut s = TimeoutStream::new(s);
624        s.set_read_timeout(Some(Duration::from_millis(100)));
625        pin!(s);
626        let _ = s.read(&mut [0]).await.unwrap();
627        let r = s.read(&mut [0]).await;
628
629        match r {
630            Ok(_) => panic!("unexpected success"),
631            Err(ref e) if e.kind() == io::ErrorKind::TimedOut => (),
632            Err(e) => panic!("{:?}", e),
633        }
634    }
635}