relay_server/utils/
thread_pool.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
use std::sync::Arc;
use std::thread;
use tokio::runtime::Handle;

pub use rayon::{ThreadPool, ThreadPoolBuildError};
use tokio::sync::Semaphore;

/// A thread kind.
///
/// The thread kind has an effect on how threads are prioritized and scheduled.
#[derive(Default, Debug, Clone, Copy)]
pub enum ThreadKind {
    /// The default kind, just a thread like any other without any special configuration.
    #[default]
    Default,
    /// A worker thread is a CPU intensive task with a lower priority than the [`Self::Default`] kind.
    Worker,
}

/// Used to create a new [`ThreadPool`] thread pool.
pub struct ThreadPoolBuilder {
    name: &'static str,
    runtime: Option<Handle>,
    num_threads: usize,
    kind: ThreadKind,
}

impl ThreadPoolBuilder {
    /// Creates a new named thread pool builder.
    pub fn new(name: &'static str) -> Self {
        Self {
            name,
            runtime: None,
            num_threads: 0,
            kind: ThreadKind::Default,
        }
    }

    /// Sets the number of threads to be used in the rayon thread-pool.
    ///
    /// See also [`rayon::ThreadPoolBuilder::num_threads`].
    pub fn num_threads(mut self, num_threads: usize) -> Self {
        self.num_threads = num_threads;
        self
    }

    /// Configures the [`ThreadKind`] for all threads spawned in the pool.
    pub fn thread_kind(mut self, kind: ThreadKind) -> Self {
        self.kind = kind;
        self
    }

    /// Sets the Tokio runtime which will be made available in the workers.
    pub fn runtime(mut self, runtime: Handle) -> Self {
        self.runtime = Some(runtime);
        self
    }

    /// Creates and returns the thread pool.
    pub fn build(self) -> Result<ThreadPool, ThreadPoolBuildError> {
        rayon::ThreadPoolBuilder::new()
            .num_threads(self.num_threads)
            .thread_name(move |id| format!("pool-{name}-{id}", name = self.name))
            // In case of panic, log that there was a panic but keep the thread alive and don't
            // exist.
            .panic_handler(move |_panic| {
                relay_log::error!("thread in pool {name} paniced!", name = self.name)
            })
            .spawn_handler(|thread| {
                let mut b = thread::Builder::new();
                if let Some(name) = thread.name() {
                    b = b.name(name.to_owned());
                }
                if let Some(stack_size) = thread.stack_size() {
                    b = b.stack_size(stack_size);
                }
                let runtime = self.runtime.clone();
                b.spawn(move || {
                    set_current_thread_priority(self.kind);
                    let _guard = runtime.as_ref().map(|runtime| runtime.enter());
                    thread.run()
                })?;
                Ok(())
            })
            .build()
    }
}

/// A [`WorkerGroup`] adds an async back-pressure mechanism to a [`ThreadPool`].
pub struct WorkerGroup {
    pool: ThreadPool,
    semaphore: Arc<Semaphore>,
}

impl WorkerGroup {
    /// Creates a new worker group from a thread pool.
    pub fn new(pool: ThreadPool) -> Self {
        // Use `current_num_threads() * 2` to guarantee all threads immediately have a new item to work on.
        let semaphore = Arc::new(Semaphore::new(pool.current_num_threads() * 2));
        Self { pool, semaphore }
    }

    /// Spawns an asynchronous task on the thread pool.
    ///
    /// If the thread pool is saturated the returned future is pending until
    /// the thread pool has capacity to work on the task.
    ///
    /// # Examples:
    ///
    /// ```ignore
    /// # async fn test(mut messages: tokio::sync::mpsc::Receiver<()>) {
    /// # use relay_server::utils::{WorkerGroup, ThreadPoolBuilder};
    /// # use std::thread;
    /// # use std::time::Duration;
    /// # let pool = ThreadPoolBuilder::new("test").num_threads(1).build().unwrap();
    /// let workers = WorkerGroup::new(pool);
    ///
    /// while let Some(message) = messages.recv().await {
    ///     workers.spawn(move || {
    ///         thread::sleep(Duration::from_secs(1));
    ///         println!("worked on message {message:?}")
    ///     }).await;
    /// }
    /// # }
    /// ```
    pub async fn spawn(&self, op: impl FnOnce() + Send + 'static) {
        let semaphore = Arc::clone(&self.semaphore);
        let permit = semaphore
            .acquire_owned()
            .await
            .expect("the semaphore is never closed");

        self.pool.spawn(move || {
            op();
            drop(permit);
        });
    }
}

