1use std::future::Future;
2use std::io;
3use std::panic::AssertUnwindSafe;
4use std::sync::Arc;
5
6use crate::builder::AsyncPoolBuilder;
7use crate::metrics::AsyncPoolMetrics;
8use crate::multiplexing::Multiplexed;
9use crate::{PanicHandler, ThreadMetrics};
10use futures::FutureExt;
11use futures::future::BoxFuture;
12use relay_system::MonitoredFuture;
13
14const DEFAULT_POOL_NAME: &str = "unnamed";
16
17#[derive(Debug)]
22pub struct AsyncPool<F> {
23 name: &'static str,
25 tx: flume::Sender<F>,
27 max_tasks: u64,
29 threads_metrics: Arc<Vec<Arc<ThreadMetrics>>>,
31}
32
33impl<F> AsyncPool<F> {
34 pub fn name(&self) -> &'static str {
36 self.name
37 }
38
39 pub fn metrics(&self) -> AsyncPoolMetrics {
41 AsyncPoolMetrics {
42 max_tasks: self.max_tasks,
43 queue_size: self.tx.len() as u64,
44 threads_metrics: &self.threads_metrics,
45 }
46 }
47}
48
49impl<F> AsyncPool<F>
50where
51 F: Future<Output = ()> + Send + 'static,
52{
53 pub fn new<S>(mut builder: AsyncPoolBuilder<S>) -> io::Result<Self>
58 where
59 S: ThreadSpawn,
60 {
61 let pool_name = builder.pool_name.unwrap_or(DEFAULT_POOL_NAME);
62 let (tx, rx) = flume::bounded(builder.num_threads * 2);
63 let mut threads_metrics = Vec::with_capacity(builder.num_threads);
64
65 for thread_id in 0..builder.num_threads {
66 let rx = rx.clone();
67
68 let thread_name: Option<String> = builder.thread_name.as_mut().map(|f| f(thread_id));
69 let metrics = Arc::new(ThreadMetrics::default());
70 let task = MonitoredFuture::wrap_with_metrics(
71 Multiplexed::new(
72 pool_name,
73 builder.max_concurrency,
74 rx.into_stream(),
75 builder.task_panic_handler.clone(),
76 metrics.clone(),
77 ),
78 metrics.raw_metrics.clone(),
79 );
80
81 let thread = Thread {
82 id: thread_id,
83 max_concurrency: builder.max_concurrency,
84 name: thread_name.clone(),
85 runtime: builder.runtime.clone(),
86 panic_handler: builder.thread_panic_handler.clone(),
87 task: task.boxed(),
88 };
89
90 threads_metrics.push(metrics);
91
92 builder.spawn_handler.spawn(thread)?;
93 }
94
95 Ok(Self {
96 name: pool_name,
97 tx,
98 max_tasks: (builder.num_threads * builder.max_concurrency) as u64,
99 threads_metrics: Arc::new(threads_metrics),
100 })
101 }
102
103 pub fn spawn(&self, future: F) {
112 assert!(
113 self.tx.send(future).is_ok(),
114 "failed to schedule task: all worker threads have terminated (either none were spawned or all have panicked)"
115 );
116 }
117
118 pub async fn spawn_async(&self, future: F) {
127 assert!(
128 self.tx.send_async(future).await.is_ok(),
129 "failed to schedule task: all worker threads have terminated (either none were spawned or all have panicked)"
130 );
131 }
132}
133
134impl<F> Clone for AsyncPool<F> {
135 fn clone(&self) -> Self {
136 Self {
137 name: self.name,
138 tx: self.tx.clone(),
139 max_tasks: self.max_tasks,
140 threads_metrics: self.threads_metrics.clone(),
141 }
142 }
143}
144
145pub struct Thread {
147 id: usize,
148 max_concurrency: usize,
149 name: Option<String>,
150 runtime: tokio::runtime::Handle,
151 panic_handler: Option<Arc<PanicHandler>>,
152 task: BoxFuture<'static, ()>,
153}
154
155impl Thread {
156 pub fn id(&self) -> usize {
160 self.id
161 }
162
163 pub fn max_concurrency(&self) -> usize {
167 self.max_concurrency
168 }
169
170 pub fn name(&self) -> Option<&str> {
174 self.name.as_deref()
175 }
176}
177
178impl Thread {
179 pub fn run(self) {
187 let result =
188 std::panic::catch_unwind(AssertUnwindSafe(|| self.runtime.block_on(self.task)));
189
190 match (self.panic_handler, result) {
191 (Some(panic_handler), Err(error)) => {
193 panic_handler(error);
194 }
195 (None, Err(error)) => {
197 std::panic::resume_unwind(error);
198 }
199 (_, Ok(())) => {}
201 }
202 }
203}
204
205pub trait ThreadSpawn {
210 fn spawn(&mut self, thread: Thread) -> io::Result<()>;
212}
213
214#[derive(Clone)]
219pub struct DefaultSpawn;
220
221impl ThreadSpawn for DefaultSpawn {
222 fn spawn(&mut self, thread: Thread) -> io::Result<()> {
223 let mut b = std::thread::Builder::new();
224 if let Some(name) = thread.name() {
225 b = b.name(name.to_owned());
226 }
227 b.spawn(|| thread.run())?;
228
229 Ok(())
230 }
231}
232
233#[derive(Clone)]
238pub struct CustomSpawn<B>(B);
239
240impl<B> CustomSpawn<B> {
241 pub fn new(spawn_handler: B) -> Self {
243 CustomSpawn(spawn_handler)
244 }
245}
246
247impl<B> ThreadSpawn for CustomSpawn<B>
248where
249 B: FnMut(Thread) -> io::Result<()>,
250{
251 fn spawn(&mut self, thread: Thread) -> io::Result<()> {
253 self.0(thread)
254 }
255}
256
257#[cfg(test)]
258mod tests {
259 use std::future::Future;
260 use std::panic::AssertUnwindSafe;
261 use std::sync::atomic::AtomicBool;
262 use std::sync::{
263 Arc,
264 atomic::{AtomicUsize, Ordering},
265 };
266 use std::time::{Duration, Instant};
267
268 use futures::FutureExt;
269 use futures::future::BoxFuture;
270 use tokio::runtime::Runtime;
271 use tokio::sync::Semaphore;
272 use tokio::{runtime::Handle, time::sleep};
273
274 use crate::builder::AsyncPoolBuilder;
275 use crate::{AsyncPool, Thread};
276
277 struct TestBarrier {
278 semaphore: Arc<Semaphore>,
279 count: u32,
280 }
281
282 impl TestBarrier {
283 async fn new(count: u32) -> Self {
284 Self {
285 semaphore: Arc::new(Semaphore::new(count as usize)),
286 count,
287 }
288 }
289
290 async fn spawn<F, Fut>(&self, pool: &AsyncPool<BoxFuture<'static, ()>>, f: F)
291 where
292 F: FnOnce() -> Fut + Send + 'static,
293 Fut: Future<Output = ()> + Send + 'static,
294 {
295 let semaphore = self.semaphore.clone();
296 let permit = semaphore.acquire_owned().await.unwrap();
297 pool.spawn_async(
298 async move {
299 f().await;
300 drop(permit);
301 }
302 .boxed(),
303 )
304 .await;
305 }
306
307 async fn wait(&self) {
308 let _ = self.semaphore.acquire_many(self.count).await.unwrap();
309 }
310 }
311
312 #[tokio::test]
313 async fn test_async_pool_executes_all_tasks() {
314 let pool = AsyncPoolBuilder::new(Handle::current())
315 .num_threads(1)
316 .max_concurrency(2)
317 .build()
318 .unwrap();
319 let counter = Arc::new(AtomicUsize::new(0));
320 let barrier = TestBarrier::new(20).await;
321
322 for _ in 0..20 {
324 let counter_clone = counter.clone();
325 barrier
326 .spawn(&pool, move || async move {
327 sleep(Duration::from_millis(50)).await;
328 counter_clone.fetch_add(1, Ordering::SeqCst);
329 })
330 .await;
331 }
332
333 barrier.wait().await;
334 assert_eq!(counter.load(Ordering::SeqCst), 20);
335 }
336
337 #[tokio::test]
338 async fn test_async_pool_executes_all_tasks_concurrently_with_single_thread() {
339 let pool = AsyncPoolBuilder::new(Handle::current())
340 .num_threads(1)
341 .max_concurrency(2)
342 .build()
343 .unwrap();
344
345 let start = Instant::now();
346 let barrier = TestBarrier::new(2).await;
347
348 for _ in 0..2 {
350 barrier
351 .spawn(&pool, || async {
352 sleep(Duration::from_millis(200)).await;
353 })
354 .await;
355 }
356
357 barrier.wait().await;
358
359 let elapsed = start.elapsed();
360 assert!(
362 elapsed < Duration::from_millis(250),
363 "Elapsed time was too high: {elapsed:?}"
364 );
365 }
366
367 #[tokio::test]
368 async fn test_async_pool_executes_all_tasks_concurrently_with_multiple_threads() {
369 let pool = AsyncPoolBuilder::new(Handle::current())
370 .num_threads(2)
371 .max_concurrency(1)
372 .build()
373 .unwrap();
374
375 let start = Instant::now();
376 let barrier = TestBarrier::new(2).await;
377
378 for _ in 0..2 {
380 barrier
381 .spawn(&pool, || async {
382 sleep(Duration::from_millis(200)).await;
383 })
384 .await;
385 }
386
387 barrier.wait().await;
388
389 let elapsed = start.elapsed();
390 assert!(
392 elapsed < Duration::from_millis(250),
393 "Elapsed time was too high: {elapsed:?}"
394 );
395 }
396
397 #[test]
398 fn test_thread_panic_handling() {
399 let runtime = Runtime::new().unwrap();
400 let has_panicked = Arc::new(AtomicBool::new(false));
401 let has_panicked_clone = has_panicked.clone();
402 let panic_handler = move |_| {
403 has_panicked_clone.store(true, Ordering::SeqCst);
404 };
405
406 Thread {
407 id: 0,
408 max_concurrency: 1,
409 name: Some("test-thread".into()),
410 runtime: runtime.handle().clone(),
411 panic_handler: Some(Arc::new(panic_handler)),
412 task: async move {
413 panic!("panicked");
414 }
415 .boxed(),
416 }
417 .run();
418
419 assert!(has_panicked.load(Ordering::SeqCst));
420 }
421
422 #[tokio::test]
423 async fn test_spawn_panics_if_no_threads_are_available() {
424 let pool = AsyncPoolBuilder::new(Handle::current())
425 .num_threads(0)
426 .max_concurrency(1)
427 .build()
428 .unwrap();
429
430 let result = std::panic::catch_unwind(AssertUnwindSafe(|| {
431 pool.spawn(async move {});
432 }));
433
434 assert!(result.is_err());
435 }
436}