relay_server/utils/thread_pool.rs
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295
use std::sync::Arc;
use std::thread;
use tokio::runtime::Handle;
pub use rayon::{ThreadPool, ThreadPoolBuildError};
use tokio::sync::Semaphore;
/// A thread kind.
///
/// The thread kind has an effect on how threads are prioritized and scheduled.
#[derive(Default, Debug, Clone, Copy)]
pub enum ThreadKind {
/// The default kind, just a thread like any other without any special configuration.
#[default]
Default,
/// A worker thread is a CPU intensive task with a lower priority than the [`Self::Default`] kind.
Worker,
}
/// Used to create a new [`ThreadPool`] thread pool.
pub struct ThreadPoolBuilder {
name: &'static str,
runtime: Option<Handle>,
num_threads: usize,
kind: ThreadKind,
}
impl ThreadPoolBuilder {
/// Creates a new named thread pool builder.
pub fn new(name: &'static str) -> Self {
Self {
name,
runtime: None,
num_threads: 0,
kind: ThreadKind::Default,
}
}
/// Sets the number of threads to be used in the rayon thread-pool.
///
/// See also [`rayon::ThreadPoolBuilder::num_threads`].
pub fn num_threads(mut self, num_threads: usize) -> Self {
self.num_threads = num_threads;
self
}
/// Configures the [`ThreadKind`] for all threads spawned in the pool.
pub fn thread_kind(mut self, kind: ThreadKind) -> Self {
self.kind = kind;
self
}
/// Sets the Tokio runtime which will be made available in the workers.
pub fn runtime(mut self, runtime: Handle) -> Self {
self.runtime = Some(runtime);
self
}
/// Creates and returns the thread pool.
pub fn build(self) -> Result<ThreadPool, ThreadPoolBuildError> {
rayon::ThreadPoolBuilder::new()
.num_threads(self.num_threads)
.thread_name(move |id| format!("pool-{name}-{id}", name = self.name))
// In case of panic, log that there was a panic but keep the thread alive and don't
// exist.
.panic_handler(move |_panic| {
relay_log::error!("thread in pool {name} paniced!", name = self.name)
})
.spawn_handler(|thread| {
let mut b = thread::Builder::new();
if let Some(name) = thread.name() {
b = b.name(name.to_owned());
}
if let Some(stack_size) = thread.stack_size() {
b = b.stack_size(stack_size);
}
let runtime = self.runtime.clone();
b.spawn(move || {
set_current_thread_priority(self.kind);
let _guard = runtime.as_ref().map(|runtime| runtime.enter());
thread.run()
})?;
Ok(())
})
.build()
}
}
/// A [`WorkerGroup`] adds an async back-pressure mechanism to a [`ThreadPool`].
pub struct WorkerGroup {
pool: ThreadPool,
semaphore: Arc<Semaphore>,
}
impl WorkerGroup {
/// Creates a new worker group from a thread pool.
pub fn new(pool: ThreadPool) -> Self {
// Use `current_num_threads() * 2` to guarantee all threads immediately have a new item to work on.
let semaphore = Arc::new(Semaphore::new(pool.current_num_threads() * 2));
Self { pool, semaphore }
}
/// Spawns an asynchronous task on the thread pool.
///
/// If the thread pool is saturated the returned future is pending until
/// the thread pool has capacity to work on the task.
///
/// # Examples:
///
/// ```ignore
/// # async fn test(mut messages: tokio::sync::mpsc::Receiver<()>) {
/// # use relay_server::utils::{WorkerGroup, ThreadPoolBuilder};
/// # use std::thread;
/// # use std::time::Duration;
/// # let pool = ThreadPoolBuilder::new("test").num_threads(1).build().unwrap();
/// let workers = WorkerGroup::new(pool);
///
/// while let Some(message) = messages.recv().await {
/// workers.spawn(move || {
/// thread::sleep(Duration::from_secs(1));
/// println!("worked on message {message:?}")
/// }).await;
/// }
/// # }
/// ```
pub async fn spawn(&self, op: impl FnOnce() + Send + 'static) {
let semaphore = Arc::clone(&self.semaphore);
let permit = semaphore
.acquire_owned()
.await
.expect("the semaphore is never closed");
self.pool.spawn(move || {
op();
drop(permit);
});
}
}
#[cfg(unix)]
fn set_current_thread_priority(kind: ThreadKind) {
// Lower priorities cause more favorable scheduling.
// Higher priorities cause less favorable scheduling.
//
// The relative niceness between threads determines their relative
// priority. The formula to map a nice value to a weight is approximately
// `1024 / (1.25 ^ nice)`.
//
// More information can be found:
// - https://www.kernel.org/doc/Documentation/scheduler/sched-nice-design.txt
// - https://oakbytes.wordpress.com/2012/06/06/linux-scheduler-cfs-and-nice/
// - `man setpriority(2)`
let prio = match kind {
// The default priority needs no change, and defaults to `0`.
ThreadKind::Default => return,
// Set a priority of `10` for worker threads.
ThreadKind::Worker => 10,
};
if unsafe { libc::setpriority(libc::PRIO_PROCESS, 0, prio) } != 0 {
// Clear the `errno` and log it.
let error = std::io::Error::last_os_error();
relay_log::warn!(
error = &error as &dyn std::error::Error,
"failed to set thread priority for a {kind:?} thread: {error:?}"
);
};
}
#[cfg(not(unix))]
fn set_current_thread_priority(_kind: ThreadKind) {
// Ignored for non-Unix platforms.
}
#[cfg(test)]
mod tests {
use std::sync::Barrier;
use std::time::Duration;
use futures::FutureExt;
use super::*;
#[test]
fn test_thread_pool_num_threads() {
let pool = ThreadPoolBuilder::new("s").num_threads(3).build().unwrap();
assert_eq!(pool.current_num_threads(), 3);
}
#[test]
fn test_thread_pool_runtime() {
let rt = tokio::runtime::Runtime::new().unwrap();
let pool = ThreadPoolBuilder::new("s")
.num_threads(1)
.runtime(rt.handle().clone())
.build()
.unwrap();
let has_runtime = pool.install(|| tokio::runtime::Handle::try_current().is_ok());
assert!(has_runtime);
}
#[test]
fn test_thread_pool_no_runtime() {
let pool = ThreadPoolBuilder::new("s").num_threads(1).build().unwrap();
let has_runtime = pool.install(|| tokio::runtime::Handle::try_current().is_ok());
assert!(!has_runtime);
}
#[test]
fn test_thread_pool_panic() {
let pool = ThreadPoolBuilder::new("s").num_threads(1).build().unwrap();
let barrier = Arc::new(Barrier::new(2));
pool.spawn({
let barrier = Arc::clone(&barrier);
move || {
barrier.wait();
panic!();
}
});
barrier.wait();
pool.spawn({
let barrier = Arc::clone(&barrier);
move || {
barrier.wait();
}
});
barrier.wait();
}
#[test]
#[cfg(unix)]
fn test_thread_pool_priority() {
fn get_current_priority() -> i32 {
unsafe { libc::getpriority(libc::PRIO_PROCESS, 0) }
}
let default_prio = get_current_priority();
{
let pool = ThreadPoolBuilder::new("s").num_threads(1).build().unwrap();
let prio = pool.install(get_current_priority);
// Default pool priority must match current priority.
assert_eq!(prio, default_prio);
}
{
let pool = ThreadPoolBuilder::new("s")
.num_threads(1)
.thread_kind(ThreadKind::Worker)
.build()
.unwrap();
let prio = pool.install(get_current_priority);
// Worker must be higher than the default priority (higher number = lower priority).
assert!(prio > default_prio);
}
}
#[test]
fn test_worker_group_backpressure() {
let pool = ThreadPoolBuilder::new("s").num_threads(1).build().unwrap();
let workers = WorkerGroup::new(pool);
// Num Threads * 2 is the limit after backpressure kicks in
let barrier = Arc::new(Barrier::new(2));
let spawn = || {
let barrier = Arc::clone(&barrier);
workers
.spawn(move || {
barrier.wait();
})
.now_or_never()
.is_some()
};
for _ in 0..15 {
// Pool should accept two immediately.
assert!(spawn());
assert!(spawn());
// Pool should reject because there are already 2 tasks active.
assert!(!spawn());
// Unblock the barrier
barrier.wait(); // first spawn
barrier.wait(); // second spawn
// wait a tiny bit to make sure the semaphore handle is dropped
thread::sleep(Duration::from_millis(50));
}
}
}