#[cfg(unix)]
fn set_current_thread_priority(kind: ThreadKind) {
    // Lower priorities cause more favorable scheduling.
    // Higher priorities cause less favorable scheduling.
    //
    // The relative niceness between threads determines their relative
    // priority. The formula to map a nice value to a weight is approximately
    // `1024 / (1.25 ^ nice)`.
    //
    // More information can be found:
    //  - https://www.kernel.org/doc/Documentation/scheduler/sched-nice-design.txt
    //  - https://oakbytes.wordpress.com/2012/06/06/linux-scheduler-cfs-and-nice/
    //  - `man setpriority(2)`
    let prio = match kind {
        // The default priority needs no change, and defaults to `0`.
        ThreadKind::Default => return,
        // Set a priority of `10` for worker threads.
        ThreadKind::Worker => 10,
    };
    if unsafe { libc::setpriority(libc::PRIO_PROCESS, 0, prio) } != 0 {
        // Clear the `errno` and log it.
        let error = std::io::Error::last_os_error();
        relay_log::warn!(
            error = &error as &dyn std::error::Error,
            "failed to set thread priority for a {kind:?} thread: {error:?}"
        );
    };
}

#[cfg(not(unix))]
fn set_current_thread_priority(_kind: ThreadKind) {
    // Ignored for non-Unix platforms.
}

#[cfg(test)]
mod tests {
    use std::sync::Barrier;
    use std::time::Duration;

    use futures::FutureExt;

    use super::*;

    #[test]
    fn test_thread_pool_num_threads() {
        let pool = ThreadPoolBuilder::new("s").num_threads(3).build().unwrap();
        assert_eq!(pool.current_num_threads(), 3);
    }

    #[test]
    fn test_thread_pool_runtime() {
        let rt = tokio::runtime::Runtime::new().unwrap();

        let pool = ThreadPoolBuilder::new("s")
            .num_threads(1)
            .runtime(rt.handle().clone())
            .build()
            .unwrap();

        let has_runtime = pool.install(|| tokio::runtime::Handle::try_current().is_ok());
        assert!(has_runtime);
    }

    #[test]
    fn test_thread_pool_no_runtime() {
        let pool = ThreadPoolBuilder::new("s").num_threads(1).build().unwrap();

        let has_runtime = pool.install(|| tokio::runtime::Handle::try_current().is_ok());
        assert!(!has_runtime);
    }

    #[test]
    fn test_thread_pool_panic() {
        let pool = ThreadPoolBuilder::new("s").num_threads(1).build().unwrap();
        let barrier = Arc::new(Barrier::new(2));

        pool.spawn({
            let barrier = Arc::clone(&barrier);
            move || {
                barrier.wait();
                panic!();
            }
        });
        barrier.wait();

        pool.spawn({
            let barrier = Arc::clone(&barrier);
            move || {
                barrier.wait();
            }
        });
        barrier.wait();
    }

    #[test]
    #[cfg(unix)]
    fn test_thread_pool_priority() {
        fn get_current_priority() -> i32 {
            unsafe { libc::getpriority(libc::PRIO_PROCESS, 0) }
        }

        let default_prio = get_current_priority();

        {
            let pool = ThreadPoolBuilder::new("s").num_threads(1).build().unwrap();
            let prio = pool.install(get_current_priority);
            // Default pool priority must match current priority.
            assert_eq!(prio, default_prio);
        }

        {
            let pool = ThreadPoolBuilder::new("s")
                .num_threads(1)
                .thread_kind(ThreadKind::Worker)
                .build()
                .unwrap();
            let prio = pool.install(get_current_priority);
            // Worker must be higher than the default priority (higher number = lower priority).
            assert!(prio > default_prio);
        }
    }

    #[test]
    fn test_worker_group_backpressure() {
        let pool = ThreadPoolBuilder::new("s").num_threads(1).build().unwrap();
        let workers = WorkerGroup::new(pool);

        // Num Threads * 2 is the limit after backpressure kicks in
        let barrier = Arc::new(Barrier::new(2));

        let spawn = || {
            let barrier = Arc::clone(&barrier);
            workers
                .spawn(move || {
                    barrier.wait();
                })
                .now_or_never()
                .is_some()
        };

        for _ in 0..15 {
            // Pool should accept two immediately.
            assert!(spawn());
            assert!(spawn());
            // Pool should reject because there are already 2 tasks active.
            assert!(!spawn());

            // Unblock the barrier
            barrier.wait(); // first spawn
            barrier.wait(); // second spawn

            // wait a tiny bit to make sure the semaphore handle is dropped
            thread::sleep(Duration::from_millis(50));
        }
    }
}