1use std::future::Future;
2use std::panic::AssertUnwindSafe;
3use std::pin::Pin;
4use std::sync::Arc;
5use std::sync::atomic::Ordering;
6use std::task::{Context, Poll};
7
8use futures::FutureExt;
9use futures::future::CatchUnwind;
10use futures::stream::{FusedStream, FuturesUnordered, Stream};
11use pin_project_lite::pin_project;
12use tokio::task::Unconstrained;
13
14use crate::{PanicHandler, ThreadMetrics};
15
16pin_project! {
17 struct Tasks<F> {
22 #[pin]
23 futures: FuturesUnordered<Unconstrained<CatchUnwind<AssertUnwindSafe<F>>>>,
24 panic_handler: Option<Arc<PanicHandler>>,
25 }
26}
27
28impl<F> Tasks<F> {
29 fn new(panic_handler: Option<Arc<PanicHandler>>) -> Self {
33 Self {
34 futures: FuturesUnordered::new(),
35 panic_handler,
36 }
37 }
38
39 fn len(&self) -> usize {
41 self.futures.len()
42 }
43
44 fn is_empty(&self) -> bool {
46 self.len() == 0
47 }
48}
49
50impl<F> Tasks<F>
51where
52 F: Future<Output = ()>,
53{
54 fn push(&mut self, future: F) {
56 let future = AssertUnwindSafe(future).catch_unwind();
57 self.futures.push(tokio::task::unconstrained(future));
58 }
59
60 fn poll_tasks_until_pending(self: Pin<&mut Self>, cx: &mut Context<'_>) {
69 let mut this = self.project();
70
71 loop {
72 if this.futures.is_terminated() {
74 return;
75 }
76
77 let Poll::Ready(Some(result)) = this.futures.as_mut().poll_next(cx) else {
80 return;
81 };
82
83 match (this.panic_handler.as_ref(), result) {
85 (Some(panic_handler), Err(error)) => {
87 panic_handler(error);
88 }
89 (None, Err(error)) => {
91 std::panic::resume_unwind(error);
92 }
93 (_, Ok(())) => {}
95 }
96 }
97 }
98}
99
100pin_project! {
101 pub struct Multiplexed<S, F> {
106 pool_name: &'static str,
107 max_concurrency: usize,
108 #[pin]
109 rx: S,
110 #[pin]
111 tasks: Tasks<F>,
112 metrics: Arc<ThreadMetrics>
113 }
114}
115
116impl<S, F> Multiplexed<S, F>
117where
118 S: Stream<Item = F>,
119{
120 pub fn new(
125 pool_name: &'static str,
126 max_concurrency: usize,
127 rx: S,
128 panic_handler: Option<Arc<PanicHandler>>,
129 metrics: Arc<ThreadMetrics>,
130 ) -> Self {
131 Self {
132 pool_name,
133 max_concurrency,
134 rx,
135 tasks: Tasks::new(panic_handler),
136 metrics,
137 }
138 }
139}
140
141impl<S, F> Future for Multiplexed<S, F>
142where
143 S: FusedStream<Item = F>,
144 F: Future<Output = ()>,
145{
146 type Output = ();
147
148 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
153 let mut this = self.project();
154
155 loop {
156 let before_len = this.tasks.len() as u64;
160 this.metrics
161 .active_tasks
162 .store(before_len, Ordering::Relaxed);
163
164 this.tasks.as_mut().poll_tasks_until_pending(cx);
165
166 let after_len = this.tasks.len() as u64;
169 this.metrics
170 .active_tasks
171 .store(after_len, Ordering::Relaxed);
172
173 if let Some(finished_tasks) = before_len.checked_sub(after_len) {
175 this.metrics
176 .finished_tasks
177 .fetch_add(finished_tasks, Ordering::Relaxed);
178 }
179
180 if this.tasks.is_empty() && this.rx.is_terminated() {
183 return Poll::Ready(());
184 } else if this.rx.is_terminated() {
185 return Poll::Pending;
186 }
187
188 if this.tasks.len() >= *this.max_concurrency {
190 return Poll::Pending;
191 }
192
193 match this.rx.as_mut().poll_next(cx) {
195 Poll::Ready(Some(task)) => {
196 this.tasks.push(task);
197 }
198 Poll::Ready(None) if this.tasks.is_empty() => return Poll::Ready(()),
200 Poll::Ready(None) => return Poll::Pending,
203 Poll::Pending => return Poll::Pending,
204 }
205 }
206 }
207}
208
209#[cfg(test)]
210mod tests {
211 use futures::{FutureExt, future::BoxFuture};
212 use std::future;
213 use std::sync::atomic::AtomicBool;
214 use std::sync::{
215 Arc, Mutex,
216 atomic::{AtomicUsize, Ordering},
217 };
218 use std::time::Duration;
219
220 use super::*;
221
222 fn future_with(block: impl FnOnce() + Send + 'static) -> BoxFuture<'static, ()> {
223 let fut = async {
224 tokio::task::yield_now().await;
226 block();
227 };
228
229 fut.boxed()
230 }
231
232 fn mock_metrics() -> Arc<ThreadMetrics> {
233 Arc::new(ThreadMetrics::default())
234 }
235
236 #[test]
237 fn test_multiplexer_with_no_futures() {
238 let (_, rx) = flume::bounded::<BoxFuture<'static, _>>(10);
239 futures::executor::block_on(Multiplexed::new(
240 "my_pool",
241 1,
242 rx.into_stream(),
243 None,
244 mock_metrics(),
245 ));
246 }
247
248 #[test]
249 fn test_multiplexer_with_panic_handler_panicking_future() {
250 let panic_handler_called = Arc::new(AtomicBool::new(false));
251 let count = Arc::new(AtomicUsize::new(0));
252 let (tx, rx) = flume::bounded(10);
253
254 let count_clone = count.clone();
255 tx.send(future_with(move || {
256 count_clone.fetch_add(1, Ordering::SeqCst);
257 panic!("panicked");
258 }))
259 .unwrap();
260
261 drop(tx);
262
263 let panic_handler_called_clone = panic_handler_called.clone();
264 let panic_handler = move |_| {
265 panic_handler_called_clone.store(true, Ordering::SeqCst);
266 };
267 futures::executor::block_on(Multiplexed::new(
268 "my_pool",
269 1,
270 rx.into_stream(),
271 Some(Arc::new(panic_handler)),
272 mock_metrics(),
273 ));
274
275 assert_eq!(count.load(Ordering::SeqCst), 1);
277 assert!(panic_handler_called.load(Ordering::SeqCst));
278 }
279
280 #[test]
281 fn test_multiplexer_with_no_panic_handler_panicking_future() {
282 let count = Arc::new(AtomicUsize::new(0));
283 let (tx, rx) = flume::bounded(10);
284
285 let count_clone = count.clone();
286 tx.send(future_with(move || {
287 count_clone.fetch_add(1, Ordering::SeqCst);
288 panic!("panicked");
289 }))
290 .unwrap();
291
292 drop(tx);
293
294 let result = std::panic::catch_unwind(AssertUnwindSafe(|| {
295 futures::executor::block_on(Multiplexed::new(
296 "my_pool",
297 1,
298 rx.into_stream(),
299 None,
300 mock_metrics(),
301 ))
302 }));
303
304 assert_eq!(count.load(Ordering::SeqCst), 1);
306 assert!(result.is_err());
307 }
308
309 #[test]
310 fn test_multiplexer_with_one_concurrency_and_one_future() {
311 let count = Arc::new(AtomicUsize::new(0));
312 let (tx, rx) = flume::bounded(10);
313
314 let count_clone = count.clone();
315 tx.send(future_with(move || {
316 count_clone.fetch_add(1, Ordering::SeqCst);
317 }))
318 .unwrap();
319
320 drop(tx);
321
322 futures::executor::block_on(Multiplexed::new(
323 "my_pool",
324 1,
325 rx.into_stream(),
326 None,
327 mock_metrics(),
328 ));
329
330 assert_eq!(count.load(Ordering::SeqCst), 1);
332 }
333
334 #[test]
335 fn test_multiplexer_with_one_concurrency_and_multiple_futures() {
336 let entries = Arc::new(Mutex::new(Vec::new()));
337 let (tx, rx) = flume::bounded(10);
338
339 for i in 0..5 {
340 let entries_clone = entries.clone();
341 tx.send(future_with(move || {
342 entries_clone.lock().unwrap().push(i);
343 }))
344 .unwrap();
345 }
346
347 drop(tx);
348
349 futures::executor::block_on(Multiplexed::new(
350 "my_pool",
351 1,
352 rx.into_stream(),
353 None,
354 mock_metrics(),
355 ));
356
357 assert_eq!(*entries.lock().unwrap(), (0..5).collect::<Vec<_>>());
359 }
360
361 #[test]
362 fn test_multiplexer_with_multiple_concurrency_and_one_future() {
363 let count = Arc::new(AtomicUsize::new(0));
364 let (tx, rx) = flume::bounded(10);
365
366 let count_clone = count.clone();
367 tx.send(future_with(move || {
368 count_clone.fetch_add(1, Ordering::SeqCst);
369 }))
370 .unwrap();
371
372 drop(tx);
373
374 futures::executor::block_on(Multiplexed::new(
375 "my_pool",
376 5,
377 rx.into_stream(),
378 None,
379 mock_metrics(),
380 ));
381
382 assert_eq!(count.load(Ordering::SeqCst), 1);
384 }
385
386 #[test]
387 fn test_multiplexer_with_multiple_concurrency_and_multiple_futures() {
388 let entries = Arc::new(Mutex::new(Vec::new()));
389 let (tx, rx) = flume::bounded(10);
390
391 for i in 0..5 {
392 let entries_clone = entries.clone();
393 tx.send(future_with(move || {
394 entries_clone.lock().unwrap().push(i);
395 }))
396 .unwrap();
397 }
398
399 drop(tx);
400
401 futures::executor::block_on(Multiplexed::new(
402 "my_pool",
403 5,
404 rx.into_stream(),
405 None,
406 mock_metrics(),
407 ));
408
409 assert_eq!(*entries.lock().unwrap(), (0..5).collect::<Vec<_>>());
411 }
412
413 #[test]
414 fn test_multiplexer_with_multiple_concurrency_and_less_multiple_futures() {
415 let entries = Arc::new(Mutex::new(Vec::new()));
416 let (tx, rx) = flume::bounded(10);
417
418 for i in 0..3 {
422 let entries_clone = entries.clone();
423 tx.send(future_with(move || {
424 entries_clone.lock().unwrap().push(i);
425 }))
426 .unwrap();
427 }
428
429 drop(tx);
430
431 futures::executor::block_on(Multiplexed::new(
432 "my_pool",
433 5,
434 rx.into_stream(),
435 None,
436 mock_metrics(),
437 ));
438
439 assert_eq!(*entries.lock().unwrap(), (0..3).collect::<Vec<_>>());
441 }
442
443 #[test]
444 fn test_multiplexer_with_multiple_concurrency_and_multiple_futures_from_multiple_threads() {
445 let entries = Arc::new(Mutex::new(Vec::new()));
446 let (tx, rx) = flume::bounded(10);
447
448 let mut handles = vec![];
449 for i in 0..5 {
450 let entries_clone = entries.clone();
451 let tx_clone = tx.clone();
452 handles.push(std::thread::spawn(move || {
453 tx_clone
454 .send(future_with(move || {
455 entries_clone.lock().unwrap().push(i);
456 }))
457 .unwrap();
458 }));
459 }
460
461 for handle in handles {
462 handle.join().unwrap();
463 }
464
465 drop(tx);
466
467 futures::executor::block_on(Multiplexed::new(
468 "my_pool",
469 5,
470 rx.into_stream(),
471 None,
472 mock_metrics(),
473 ));
474
475 let mut entries = entries.lock().unwrap();
477 entries.sort();
478 assert_eq!(*entries, (0..5).collect::<Vec<_>>());
479 }
480
481 #[test]
482 fn test_catch_unwind_future_handles_panics() {
483 let future = AssertUnwindSafe(async {
484 panic!("panicked");
485 })
486 .catch_unwind();
487
488 assert!(futures::executor::block_on(future).is_err());
490
491 let future = AssertUnwindSafe(async {
493 })
495 .catch_unwind();
496
497 assert!(futures::executor::block_on(future).is_ok());
499 }
500
501 #[tokio::test]
502 async fn test_multiplexer_emits_metrics() {
503 let (tx, rx) = flume::bounded::<BoxFuture<'static, _>>(10);
504 let metrics = mock_metrics();
505
506 tx.send(future::pending().boxed()).unwrap();
507
508 drop(tx);
509
510 #[allow(clippy::disallowed_methods)]
512 tokio::spawn(Multiplexed::new(
513 "my_pool",
514 1,
515 rx.into_stream(),
516 None,
517 metrics.clone(),
518 ));
519
520 tokio::time::sleep(Duration::from_millis(1)).await;
523
524 assert_eq!(metrics.active_tasks.load(Ordering::Relaxed), 1);
526 assert_eq!(metrics.finished_tasks.load(Ordering::Relaxed), 0);
528 }
529}