tokio/runtime/task/
core.rs

1//! Core task module.
2//!
3//! # Safety
4//!
5//! The functions in this module are private to the `task` module. All of them
6//! should be considered `unsafe` to use, but are not marked as such since it
7//! would be too noisy.
8//!
9//! Make sure to consult the relevant safety section of each function before
10//! use.
11
12use crate::future::Future;
13use crate::loom::cell::UnsafeCell;
14use crate::runtime::context;
15use crate::runtime::task::raw::{self, Vtable};
16use crate::runtime::task::state::State;
17use crate::runtime::task::{Id, Schedule, TaskHarnessScheduleHooks};
18use crate::util::linked_list;
19
20use std::num::NonZeroU64;
21#[cfg(tokio_unstable)]
22use std::panic::Location;
23use std::pin::Pin;
24use std::ptr::NonNull;
25use std::task::{Context, Poll, Waker};
26
27/// The task cell. Contains the components of the task.
28///
29/// It is critical for `Header` to be the first field as the task structure will
30/// be referenced by both *mut Cell and *mut Header.
31///
32/// Any changes to the layout of this struct _must_ also be reflected in the
33/// `const` fns in raw.rs.
34///
35// # This struct should be cache padded to avoid false sharing. The cache padding rules are copied
36// from crossbeam-utils/src/cache_padded.rs
37//
38// Starting from Intel's Sandy Bridge, spatial prefetcher is now pulling pairs of 64-byte cache
39// lines at a time, so we have to align to 128 bytes rather than 64.
40//
41// Sources:
42// - https://www.intel.com/content/dam/www/public/us/en/documents/manuals/64-ia-32-architectures-optimization-manual.pdf
43// - https://github.com/facebook/folly/blob/1b5288e6eea6df074758f877c849b6e73bbb9fbb/folly/lang/Align.h#L107
44//
45// ARM's big.LITTLE architecture has asymmetric cores and "big" cores have 128-byte cache line size.
46//
47// Sources:
48// - https://www.mono-project.com/news/2016/09/12/arm64-icache/
49//
50// powerpc64 has 128-byte cache line size.
51//
52// Sources:
53// - https://github.com/golang/go/blob/3dd58676054223962cd915bb0934d1f9f489d4d2/src/internal/cpu/cpu_ppc64x.go#L9
54#[cfg_attr(
55    any(
56        target_arch = "x86_64",
57        target_arch = "aarch64",
58        target_arch = "powerpc64",
59    ),
60    repr(align(128))
61)]
62// arm, mips, mips64, sparc, and hexagon have 32-byte cache line size.
63//
64// Sources:
65// - https://github.com/golang/go/blob/3dd58676054223962cd915bb0934d1f9f489d4d2/src/internal/cpu/cpu_arm.go#L7
66// - https://github.com/golang/go/blob/3dd58676054223962cd915bb0934d1f9f489d4d2/src/internal/cpu/cpu_mips.go#L7
67// - https://github.com/golang/go/blob/3dd58676054223962cd915bb0934d1f9f489d4d2/src/internal/cpu/cpu_mipsle.go#L7
68// - https://github.com/golang/go/blob/3dd58676054223962cd915bb0934d1f9f489d4d2/src/internal/cpu/cpu_mips64x.go#L9
69// - https://github.com/torvalds/linux/blob/3516bd729358a2a9b090c1905bd2a3fa926e24c6/arch/sparc/include/asm/cache.h#L17
70// - https://github.com/torvalds/linux/blob/3516bd729358a2a9b090c1905bd2a3fa926e24c6/arch/hexagon/include/asm/cache.h#L12
71#[cfg_attr(
72    any(
73        target_arch = "arm",
74        target_arch = "mips",
75        target_arch = "mips64",
76        target_arch = "sparc",
77        target_arch = "hexagon",
78    ),
79    repr(align(32))
80)]
81// m68k has 16-byte cache line size.
82//
83// Sources:
84// - https://github.com/torvalds/linux/blob/3516bd729358a2a9b090c1905bd2a3fa926e24c6/arch/m68k/include/asm/cache.h#L9
85#[cfg_attr(target_arch = "m68k", repr(align(16)))]
86// s390x has 256-byte cache line size.
87//
88// Sources:
89// - https://github.com/golang/go/blob/3dd58676054223962cd915bb0934d1f9f489d4d2/src/internal/cpu/cpu_s390x.go#L7
90// - https://github.com/torvalds/linux/blob/3516bd729358a2a9b090c1905bd2a3fa926e24c6/arch/s390/include/asm/cache.h#L13
91#[cfg_attr(target_arch = "s390x", repr(align(256)))]
92// x86, riscv, wasm, and sparc64 have 64-byte cache line size.
93//
94// Sources:
95// - https://github.com/golang/go/blob/dda2991c2ea0c5914714469c4defc2562a907230/src/internal/cpu/cpu_x86.go#L9
96// - https://github.com/golang/go/blob/3dd58676054223962cd915bb0934d1f9f489d4d2/src/internal/cpu/cpu_wasm.go#L7
97// - https://github.com/torvalds/linux/blob/3516bd729358a2a9b090c1905bd2a3fa926e24c6/arch/sparc/include/asm/cache.h#L19
98// - https://github.com/torvalds/linux/blob/3516bd729358a2a9b090c1905bd2a3fa926e24c6/arch/riscv/include/asm/cache.h#L10
99//
100// All others are assumed to have 64-byte cache line size.
101#[cfg_attr(
102    not(any(
103        target_arch = "x86_64",
104        target_arch = "aarch64",
105        target_arch = "powerpc64",
106        target_arch = "arm",
107        target_arch = "mips",
108        target_arch = "mips64",
109        target_arch = "sparc",
110        target_arch = "hexagon",
111        target_arch = "m68k",
112        target_arch = "s390x",
113    )),
114    repr(align(64))
115)]
116#[repr(C)]
117pub(super) struct Cell<T: Future, S> {
118    /// Hot task state data
119    pub(super) header: Header,
120
121    /// Either the future or output, depending on the execution stage.
122    pub(super) core: Core<T, S>,
123
124    /// Cold data
125    pub(super) trailer: Trailer,
126}
127
128pub(super) struct CoreStage<T: Future> {
129    stage: UnsafeCell<Stage<T>>,
130}
131
132/// The core of the task.
133///
134/// Holds the future or output, depending on the stage of execution.
135///
136/// Any changes to the layout of this struct _must_ also be reflected in the
137/// `const` fns in raw.rs.
138#[repr(C)]
139pub(super) struct Core<T: Future, S> {
140    /// Scheduler used to drive this future.
141    pub(super) scheduler: S,
142
143    /// The task's ID, used for populating `JoinError`s.
144    pub(super) task_id: Id,
145
146    /// The source code location where the task was spawned.
147    ///
148    /// This is used for populating the `TaskMeta` passed to the task runtime
149    /// hooks.
150    #[cfg(tokio_unstable)]
151    pub(super) spawned_at: &'static Location<'static>,
152
153    /// Either the future or the output.
154    pub(super) stage: CoreStage<T>,
155}
156
157/// Crate public as this is also needed by the pool.
158#[repr(C)]
159pub(crate) struct Header {
160    /// Task state.
161    pub(super) state: State,
162
163    /// Pointer to next task, used with the injection queue.
164    pub(super) queue_next: UnsafeCell<Option<NonNull<Header>>>,
165
166    /// Table of function pointers for executing actions on the task.
167    pub(super) vtable: &'static Vtable,
168
169    /// This integer contains the id of the `OwnedTasks` or `LocalOwnedTasks`
170    /// that this task is stored in. If the task is not in any list, should be
171    /// the id of the list that it was previously in, or `None` if it has never
172    /// been in any list.
173    ///
174    /// Once a task has been bound to a list, it can never be bound to another
175    /// list, even if removed from the first list.
176    ///
177    /// The id is not unset when removed from a list because we want to be able
178    /// to read the id without synchronization, even if it is concurrently being
179    /// removed from the list.
180    pub(super) owner_id: UnsafeCell<Option<NonZeroU64>>,
181
182    /// The tracing ID for this instrumented task.
183    #[cfg(all(tokio_unstable, feature = "tracing"))]
184    pub(super) tracing_id: Option<tracing::Id>,
185}
186
187unsafe impl Send for Header {}
188unsafe impl Sync for Header {}
189
190/// Cold data is stored after the future. Data is considered cold if it is only
191/// used during creation or shutdown of the task.
192pub(super) struct Trailer {
193    /// Pointers for the linked list in the `OwnedTasks` that owns this task.
194    pub(super) owned: linked_list::Pointers<Header>,
195    /// Consumer task waiting on completion of this task.
196    pub(super) waker: UnsafeCell<Option<Waker>>,
197    /// Optional hooks needed in the harness.
198    pub(super) hooks: TaskHarnessScheduleHooks,
199}
200
201generate_addr_of_methods! {
202    impl<> Trailer {
203        pub(super) unsafe fn addr_of_owned(self: NonNull<Self>) -> NonNull<linked_list::Pointers<Header>> {
204            &self.owned
205        }
206    }
207}
208
209/// Either the future or the output.
210#[repr(C)] // https://github.com/rust-lang/miri/issues/3780
211pub(super) enum Stage<T: Future> {
212    Running(T),
213    Finished(super::Result<T::Output>),
214    Consumed,
215}
216
217impl<T: Future, S: Schedule> Cell<T, S> {
218    /// Allocates a new task cell, containing the header, trailer, and core
219    /// structures.
220    pub(super) fn new(
221        future: T,
222        scheduler: S,
223        state: State,
224        task_id: Id,
225        #[cfg(tokio_unstable)] spawned_at: &'static Location<'static>,
226    ) -> Box<Cell<T, S>> {
227        // Separated into a non-generic function to reduce LLVM codegen
228        fn new_header(
229            state: State,
230            vtable: &'static Vtable,
231            #[cfg(all(tokio_unstable, feature = "tracing"))] tracing_id: Option<tracing::Id>,
232        ) -> Header {
233            Header {
234                state,
235                queue_next: UnsafeCell::new(None),
236                vtable,
237                owner_id: UnsafeCell::new(None),
238                #[cfg(all(tokio_unstable, feature = "tracing"))]
239                tracing_id,
240            }
241        }
242
243        #[cfg(all(tokio_unstable, feature = "tracing"))]
244        let tracing_id = future.id();
245        let vtable = raw::vtable::<T, S>();
246        let result = Box::new(Cell {
247            trailer: Trailer::new(scheduler.hooks()),
248            header: new_header(
249                state,
250                vtable,
251                #[cfg(all(tokio_unstable, feature = "tracing"))]
252                tracing_id,
253            ),
254            core: Core {
255                scheduler,
256                stage: CoreStage {
257                    stage: UnsafeCell::new(Stage::Running(future)),
258                },
259                task_id,
260                #[cfg(tokio_unstable)]
261                spawned_at,
262            },
263        });
264
265        #[cfg(debug_assertions)]
266        {
267            // Using a separate function for this code avoids instantiating it separately for every `T`.
268            unsafe fn check<S>(
269                header: &Header,
270                trailer: &Trailer,
271                scheduler: &S,
272                task_id: &Id,
273                #[cfg(tokio_unstable)] spawn_location: &&'static Location<'static>,
274            ) {
275                let trailer_addr = trailer as *const Trailer as usize;
276                let trailer_ptr = unsafe { Header::get_trailer(NonNull::from(header)) };
277                assert_eq!(trailer_addr, trailer_ptr.as_ptr() as usize);
278
279                let scheduler_addr = scheduler as *const S as usize;
280                let scheduler_ptr = unsafe { Header::get_scheduler::<S>(NonNull::from(header)) };
281                assert_eq!(scheduler_addr, scheduler_ptr.as_ptr() as usize);
282
283                let id_addr = task_id as *const Id as usize;
284                let id_ptr = unsafe { Header::get_id_ptr(NonNull::from(header)) };
285                assert_eq!(id_addr, id_ptr.as_ptr() as usize);
286
287                #[cfg(tokio_unstable)]
288                {
289                    let spawn_location_addr =
290                        spawn_location as *const &'static Location<'static> as usize;
291                    let spawn_location_ptr =
292                        unsafe { Header::get_spawn_location_ptr(NonNull::from(header)) };
293                    assert_eq!(spawn_location_addr, spawn_location_ptr.as_ptr() as usize);
294                }
295            }
296            unsafe {
297                check(
298                    &result.header,
299                    &result.trailer,
300                    &result.core.scheduler,
301                    &result.core.task_id,
302                    #[cfg(tokio_unstable)]
303                    &result.core.spawned_at,
304                );
305            }
306        }
307
308        result
309    }
310}
311
312impl<T: Future> CoreStage<T> {
313    pub(super) fn with_mut<R>(&self, f: impl FnOnce(*mut Stage<T>) -> R) -> R {
314        self.stage.with_mut(f)
315    }
316}
317
318/// Set and clear the task id in the context when the future is executed or
319/// dropped, or when the output produced by the future is dropped.
320pub(crate) struct TaskIdGuard {
321    parent_task_id: Option<Id>,
322}
323
324impl TaskIdGuard {
325    fn enter(id: Id) -> Self {
326        TaskIdGuard {
327            parent_task_id: context::set_current_task_id(Some(id)),
328        }
329    }
330}
331
332impl Drop for TaskIdGuard {
333    fn drop(&mut self) {
334        context::set_current_task_id(self.parent_task_id);
335    }
336}
337
338impl<T: Future, S: Schedule> Core<T, S> {
339    /// Polls the future.
340    ///
341    /// # Safety
342    ///
343    /// The caller must ensure it is safe to mutate the `state` field. This
344    /// requires ensuring mutual exclusion between any concurrent thread that
345    /// might modify the future or output field.
346    ///
347    /// The mutual exclusion is implemented by `Harness` and the `Lifecycle`
348    /// component of the task state.
349    ///
350    /// `self` must also be pinned. This is handled by storing the task on the
351    /// heap.
352    pub(super) fn poll(&self, mut cx: Context<'_>) -> Poll<T::Output> {
353        let res = {
354            self.stage.stage.with_mut(|ptr| {
355                // Safety: The caller ensures mutual exclusion to the field.
356                let future = match unsafe { &mut *ptr } {
357                    Stage::Running(future) => future,
358                    _ => unreachable!("unexpected stage"),
359                };
360
361                // Safety: The caller ensures the future is pinned.
362                let future = unsafe { Pin::new_unchecked(future) };
363
364                let _guard = TaskIdGuard::enter(self.task_id);
365                future.poll(&mut cx)
366            })
367        };
368
369        if res.is_ready() {
370            self.drop_future_or_output();
371        }
372
373        res
374    }
375
376    /// Drops the future.
377    ///
378    /// # Safety
379    ///
380    /// The caller must ensure it is safe to mutate the `stage` field.
381    pub(super) fn drop_future_or_output(&self) {
382        // Safety: the caller ensures mutual exclusion to the field.
383        unsafe {
384            self.set_stage(Stage::Consumed);
385        }
386    }
387
388    /// Stores the task output.
389    ///
390    /// # Safety
391    ///
392    /// The caller must ensure it is safe to mutate the `stage` field.
393    pub(super) fn store_output(&self, output: super::Result<T::Output>) {
394        // Safety: the caller ensures mutual exclusion to the field.
395        unsafe {
396            self.set_stage(Stage::Finished(output));
397        }
398    }
399
400    /// Takes the task output.
401    ///
402    /// # Safety
403    ///
404    /// The caller must ensure it is safe to mutate the `stage` field.
405    pub(super) fn take_output(&self) -> super::Result<T::Output> {
406        use std::mem;
407
408        self.stage.stage.with_mut(|ptr| {
409            // Safety:: the caller ensures mutual exclusion to the field.
410            match mem::replace(unsafe { &mut *ptr }, Stage::Consumed) {
411                Stage::Finished(output) => output,
412                _ => panic!("JoinHandle polled after completion"),
413            }
414        })
415    }
416
417    unsafe fn set_stage(&self, stage: Stage<T>) {
418        let _guard = TaskIdGuard::enter(self.task_id);
419        self.stage.stage.with_mut(|ptr| *ptr = stage);
420    }
421}
422
423impl Header {
424    pub(super) unsafe fn set_next(&self, next: Option<NonNull<Header>>) {
425        self.queue_next.with_mut(|ptr| *ptr = next);
426    }
427
428    // safety: The caller must guarantee exclusive access to this field, and
429    // must ensure that the id is either `None` or the id of the OwnedTasks
430    // containing this task.
431    pub(super) unsafe fn set_owner_id(&self, owner: NonZeroU64) {
432        self.owner_id.with_mut(|ptr| *ptr = Some(owner));
433    }
434
435    pub(super) fn get_owner_id(&self) -> Option<NonZeroU64> {
436        // safety: If there are concurrent writes, then that write has violated
437        // the safety requirements on `set_owner_id`.
438        unsafe { self.owner_id.with(|ptr| *ptr) }
439    }
440
441    /// Gets a pointer to the `Trailer` of the task containing this `Header`.
442    ///
443    /// # Safety
444    ///
445    /// The provided raw pointer must point at the header of a task.
446    pub(super) unsafe fn get_trailer(me: NonNull<Header>) -> NonNull<Trailer> {
447        let offset = me.as_ref().vtable.trailer_offset;
448        let trailer = me.as_ptr().cast::<u8>().add(offset).cast::<Trailer>();
449        NonNull::new_unchecked(trailer)
450    }
451
452    /// Gets a pointer to the scheduler of the task containing this `Header`.
453    ///
454    /// # Safety
455    ///
456    /// The provided raw pointer must point at the header of a task.
457    ///
458    /// The generic type S must be set to the correct scheduler type for this
459    /// task.
460    pub(super) unsafe fn get_scheduler<S>(me: NonNull<Header>) -> NonNull<S> {
461        let offset = me.as_ref().vtable.scheduler_offset;
462        let scheduler = me.as_ptr().cast::<u8>().add(offset).cast::<S>();
463        NonNull::new_unchecked(scheduler)
464    }
465
466    /// Gets a pointer to the id of the task containing this `Header`.
467    ///
468    /// # Safety
469    ///
470    /// The provided raw pointer must point at the header of a task.
471    pub(super) unsafe fn get_id_ptr(me: NonNull<Header>) -> NonNull<Id> {
472        let offset = me.as_ref().vtable.id_offset;
473        let id = me.as_ptr().cast::<u8>().add(offset).cast::<Id>();
474        NonNull::new_unchecked(id)
475    }
476
477    /// Gets the id of the task containing this `Header`.
478    ///
479    /// # Safety
480    ///
481    /// The provided raw pointer must point at the header of a task.
482    pub(super) unsafe fn get_id(me: NonNull<Header>) -> Id {
483        let ptr = Header::get_id_ptr(me).as_ptr();
484        *ptr
485    }
486
487    /// Gets a pointer to the source code location where the task containing
488    /// this `Header` was spawned.
489    ///
490    /// # Safety
491    ///
492    /// The provided raw pointer must point at the header of a task.
493    #[cfg(tokio_unstable)]
494    pub(super) unsafe fn get_spawn_location_ptr(
495        me: NonNull<Header>,
496    ) -> NonNull<&'static Location<'static>> {
497        let offset = me.as_ref().vtable.spawn_location_offset;
498        let spawned_at = me
499            .as_ptr()
500            .cast::<u8>()
501            .add(offset)
502            .cast::<&'static Location<'static>>();
503        NonNull::new_unchecked(spawned_at)
504    }
505
506    /// Gets the source code location where the task containing
507    /// this `Header` was spawned
508    ///
509    /// # Safety
510    ///
511    /// The provided raw pointer must point at the header of a task.
512    #[cfg(tokio_unstable)]
513    pub(super) unsafe fn get_spawn_location(me: NonNull<Header>) -> &'static Location<'static> {
514        let ptr = Header::get_spawn_location_ptr(me).as_ptr();
515        *ptr
516    }
517
518    /// Gets the tracing id of the task containing this `Header`.
519    ///
520    /// # Safety
521    ///
522    /// The provided raw pointer must point at the header of a task.
523    #[cfg(all(tokio_unstable, feature = "tracing"))]
524    pub(super) unsafe fn get_tracing_id(me: &NonNull<Header>) -> Option<&tracing::Id> {
525        me.as_ref().tracing_id.as_ref()
526    }
527}
528
529impl Trailer {
530    fn new(hooks: TaskHarnessScheduleHooks) -> Self {
531        Trailer {
532            waker: UnsafeCell::new(None),
533            owned: linked_list::Pointers::new(),
534            hooks,
535        }
536    }
537
538    pub(super) unsafe fn set_waker(&self, waker: Option<Waker>) {
539        self.waker.with_mut(|ptr| {
540            *ptr = waker;
541        });
542    }
543
544    pub(super) unsafe fn will_wake(&self, waker: &Waker) -> bool {
545        self.waker
546            .with(|ptr| (*ptr).as_ref().unwrap().will_wake(waker))
547    }
548
549    pub(super) fn wake_join(&self) {
550        self.waker.with(|ptr| match unsafe { &*ptr } {
551            Some(waker) => waker.wake_by_ref(),
552            None => panic!("waker missing"),
553        });
554    }
555}
556
557#[test]
558#[cfg(not(loom))]
559fn header_lte_cache_line() {
560    assert!(std::mem::size_of::<Header>() <= 8 * std::mem::size_of::<*const ()>());
561}