relay_threading/
multiplexing.rs

1use std::future::Future;
2use std::panic::AssertUnwindSafe;
3use std::pin::Pin;
4use std::sync::Arc;
5use std::sync::atomic::Ordering;
6use std::task::{Context, Poll};
7
8use futures::FutureExt;
9use futures::future::CatchUnwind;
10use futures::stream::{FusedStream, FuturesUnordered, Stream};
11use pin_project_lite::pin_project;
12use tokio::task::Unconstrained;
13
14use crate::{PanicHandler, ThreadMetrics};
15
16pin_project! {
17    /// Manages concurrent execution of asynchronous tasks.
18    ///
19    /// This internal structure collects and drives futures concurrently, invoking a panic handler (if provided)
20    /// when a task encounters a panic.
21    struct Tasks<F> {
22        #[pin]
23        futures: FuturesUnordered<Unconstrained<CatchUnwind<AssertUnwindSafe<F>>>>,
24        panic_handler: Option<Arc<PanicHandler>>,
25    }
26}
27
28impl<F> Tasks<F> {
29    /// Creates a new task manager.
30    ///
31    /// This internal constructor initializes a new collection for tracking asynchronous tasks.
32    fn new(panic_handler: Option<Arc<PanicHandler>>) -> Self {
33        Self {
34            futures: FuturesUnordered::new(),
35            panic_handler,
36        }
37    }
38
39    /// Returns the number of tasks currently scheduled for execution.
40    fn len(&self) -> usize {
41        self.futures.len()
42    }
43
44    /// Returns whether there are no tasks scheduled.
45    fn is_empty(&self) -> bool {
46        self.len() == 0
47    }
48}
49
50impl<F> Tasks<F>
51where
52    F: Future<Output = ()>,
53{
54    /// Adds a future to the collection for concurrent execution.
55    fn push(&mut self, future: F) {
56        let future = AssertUnwindSafe(future).catch_unwind();
57        self.futures.push(tokio::task::unconstrained(future));
58    }
59
60    /// Drives the execution of collected tasks until a pending state is encountered.
61    ///
62    /// If a future panics and a panic handler is provided, the handler is invoked.
63    /// Otherwise, the panic is propagated.
64    ///
65    /// # Panics
66    ///
67    /// Panics are either handled by the custom handler or propagated if no handler is specified.
68    fn poll_tasks_until_pending(self: Pin<&mut Self>, cx: &mut Context<'_>) {
69        let mut this = self.project();
70
71        loop {
72            // If the unordered pool of futures is terminated, we stop polling.
73            if this.futures.is_terminated() {
74                return;
75            }
76
77            // If we don't get a Ready(Some(_)), it means we are now polling a pending future or the
78            // stream has ended, in that case we return.
79            let Poll::Ready(Some(result)) = this.futures.as_mut().poll_next(cx) else {
80                return;
81            };
82
83            // If there is an error, it means that the future has panicked, we want to notify this.
84            match (this.panic_handler.as_ref(), result) {
85                // Panic handler and error, we swallow the panic and invoke the callback.
86                (Some(panic_handler), Err(error)) => {
87                    panic_handler(error);
88                }
89                // No panic handler and error, we propagate the panic.
90                (None, Err(error)) => {
91                    std::panic::resume_unwind(error);
92                }
93                // Otherwise, we do nothing.
94                (_, Ok(())) => {}
95            }
96        }
97    }
98}
99
100pin_project! {
101    /// [`Multiplexed`] is a future that concurrently schedules asynchronous tasks from a stream while ensuring that
102    /// the number of concurrently executing tasks does not exceed a specified limit.
103    ///
104    /// This multiplexer is primarily used by the [`AsyncPool`] to manage task execution on worker threads.
105    pub struct Multiplexed<S, F> {
106        pool_name: &'static str,
107        max_concurrency: usize,
108        #[pin]
109        rx: S,
110        #[pin]
111        tasks: Tasks<F>,
112        metrics: Arc<ThreadMetrics>
113    }
114}
115
116impl<S, F> Multiplexed<S, F>
117where
118    S: Stream<Item = F>,
119{
120    /// Creates a new [`Multiplexed`] instance with a defined concurrency limit and a stream of tasks.
121    ///
122    /// Tasks from the stream will be scheduled for execution concurrently, and an optional panic handler
123    /// can be provided to manage errors during task execution.
124    pub fn new(
125        pool_name: &'static str,
126        max_concurrency: usize,
127        rx: S,
128        panic_handler: Option<Arc<PanicHandler>>,
129        metrics: Arc<ThreadMetrics>,
130    ) -> Self {
131        Self {
132            pool_name,
133            max_concurrency,
134            rx,
135            tasks: Tasks::new(panic_handler),
136            metrics,
137        }
138    }
139}
140
141impl<S, F> Future for Multiplexed<S, F>
142where
143    S: FusedStream<Item = F>,
144    F: Future<Output = ()>,
145{
146    type Output = ();
147
148    /// Polls the [`Multiplexed`] future to drive task execution.
149    ///
150    /// This method repeatedly schedules new tasks from the stream while enforcing the concurrency limit.
151    /// It completes when the stream is exhausted and no active tasks remain.
152    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
153        let mut this = self.project();
154
155        loop {
156            // We report before polling since we might have only blocking tasks meaning that the
157            // measure after the `poll_tasks_until_pending` will return 0, since all futures will
158            // be completed.
159            let before_len = this.tasks.len() as u64;
160            this.metrics
161                .active_tasks
162                .store(before_len, Ordering::Relaxed);
163
164            this.tasks.as_mut().poll_tasks_until_pending(cx);
165
166            // We also want to report after polling since we might have finished polling some futures
167            // and some not.
168            let after_len = this.tasks.len() as u64;
169            this.metrics
170                .active_tasks
171                .store(after_len, Ordering::Relaxed);
172
173            // We calculate how many tasks have been driven to completion.
174            if let Some(finished_tasks) = before_len.checked_sub(after_len) {
175                this.metrics
176                    .finished_tasks
177                    .fetch_add(finished_tasks, Ordering::Relaxed);
178            }
179
180            // If we can't get anymore tasks, and we don't have anything else to process, we report
181            // ready. Otherwise, if we have something to process, we report pending.
182            if this.tasks.is_empty() && this.rx.is_terminated() {
183                return Poll::Ready(());
184            } else if this.rx.is_terminated() {
185                return Poll::Pending;
186            }
187
188            // If we could accept tasks, but we don't have space we report pending.
189            if this.tasks.len() >= *this.max_concurrency {
190                return Poll::Pending;
191            }
192
193            // At this point, we are free to start driving another future.
194            match this.rx.as_mut().poll_next(cx) {
195                Poll::Ready(Some(task)) => {
196                    this.tasks.push(task);
197                }
198                // The stream is exhausted and there are no remaining tasks.
199                Poll::Ready(None) if this.tasks.is_empty() => return Poll::Ready(()),
200                // The stream is exhausted but tasks remain active. Now we need to make sure we
201                // stop polling the stream and just process tasks.
202                Poll::Ready(None) => return Poll::Pending,
203                Poll::Pending => return Poll::Pending,
204            }
205        }
206    }
207}
208
209#[cfg(test)]
210mod tests {
211    use futures::{FutureExt, future::BoxFuture};
212    use std::future;
213    use std::sync::atomic::AtomicBool;
214    use std::sync::{
215        Arc, Mutex,
216        atomic::{AtomicUsize, Ordering},
217    };
218    use std::time::Duration;
219
220    use super::*;
221
222    fn future_with(block: impl FnOnce() + Send + 'static) -> BoxFuture<'static, ()> {
223        let fut = async {
224            // Yield to allow a pending state during polling.
225            tokio::task::yield_now().await;
226            block();
227        };
228
229        fut.boxed()
230    }
231
232    fn mock_metrics() -> Arc<ThreadMetrics> {
233        Arc::new(ThreadMetrics::default())
234    }
235
236    #[test]
237    fn test_multiplexer_with_no_futures() {
238        let (_, rx) = flume::bounded::<BoxFuture<'static, _>>(10);
239        futures::executor::block_on(Multiplexed::new(
240            "my_pool",
241            1,
242            rx.into_stream(),
243            None,
244            mock_metrics(),
245        ));
246    }
247
248    #[test]
249    fn test_multiplexer_with_panic_handler_panicking_future() {
250        let panic_handler_called = Arc::new(AtomicBool::new(false));
251        let count = Arc::new(AtomicUsize::new(0));
252        let (tx, rx) = flume::bounded(10);
253
254        let count_clone = count.clone();
255        tx.send(future_with(move || {
256            count_clone.fetch_add(1, Ordering::SeqCst);
257            panic!("panicked");
258        }))
259        .unwrap();
260
261        drop(tx);
262
263        let panic_handler_called_clone = panic_handler_called.clone();
264        let panic_handler = move |_| {
265            panic_handler_called_clone.store(true, Ordering::SeqCst);
266        };
267        futures::executor::block_on(Multiplexed::new(
268            "my_pool",
269            1,
270            rx.into_stream(),
271            Some(Arc::new(panic_handler)),
272            mock_metrics(),
273        ));
274
275        // The count is expected to have been incremented and the handler called.
276        assert_eq!(count.load(Ordering::SeqCst), 1);
277        assert!(panic_handler_called.load(Ordering::SeqCst));
278    }
279
280    #[test]
281    fn test_multiplexer_with_no_panic_handler_panicking_future() {
282        let count = Arc::new(AtomicUsize::new(0));
283        let (tx, rx) = flume::bounded(10);
284
285        let count_clone = count.clone();
286        tx.send(future_with(move || {
287            count_clone.fetch_add(1, Ordering::SeqCst);
288            panic!("panicked");
289        }))
290        .unwrap();
291
292        drop(tx);
293
294        let result = std::panic::catch_unwind(AssertUnwindSafe(|| {
295            futures::executor::block_on(Multiplexed::new(
296                "my_pool",
297                1,
298                rx.into_stream(),
299                None,
300                mock_metrics(),
301            ))
302        }));
303
304        // The count is expected to have been incremented and the handler called.
305        assert_eq!(count.load(Ordering::SeqCst), 1);
306        assert!(result.is_err());
307    }
308
309    #[test]
310    fn test_multiplexer_with_one_concurrency_and_one_future() {
311        let count = Arc::new(AtomicUsize::new(0));
312        let (tx, rx) = flume::bounded(10);
313
314        let count_clone = count.clone();
315        tx.send(future_with(move || {
316            count_clone.fetch_add(1, Ordering::SeqCst);
317        }))
318        .unwrap();
319
320        drop(tx);
321
322        futures::executor::block_on(Multiplexed::new(
323            "my_pool",
324            1,
325            rx.into_stream(),
326            None,
327            mock_metrics(),
328        ));
329
330        // The count is expected to have been incremented.
331        assert_eq!(count.load(Ordering::SeqCst), 1);
332    }
333
334    #[test]
335    fn test_multiplexer_with_one_concurrency_and_multiple_futures() {
336        let entries = Arc::new(Mutex::new(Vec::new()));
337        let (tx, rx) = flume::bounded(10);
338
339        for i in 0..5 {
340            let entries_clone = entries.clone();
341            tx.send(future_with(move || {
342                entries_clone.lock().unwrap().push(i);
343            }))
344            .unwrap();
345        }
346
347        drop(tx);
348
349        futures::executor::block_on(Multiplexed::new(
350            "my_pool",
351            1,
352            rx.into_stream(),
353            None,
354            mock_metrics(),
355        ));
356
357        // The order of completion is expected to match the order of submission.
358        assert_eq!(*entries.lock().unwrap(), (0..5).collect::<Vec<_>>());
359    }
360
361    #[test]
362    fn test_multiplexer_with_multiple_concurrency_and_one_future() {
363        let count = Arc::new(AtomicUsize::new(0));
364        let (tx, rx) = flume::bounded(10);
365
366        let count_clone = count.clone();
367        tx.send(future_with(move || {
368            count_clone.fetch_add(1, Ordering::SeqCst);
369        }))
370        .unwrap();
371
372        drop(tx);
373
374        futures::executor::block_on(Multiplexed::new(
375            "my_pool",
376            5,
377            rx.into_stream(),
378            None,
379            mock_metrics(),
380        ));
381
382        // The count is expected to have been incremented.
383        assert_eq!(count.load(Ordering::SeqCst), 1);
384    }
385
386    #[test]
387    fn test_multiplexer_with_multiple_concurrency_and_multiple_futures() {
388        let entries = Arc::new(Mutex::new(Vec::new()));
389        let (tx, rx) = flume::bounded(10);
390
391        for i in 0..5 {
392            let entries_clone = entries.clone();
393            tx.send(future_with(move || {
394                entries_clone.lock().unwrap().push(i);
395            }))
396            .unwrap();
397        }
398
399        drop(tx);
400
401        futures::executor::block_on(Multiplexed::new(
402            "my_pool",
403            5,
404            rx.into_stream(),
405            None,
406            mock_metrics(),
407        ));
408
409        // The order of completion is expected to be the same as the order of submission.
410        assert_eq!(*entries.lock().unwrap(), (0..5).collect::<Vec<_>>());
411    }
412
413    #[test]
414    fn test_multiplexer_with_multiple_concurrency_and_less_multiple_futures() {
415        let entries = Arc::new(Mutex::new(Vec::new()));
416        let (tx, rx) = flume::bounded(10);
417
418        // We send 3 futures with a concurrency of 5, to make sure that if the stream returns
419        // `Poll::Ready(None)` the system will stop polling from the stream and continue driving
420        // the remaining futures.
421        for i in 0..3 {
422            let entries_clone = entries.clone();
423            tx.send(future_with(move || {
424                entries_clone.lock().unwrap().push(i);
425            }))
426            .unwrap();
427        }
428
429        drop(tx);
430
431        futures::executor::block_on(Multiplexed::new(
432            "my_pool",
433            5,
434            rx.into_stream(),
435            None,
436            mock_metrics(),
437        ));
438
439        // The order of completion is expected to be the same as the order of submission.
440        assert_eq!(*entries.lock().unwrap(), (0..3).collect::<Vec<_>>());
441    }
442
443    #[test]
444    fn test_multiplexer_with_multiple_concurrency_and_multiple_futures_from_multiple_threads() {
445        let entries = Arc::new(Mutex::new(Vec::new()));
446        let (tx, rx) = flume::bounded(10);
447
448        let mut handles = vec![];
449        for i in 0..5 {
450            let entries_clone = entries.clone();
451            let tx_clone = tx.clone();
452            handles.push(std::thread::spawn(move || {
453                tx_clone
454                    .send(future_with(move || {
455                        entries_clone.lock().unwrap().push(i);
456                    }))
457                    .unwrap();
458            }));
459        }
460
461        for handle in handles {
462            handle.join().unwrap();
463        }
464
465        drop(tx);
466
467        futures::executor::block_on(Multiplexed::new(
468            "my_pool",
469            5,
470            rx.into_stream(),
471            None,
472            mock_metrics(),
473        ));
474
475        // The order of completion may vary; verify that all expected elements are present.
476        let mut entries = entries.lock().unwrap();
477        entries.sort();
478        assert_eq!(*entries, (0..5).collect::<Vec<_>>());
479    }
480
481    #[test]
482    fn test_catch_unwind_future_handles_panics() {
483        let future = AssertUnwindSafe(async {
484            panic!("panicked");
485        })
486        .catch_unwind();
487
488        // The future should complete without propagating the panic but propagating the error.
489        assert!(futures::executor::block_on(future).is_err());
490
491        // Verify that non-panicking tasks complete normally.
492        let future = AssertUnwindSafe(async {
493            // A normal future that completes.
494        })
495        .catch_unwind();
496
497        // The future should successfully complete.
498        assert!(futures::executor::block_on(future).is_ok());
499    }
500
501    #[tokio::test]
502    async fn test_multiplexer_emits_metrics() {
503        let (tx, rx) = flume::bounded::<BoxFuture<'static, _>>(10);
504        let metrics = mock_metrics();
505
506        tx.send(future::pending().boxed()).unwrap();
507
508        drop(tx);
509
510        // We spawn the future, which will be indefinitely pending since it's never woken up.
511        #[allow(clippy::disallowed_methods)]
512        tokio::spawn(Multiplexed::new(
513            "my_pool",
514            1,
515            rx.into_stream(),
516            None,
517            metrics.clone(),
518        ));
519
520        // We sleep to let the pending be processed by the `Multiplexed` so that the metric is then
521        // correctly emitted.
522        tokio::time::sleep(Duration::from_millis(1)).await;
523
524        // We expect that now we have 1 active task, the one that is indefinitely pending.
525        assert_eq!(metrics.active_tasks.load(Ordering::Relaxed), 1);
526        // An indefinitely pending task is never finished.
527        assert_eq!(metrics.finished_tasks.load(Ordering::Relaxed), 0);
528    }
529}