1#![doc(html_root_url = "https://docs.rs/tokio-io-timeout/1")]
7#![warn(missing_docs)]
8
9use pin_project_lite::pin_project;
10use std::future::Future;
11use std::io;
12use std::io::SeekFrom;
13use std::pin::Pin;
14use std::task::{Context, Poll};
15use std::time::Duration;
16use tokio::io::{AsyncRead, AsyncSeek, AsyncWrite, ReadBuf};
17use tokio::time::{sleep_until, Instant, Sleep};
18
19pin_project! {
20 #[derive(Debug)]
21 struct TimeoutState {
22 timeout: Option<Duration>,
23 #[pin]
24 cur: Sleep,
25 active: bool,
26 }
27}
28
29impl TimeoutState {
30 #[inline]
31 fn new() -> TimeoutState {
32 TimeoutState {
33 timeout: None,
34 cur: sleep_until(Instant::now()),
35 active: false,
36 }
37 }
38
39 #[inline]
40 fn timeout(&self) -> Option<Duration> {
41 self.timeout
42 }
43
44 #[inline]
45 fn set_timeout(&mut self, timeout: Option<Duration>) {
46 self.timeout = timeout;
48 }
49
50 #[inline]
51 fn set_timeout_pinned(mut self: Pin<&mut Self>, timeout: Option<Duration>) {
52 *self.as_mut().project().timeout = timeout;
53 self.reset();
54 }
55
56 #[inline]
57 fn reset(self: Pin<&mut Self>) {
58 let this = self.project();
59
60 if *this.active {
61 *this.active = false;
62 this.cur.reset(Instant::now());
63 }
64 }
65
66 #[inline]
67 fn poll_check(self: Pin<&mut Self>, cx: &mut Context<'_>) -> io::Result<()> {
68 let mut this = self.project();
69
70 let timeout = match this.timeout {
71 Some(timeout) => *timeout,
72 None => return Ok(()),
73 };
74
75 if !*this.active {
76 this.cur.as_mut().reset(Instant::now() + timeout);
77 *this.active = true;
78 }
79
80 match this.cur.poll(cx) {
81 Poll::Ready(()) => {
82 *this.active = false;
83 Err(io::Error::from(io::ErrorKind::TimedOut))
84 }
85 Poll::Pending => Ok(()),
86 }
87 }
88}
89
90pin_project! {
91 #[derive(Debug)]
93 pub struct TimeoutReader<R> {
94 #[pin]
95 reader: R,
96 #[pin]
97 state: TimeoutState,
98 }
99}
100
101impl<R> TimeoutReader<R>
102where
103 R: AsyncRead,
104{
105 pub fn new(reader: R) -> TimeoutReader<R> {
109 TimeoutReader {
110 reader,
111 state: TimeoutState::new(),
112 }
113 }
114
115 pub fn timeout(&self) -> Option<Duration> {
117 self.state.timeout()
118 }
119
120 pub fn set_timeout(&mut self, timeout: Option<Duration>) {
125 self.state.set_timeout(timeout);
126 }
127
128 pub fn set_timeout_pinned(self: Pin<&mut Self>, timeout: Option<Duration>) {
133 self.project().state.set_timeout_pinned(timeout);
134 }
135
136 pub fn get_ref(&self) -> &R {
138 &self.reader
139 }
140
141 pub fn get_mut(&mut self) -> &mut R {
143 &mut self.reader
144 }
145
146 pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut R> {
148 self.project().reader
149 }
150
151 pub fn into_inner(self) -> R {
153 self.reader
154 }
155}
156
157impl<R> AsyncRead for TimeoutReader<R>
158where
159 R: AsyncRead,
160{
161 fn poll_read(
162 self: Pin<&mut Self>,
163 cx: &mut Context<'_>,
164 buf: &mut ReadBuf<'_>,
165 ) -> Poll<Result<(), io::Error>> {
166 let this = self.project();
167 let r = this.reader.poll_read(cx, buf);
168 match r {
169 Poll::Pending => this.state.poll_check(cx)?,
170 _ => this.state.reset(),
171 }
172 r
173 }
174}
175
176impl<R> AsyncWrite for TimeoutReader<R>
177where
178 R: AsyncWrite,
179{
180 fn poll_write(
181 self: Pin<&mut Self>,
182 cx: &mut Context,
183 buf: &[u8],
184 ) -> Poll<Result<usize, io::Error>> {
185 self.project().reader.poll_write(cx, buf)
186 }
187
188 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
189 self.project().reader.poll_flush(cx)
190 }
191
192 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
193 self.project().reader.poll_shutdown(cx)
194 }
195
196 fn poll_write_vectored(
197 self: Pin<&mut Self>,
198 cx: &mut Context<'_>,
199 bufs: &[io::IoSlice<'_>],
200 ) -> Poll<io::Result<usize>> {
201 self.project().reader.poll_write_vectored(cx, bufs)
202 }
203
204 fn is_write_vectored(&self) -> bool {
205 self.reader.is_write_vectored()
206 }
207}
208
209impl<R> AsyncSeek for TimeoutReader<R>
210where
211 R: AsyncSeek,
212{
213 fn start_seek(self: Pin<&mut Self>, position: SeekFrom) -> io::Result<()> {
214 self.project().reader.start_seek(position)
215 }
216 fn poll_complete(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> {
217 self.project().reader.poll_complete(cx)
218 }
219}
220
221pin_project! {
222 #[derive(Debug)]
224 pub struct TimeoutWriter<W> {
225 #[pin]
226 writer: W,
227 #[pin]
228 state: TimeoutState,
229 }
230}
231
232impl<W> TimeoutWriter<W>
233where
234 W: AsyncWrite,
235{
236 pub fn new(writer: W) -> TimeoutWriter<W> {
240 TimeoutWriter {
241 writer,
242 state: TimeoutState::new(),
243 }
244 }
245
246 pub fn timeout(&self) -> Option<Duration> {
248 self.state.timeout()
249 }
250
251 pub fn set_timeout(&mut self, timeout: Option<Duration>) {
256 self.state.set_timeout(timeout);
257 }
258
259 pub fn set_timeout_pinned(self: Pin<&mut Self>, timeout: Option<Duration>) {
264 self.project().state.set_timeout_pinned(timeout);
265 }
266
267 pub fn get_ref(&self) -> &W {
269 &self.writer
270 }
271
272 pub fn get_mut(&mut self) -> &mut W {
274 &mut self.writer
275 }
276
277 pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut W> {
279 self.project().writer
280 }
281
282 pub fn into_inner(self) -> W {
284 self.writer
285 }
286}
287
288impl<W> AsyncWrite for TimeoutWriter<W>
289where
290 W: AsyncWrite,
291{
292 fn poll_write(
293 self: Pin<&mut Self>,
294 cx: &mut Context,
295 buf: &[u8],
296 ) -> Poll<Result<usize, io::Error>> {
297 let this = self.project();
298 let r = this.writer.poll_write(cx, buf);
299 match r {
300 Poll::Pending => this.state.poll_check(cx)?,
301 _ => this.state.reset(),
302 }
303 r
304 }
305
306 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
307 let this = self.project();
308 let r = this.writer.poll_flush(cx);
309 match r {
310 Poll::Pending => this.state.poll_check(cx)?,
311 _ => this.state.reset(),
312 }
313 r
314 }
315
316 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
317 let this = self.project();
318 let r = this.writer.poll_shutdown(cx);
319 match r {
320 Poll::Pending => this.state.poll_check(cx)?,
321 _ => this.state.reset(),
322 }
323 r
324 }
325
326 fn poll_write_vectored(
327 self: Pin<&mut Self>,
328 cx: &mut Context<'_>,
329 bufs: &[io::IoSlice<'_>],
330 ) -> Poll<io::Result<usize>> {
331 let this = self.project();
332 let r = this.writer.poll_write_vectored(cx, bufs);
333 match r {
334 Poll::Pending => this.state.poll_check(cx)?,
335 _ => this.state.reset(),
336 }
337 r
338 }
339
340 fn is_write_vectored(&self) -> bool {
341 self.writer.is_write_vectored()
342 }
343}
344
345impl<W> AsyncRead for TimeoutWriter<W>
346where
347 W: AsyncRead,
348{
349 fn poll_read(
350 self: Pin<&mut Self>,
351 cx: &mut Context<'_>,
352 buf: &mut ReadBuf<'_>,
353 ) -> Poll<Result<(), io::Error>> {
354 self.project().writer.poll_read(cx, buf)
355 }
356}
357
358impl<W> AsyncSeek for TimeoutWriter<W>
359where
360 W: AsyncSeek,
361{
362 fn start_seek(self: Pin<&mut Self>, position: SeekFrom) -> io::Result<()> {
363 self.project().writer.start_seek(position)
364 }
365 fn poll_complete(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> {
366 self.project().writer.poll_complete(cx)
367 }
368}
369
370pin_project! {
371 #[derive(Debug)]
373 pub struct TimeoutStream<S> {
374 #[pin]
375 stream: TimeoutReader<TimeoutWriter<S>>
376 }
377}
378
379impl<S> TimeoutStream<S>
380where
381 S: AsyncRead + AsyncWrite,
382{
383 pub fn new(stream: S) -> TimeoutStream<S> {
387 let writer = TimeoutWriter::new(stream);
388 let stream = TimeoutReader::new(writer);
389 TimeoutStream { stream }
390 }
391
392 pub fn read_timeout(&self) -> Option<Duration> {
394 self.stream.timeout()
395 }
396
397 pub fn set_read_timeout(&mut self, timeout: Option<Duration>) {
402 self.stream.set_timeout(timeout)
403 }
404
405 pub fn set_read_timeout_pinned(self: Pin<&mut Self>, timeout: Option<Duration>) {
410 self.project().stream.set_timeout_pinned(timeout)
411 }
412
413 pub fn write_timeout(&self) -> Option<Duration> {
415 self.stream.get_ref().timeout()
416 }
417
418 pub fn set_write_timeout(&mut self, timeout: Option<Duration>) {
423 self.stream.get_mut().set_timeout(timeout)
424 }
425
426 pub fn set_write_timeout_pinned(self: Pin<&mut Self>, timeout: Option<Duration>) {
431 self.project()
432 .stream
433 .get_pin_mut()
434 .set_timeout_pinned(timeout)
435 }
436
437 pub fn get_ref(&self) -> &S {
439 self.stream.get_ref().get_ref()
440 }
441
442 pub fn get_mut(&mut self) -> &mut S {
444 self.stream.get_mut().get_mut()
445 }
446
447 pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut S> {
449 self.project().stream.get_pin_mut().get_pin_mut()
450 }
451
452 pub fn into_inner(self) -> S {
454 self.stream.into_inner().into_inner()
455 }
456}
457
458impl<S> AsyncRead for TimeoutStream<S>
459where
460 S: AsyncRead + AsyncWrite,
461{
462 fn poll_read(
463 self: Pin<&mut Self>,
464 cx: &mut Context<'_>,
465 buf: &mut ReadBuf<'_>,
466 ) -> Poll<Result<(), io::Error>> {
467 self.project().stream.poll_read(cx, buf)
468 }
469}
470
471impl<S> AsyncWrite for TimeoutStream<S>
472where
473 S: AsyncRead + AsyncWrite,
474{
475 fn poll_write(
476 self: Pin<&mut Self>,
477 cx: &mut Context,
478 buf: &[u8],
479 ) -> Poll<Result<usize, io::Error>> {
480 self.project().stream.poll_write(cx, buf)
481 }
482
483 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
484 self.project().stream.poll_flush(cx)
485 }
486
487 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
488 self.project().stream.poll_shutdown(cx)
489 }
490
491 fn poll_write_vectored(
492 self: Pin<&mut Self>,
493 cx: &mut Context<'_>,
494 bufs: &[io::IoSlice<'_>],
495 ) -> Poll<io::Result<usize>> {
496 self.project().stream.poll_write_vectored(cx, bufs)
497 }
498
499 fn is_write_vectored(&self) -> bool {
500 self.stream.is_write_vectored()
501 }
502}
503
504#[cfg(test)]
505mod test {
506 use super::*;
507 use std::io::Write;
508 use std::net::TcpListener;
509 use std::thread;
510 use tokio::io::{AsyncReadExt, AsyncWriteExt};
511 use tokio::net::TcpStream;
512 use tokio::pin;
513
514 pin_project! {
515 struct DelayStream {
516 #[pin]
517 sleep: Sleep,
518 }
519 }
520
521 impl DelayStream {
522 fn new(until: Instant) -> Self {
523 DelayStream {
524 sleep: sleep_until(until),
525 }
526 }
527 }
528
529 impl AsyncRead for DelayStream {
530 fn poll_read(
531 self: Pin<&mut Self>,
532 cx: &mut Context,
533 _buf: &mut ReadBuf,
534 ) -> Poll<Result<(), io::Error>> {
535 match self.project().sleep.poll(cx) {
536 Poll::Ready(()) => Poll::Ready(Ok(())),
537 Poll::Pending => Poll::Pending,
538 }
539 }
540 }
541
542 impl AsyncWrite for DelayStream {
543 fn poll_write(
544 self: Pin<&mut Self>,
545 cx: &mut Context,
546 buf: &[u8],
547 ) -> Poll<Result<usize, io::Error>> {
548 match self.project().sleep.poll(cx) {
549 Poll::Ready(()) => Poll::Ready(Ok(buf.len())),
550 Poll::Pending => Poll::Pending,
551 }
552 }
553
554 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<Result<(), io::Error>> {
555 Poll::Ready(Ok(()))
556 }
557
558 fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<Result<(), io::Error>> {
559 Poll::Ready(Ok(()))
560 }
561 }
562
563 #[tokio::test]
564 async fn read_timeout() {
565 let reader = DelayStream::new(Instant::now() + Duration::from_millis(150));
566 let mut reader = TimeoutReader::new(reader);
567 reader.set_timeout(Some(Duration::from_millis(100)));
568 pin!(reader);
569
570 let r = reader.read(&mut [0]).await;
571 assert_eq!(r.err().unwrap().kind(), io::ErrorKind::TimedOut);
572
573 let _ = reader.read(&mut [0]).await.unwrap();
574 }
575
576 #[tokio::test]
577 async fn read_ok() {
578 let reader = DelayStream::new(Instant::now() + Duration::from_millis(100));
579 let mut reader = TimeoutReader::new(reader);
580 reader.set_timeout(Some(Duration::from_millis(500)));
581 pin!(reader);
582
583 let _ = reader.read(&mut [0]).await.unwrap();
584 }
585
586 #[tokio::test]
587 async fn write_timeout() {
588 let writer = DelayStream::new(Instant::now() + Duration::from_millis(150));
589 let mut writer = TimeoutWriter::new(writer);
590 writer.set_timeout(Some(Duration::from_millis(100)));
591 pin!(writer);
592
593 let r = writer.write(&[0]).await;
594 assert_eq!(r.err().unwrap().kind(), io::ErrorKind::TimedOut);
595
596 let _ = writer.write(&[0]).await.unwrap();
597 }
598
599 #[tokio::test]
600 async fn write_ok() {
601 let writer = DelayStream::new(Instant::now() + Duration::from_millis(100));
602 let mut writer = TimeoutWriter::new(writer);
603 writer.set_timeout(Some(Duration::from_millis(500)));
604 pin!(writer);
605
606 let _ = writer.write(&[0]).await.unwrap();
607 }
608
609 #[tokio::test]
610 async fn tcp_read() {
611 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
612 let addr = listener.local_addr().unwrap();
613
614 thread::spawn(move || {
615 let mut socket = listener.accept().unwrap().0;
616 thread::sleep(Duration::from_millis(10));
617 socket.write_all(b"f").unwrap();
618 thread::sleep(Duration::from_millis(500));
619 let _ = socket.write_all(b"f"); });
621
622 let s = TcpStream::connect(&addr).await.unwrap();
623 let mut s = TimeoutStream::new(s);
624 s.set_read_timeout(Some(Duration::from_millis(100)));
625 pin!(s);
626 let _ = s.read(&mut [0]).await.unwrap();
627 let r = s.read(&mut [0]).await;
628
629 match r {
630 Ok(_) => panic!("unexpected success"),
631 Err(ref e) if e.kind() == io::ErrorKind::TimedOut => (),
632 Err(e) => panic!("{:?}", e),
633 }
634 }
635}