relay_system/service/
concurrent.rs

1use futures::future::BoxFuture;
2use futures::stream::FuturesUnordered;
3use futures::{FutureExt, StreamExt};
4
5use crate::Service;
6use crate::service::simple::SimpleService;
7use crate::statsd::SystemGauges;
8
9/// A service that handles messages concurrently.
10///
11/// When the service reaches its maximum concurrency, it either drops messages
12/// or keeps them in the input queue.
13///
14/// ```
15/// use relay_system::{Interface, SimpleService, LoadShed, ConcurrentService};
16///
17/// #[derive(Clone)]
18/// struct MyService;
19///
20/// struct MyMessage;
21/// impl Interface for MyMessage {}
22///
23/// impl SimpleService for MyService {
24///     type Interface = MyMessage;
25///     async fn handle_message(&self, message: MyMessage) {
26///         // do your thing
27///     }
28/// }
29///
30/// // `Loadshed` implementation is required but can be empty.
31/// impl LoadShed<MyMessage> for MyService {
32///     fn handle_loadshed(&self, _: MyMessage) {
33///         eprintln!("Dropped a message!");
34///     }
35/// }
36///
37/// let concurrent_service = ConcurrentService::new(MyService).with_concurrency_limit(5);
38/// ```
39pub struct ConcurrentService<S>
40where
41    S: SimpleService + Clone + Send + Sync,
42{
43    inner: S,
44    max_concurrency: usize,
45    max_backlog: usize,
46    pending: FuturesUnordered<BoxFuture<'static, ()>>,
47}
48
49impl<S> ConcurrentService<S>
50where
51    S: SimpleService + Clone + Send + Sync,
52{
53    /// Creates a new concurrent service from a [`SimpleService`].
54    ///
55    /// The default strategy for congestion control is to keep messages in the input queue.
56    pub fn new(inner: S) -> Self {
57        Self {
58            inner,
59            max_concurrency: usize::MAX,
60            max_backlog: usize::MAX,
61            pending: FuturesUnordered::new(),
62        }
63    }
64
65    /// Sets the maximum number of messages that can be handled concurrently.
66    pub fn with_concurrency_limit(mut self, limit: usize) -> Self {
67        self.max_concurrency = limit;
68        self
69    }
70
71    /// Limits the amount of messages that wait in the queue by loadshedding.
72    ///
73    /// Setting this limit will cause message loss.
74    ///
75    /// Note that cleanup of the queue may be deferred until the next pending
76    /// future completes.
77    pub fn with_backlog_limit(mut self, limit: usize) -> Self {
78        self.max_backlog = limit;
79        self
80    }
81}
82
83impl<S> Service for ConcurrentService<S>
84where
85    S: SimpleService + LoadShed<S::Interface> + Clone + Send + Sync + 'static,
86{
87    type Interface = S::Interface;
88
89    async fn run(mut self, mut rx: super::Receiver<Self::Interface>) {
90        loop {
91            relay_log::trace!("Concurrent service loop iteration");
92
93            let has_capacity = self.pending.len() < self.max_concurrency;
94            let should_consume = has_capacity || {
95                let backlog = rx.queue_size.load(std::sync::atomic::Ordering::Relaxed);
96                backlog > self.max_backlog as u64
97            };
98
99            tokio::select! {
100                // Bias towards handling responses so that there's space for new incoming requests.
101                biased;
102
103                Some(_) = self.pending.next() => {},
104                Some(message) = rx.recv(), if should_consume => {
105                    if has_capacity {
106                        let inner = self.inner.clone();
107                        self.pending
108                            .push(async move { inner.handle_message(message).await }.boxed());
109                    } else {
110                        self.inner.handle_loadshed(message);
111                    }
112                },
113                else => break,
114            }
115
116            relay_statsd::metric!(
117                gauge(SystemGauges::ServiceConcurrency) = self.pending.len() as u64,
118                service = Self::name()
119            );
120        }
121    }
122}
123
124/// A trait describing what to do with a message that was load-shed.
125pub trait LoadShed<T> {
126    /// Gets called for every message that gets dropped by loadshedding.
127    fn handle_loadshed(&self, _message: T);
128}
129
130#[cfg(test)]
131mod tests {
132    use std::sync::Arc;
133    use std::sync::atomic::{AtomicUsize, Ordering};
134    use std::time::Duration;
135
136    use crate::{FromMessage, Interface, NoResponse};
137
138    use super::*;
139
140    #[derive(Clone)]
141    struct CountingService {
142        success: Arc<AtomicUsize>,
143        fail: Arc<AtomicUsize>,
144    }
145    struct Incr;
146    impl Interface for Incr {}
147    impl FromMessage<()> for Incr {
148        type Response = NoResponse;
149        fn from_message(_message: (), _sender: ()) -> Self {
150            Self
151        }
152    }
153    impl SimpleService for CountingService {
154        type Interface = Incr;
155
156        async fn handle_message(&self, _message: Incr) {
157            tokio::time::sleep(Duration::from_secs(2)).await;
158            self.success.fetch_add(1, Ordering::Relaxed);
159        }
160    }
161
162    impl LoadShed<Incr> for CountingService {
163        fn handle_loadshed(&self, _message: Incr) {
164            self.fail.fetch_add(1, Ordering::Relaxed);
165        }
166    }
167
168    #[tokio::test(start_paused = true)]
169    async fn loadshed() {
170        let inner = CountingService {
171            success: Arc::new(AtomicUsize::new(0)),
172            fail: Arc::new(AtomicUsize::new(0)),
173        };
174        let service = ConcurrentService::new(inner.clone())
175            .with_concurrency_limit(5)
176            .with_backlog_limit(0);
177        let addr = service.start_detached();
178
179        for _ in 0..10 {
180            addr.send(());
181        }
182
183        assert_eq!(inner.success.load(Ordering::Relaxed), 0);
184        assert_eq!(inner.fail.load(Ordering::Relaxed), 0);
185
186        tokio::time::sleep(Duration::from_secs(10)).await;
187
188        // Only half of the items have been processed
189        assert_eq!(inner.success.load(Ordering::Relaxed), 5);
190        assert_eq!(inner.fail.load(Ordering::Relaxed), 5);
191    }
192
193    #[tokio::test(start_paused = true)]
194    async fn backpressure() {
195        let inner = CountingService {
196            success: Arc::new(AtomicUsize::new(0)),
197            fail: Arc::new(AtomicUsize::new(0)),
198        };
199        let service = ConcurrentService::new(inner.clone()).with_concurrency_limit(5);
200        let addr = service.start_detached();
201
202        for _ in 0..10 {
203            addr.send(());
204        }
205
206        // After 3 seconds, only half the messages have been handled:
207        tokio::time::sleep(Duration::from_secs(3)).await;
208        assert_eq!(inner.success.load(Ordering::Relaxed), 5);
209        assert_eq!(inner.fail.load(Ordering::Relaxed), 0);
210
211        // After 5 seconds, everything's been handled:
212        tokio::time::sleep(Duration::from_secs(2)).await;
213        assert_eq!(inner.success.load(Ordering::Relaxed), 10);
214        assert_eq!(inner.fail.load(Ordering::Relaxed), 0);
215    }
216
217    #[tokio::test(start_paused = true)]
218    async fn backpressure_and_loadshed() {
219        let inner = CountingService {
220            success: Arc::new(AtomicUsize::new(0)),
221            fail: Arc::new(AtomicUsize::new(0)),
222        };
223        let service = ConcurrentService::new(inner.clone())
224            .with_concurrency_limit(5)
225            .with_backlog_limit(5);
226        let addr = service.start_detached();
227
228        for _ in 0..13 {
229            addr.send(());
230        }
231
232        // After 3 seconds, only 5 messages have been handled.
233        // Three have been dropped due to loadshedding.
234        tokio::time::sleep(Duration::from_secs(3)).await;
235        assert_eq!(inner.success.load(Ordering::Relaxed), 5);
236        assert_eq!(inner.fail.load(Ordering::Relaxed), 3);
237
238        // After 5 seconds, another 5 messages got handled:
239        tokio::time::sleep(Duration::from_secs(2)).await;
240        assert_eq!(inner.success.load(Ordering::Relaxed), 10);
241        assert_eq!(inner.fail.load(Ordering::Relaxed), 3);
242    }
243}