1use crate::future::Future;
10use crate::loom::cell::UnsafeCell;
11use crate::runtime::task::{JoinHandle, LocalNotified, Notified, Schedule, SpawnLocation, Task};
12use crate::util::linked_list::{Link, LinkedList};
13use crate::util::sharded_list;
14
15use crate::loom::sync::atomic::{AtomicBool, Ordering};
16use std::marker::PhantomData;
17use std::num::NonZeroU64;
18
19cfg_has_atomic_u64! {
29 use std::sync::atomic::AtomicU64;
30
31 static NEXT_OWNED_TASKS_ID: AtomicU64 = AtomicU64::new(1);
32
33 fn get_next_id() -> NonZeroU64 {
34 loop {
35 let id = NEXT_OWNED_TASKS_ID.fetch_add(1, Ordering::Relaxed);
36 if let Some(id) = NonZeroU64::new(id) {
37 return id;
38 }
39 }
40 }
41}
42
43cfg_not_has_atomic_u64! {
44 use std::sync::atomic::AtomicU32;
45
46 static NEXT_OWNED_TASKS_ID: AtomicU32 = AtomicU32::new(1);
47
48 fn get_next_id() -> NonZeroU64 {
49 loop {
50 let id = NEXT_OWNED_TASKS_ID.fetch_add(1, Ordering::Relaxed);
51 if let Some(id) = NonZeroU64::new(u64::from(id)) {
52 return id;
53 }
54 }
55 }
56}
57
58pub(crate) struct OwnedTasks<S: 'static> {
59 list: List<S>,
60 pub(crate) id: NonZeroU64,
61 closed: AtomicBool,
62}
63
64type List<S> = sharded_list::ShardedList<Task<S>, <Task<S> as Link>::Target>;
65
66pub(crate) struct LocalOwnedTasks<S: 'static> {
67 inner: UnsafeCell<OwnedTasksInner<S>>,
68 pub(crate) id: NonZeroU64,
69 _not_send_or_sync: PhantomData<*const ()>,
70}
71
72struct OwnedTasksInner<S: 'static> {
73 list: LinkedList<Task<S>, <Task<S> as Link>::Target>,
74 closed: bool,
75}
76
77impl<S: 'static> OwnedTasks<S> {
78 pub(crate) fn new(num_cores: usize) -> Self {
79 let shard_size = Self::gen_shared_list_size(num_cores);
80 Self {
81 list: List::new(shard_size),
82 closed: AtomicBool::new(false),
83 id: get_next_id(),
84 }
85 }
86
87 pub(crate) fn bind<T>(
90 &self,
91 task: T,
92 scheduler: S,
93 id: super::Id,
94 spawned_at: SpawnLocation,
95 ) -> (JoinHandle<T::Output>, Option<Notified<S>>)
96 where
97 S: Schedule,
98 T: Future + Send + 'static,
99 T::Output: Send + 'static,
100 {
101 let (task, notified, join) = super::new_task(task, scheduler, id, spawned_at);
102 let notified = unsafe { self.bind_inner(task, notified) };
103 (join, notified)
104 }
105
106 pub(crate) unsafe fn bind_local<T>(
111 &self,
112 task: T,
113 scheduler: S,
114 id: super::Id,
115 spawned_at: SpawnLocation,
116 ) -> (JoinHandle<T::Output>, Option<Notified<S>>)
117 where
118 S: Schedule,
119 T: Future + 'static,
120 T::Output: 'static,
121 {
122 let (task, notified, join) = super::new_task(task, scheduler, id, spawned_at);
123 let notified = unsafe { self.bind_inner(task, notified) };
124 (join, notified)
125 }
126
127 unsafe fn bind_inner(&self, task: Task<S>, notified: Notified<S>) -> Option<Notified<S>>
129 where
130 S: Schedule,
131 {
132 unsafe {
133 task.header().set_owner_id(self.id);
136 }
137
138 let shard = self.list.lock_shard(&task);
139 if self.closed.load(Ordering::Acquire) {
142 drop(shard);
143 task.shutdown();
144 return None;
145 }
146 shard.push(task);
147 Some(notified)
148 }
149
150 #[inline]
153 pub(crate) fn assert_owner(&self, task: Notified<S>) -> LocalNotified<S> {
154 debug_assert_eq!(task.header().get_owner_id(), Some(self.id));
155 LocalNotified {
158 task: task.0,
159 _not_send: PhantomData,
160 }
161 }
162
163 pub(crate) fn close_and_shutdown_all(&self, start: usize)
169 where
170 S: Schedule,
171 {
172 self.closed.store(true, Ordering::Release);
173 for i in start..self.get_shard_size() + start {
174 loop {
175 let task = self.list.pop_back(i);
176 match task {
177 Some(task) => {
178 task.shutdown();
179 }
180 None => break,
181 }
182 }
183 }
184 }
185
186 #[inline]
187 pub(crate) fn get_shard_size(&self) -> usize {
188 self.list.shard_size()
189 }
190
191 pub(crate) fn num_alive_tasks(&self) -> usize {
192 self.list.len()
193 }
194
195 cfg_64bit_metrics! {
196 pub(crate) fn spawned_tasks_count(&self) -> u64 {
197 self.list.added()
198 }
199 }
200
201 pub(crate) fn remove(&self, task: &Task<S>) -> Option<Task<S>> {
202 let task_id = task.header().get_owner_id()?;
205
206 assert_eq!(task_id, self.id);
207
208 unsafe { self.list.remove(task.header_ptr()) }
211 }
212
213 pub(crate) fn is_empty(&self) -> bool {
214 self.list.is_empty()
215 }
216
217 fn gen_shared_list_size(num_cores: usize) -> usize {
229 const MAX_SHARED_LIST_SIZE: usize = 1 << 16;
230 usize::min(MAX_SHARED_LIST_SIZE, num_cores.next_power_of_two() * 4)
231 }
232}
233
234cfg_taskdump! {
235 impl<S: 'static> OwnedTasks<S> {
236 pub(crate) fn for_each<F>(&self, f: F)
238 where
239 F: FnMut(&Task<S>),
240 {
241 self.list.for_each(f);
242 }
243 }
244}
245
246impl<S: 'static> LocalOwnedTasks<S> {
247 pub(crate) fn new() -> Self {
248 Self {
249 inner: UnsafeCell::new(OwnedTasksInner {
250 list: LinkedList::new(),
251 closed: false,
252 }),
253 id: get_next_id(),
254 _not_send_or_sync: PhantomData,
255 }
256 }
257
258 pub(crate) fn bind<T>(
259 &self,
260 task: T,
261 scheduler: S,
262 id: super::Id,
263 spawned_at: SpawnLocation,
264 ) -> (JoinHandle<T::Output>, Option<Notified<S>>)
265 where
266 S: Schedule,
267 T: Future + 'static,
268 T::Output: 'static,
269 {
270 let (task, notified, join) = super::new_task(task, scheduler, id, spawned_at);
271
272 unsafe {
273 task.header().set_owner_id(self.id);
276 }
277
278 if self.is_closed() {
279 drop(notified);
280 task.shutdown();
281 (join, None)
282 } else {
283 self.with_inner(|inner| {
284 inner.list.push_front(task);
285 });
286 (join, Some(notified))
287 }
288 }
289
290 pub(crate) fn close_and_shutdown_all(&self)
293 where
294 S: Schedule,
295 {
296 self.with_inner(|inner| inner.closed = true);
297
298 while let Some(task) = self.with_inner(|inner| inner.list.pop_back()) {
299 task.shutdown();
300 }
301 }
302
303 pub(crate) fn remove(&self, task: &Task<S>) -> Option<Task<S>> {
304 let task_id = task.header().get_owner_id()?;
307
308 assert_eq!(task_id, self.id);
309
310 self.with_inner(|inner|
311 unsafe { inner.list.remove(task.header_ptr()) })
314 }
315
316 #[inline]
319 pub(crate) fn assert_owner(&self, task: Notified<S>) -> LocalNotified<S> {
320 assert_eq!(task.header().get_owner_id(), Some(self.id));
321
322 LocalNotified {
326 task: task.0,
327 _not_send: PhantomData,
328 }
329 }
330
331 #[inline]
332 fn with_inner<F, T>(&self, f: F) -> T
333 where
334 F: FnOnce(&mut OwnedTasksInner<S>) -> T,
335 {
336 self.inner.with_mut(|ptr| unsafe { f(&mut *ptr) })
340 }
341
342 pub(crate) fn is_closed(&self) -> bool {
343 self.with_inner(|inner| inner.closed)
344 }
345
346 pub(crate) fn is_empty(&self) -> bool {
347 self.with_inner(|inner| inner.list.is_empty())
348 }
349}
350
351#[cfg(test)]
352mod tests {
353 use super::*;
354
355 #[test]
358 fn test_id_not_broken() {
359 let mut last_id = get_next_id();
360
361 for _ in 0..1000 {
362 let next_id = get_next_id();
363 assert!(last_id < next_id);
364 last_id = next_id;
365 }
366 }
367}