1use crate::loom::sync::{Arc, Condvar, Mutex};
4use crate::loom::thread;
5use crate::runtime::blocking::schedule::BlockingSchedule;
6use crate::runtime::blocking::{shutdown, BlockingTask};
7use crate::runtime::builder::ThreadNameFn;
8use crate::runtime::task::{self, JoinHandle};
9use crate::runtime::{Builder, Callback, Handle, BOX_FUTURE_THRESHOLD};
10use crate::util::metric_atomics::MetricAtomicUsize;
11use crate::util::trace::{blocking_task, SpawnMeta};
12
13use std::collections::{HashMap, VecDeque};
14use std::fmt;
15use std::io;
16use std::sync::atomic::Ordering;
17use std::time::Duration;
18
19pub(crate) struct BlockingPool {
20 spawner: Spawner,
21 shutdown_rx: shutdown::Receiver,
22}
23
24#[derive(Clone)]
25pub(crate) struct Spawner {
26 inner: Arc<Inner>,
27}
28
29#[derive(Default)]
30pub(crate) struct SpawnerMetrics {
31 num_threads: MetricAtomicUsize,
32 num_idle_threads: MetricAtomicUsize,
33 queue_depth: MetricAtomicUsize,
34}
35
36impl SpawnerMetrics {
37 fn num_threads(&self) -> usize {
38 self.num_threads.load(Ordering::Relaxed)
39 }
40
41 fn num_idle_threads(&self) -> usize {
42 self.num_idle_threads.load(Ordering::Relaxed)
43 }
44
45 cfg_unstable_metrics! {
46 fn queue_depth(&self) -> usize {
47 self.queue_depth.load(Ordering::Relaxed)
48 }
49 }
50
51 fn inc_num_threads(&self) {
52 self.num_threads.increment();
53 }
54
55 fn dec_num_threads(&self) {
56 self.num_threads.decrement();
57 }
58
59 fn inc_num_idle_threads(&self) {
60 self.num_idle_threads.increment();
61 }
62
63 fn dec_num_idle_threads(&self) -> usize {
64 self.num_idle_threads.decrement()
65 }
66
67 fn inc_queue_depth(&self) {
68 self.queue_depth.increment();
69 }
70
71 fn dec_queue_depth(&self) {
72 self.queue_depth.decrement();
73 }
74}
75
76struct Inner {
77 shared: Mutex<Shared>,
79
80 condvar: Condvar,
82
83 thread_name: ThreadNameFn,
85
86 stack_size: Option<usize>,
88
89 after_start: Option<Callback>,
91
92 before_stop: Option<Callback>,
94
95 thread_cap: usize,
97
98 keep_alive: Duration,
100
101 metrics: SpawnerMetrics,
103}
104
105struct Shared {
106 queue: VecDeque<Task>,
107 num_notify: u32,
108 shutdown: bool,
109 shutdown_tx: Option<shutdown::Sender>,
110 last_exiting_thread: Option<thread::JoinHandle<()>>,
116 worker_threads: HashMap<usize, thread::JoinHandle<()>>,
119 worker_thread_index: usize,
122}
123
124pub(crate) struct Task {
125 task: task::UnownedTask<BlockingSchedule>,
126 mandatory: Mandatory,
127}
128
129#[derive(PartialEq, Eq)]
130pub(crate) enum Mandatory {
131 #[cfg_attr(not(feature = "fs"), allow(dead_code))]
132 Mandatory,
133 NonMandatory,
134}
135
136pub(crate) enum SpawnError {
137 ShuttingDown,
139 NoThreads(io::Error),
142}
143
144impl From<SpawnError> for io::Error {
145 fn from(e: SpawnError) -> Self {
146 match e {
147 SpawnError::ShuttingDown => {
148 io::Error::new(io::ErrorKind::Other, "blocking pool shutting down")
149 }
150 SpawnError::NoThreads(e) => e,
151 }
152 }
153}
154
155impl Task {
156 pub(crate) fn new(task: task::UnownedTask<BlockingSchedule>, mandatory: Mandatory) -> Task {
157 Task { task, mandatory }
158 }
159
160 fn run(self) {
161 self.task.run();
162 }
163
164 fn shutdown_or_run_if_mandatory(self) {
165 match self.mandatory {
166 Mandatory::NonMandatory => self.task.shutdown(),
167 Mandatory::Mandatory => self.task.run(),
168 }
169 }
170}
171
172const KEEP_ALIVE: Duration = Duration::from_secs(10);
173
174#[track_caller]
178#[cfg_attr(target_os = "wasi", allow(dead_code))]
179pub(crate) fn spawn_blocking<F, R>(func: F) -> JoinHandle<R>
180where
181 F: FnOnce() -> R + Send + 'static,
182 R: Send + 'static,
183{
184 let rt = Handle::current();
185 rt.spawn_blocking(func)
186}
187
188cfg_fs! {
189 #[cfg_attr(any(
190 all(loom, not(test)), test
192 ), allow(dead_code))]
193 pub(crate) fn spawn_mandatory_blocking<F, R>(func: F) -> Option<JoinHandle<R>>
198 where
199 F: FnOnce() -> R + Send + 'static,
200 R: Send + 'static,
201 {
202 let rt = Handle::current();
203 rt.inner.blocking_spawner().spawn_mandatory_blocking(&rt, func)
204 }
205}
206
207impl BlockingPool {
210 pub(crate) fn new(builder: &Builder, thread_cap: usize) -> BlockingPool {
211 let (shutdown_tx, shutdown_rx) = shutdown::channel();
212 let keep_alive = builder.keep_alive.unwrap_or(KEEP_ALIVE);
213
214 BlockingPool {
215 spawner: Spawner {
216 inner: Arc::new(Inner {
217 shared: Mutex::new(Shared {
218 queue: VecDeque::new(),
219 num_notify: 0,
220 shutdown: false,
221 shutdown_tx: Some(shutdown_tx),
222 last_exiting_thread: None,
223 worker_threads: HashMap::new(),
224 worker_thread_index: 0,
225 }),
226 condvar: Condvar::new(),
227 thread_name: builder.thread_name.clone(),
228 stack_size: builder.thread_stack_size,
229 after_start: builder.after_start.clone(),
230 before_stop: builder.before_stop.clone(),
231 thread_cap,
232 keep_alive,
233 metrics: SpawnerMetrics::default(),
234 }),
235 },
236 shutdown_rx,
237 }
238 }
239
240 pub(crate) fn spawner(&self) -> &Spawner {
241 &self.spawner
242 }
243
244 pub(crate) fn shutdown(&mut self, timeout: Option<Duration>) {
245 let mut shared = self.spawner.inner.shared.lock();
246
247 if shared.shutdown {
251 return;
252 }
253
254 shared.shutdown = true;
255 shared.shutdown_tx = None;
256 self.spawner.inner.condvar.notify_all();
257
258 let last_exited_thread = std::mem::take(&mut shared.last_exiting_thread);
259 let workers = std::mem::take(&mut shared.worker_threads);
260
261 drop(shared);
262
263 if self.shutdown_rx.wait(timeout) {
264 let _ = last_exited_thread.map(thread::JoinHandle::join);
265
266 #[cfg(loom)]
269 let workers: Vec<(usize, thread::JoinHandle<()>)> = {
270 let mut workers: Vec<_> = workers.into_iter().collect();
271 workers.sort_by_key(|(id, _)| *id);
272 workers
273 };
274
275 for (_id, handle) in workers {
276 let _ = handle.join();
277 }
278 }
279 }
280}
281
282impl Drop for BlockingPool {
283 fn drop(&mut self) {
284 self.shutdown(None);
285 }
286}
287
288impl fmt::Debug for BlockingPool {
289 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
290 fmt.debug_struct("BlockingPool").finish()
291 }
292}
293
294impl Spawner {
297 #[track_caller]
298 pub(crate) fn spawn_blocking<F, R>(&self, rt: &Handle, func: F) -> JoinHandle<R>
299 where
300 F: FnOnce() -> R + Send + 'static,
301 R: Send + 'static,
302 {
303 let fn_size = std::mem::size_of::<F>();
304 let (join_handle, spawn_result) = if fn_size > BOX_FUTURE_THRESHOLD {
305 self.spawn_blocking_inner(
306 Box::new(func),
307 Mandatory::NonMandatory,
308 SpawnMeta::new_unnamed(fn_size),
309 rt,
310 )
311 } else {
312 self.spawn_blocking_inner(
313 func,
314 Mandatory::NonMandatory,
315 SpawnMeta::new_unnamed(fn_size),
316 rt,
317 )
318 };
319
320 match spawn_result {
321 Ok(()) => join_handle,
322 Err(SpawnError::ShuttingDown) => join_handle,
324 Err(SpawnError::NoThreads(e)) => {
325 panic!("OS can't spawn worker thread: {e}")
326 }
327 }
328 }
329
330 cfg_fs! {
331 #[track_caller]
332 #[cfg_attr(any(
333 all(loom, not(test)), test
335 ), allow(dead_code))]
336 pub(crate) fn spawn_mandatory_blocking<F, R>(&self, rt: &Handle, func: F) -> Option<JoinHandle<R>>
337 where
338 F: FnOnce() -> R + Send + 'static,
339 R: Send + 'static,
340 {
341 let fn_size = std::mem::size_of::<F>();
342 let (join_handle, spawn_result) = if fn_size > BOX_FUTURE_THRESHOLD {
343 self.spawn_blocking_inner(
344 Box::new(func),
345 Mandatory::Mandatory,
346 SpawnMeta::new_unnamed(fn_size),
347 rt,
348 )
349 } else {
350 self.spawn_blocking_inner(
351 func,
352 Mandatory::Mandatory,
353 SpawnMeta::new_unnamed(fn_size),
354 rt,
355 )
356 };
357
358 if spawn_result.is_ok() {
359 Some(join_handle)
360 } else {
361 None
362 }
363 }
364 }
365
366 #[track_caller]
367 pub(crate) fn spawn_blocking_inner<F, R>(
368 &self,
369 func: F,
370 is_mandatory: Mandatory,
371 spawn_meta: SpawnMeta<'_>,
372 rt: &Handle,
373 ) -> (JoinHandle<R>, Result<(), SpawnError>)
374 where
375 F: FnOnce() -> R + Send + 'static,
376 R: Send + 'static,
377 {
378 let id = task::Id::next();
379 let fut =
380 blocking_task::<F, BlockingTask<F>>(BlockingTask::new(func), spawn_meta, id.as_u64());
381
382 let (task, handle) = task::unowned(
383 fut,
384 BlockingSchedule::new(rt),
385 id,
386 task::SpawnLocation::capture(),
387 );
388
389 let spawned = self.spawn_task(Task::new(task, is_mandatory), rt);
390 (handle, spawned)
391 }
392
393 fn spawn_task(&self, task: Task, rt: &Handle) -> Result<(), SpawnError> {
394 let mut shared = self.inner.shared.lock();
395
396 if shared.shutdown {
397 task.task.shutdown();
401
402 return Err(SpawnError::ShuttingDown);
404 }
405
406 shared.queue.push_back(task);
407 self.inner.metrics.inc_queue_depth();
408
409 if self.inner.metrics.num_idle_threads() == 0 {
410 if self.inner.metrics.num_threads() == self.inner.thread_cap {
413 } else {
415 assert!(shared.shutdown_tx.is_some());
416 let shutdown_tx = shared.shutdown_tx.clone();
417
418 if let Some(shutdown_tx) = shutdown_tx {
419 let id = shared.worker_thread_index;
420
421 match self.spawn_thread(shutdown_tx, rt, id) {
422 Ok(handle) => {
423 self.inner.metrics.inc_num_threads();
424 shared.worker_thread_index += 1;
425 shared.worker_threads.insert(id, handle);
426 }
427 Err(ref e)
428 if is_temporary_os_thread_error(e)
429 && self.inner.metrics.num_threads() > 0 =>
430 {
431 }
435 Err(e) => {
436 return Err(SpawnError::NoThreads(e));
439 }
440 }
441 }
442 }
443 } else {
444 self.inner.metrics.dec_num_idle_threads();
450 shared.num_notify += 1;
451 self.inner.condvar.notify_one();
452 }
453
454 Ok(())
455 }
456
457 fn spawn_thread(
458 &self,
459 shutdown_tx: shutdown::Sender,
460 rt: &Handle,
461 id: usize,
462 ) -> io::Result<thread::JoinHandle<()>> {
463 let mut builder = thread::Builder::new().name((self.inner.thread_name)());
464
465 if let Some(stack_size) = self.inner.stack_size {
466 builder = builder.stack_size(stack_size);
467 }
468
469 let rt = rt.clone();
470
471 builder.spawn(move || {
472 let _enter = rt.enter();
474 rt.inner.blocking_spawner().inner.run(id);
475 drop(shutdown_tx);
476 })
477 }
478}
479
480cfg_unstable_metrics! {
481 impl Spawner {
482 pub(crate) fn num_threads(&self) -> usize {
483 self.inner.metrics.num_threads()
484 }
485
486 pub(crate) fn num_idle_threads(&self) -> usize {
487 self.inner.metrics.num_idle_threads()
488 }
489
490 pub(crate) fn queue_depth(&self) -> usize {
491 self.inner.metrics.queue_depth()
492 }
493 }
494}
495
496#[inline]
498fn is_temporary_os_thread_error(error: &io::Error) -> bool {
499 matches!(error.kind(), io::ErrorKind::WouldBlock)
500}
501
502impl Inner {
503 fn run(&self, worker_thread_id: usize) {
504 if let Some(f) = &self.after_start {
505 f();
506 }
507
508 let mut shared = self.shared.lock();
509 let mut join_on_thread = None;
510
511 'main: loop {
512 while let Some(task) = shared.queue.pop_front() {
514 self.metrics.dec_queue_depth();
515 drop(shared);
516 task.run();
517
518 shared = self.shared.lock();
519 }
520
521 self.metrics.inc_num_idle_threads();
523
524 while !shared.shutdown {
525 let lock_result = self.condvar.wait_timeout(shared, self.keep_alive).unwrap();
526
527 shared = lock_result.0;
528 let timeout_result = lock_result.1;
529
530 if shared.num_notify != 0 {
531 shared.num_notify -= 1;
535 break;
536 }
537
538 if !shared.shutdown && timeout_result.timed_out() {
541 let my_handle = shared.worker_threads.remove(&worker_thread_id);
545 join_on_thread = std::mem::replace(&mut shared.last_exiting_thread, my_handle);
546
547 break 'main;
548 }
549
550 }
552
553 if shared.shutdown {
554 while let Some(task) = shared.queue.pop_front() {
556 self.metrics.dec_queue_depth();
557 drop(shared);
558
559 task.shutdown_or_run_if_mandatory();
560
561 shared = self.shared.lock();
562 }
563
564 self.metrics.inc_num_idle_threads();
568 break;
571 }
572 }
573
574 self.metrics.dec_num_threads();
576
577 let prev_idle = self.metrics.dec_num_idle_threads();
581 assert!(
582 prev_idle >= self.metrics.num_idle_threads(),
583 "num_idle_threads underflowed on thread exit"
584 );
585
586 if shared.shutdown && self.metrics.num_threads() == 0 {
587 self.condvar.notify_one();
588 }
589
590 drop(shared);
591
592 if let Some(f) = &self.before_stop {
593 f();
594 }
595
596 if let Some(handle) = join_on_thread {
597 let _ = handle.join();
598 }
599 }
600}
601
602impl fmt::Debug for Spawner {
603 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
604 fmt.debug_struct("blocking::Spawner").finish()
605 }
606}