relay_server/utils/
thread_pool.rs1use std::future::Future;
2use std::{io, thread};
3
4use tokio::runtime::Handle;
5
6use relay_threading::{AsyncPool, AsyncPoolBuilder};
7
8#[derive(Default, Debug, Clone, Copy)]
12pub enum ThreadKind {
13 #[default]
15 Default,
16 Worker,
18}
19
20pub struct ThreadPoolBuilder {
22 name: &'static str,
23 runtime: Handle,
24 num_threads: usize,
25 max_concurrency: usize,
26 kind: ThreadKind,
27}
28
29impl ThreadPoolBuilder {
30 pub fn new(name: &'static str, runtime: Handle) -> Self {
32 Self {
33 name,
34 runtime,
35 num_threads: 0,
36 max_concurrency: 1,
37 kind: ThreadKind::Default,
38 }
39 }
40
41 pub fn num_threads(mut self, num_threads: usize) -> Self {
45 self.num_threads = num_threads;
46 self
47 }
48
49 pub fn max_concurrency(mut self, max_concurrency: usize) -> Self {
53 self.max_concurrency = max_concurrency;
54 self
55 }
56
57 pub fn thread_kind(mut self, kind: ThreadKind) -> Self {
59 self.kind = kind;
60 self
61 }
62
63 pub fn build<F>(self) -> Result<AsyncPool<F>, io::Error>
65 where
66 F: Future<Output = ()> + Send + 'static,
67 {
68 AsyncPoolBuilder::new(self.runtime)
69 .pool_name(self.name)
70 .thread_name(move |id| format!("pool-{name}-{id}", name = self.name))
71 .num_threads(self.num_threads)
72 .max_concurrency(self.max_concurrency)
73 .task_panic_handler(move |_panic| {
76 relay_log::error!(
77 "task in pool {name} panicked, other tasks will continue execution",
78 name = self.name
79 );
80 })
81 .thread_panic_handler(move |panic| {
83 relay_log::error!("thread in pool {name} panicked", name = self.name);
84 std::panic::resume_unwind(panic);
85 })
86 .spawn_handler(|thread| {
87 let mut b = thread::Builder::new();
88 if let Some(name) = thread.name() {
89 b = b.name(name.to_owned());
90 }
91 b.spawn(move || {
92 set_current_thread_priority(self.kind);
93 thread.run()
94 })?;
95
96 Ok(())
97 })
98 .build()
99 }
100}
101
102#[cfg(unix)]
103fn set_current_thread_priority(kind: ThreadKind) {
104 let prio = match kind {
116 ThreadKind::Default => return,
118 ThreadKind::Worker => 10,
120 };
121 if unsafe { libc::setpriority(libc::PRIO_PROCESS, 0, prio) } != 0 {
122 let error = std::io::Error::last_os_error();
124 relay_log::warn!(
125 error = &error as &dyn std::error::Error,
126 "failed to set thread priority for a {kind:?} thread: {error:?}"
127 );
128 };
129}
130
131#[cfg(not(unix))]
132fn set_current_thread_priority(_kind: ThreadKind) {
133 }
135
136#[cfg(test)]
137mod tests {
138 use crate::utils::{ThreadKind, ThreadPoolBuilder};
139 use futures::FutureExt;
140 use std::sync::Arc;
141 use std::sync::atomic::{AtomicI32, Ordering};
142 use tokio::runtime::Handle;
143 use tokio::sync::Barrier;
144
145 #[tokio::test]
146 async fn test_thread_pool_panic() {
147 let pool = ThreadPoolBuilder::new("s", Handle::current())
148 .num_threads(1)
149 .build()
150 .unwrap();
151 let barrier = Arc::new(Barrier::new(2));
152
153 let barrier_clone = barrier.clone();
154 pool.spawn(
155 async move {
156 barrier_clone.wait().await;
157 panic!();
158 }
159 .boxed(),
160 );
161 barrier.wait().await;
162
163 let barrier_clone = barrier.clone();
164 pool.spawn(
165 async move {
166 barrier_clone.wait().await;
167 }
168 .boxed(),
169 );
170 barrier.wait().await;
171 }
172
173 #[tokio::test]
174 #[cfg(unix)]
175 async fn test_thread_pool_priority() {
176 fn get_current_priority() -> i32 {
177 unsafe { libc::getpriority(libc::PRIO_PROCESS, 0) }
178 }
179
180 let default_priority = get_current_priority();
181
182 {
183 let pool = ThreadPoolBuilder::new("s", Handle::current())
184 .num_threads(1)
185 .build()
186 .unwrap();
187
188 let barrier = Arc::new(Barrier::new(2));
189 let priority = Arc::new(AtomicI32::new(0));
190 let barrier_clone = barrier.clone();
191 let priority_clone = priority.clone();
192 pool.spawn(async move {
193 priority_clone.store(get_current_priority(), Ordering::SeqCst);
194 barrier_clone.wait().await;
195 });
196 barrier.wait().await;
197
198 assert_eq!(priority.load(Ordering::SeqCst), default_priority);
200 }
201
202 {
203 let pool = ThreadPoolBuilder::new("s", Handle::current())
204 .num_threads(1)
205 .thread_kind(ThreadKind::Worker)
206 .build()
207 .unwrap();
208
209 let barrier = Arc::new(Barrier::new(2));
210 let priority = Arc::new(AtomicI32::new(0));
211 let barrier_clone = barrier.clone();
212 let priority_clone = priority.clone();
213 pool.spawn(async move {
214 priority_clone.store(get_current_priority(), Ordering::SeqCst);
215 barrier_clone.wait().await;
216 });
217 barrier.wait().await;
218
219 assert!(priority.load(Ordering::SeqCst) > default_priority);
221 }
222 }
223}