relay_threading/
pool.rs

1use std::future::Future;
2use std::io;
3use std::panic::AssertUnwindSafe;
4use std::sync::Arc;
5
6use crate::builder::AsyncPoolBuilder;
7use crate::metrics::AsyncPoolMetrics;
8use crate::multiplexing::Multiplexed;
9use crate::{PanicHandler, ThreadMetrics};
10use futures::FutureExt;
11use futures::future::BoxFuture;
12use relay_system::MonitoredFuture;
13
14/// Default name of the pool.
15const DEFAULT_POOL_NAME: &str = "unnamed";
16
17/// [`AsyncPool`] is a thread-based executor that runs asynchronous tasks on dedicated worker threads.
18///
19/// The pool collects tasks through a bounded channel and distributes them among threads, each of which runs its own
20/// Tokio executor. This design enables controlled concurrency and efficient use of system resources.
21#[derive(Debug)]
22pub struct AsyncPool<F> {
23    /// Name of the pool.
24    name: &'static str,
25    /// Transmission containing all tasks.
26    tx: flume::Sender<F>,
27    /// The maximum number of tasks that are expected to run concurrently at any point in time.
28    max_tasks: u64,
29    /// Vector containing all the metrics collected individually in each thread.
30    threads_metrics: Arc<Vec<Arc<ThreadMetrics>>>,
31}
32
33impl<F> AsyncPool<F> {
34    /// Returns the `name` of the [`AsyncPool`].
35    pub fn name(&self) -> &'static str {
36        self.name
37    }
38
39    /// Returns the [`AsyncPoolMetrics`] that are updated by the pool.
40    pub fn metrics(&self) -> AsyncPoolMetrics {
41        AsyncPoolMetrics {
42            max_tasks: self.max_tasks,
43            queue_size: self.tx.len() as u64,
44            threads_metrics: &self.threads_metrics,
45        }
46    }
47}
48
49impl<F> AsyncPool<F>
50where
51    F: Future<Output = ()> + Send + 'static,
52{
53    /// Creates a new [`AsyncPool`] based on the configuration specified by [`AsyncPoolBuilder`].
54    ///
55    /// This method initializes the dedicated worker threads and configures each executor with the defined
56    /// concurrency limits.
57    pub fn new<S>(mut builder: AsyncPoolBuilder<S>) -> io::Result<Self>
58    where
59        S: ThreadSpawn,
60    {
61        let pool_name = builder.pool_name.unwrap_or(DEFAULT_POOL_NAME);
62        let (tx, rx) = flume::bounded(builder.num_threads * 2);
63        let mut threads_metrics = Vec::with_capacity(builder.num_threads);
64
65        for thread_id in 0..builder.num_threads {
66            let rx = rx.clone();
67
68            let thread_name: Option<String> = builder.thread_name.as_mut().map(|f| f(thread_id));
69            let metrics = Arc::new(ThreadMetrics::default());
70            let task = MonitoredFuture::wrap_with_metrics(
71                Multiplexed::new(
72                    pool_name,
73                    builder.max_concurrency,
74                    rx.into_stream(),
75                    builder.task_panic_handler.clone(),
76                    metrics.clone(),
77                ),
78                metrics.raw_metrics.clone(),
79            );
80
81            let thread = Thread {
82                id: thread_id,
83                max_concurrency: builder.max_concurrency,
84                name: thread_name.clone(),
85                runtime: builder.runtime.clone(),
86                panic_handler: builder.thread_panic_handler.clone(),
87                task: task.boxed(),
88            };
89
90            threads_metrics.push(metrics);
91
92            builder.spawn_handler.spawn(thread)?;
93        }
94
95        Ok(Self {
96            name: pool_name,
97            tx,
98            max_tasks: (builder.num_threads * builder.max_concurrency) as u64,
99            threads_metrics: Arc::new(threads_metrics),
100        })
101    }
102
103    /// Schedules a future for execution within the [`AsyncPool`].
104    ///
105    /// The task is added to the pool's internal queue to be executed by an available worker thread.
106    ///
107    /// # Panics
108    ///
109    /// This method panics if all receivers have been dropped which can happen when all threads of
110    /// the pool panicked.
111    pub fn spawn(&self, future: F) {
112        assert!(
113            self.tx.send(future).is_ok(),
114            "failed to schedule task: all worker threads have terminated (either none were spawned or all have panicked)"
115        );
116    }
117
118    /// Asynchronously enqueues a future for execution within the [`AsyncPool`].
119    ///
120    /// This method awaits until the task is successfully added to the internal queue.
121    ///
122    /// # Panics
123    ///
124    /// This method panics if all receivers have been dropped which can happen when all threads of
125    /// the pool panicked.
126    pub async fn spawn_async(&self, future: F) {
127        assert!(
128            self.tx.send_async(future).await.is_ok(),
129            "failed to schedule task: all worker threads have terminated (either none were spawned or all have panicked)"
130        );
131    }
132}
133
134impl<F> Clone for AsyncPool<F> {
135    fn clone(&self) -> Self {
136        Self {
137            name: self.name,
138            tx: self.tx.clone(),
139            max_tasks: self.max_tasks,
140            threads_metrics: self.threads_metrics.clone(),
141        }
142    }
143}
144
145/// [`Thread`] represents a dedicated worker thread within an [`AsyncPool`] that executes scheduled tasks.
146pub struct Thread {
147    id: usize,
148    max_concurrency: usize,
149    name: Option<String>,
150    runtime: tokio::runtime::Handle,
151    panic_handler: Option<Arc<PanicHandler>>,
152    task: BoxFuture<'static, ()>,
153}
154
155impl Thread {
156    /// Returns the unique index assigned to this [`Thread`].
157    ///
158    /// The index can help identify the thread during debugging or logging.
159    pub fn id(&self) -> usize {
160        self.id
161    }
162
163    /// Returns the maximum number of concurrent tasks permitted on this [`Thread`].
164    ///
165    /// This reflects the concurrency limit configured via the [`AsyncPoolBuilder`].
166    pub fn max_concurrency(&self) -> usize {
167        self.max_concurrency
168    }
169
170    /// Returns the human-readable name of this [`Thread`], if one was set.
171    ///
172    /// Thread names can assist in monitoring and debugging the execution environment.
173    pub fn name(&self) -> Option<&str> {
174        self.name.as_deref()
175    }
176}
177
178impl Thread {
179    /// Runs the task multiplexer associated with this [`Thread`].
180    ///
181    /// This method drives the execution of tasks on the worker thread.
182    ///
183    /// # Panics
184    ///
185    /// Panics are either handled by the custom handler or propagated if no handler is specified.
186    pub fn run(self) {
187        let result =
188            std::panic::catch_unwind(AssertUnwindSafe(|| self.runtime.block_on(self.task)));
189
190        match (self.panic_handler, result) {
191            // Panic handler and error, we swallow the panic and invoke the callback.
192            (Some(panic_handler), Err(error)) => {
193                panic_handler(error);
194            }
195            // No panic handler and error, we propagate the panic.
196            (None, Err(error)) => {
197                std::panic::resume_unwind(error);
198            }
199            // Otherwise, we do nothing.
200            (_, Ok(())) => {}
201        }
202    }
203}
204
205/// [`ThreadSpawn`] defines how threads are spawned in an [`AsyncPool`].
206///
207/// This trait allows customization of thread creation (for example, setting names or adjusting stack sizes)
208/// without altering the core functionality of the pool.
209pub trait ThreadSpawn {
210    /// Spawns a new thread using the provided configuration.
211    fn spawn(&mut self, thread: Thread) -> io::Result<()>;
212}
213
214/// [`DefaultSpawn`] is the default implementation of [`ThreadSpawn`] that delegates to the system's
215/// standard thread creation mechanism.
216///
217/// It applies any provided thread name using the standard thread builder.
218#[derive(Clone)]
219pub struct DefaultSpawn;
220
221impl ThreadSpawn for DefaultSpawn {
222    fn spawn(&mut self, thread: Thread) -> io::Result<()> {
223        let mut b = std::thread::Builder::new();
224        if let Some(name) = thread.name() {
225            b = b.name(name.to_owned());
226        }
227        b.spawn(|| thread.run())?;
228
229        Ok(())
230    }
231}
232
233/// [`CustomSpawn`] is an alternative implementation of [`ThreadSpawn`] that uses a user-supplied closure
234/// for custom thread configuration.
235///
236/// This allows for fine-grained control over thread properties, enabling application-specific setups.
237#[derive(Clone)]
238pub struct CustomSpawn<B>(B);
239
240impl<B> CustomSpawn<B> {
241    /// Creates a new instance of [`CustomSpawn`] with the specified configuration closure.
242    pub fn new(spawn_handler: B) -> Self {
243        CustomSpawn(spawn_handler)
244    }
245}
246
247impl<B> ThreadSpawn for CustomSpawn<B>
248where
249    B: FnMut(Thread) -> io::Result<()>,
250{
251    /// Applies the custom configuration closure when spawning a new thread.
252    fn spawn(&mut self, thread: Thread) -> io::Result<()> {
253        self.0(thread)
254    }
255}
256
257#[cfg(test)]
258mod tests {
259    use std::future::Future;
260    use std::panic::AssertUnwindSafe;
261    use std::sync::atomic::AtomicBool;
262    use std::sync::{
263        Arc,
264        atomic::{AtomicUsize, Ordering},
265    };
266    use std::time::{Duration, Instant};
267
268    use futures::FutureExt;
269    use futures::future::BoxFuture;
270    use tokio::runtime::Runtime;
271    use tokio::sync::Semaphore;
272    use tokio::{runtime::Handle, time::sleep};
273
274    use crate::builder::AsyncPoolBuilder;
275    use crate::{AsyncPool, Thread};
276
277    struct TestBarrier {
278        semaphore: Arc<Semaphore>,
279        count: u32,
280    }
281
282    impl TestBarrier {
283        async fn new(count: u32) -> Self {
284            Self {
285                semaphore: Arc::new(Semaphore::new(count as usize)),
286                count,
287            }
288        }
289
290        async fn spawn<F, Fut>(&self, pool: &AsyncPool<BoxFuture<'static, ()>>, f: F)
291        where
292            F: FnOnce() -> Fut + Send + 'static,
293            Fut: Future<Output = ()> + Send + 'static,
294        {
295            let semaphore = self.semaphore.clone();
296            let permit = semaphore.acquire_owned().await.unwrap();
297            pool.spawn_async(
298                async move {
299                    f().await;
300                    drop(permit);
301                }
302                .boxed(),
303            )
304            .await;
305        }
306
307        async fn wait(&self) {
308            let _ = self.semaphore.acquire_many(self.count).await.unwrap();
309        }
310    }
311
312    #[tokio::test]
313    async fn test_async_pool_executes_all_tasks() {
314        let pool = AsyncPoolBuilder::new(Handle::current())
315            .num_threads(1)
316            .max_concurrency(2)
317            .build()
318            .unwrap();
319        let counter = Arc::new(AtomicUsize::new(0));
320        let barrier = TestBarrier::new(20).await;
321
322        // Spawn 20 tasks that wait briefly and then update the counter.
323        for _ in 0..20 {
324            let counter_clone = counter.clone();
325            barrier
326                .spawn(&pool, move || async move {
327                    sleep(Duration::from_millis(50)).await;
328                    counter_clone.fetch_add(1, Ordering::SeqCst);
329                })
330                .await;
331        }
332
333        barrier.wait().await;
334        assert_eq!(counter.load(Ordering::SeqCst), 20);
335    }
336
337    #[tokio::test]
338    async fn test_async_pool_executes_all_tasks_concurrently_with_single_thread() {
339        let pool = AsyncPoolBuilder::new(Handle::current())
340            .num_threads(1)
341            .max_concurrency(2)
342            .build()
343            .unwrap();
344
345        let start = Instant::now();
346        let barrier = TestBarrier::new(2).await;
347
348        // Spawn 2 tasks that each sleep for 200ms.
349        for _ in 0..2 {
350            barrier
351                .spawn(&pool, || async {
352                    sleep(Duration::from_millis(200)).await;
353                })
354                .await;
355        }
356
357        barrier.wait().await;
358
359        let elapsed = start.elapsed();
360        // If running concurrently, the overall time should be near 200ms (with some allowance).
361        assert!(
362            elapsed < Duration::from_millis(250),
363            "Elapsed time was too high: {elapsed:?}"
364        );
365    }
366
367    #[tokio::test]
368    async fn test_async_pool_executes_all_tasks_concurrently_with_multiple_threads() {
369        let pool = AsyncPoolBuilder::new(Handle::current())
370            .num_threads(2)
371            .max_concurrency(1)
372            .build()
373            .unwrap();
374
375        let start = Instant::now();
376        let barrier = TestBarrier::new(2).await;
377
378        // Spawn 2 tasks that each sleep for 200ms.
379        for _ in 0..2 {
380            barrier
381                .spawn(&pool, || async {
382                    sleep(Duration::from_millis(200)).await;
383                })
384                .await;
385        }
386
387        barrier.wait().await;
388
389        let elapsed = start.elapsed();
390        // If running concurrently, the overall time should be near 200ms (with some allowance).
391        assert!(
392            elapsed < Duration::from_millis(250),
393            "Elapsed time was too high: {elapsed:?}"
394        );
395    }
396
397    #[test]
398    fn test_thread_panic_handling() {
399        let runtime = Runtime::new().unwrap();
400        let has_panicked = Arc::new(AtomicBool::new(false));
401        let has_panicked_clone = has_panicked.clone();
402        let panic_handler = move |_| {
403            has_panicked_clone.store(true, Ordering::SeqCst);
404        };
405
406        Thread {
407            id: 0,
408            max_concurrency: 1,
409            name: Some("test-thread".into()),
410            runtime: runtime.handle().clone(),
411            panic_handler: Some(Arc::new(panic_handler)),
412            task: async move {
413                panic!("panicked");
414            }
415            .boxed(),
416        }
417        .run();
418
419        assert!(has_panicked.load(Ordering::SeqCst));
420    }
421
422    #[tokio::test]
423    async fn test_spawn_panics_if_no_threads_are_available() {
424        let pool = AsyncPoolBuilder::new(Handle::current())
425            .num_threads(0)
426            .max_concurrency(1)
427            .build()
428            .unwrap();
429
430        let result = std::panic::catch_unwind(AssertUnwindSafe(|| {
431            pool.spawn(async move {});
432        }));
433
434        assert!(result.is_err());
435    }
436}