relay_server/utils/
thread_pool.rs

1use std::future::Future;
2use std::{io, thread};
3
4use tokio::runtime::Handle;
5
6use relay_threading::{AsyncPool, AsyncPoolBuilder};
7
8/// A thread kind.
9///
10/// The thread kind has an effect on how threads are prioritized and scheduled.
11#[derive(Default, Debug, Clone, Copy)]
12pub enum ThreadKind {
13    /// The default kind, just a thread like any other without any special configuration.
14    #[default]
15    Default,
16    /// A worker thread is a CPU intensive task with a lower priority than the [`Self::Default`] kind.
17    Worker,
18}
19
20/// Used to create a new [`AsyncPool`] thread pool.
21pub struct ThreadPoolBuilder {
22    name: &'static str,
23    runtime: Handle,
24    num_threads: usize,
25    max_concurrency: usize,
26    kind: ThreadKind,
27}
28
29impl ThreadPoolBuilder {
30    /// Creates a new named thread pool builder.
31    pub fn new(name: &'static str, runtime: Handle) -> Self {
32        Self {
33            name,
34            runtime,
35            num_threads: 0,
36            max_concurrency: 1,
37            kind: ThreadKind::Default,
38        }
39    }
40
41    /// Sets the number of threads to be used in the pool.
42    ///
43    /// See also [`AsyncPoolBuilder::num_threads`].
44    pub fn num_threads(mut self, num_threads: usize) -> Self {
45        self.num_threads = num_threads;
46        self
47    }
48
49    /// Sets the maximum number of tasks that can run concurrently per thread.
50    ///
51    /// See also [`AsyncPoolBuilder::max_concurrency`].
52    pub fn max_concurrency(mut self, max_concurrency: usize) -> Self {
53        self.max_concurrency = max_concurrency;
54        self
55    }
56
57    /// Configures the [`ThreadKind`] for all threads spawned in the pool.
58    pub fn thread_kind(mut self, kind: ThreadKind) -> Self {
59        self.kind = kind;
60        self
61    }
62
63    /// Creates and returns the thread pool.
64    pub fn build<F>(self) -> Result<AsyncPool<F>, io::Error>
65    where
66        F: Future<Output = ()> + Send + 'static,
67    {
68        AsyncPoolBuilder::new(self.runtime)
69            .pool_name(self.name)
70            .thread_name(move |id| format!("pool-{name}-{id}", name = self.name))
71            .num_threads(self.num_threads)
72            .max_concurrency(self.max_concurrency)
73            // In case of panic in a task sent to the pool, we catch it to continue the remaining
74            // work and just log an error.
75            .task_panic_handler(move |_panic| {
76                relay_log::error!(
77                    "task in pool {name} panicked, other tasks will continue execution",
78                    name = self.name
79                );
80            })
81            // In case of panic in the thread, log it. After a panic in the thread, it will stop.
82            .thread_panic_handler(move |panic| {
83                relay_log::error!("thread in pool {name} panicked", name = self.name);
84                std::panic::resume_unwind(panic);
85            })
86            .spawn_handler(|thread| {
87                let mut b = thread::Builder::new();
88                if let Some(name) = thread.name() {
89                    b = b.name(name.to_owned());
90                }
91                b.spawn(move || {
92                    set_current_thread_priority(self.kind);
93                    thread.run()
94                })?;
95
96                Ok(())
97            })
98            .build()
99    }
100}
101
102#[cfg(unix)]
103fn set_current_thread_priority(kind: ThreadKind) {
104    // Lower priorities cause more favorable scheduling.
105    // Higher priorities cause less favorable scheduling.
106    //
107    // The relative niceness between threads determines their relative
108    // priority. The formula to map a nice value to a weight is approximately
109    // `1024 / (1.25 ^ nice)`.
110    //
111    // More information can be found:
112    //  - https://www.kernel.org/doc/Documentation/scheduler/sched-nice-design.txt
113    //  - https://oakbytes.wordpress.com/2012/06/06/linux-scheduler-cfs-and-nice/
114    //  - `man setpriority(2)`
115    let prio = match kind {
116        // The default priority needs no change, and defaults to `0`.
117        ThreadKind::Default => return,
118        // Set a priority of `10` for worker threads.
119        ThreadKind::Worker => 10,
120    };
121    if unsafe { libc::setpriority(libc::PRIO_PROCESS, 0, prio) } != 0 {
122        // Clear the `errno` and log it.
123        let error = std::io::Error::last_os_error();
124        relay_log::warn!(
125            error = &error as &dyn std::error::Error,
126            "failed to set thread priority for a {kind:?} thread: {error:?}"
127        );
128    };
129}
130
131#[cfg(not(unix))]
132fn set_current_thread_priority(_kind: ThreadKind) {
133    // Ignored for non-Unix platforms.
134}
135
136#[cfg(test)]
137mod tests {
138    use crate::utils::{ThreadKind, ThreadPoolBuilder};
139    use futures::FutureExt;
140    use std::sync::Arc;
141    use std::sync::atomic::{AtomicI32, Ordering};
142    use tokio::runtime::Handle;
143    use tokio::sync::Barrier;
144
145    #[tokio::test]
146    async fn test_thread_pool_panic() {
147        let pool = ThreadPoolBuilder::new("s", Handle::current())
148            .num_threads(1)
149            .build()
150            .unwrap();
151        let barrier = Arc::new(Barrier::new(2));
152
153        let barrier_clone = barrier.clone();
154        pool.spawn(
155            async move {
156                barrier_clone.wait().await;
157                panic!();
158            }
159            .boxed(),
160        );
161        barrier.wait().await;
162
163        let barrier_clone = barrier.clone();
164        pool.spawn(
165            async move {
166                barrier_clone.wait().await;
167            }
168            .boxed(),
169        );
170        barrier.wait().await;
171    }
172
173    #[tokio::test]
174    #[cfg(unix)]
175    async fn test_thread_pool_priority() {
176        fn get_current_priority() -> i32 {
177            unsafe { libc::getpriority(libc::PRIO_PROCESS, 0) }
178        }
179
180        let default_priority = get_current_priority();
181
182        {
183            let pool = ThreadPoolBuilder::new("s", Handle::current())
184                .num_threads(1)
185                .build()
186                .unwrap();
187
188            let barrier = Arc::new(Barrier::new(2));
189            let priority = Arc::new(AtomicI32::new(0));
190            let barrier_clone = barrier.clone();
191            let priority_clone = priority.clone();
192            pool.spawn(async move {
193                priority_clone.store(get_current_priority(), Ordering::SeqCst);
194                barrier_clone.wait().await;
195            });
196            barrier.wait().await;
197
198            // Default pool priority must match current priority.
199            assert_eq!(priority.load(Ordering::SeqCst), default_priority);
200        }
201
202        {
203            let pool = ThreadPoolBuilder::new("s", Handle::current())
204                .num_threads(1)
205                .thread_kind(ThreadKind::Worker)
206                .build()
207                .unwrap();
208
209            let barrier = Arc::new(Barrier::new(2));
210            let priority = Arc::new(AtomicI32::new(0));
211            let barrier_clone = barrier.clone();
212            let priority_clone = priority.clone();
213            pool.spawn(async move {
214                priority_clone.store(get_current_priority(), Ordering::SeqCst);
215                barrier_clone.wait().await;
216            });
217            barrier.wait().await;
218
219            // Worker must be higher than the default priority (higher number = lower priority).
220            assert!(priority.load(Ordering::SeqCst) > default_priority);
221        }
222    }
223}