1#![doc(html_root_url = "https://docs.rs/tokio-io-timeout/1")]
7#![warn(missing_docs)]
8
9use pin_project_lite::pin_project;
10use std::future::Future;
11use std::io;
12use std::io::SeekFrom;
13use std::pin::Pin;
14use std::task::{Context, Poll};
15use std::time::Duration;
16use tokio::io::{AsyncRead, AsyncSeek, AsyncWrite, ReadBuf};
17use tokio::time::{sleep_until, Instant, Sleep};
18
19pin_project! {
20 #[derive(Debug)]
21 struct TimeoutState {
22 timeout: Option<Duration>,
23 #[pin]
24 cur: Sleep,
25 active: bool,
26 }
27}
28
29impl TimeoutState {
30 #[inline]
31 fn new() -> TimeoutState {
32 TimeoutState {
33 timeout: None,
34 cur: sleep_until(Instant::now()),
35 active: false,
36 }
37 }
38
39 #[inline]
40 fn timeout(&self) -> Option<Duration> {
41 self.timeout
42 }
43
44 #[inline]
45 fn set_timeout(&mut self, timeout: Option<Duration>) {
46 self.timeout = timeout;
48 }
49
50 #[inline]
51 fn set_timeout_pinned(mut self: Pin<&mut Self>, timeout: Option<Duration>) {
52 *self.as_mut().project().timeout = timeout;
53 self.reset();
54 }
55
56 #[inline]
57 fn reset(self: Pin<&mut Self>) {
58 let this = self.project();
59
60 if *this.active {
61 *this.active = false;
62 this.cur.reset(Instant::now());
63 }
64 }
65
66 #[inline]
67 fn poll_check(self: Pin<&mut Self>, cx: &mut Context<'_>) -> io::Result<()> {
68 let mut this = self.project();
69
70 let timeout = match this.timeout {
71 Some(timeout) => *timeout,
72 None => return Ok(()),
73 };
74
75 if !*this.active {
76 this.cur.as_mut().reset(Instant::now() + timeout);
77 *this.active = true;
78 }
79
80 match this.cur.poll(cx) {
81 Poll::Ready(()) => Err(io::Error::from(io::ErrorKind::TimedOut)),
82 Poll::Pending => Ok(()),
83 }
84 }
85}
86
87pin_project! {
88 #[derive(Debug)]
90 pub struct TimeoutReader<R> {
91 #[pin]
92 reader: R,
93 #[pin]
94 state: TimeoutState,
95 }
96}
97
98impl<R> TimeoutReader<R>
99where
100 R: AsyncRead,
101{
102 pub fn new(reader: R) -> TimeoutReader<R> {
106 TimeoutReader {
107 reader,
108 state: TimeoutState::new(),
109 }
110 }
111
112 pub fn timeout(&self) -> Option<Duration> {
114 self.state.timeout()
115 }
116
117 pub fn set_timeout(&mut self, timeout: Option<Duration>) {
122 self.state.set_timeout(timeout);
123 }
124
125 pub fn set_timeout_pinned(self: Pin<&mut Self>, timeout: Option<Duration>) {
130 self.project().state.set_timeout_pinned(timeout);
131 }
132
133 pub fn get_ref(&self) -> &R {
135 &self.reader
136 }
137
138 pub fn get_mut(&mut self) -> &mut R {
140 &mut self.reader
141 }
142
143 pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut R> {
145 self.project().reader
146 }
147
148 pub fn into_inner(self) -> R {
150 self.reader
151 }
152}
153
154impl<R> AsyncRead for TimeoutReader<R>
155where
156 R: AsyncRead,
157{
158 fn poll_read(
159 self: Pin<&mut Self>,
160 cx: &mut Context<'_>,
161 buf: &mut ReadBuf<'_>,
162 ) -> Poll<Result<(), io::Error>> {
163 let this = self.project();
164 let r = this.reader.poll_read(cx, buf);
165 match r {
166 Poll::Pending => this.state.poll_check(cx)?,
167 _ => this.state.reset(),
168 }
169 r
170 }
171}
172
173impl<R> AsyncWrite for TimeoutReader<R>
174where
175 R: AsyncWrite,
176{
177 fn poll_write(
178 self: Pin<&mut Self>,
179 cx: &mut Context,
180 buf: &[u8],
181 ) -> Poll<Result<usize, io::Error>> {
182 self.project().reader.poll_write(cx, buf)
183 }
184
185 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
186 self.project().reader.poll_flush(cx)
187 }
188
189 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
190 self.project().reader.poll_shutdown(cx)
191 }
192
193 fn poll_write_vectored(
194 self: Pin<&mut Self>,
195 cx: &mut Context<'_>,
196 bufs: &[io::IoSlice<'_>],
197 ) -> Poll<io::Result<usize>> {
198 self.project().reader.poll_write_vectored(cx, bufs)
199 }
200
201 fn is_write_vectored(&self) -> bool {
202 self.reader.is_write_vectored()
203 }
204}
205
206impl<R> AsyncSeek for TimeoutReader<R>
207where
208 R: AsyncSeek,
209{
210 fn start_seek(self: Pin<&mut Self>, position: SeekFrom) -> io::Result<()> {
211 self.project().reader.start_seek(position)
212 }
213 fn poll_complete(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> {
214 self.project().reader.poll_complete(cx)
215 }
216}
217
218pin_project! {
219 #[derive(Debug)]
221 pub struct TimeoutWriter<W> {
222 #[pin]
223 writer: W,
224 #[pin]
225 state: TimeoutState,
226 }
227}
228
229impl<W> TimeoutWriter<W>
230where
231 W: AsyncWrite,
232{
233 pub fn new(writer: W) -> TimeoutWriter<W> {
237 TimeoutWriter {
238 writer,
239 state: TimeoutState::new(),
240 }
241 }
242
243 pub fn timeout(&self) -> Option<Duration> {
245 self.state.timeout()
246 }
247
248 pub fn set_timeout(&mut self, timeout: Option<Duration>) {
253 self.state.set_timeout(timeout);
254 }
255
256 pub fn set_timeout_pinned(self: Pin<&mut Self>, timeout: Option<Duration>) {
261 self.project().state.set_timeout_pinned(timeout);
262 }
263
264 pub fn get_ref(&self) -> &W {
266 &self.writer
267 }
268
269 pub fn get_mut(&mut self) -> &mut W {
271 &mut self.writer
272 }
273
274 pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut W> {
276 self.project().writer
277 }
278
279 pub fn into_inner(self) -> W {
281 self.writer
282 }
283}
284
285impl<W> AsyncWrite for TimeoutWriter<W>
286where
287 W: AsyncWrite,
288{
289 fn poll_write(
290 self: Pin<&mut Self>,
291 cx: &mut Context,
292 buf: &[u8],
293 ) -> Poll<Result<usize, io::Error>> {
294 let this = self.project();
295 let r = this.writer.poll_write(cx, buf);
296 match r {
297 Poll::Pending => this.state.poll_check(cx)?,
298 _ => this.state.reset(),
299 }
300 r
301 }
302
303 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
304 let this = self.project();
305 let r = this.writer.poll_flush(cx);
306 match r {
307 Poll::Pending => this.state.poll_check(cx)?,
308 _ => this.state.reset(),
309 }
310 r
311 }
312
313 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
314 let this = self.project();
315 let r = this.writer.poll_shutdown(cx);
316 match r {
317 Poll::Pending => this.state.poll_check(cx)?,
318 _ => this.state.reset(),
319 }
320 r
321 }
322
323 fn poll_write_vectored(
324 self: Pin<&mut Self>,
325 cx: &mut Context<'_>,
326 bufs: &[io::IoSlice<'_>],
327 ) -> Poll<io::Result<usize>> {
328 let this = self.project();
329 let r = this.writer.poll_write_vectored(cx, bufs);
330 match r {
331 Poll::Pending => this.state.poll_check(cx)?,
332 _ => this.state.reset(),
333 }
334 r
335 }
336
337 fn is_write_vectored(&self) -> bool {
338 self.writer.is_write_vectored()
339 }
340}
341
342impl<W> AsyncRead for TimeoutWriter<W>
343where
344 W: AsyncRead,
345{
346 fn poll_read(
347 self: Pin<&mut Self>,
348 cx: &mut Context<'_>,
349 buf: &mut ReadBuf<'_>,
350 ) -> Poll<Result<(), io::Error>> {
351 self.project().writer.poll_read(cx, buf)
352 }
353}
354
355impl<W> AsyncSeek for TimeoutWriter<W>
356where
357 W: AsyncSeek,
358{
359 fn start_seek(self: Pin<&mut Self>, position: SeekFrom) -> io::Result<()> {
360 self.project().writer.start_seek(position)
361 }
362 fn poll_complete(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> {
363 self.project().writer.poll_complete(cx)
364 }
365}
366
367pin_project! {
368 #[derive(Debug)]
370 pub struct TimeoutStream<S> {
371 #[pin]
372 stream: TimeoutReader<TimeoutWriter<S>>
373 }
374}
375
376impl<S> TimeoutStream<S>
377where
378 S: AsyncRead + AsyncWrite,
379{
380 pub fn new(stream: S) -> TimeoutStream<S> {
384 let writer = TimeoutWriter::new(stream);
385 let stream = TimeoutReader::new(writer);
386 TimeoutStream { stream }
387 }
388
389 pub fn read_timeout(&self) -> Option<Duration> {
391 self.stream.timeout()
392 }
393
394 pub fn set_read_timeout(&mut self, timeout: Option<Duration>) {
399 self.stream.set_timeout(timeout)
400 }
401
402 pub fn set_read_timeout_pinned(self: Pin<&mut Self>, timeout: Option<Duration>) {
407 self.project().stream.set_timeout_pinned(timeout)
408 }
409
410 pub fn write_timeout(&self) -> Option<Duration> {
412 self.stream.get_ref().timeout()
413 }
414
415 pub fn set_write_timeout(&mut self, timeout: Option<Duration>) {
420 self.stream.get_mut().set_timeout(timeout)
421 }
422
423 pub fn set_write_timeout_pinned(self: Pin<&mut Self>, timeout: Option<Duration>) {
428 self.project()
429 .stream
430 .get_pin_mut()
431 .set_timeout_pinned(timeout)
432 }
433
434 pub fn get_ref(&self) -> &S {
436 self.stream.get_ref().get_ref()
437 }
438
439 pub fn get_mut(&mut self) -> &mut S {
441 self.stream.get_mut().get_mut()
442 }
443
444 pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut S> {
446 self.project().stream.get_pin_mut().get_pin_mut()
447 }
448
449 pub fn into_inner(self) -> S {
451 self.stream.into_inner().into_inner()
452 }
453}
454
455impl<S> AsyncRead for TimeoutStream<S>
456where
457 S: AsyncRead + AsyncWrite,
458{
459 fn poll_read(
460 self: Pin<&mut Self>,
461 cx: &mut Context<'_>,
462 buf: &mut ReadBuf<'_>,
463 ) -> Poll<Result<(), io::Error>> {
464 self.project().stream.poll_read(cx, buf)
465 }
466}
467
468impl<S> AsyncWrite for TimeoutStream<S>
469where
470 S: AsyncRead + AsyncWrite,
471{
472 fn poll_write(
473 self: Pin<&mut Self>,
474 cx: &mut Context,
475 buf: &[u8],
476 ) -> Poll<Result<usize, io::Error>> {
477 self.project().stream.poll_write(cx, buf)
478 }
479
480 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
481 self.project().stream.poll_flush(cx)
482 }
483
484 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
485 self.project().stream.poll_shutdown(cx)
486 }
487
488 fn poll_write_vectored(
489 self: Pin<&mut Self>,
490 cx: &mut Context<'_>,
491 bufs: &[io::IoSlice<'_>],
492 ) -> Poll<io::Result<usize>> {
493 self.project().stream.poll_write_vectored(cx, bufs)
494 }
495
496 fn is_write_vectored(&self) -> bool {
497 self.stream.is_write_vectored()
498 }
499}
500
501#[cfg(test)]
502mod test {
503 use super::*;
504 use std::io::Write;
505 use std::net::TcpListener;
506 use std::thread;
507 use tokio::io::{AsyncReadExt, AsyncWriteExt};
508 use tokio::net::TcpStream;
509 use tokio::pin;
510
511 pin_project! {
512 struct DelayStream {
513 #[pin]
514 sleep: Sleep,
515 }
516 }
517
518 impl DelayStream {
519 fn new(until: Instant) -> Self {
520 DelayStream {
521 sleep: sleep_until(until),
522 }
523 }
524 }
525
526 impl AsyncRead for DelayStream {
527 fn poll_read(
528 self: Pin<&mut Self>,
529 cx: &mut Context,
530 _buf: &mut ReadBuf,
531 ) -> Poll<Result<(), io::Error>> {
532 match self.project().sleep.poll(cx) {
533 Poll::Ready(()) => Poll::Ready(Ok(())),
534 Poll::Pending => Poll::Pending,
535 }
536 }
537 }
538
539 impl AsyncWrite for DelayStream {
540 fn poll_write(
541 self: Pin<&mut Self>,
542 cx: &mut Context,
543 buf: &[u8],
544 ) -> Poll<Result<usize, io::Error>> {
545 match self.project().sleep.poll(cx) {
546 Poll::Ready(()) => Poll::Ready(Ok(buf.len())),
547 Poll::Pending => Poll::Pending,
548 }
549 }
550
551 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<Result<(), io::Error>> {
552 Poll::Ready(Ok(()))
553 }
554
555 fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<Result<(), io::Error>> {
556 Poll::Ready(Ok(()))
557 }
558 }
559
560 #[tokio::test]
561 async fn read_timeout() {
562 let reader = DelayStream::new(Instant::now() + Duration::from_millis(500));
563 let mut reader = TimeoutReader::new(reader);
564 reader.set_timeout(Some(Duration::from_millis(100)));
565 pin!(reader);
566
567 let r = reader.read(&mut [0]).await;
568 assert_eq!(r.err().unwrap().kind(), io::ErrorKind::TimedOut);
569 }
570
571 #[tokio::test]
572 async fn read_ok() {
573 let reader = DelayStream::new(Instant::now() + Duration::from_millis(100));
574 let mut reader = TimeoutReader::new(reader);
575 reader.set_timeout(Some(Duration::from_millis(500)));
576 pin!(reader);
577
578 reader.read(&mut [0]).await.unwrap();
579 }
580
581 #[tokio::test]
582 async fn write_timeout() {
583 let writer = DelayStream::new(Instant::now() + Duration::from_millis(500));
584 let mut writer = TimeoutWriter::new(writer);
585 writer.set_timeout(Some(Duration::from_millis(100)));
586 pin!(writer);
587
588 let r = writer.write(&[0]).await;
589 assert_eq!(r.err().unwrap().kind(), io::ErrorKind::TimedOut);
590 }
591
592 #[tokio::test]
593 async fn write_ok() {
594 let writer = DelayStream::new(Instant::now() + Duration::from_millis(100));
595 let mut writer = TimeoutWriter::new(writer);
596 writer.set_timeout(Some(Duration::from_millis(500)));
597 pin!(writer);
598
599 writer.write(&[0]).await.unwrap();
600 }
601
602 #[tokio::test]
603 async fn tcp_read() {
604 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
605 let addr = listener.local_addr().unwrap();
606
607 thread::spawn(move || {
608 let mut socket = listener.accept().unwrap().0;
609 thread::sleep(Duration::from_millis(10));
610 socket.write_all(b"f").unwrap();
611 thread::sleep(Duration::from_millis(500));
612 let _ = socket.write_all(b"f"); });
614
615 let s = TcpStream::connect(&addr).await.unwrap();
616 let mut s = TimeoutStream::new(s);
617 s.set_read_timeout(Some(Duration::from_millis(100)));
618 pin!(s);
619 s.read(&mut [0]).await.unwrap();
620 let r = s.read(&mut [0]).await;
621
622 match r {
623 Ok(_) => panic!("unexpected success"),
624 Err(ref e) if e.kind() == io::ErrorKind::TimedOut => (),
625 Err(e) => panic!("{:?}", e),
626 }
627 }
628}