tokio/runtime/task/
list.rs

1//! This module has containers for storing the tasks spawned on a scheduler. The
2//! `OwnedTasks` container is thread-safe but can only store tasks that
3//! implement Send. The `LocalOwnedTasks` container is not thread safe, but can
4//! store non-Send tasks.
5//!
6//! The collections can be closed to prevent adding new tasks during shutdown of
7//! the scheduler with the collection.
8
9use crate::future::Future;
10use crate::loom::cell::UnsafeCell;
11use crate::runtime::task::{JoinHandle, LocalNotified, Notified, Schedule, SpawnLocation, Task};
12use crate::util::linked_list::{Link, LinkedList};
13use crate::util::sharded_list;
14
15use crate::loom::sync::atomic::{AtomicBool, Ordering};
16use std::marker::PhantomData;
17use std::num::NonZeroU64;
18
19// The id from the module below is used to verify whether a given task is stored
20// in this OwnedTasks, or some other task. The counter starts at one so we can
21// use `None` for tasks not owned by any list.
22//
23// The safety checks in this file can technically be violated if the counter is
24// overflown, but the checks are not supposed to ever fail unless there is a
25// bug in Tokio, so we accept that certain bugs would not be caught if the two
26// mixed up runtimes happen to have the same id.
27
28cfg_has_atomic_u64! {
29    use std::sync::atomic::AtomicU64;
30
31    static NEXT_OWNED_TASKS_ID: AtomicU64 = AtomicU64::new(1);
32
33    fn get_next_id() -> NonZeroU64 {
34        loop {
35            let id = NEXT_OWNED_TASKS_ID.fetch_add(1, Ordering::Relaxed);
36            if let Some(id) = NonZeroU64::new(id) {
37                return id;
38            }
39        }
40    }
41}
42
43cfg_not_has_atomic_u64! {
44    use std::sync::atomic::AtomicU32;
45
46    static NEXT_OWNED_TASKS_ID: AtomicU32 = AtomicU32::new(1);
47
48    fn get_next_id() -> NonZeroU64 {
49        loop {
50            let id = NEXT_OWNED_TASKS_ID.fetch_add(1, Ordering::Relaxed);
51            if let Some(id) = NonZeroU64::new(u64::from(id)) {
52                return id;
53            }
54        }
55    }
56}
57
58pub(crate) struct OwnedTasks<S: 'static> {
59    list: List<S>,
60    pub(crate) id: NonZeroU64,
61    closed: AtomicBool,
62}
63
64type List<S> = sharded_list::ShardedList<Task<S>, <Task<S> as Link>::Target>;
65
66pub(crate) struct LocalOwnedTasks<S: 'static> {
67    inner: UnsafeCell<OwnedTasksInner<S>>,
68    pub(crate) id: NonZeroU64,
69    _not_send_or_sync: PhantomData<*const ()>,
70}
71
72struct OwnedTasksInner<S: 'static> {
73    list: LinkedList<Task<S>, <Task<S> as Link>::Target>,
74    closed: bool,
75}
76
77impl<S: 'static> OwnedTasks<S> {
78    pub(crate) fn new(num_cores: usize) -> Self {
79        let shard_size = Self::gen_shared_list_size(num_cores);
80        Self {
81            list: List::new(shard_size),
82            closed: AtomicBool::new(false),
83            id: get_next_id(),
84        }
85    }
86
87    /// Binds the provided task to this `OwnedTasks` instance. This fails if the
88    /// `OwnedTasks` has been closed.
89    pub(crate) fn bind<T>(
90        &self,
91        task: T,
92        scheduler: S,
93        id: super::Id,
94        spawned_at: SpawnLocation,
95    ) -> (JoinHandle<T::Output>, Option<Notified<S>>)
96    where
97        S: Schedule,
98        T: Future + Send + 'static,
99        T::Output: Send + 'static,
100    {
101        let (task, notified, join) = super::new_task(task, scheduler, id, spawned_at);
102        let notified = unsafe { self.bind_inner(task, notified) };
103        (join, notified)
104    }
105
106    /// Bind a task that isn't safe to transfer across thread boundaries.
107    ///
108    /// # Safety
109    /// Only use this in `LocalRuntime` where the task cannot move
110    pub(crate) unsafe fn bind_local<T>(
111        &self,
112        task: T,
113        scheduler: S,
114        id: super::Id,
115        spawned_at: SpawnLocation,
116    ) -> (JoinHandle<T::Output>, Option<Notified<S>>)
117    where
118        S: Schedule,
119        T: Future + 'static,
120        T::Output: 'static,
121    {
122        let (task, notified, join) = super::new_task(task, scheduler, id, spawned_at);
123        let notified = unsafe { self.bind_inner(task, notified) };
124        (join, notified)
125    }
126
127    /// The part of `bind` that's the same for every type of future.
128    unsafe fn bind_inner(&self, task: Task<S>, notified: Notified<S>) -> Option<Notified<S>>
129    where
130        S: Schedule,
131    {
132        unsafe {
133            // safety: We just created the task, so we have exclusive access
134            // to the field.
135            task.header().set_owner_id(self.id);
136        }
137
138        let shard = self.list.lock_shard(&task);
139        // Check the closed flag in the lock for ensuring all that tasks
140        // will shut down after the OwnedTasks has been closed.
141        if self.closed.load(Ordering::Acquire) {
142            drop(shard);
143            task.shutdown();
144            return None;
145        }
146        shard.push(task);
147        Some(notified)
148    }
149
150    /// Asserts that the given task is owned by this `OwnedTasks` and convert it to
151    /// a `LocalNotified`, giving the thread permission to poll this task.
152    #[inline]
153    pub(crate) fn assert_owner(&self, task: Notified<S>) -> LocalNotified<S> {
154        debug_assert_eq!(task.header().get_owner_id(), Some(self.id));
155        // safety: All tasks bound to this OwnedTasks are Send, so it is safe
156        // to poll it on this thread no matter what thread we are on.
157        LocalNotified {
158            task: task.0,
159            _not_send: PhantomData,
160        }
161    }
162
163    /// Shuts down all tasks in the collection. This call also closes the
164    /// collection, preventing new items from being added.
165    ///
166    /// The parameter start determines which shard this method will start at.
167    /// Using different values for each worker thread reduces contention.
168    pub(crate) fn close_and_shutdown_all(&self, start: usize)
169    where
170        S: Schedule,
171    {
172        self.closed.store(true, Ordering::Release);
173        for i in start..self.get_shard_size() + start {
174            loop {
175                let task = self.list.pop_back(i);
176                match task {
177                    Some(task) => {
178                        task.shutdown();
179                    }
180                    None => break,
181                }
182            }
183        }
184    }
185
186    #[inline]
187    pub(crate) fn get_shard_size(&self) -> usize {
188        self.list.shard_size()
189    }
190
191    pub(crate) fn num_alive_tasks(&self) -> usize {
192        self.list.len()
193    }
194
195    cfg_64bit_metrics! {
196        pub(crate) fn spawned_tasks_count(&self) -> u64 {
197            self.list.added()
198        }
199    }
200
201    pub(crate) fn remove(&self, task: &Task<S>) -> Option<Task<S>> {
202        // If the task's owner ID is `None` then it is not part of any list and
203        // doesn't need removing.
204        let task_id = task.header().get_owner_id()?;
205
206        assert_eq!(task_id, self.id);
207
208        // safety: We just checked that the provided task is not in some other
209        // linked list.
210        unsafe { self.list.remove(task.header_ptr()) }
211    }
212
213    pub(crate) fn is_empty(&self) -> bool {
214        self.list.is_empty()
215    }
216
217    /// Generates the size of the sharded list based on the number of worker threads.
218    ///
219    /// The sharded lock design can effectively alleviate
220    /// lock contention performance problems caused by high concurrency.
221    ///
222    /// However, as the number of shards increases, the memory continuity between
223    /// nodes in the intrusive linked list will diminish. Furthermore,
224    /// the construction time of the sharded list will also increase with a higher number of shards.
225    ///
226    /// Due to the above reasons, we set a maximum value for the shared list size,
227    /// denoted as `MAX_SHARED_LIST_SIZE`.
228    fn gen_shared_list_size(num_cores: usize) -> usize {
229        const MAX_SHARED_LIST_SIZE: usize = 1 << 16;
230        usize::min(MAX_SHARED_LIST_SIZE, num_cores.next_power_of_two() * 4)
231    }
232}
233
234cfg_taskdump! {
235    impl<S: 'static> OwnedTasks<S> {
236        /// Locks the tasks, and calls `f` on an iterator over them.
237        pub(crate) fn for_each<F>(&self, f: F)
238        where
239            F: FnMut(&Task<S>),
240        {
241            self.list.for_each(f);
242        }
243    }
244}
245
246impl<S: 'static> LocalOwnedTasks<S> {
247    pub(crate) fn new() -> Self {
248        Self {
249            inner: UnsafeCell::new(OwnedTasksInner {
250                list: LinkedList::new(),
251                closed: false,
252            }),
253            id: get_next_id(),
254            _not_send_or_sync: PhantomData,
255        }
256    }
257
258    pub(crate) fn bind<T>(
259        &self,
260        task: T,
261        scheduler: S,
262        id: super::Id,
263        spawned_at: SpawnLocation,
264    ) -> (JoinHandle<T::Output>, Option<Notified<S>>)
265    where
266        S: Schedule,
267        T: Future + 'static,
268        T::Output: 'static,
269    {
270        let (task, notified, join) = super::new_task(task, scheduler, id, spawned_at);
271
272        unsafe {
273            // safety: We just created the task, so we have exclusive access
274            // to the field.
275            task.header().set_owner_id(self.id);
276        }
277
278        if self.is_closed() {
279            drop(notified);
280            task.shutdown();
281            (join, None)
282        } else {
283            self.with_inner(|inner| {
284                inner.list.push_front(task);
285            });
286            (join, Some(notified))
287        }
288    }
289
290    /// Shuts down all tasks in the collection. This call also closes the
291    /// collection, preventing new items from being added.
292    pub(crate) fn close_and_shutdown_all(&self)
293    where
294        S: Schedule,
295    {
296        self.with_inner(|inner| inner.closed = true);
297
298        while let Some(task) = self.with_inner(|inner| inner.list.pop_back()) {
299            task.shutdown();
300        }
301    }
302
303    pub(crate) fn remove(&self, task: &Task<S>) -> Option<Task<S>> {
304        // If the task's owner ID is `None` then it is not part of any list and
305        // doesn't need removing.
306        let task_id = task.header().get_owner_id()?;
307
308        assert_eq!(task_id, self.id);
309
310        self.with_inner(|inner|
311            // safety: We just checked that the provided task is not in some
312            // other linked list.
313            unsafe { inner.list.remove(task.header_ptr()) })
314    }
315
316    /// Asserts that the given task is owned by this `LocalOwnedTasks` and convert
317    /// it to a `LocalNotified`, giving the thread permission to poll this task.
318    #[inline]
319    pub(crate) fn assert_owner(&self, task: Notified<S>) -> LocalNotified<S> {
320        assert_eq!(task.header().get_owner_id(), Some(self.id));
321
322        // safety: The task was bound to this LocalOwnedTasks, and the
323        // LocalOwnedTasks is not Send or Sync, so we are on the right thread
324        // for polling this task.
325        LocalNotified {
326            task: task.0,
327            _not_send: PhantomData,
328        }
329    }
330
331    #[inline]
332    fn with_inner<F, T>(&self, f: F) -> T
333    where
334        F: FnOnce(&mut OwnedTasksInner<S>) -> T,
335    {
336        // safety: This type is not Sync, so concurrent calls of this method
337        // can't happen.  Furthermore, all uses of this method in this file make
338        // sure that they don't call `with_inner` recursively.
339        self.inner.with_mut(|ptr| unsafe { f(&mut *ptr) })
340    }
341
342    pub(crate) fn is_closed(&self) -> bool {
343        self.with_inner(|inner| inner.closed)
344    }
345
346    pub(crate) fn is_empty(&self) -> bool {
347        self.with_inner(|inner| inner.list.is_empty())
348    }
349}
350
351#[cfg(test)]
352mod tests {
353    use super::*;
354
355    // This test may run in parallel with other tests, so we only test that ids
356    // come in increasing order.
357    #[test]
358    fn test_id_not_broken() {
359        let mut last_id = get_next_id();
360
361        for _ in 0..1000 {
362            let next_id = get_next_id();
363            assert!(last_id < next_id);
364            last_id = next_id;
365        }
366    }
367}