objectstore_server/
rate_limits.rs

1use std::pin::Pin;
2use std::sync::Arc;
3use std::sync::Mutex;
4use std::sync::atomic::AtomicU64;
5use std::task::{Context, Poll};
6use std::time::{Duration, Instant};
7
8use bytes::Bytes;
9use futures_util::Stream;
10use objectstore_service::PayloadStream;
11use objectstore_service::id::ObjectContext;
12use objectstore_types::scope::Scopes;
13use serde::{Deserialize, Serialize};
14
15/// Rate limits for objectstore.
16#[derive(Clone, Debug, Default, Deserialize, Serialize, PartialEq)]
17pub struct RateLimits {
18    /// Limits the number of requests per second per service instance.
19    pub throughput: ThroughputLimits,
20    /// Limits the concurrent bandwidth per service instance.
21    pub bandwidth: BandwidthLimits,
22}
23
24#[derive(Clone, Debug, Default, Deserialize, Serialize, PartialEq)]
25pub struct ThroughputLimits {
26    /// The overall maximum number of requests per second per service instance.
27    ///
28    /// Defaults to `None`, meaning no global rate limit is enforced.
29    pub global_rps: Option<u32>,
30
31    /// The maximum burst for each rate limit.
32    ///
33    /// Defaults to `0`, meaning no bursting is allowed. If set to a value greater than `0`,
34    /// short spikes above the rate limit are allowed up to the burst size.
35    pub burst: u32,
36
37    /// The maximum percentage of the global rate limit that can be used by any usecase.
38    ///
39    /// Value from `0` to `100`. Defaults to `None`, meaning no per-usecase limit is enforced.
40    pub usecase_pct: Option<u8>,
41
42    /// The maximum percentage of the global rate limit that can be used by any scope.
43    ///
44    /// This treats each full scope separately and applies across all use cases:
45    ///  - Two requests with exact same scopes count against the same limit regardless of use case.
46    ///  - Two requests that share the same top scope but differ in inner scopes count separately.
47    ///
48    /// Value from `0` to `100`. Defaults to `None`, meaning no per-scope limit is enforced.
49    pub scope_pct: Option<u8>,
50
51    /// Overrides for specific usecases and scopes.
52    pub rules: Vec<ThroughputRule>,
53}
54
55#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
56pub struct ThroughputRule {
57    /// Optional usecase to match.
58    ///
59    /// If `None`, matches any usecase.
60    pub usecase: Option<String>,
61
62    /// Scopes to match.
63    ///
64    /// If empty, matches any scopes. Additional scopes in the context are ignored, so a rule
65    /// matches if all of the specified scopes are present in the request with matching values.
66    pub scopes: Vec<(String, String)>,
67
68    /// The rate limit to apply when this rule matches.
69    ///
70    /// If both a rate and pct are specified, the more restrictive limit applies.
71    /// Should be greater than `0`. To block traffic entirely, use killswitches instead.
72    pub rps: Option<u32>,
73
74    /// The percentage of the global rate limit to apply when this rule matches.
75    ///
76    /// If both a rate and pct are specified, the more restrictive limit applies.
77    /// Should be greater than `0`. To block traffic entirely, use killswitches instead.
78    pub pct: Option<u8>,
79}
80
81impl ThroughputRule {
82    /// Returns `true` if this rule matches the given context.
83    pub fn matches(&self, context: &ObjectContext) -> bool {
84        if let Some(ref rule_usecase) = self.usecase
85            && rule_usecase != &context.usecase
86        {
87            return false;
88        }
89
90        for (scope_name, scope_value) in &self.scopes {
91            match context.scopes.get_value(scope_name) {
92                Some(value) if value == scope_value => (),
93                _ => return false,
94            }
95        }
96
97        true
98    }
99}
100
101#[derive(Clone, Debug, Default, Deserialize, Serialize, PartialEq)]
102pub struct BandwidthLimits {
103    /// The overall maximum bandwidth (in bytes per second) per service instance.
104    ///
105    /// Defaults to `None`, meaning no global bandwidth limit is enforced.
106    pub global_bps: Option<u64>,
107}
108
109#[derive(Debug)]
110pub struct RateLimiter {
111    bandwidth: BandwidthRateLimiter,
112    throughput: ThroughputRateLimiter,
113}
114
115impl RateLimiter {
116    pub fn new(config: RateLimits) -> Self {
117        Self {
118            bandwidth: BandwidthRateLimiter::new(config.bandwidth),
119            throughput: ThroughputRateLimiter::new(config.throughput),
120        }
121    }
122
123    /// Checks if the given context is within the rate limits.
124    ///
125    /// Returns `true` if the context is within the rate limits, `false` otherwise.
126    pub fn check(&self, context: &ObjectContext) -> bool {
127        self.throughput.check(context) && self.bandwidth.check()
128    }
129
130    /// Returns a reference to the shared bytes accumulator, used for bandwidth-based rate-limiting.
131    pub fn bytes_accumulator(&self) -> Arc<AtomicU64> {
132        Arc::clone(&self.bandwidth.accumulator)
133    }
134}
135
136#[derive(Debug)]
137struct BandwidthRateLimiter {
138    config: BandwidthLimits,
139    /// Accumulator that's incremented every time an operation that uses bandwidth is executed.
140    accumulator: Arc<AtomicU64>,
141    /// An estimate of the bandwidth that's currently being utilized in bytes per second.
142    estimate: Arc<AtomicU64>,
143}
144
145impl BandwidthRateLimiter {
146    fn new(config: BandwidthLimits) -> Self {
147        let accumulator = Arc::new(AtomicU64::new(0));
148        let estimate = Arc::new(AtomicU64::new(0));
149
150        let accumulator_clone = Arc::clone(&accumulator);
151        let estimate_clone = Arc::clone(&estimate);
152        tokio::task::spawn(async move {
153            Self::estimator(accumulator_clone, estimate_clone).await;
154        });
155
156        Self {
157            config,
158            accumulator,
159            estimate,
160        }
161    }
162
163    /// Estimates the current bandwidth utilization using an exponentially weighted moving average.
164    ///
165    /// The calculation is based on the increments of `self.accumulator` happened in the last `TICK`.
166    /// The estimate is stored in `self.estimate`, which can be queried for bandwidth-based rate-limiting.
167    async fn estimator(accumulator: Arc<AtomicU64>, estimate: Arc<AtomicU64>) {
168        const TICK: Duration = Duration::from_millis(50); // Recompute EWMA on every TICK
169        const ALPHA: f64 = 0.2; // EWMA alpha parameter: 20% weight to new sample, 80% to previous average
170
171        let mut interval = tokio::time::interval(TICK);
172        let to_bps = 1.0 / TICK.as_secs_f64(); // Conversion factor from bytes to bps
173        let mut ewma: f64 = 0.0;
174        loop {
175            interval.tick().await;
176            let current = accumulator.swap(0, std::sync::atomic::Ordering::Relaxed);
177            let bps = (current as f64) * to_bps;
178            ewma = ALPHA * bps + (1.0 - ALPHA) * ewma;
179
180            let ewma_int = ewma.floor() as u64;
181            estimate.store(ewma_int, std::sync::atomic::Ordering::Relaxed);
182            merni::gauge!("server.bandwidth.ewma"@b: ewma_int);
183        }
184    }
185
186    fn check(&self) -> bool {
187        let Some(bps) = self.config.global_bps else {
188            return true;
189        };
190        self.estimate.load(std::sync::atomic::Ordering::Relaxed) <= bps
191    }
192}
193
194#[derive(Debug)]
195struct ThroughputRateLimiter {
196    config: ThroughputLimits,
197    global: Option<Mutex<TokenBucket>>,
198    // NB: These maps grow unbounded but we accept this as we expect an overall limited
199    // number of usecases and scopes. We emit gauge metrics to monitor their size.
200    usecases: papaya::HashMap<String, Mutex<TokenBucket>>,
201    scopes: papaya::HashMap<Scopes, Mutex<TokenBucket>>,
202    rules: papaya::HashMap<usize, Mutex<TokenBucket>>,
203}
204
205impl ThroughputRateLimiter {
206    fn new(config: ThroughputLimits) -> Self {
207        let global = config
208            .global_rps
209            .map(|rps| Mutex::new(TokenBucket::new(rps, config.burst)));
210
211        Self {
212            config,
213            global,
214            usecases: papaya::HashMap::new(),
215            scopes: papaya::HashMap::new(),
216            rules: papaya::HashMap::new(),
217        }
218    }
219
220    fn check(&self, context: &ObjectContext) -> bool {
221        // NB: We intentionally use unwrap and crash the server if the mutexes are poisoned.
222
223        // Global check
224        if let Some(ref global) = self.global
225            && !global.lock().unwrap().try_acquire()
226        {
227            return false;
228        }
229
230        // Usecase check - only if both global_rps and usecase_pct are set
231        if let Some(usecase_rps) = self.usecase_rps() {
232            let guard = self.usecases.pin();
233            let bucket = guard
234                .get_or_insert_with(context.usecase.clone(), || self.create_bucket(usecase_rps));
235            if !bucket.lock().unwrap().try_acquire() {
236                return false;
237            }
238        }
239
240        // Scope check - only if both global_rps and scope_pct are set
241        if let Some(scope_rps) = self.scope_rps() {
242            let guard = self.scopes.pin();
243            let bucket =
244                guard.get_or_insert_with(context.scopes.clone(), || self.create_bucket(scope_rps));
245            if !bucket.lock().unwrap().try_acquire() {
246                return false;
247            }
248        }
249
250        // Rule checks - each matching rule has its own dedicated bucket
251        for (idx, rule) in self.config.rules.iter().enumerate() {
252            if !rule.matches(context) {
253                continue;
254            }
255            let Some(rule_rps) = self.rule_rps(rule) else {
256                continue;
257            };
258            let guard = self.rules.pin();
259            let bucket = guard.get_or_insert_with(idx, || self.create_bucket(rule_rps));
260            if !bucket.lock().unwrap().try_acquire() {
261                return false;
262            }
263        }
264
265        true
266    }
267
268    fn create_bucket(&self, rps: u32) -> Mutex<TokenBucket> {
269        Mutex::new(TokenBucket::new(rps, self.config.burst))
270    }
271
272    /// Returns the effective RPS for per-usecase limiting, if configured.
273    fn usecase_rps(&self) -> Option<u32> {
274        let global_rps = self.config.global_rps?;
275        let pct = self.config.usecase_pct?;
276        Some(((global_rps as f64) * (pct as f64 / 100.0)) as u32)
277    }
278
279    /// Returns the effective RPS for per-scope limiting, if configured.
280    fn scope_rps(&self) -> Option<u32> {
281        let global_rps = self.config.global_rps?;
282        let pct = self.config.scope_pct?;
283        Some(((global_rps as f64) * (pct as f64 / 100.0)) as u32)
284    }
285
286    /// Returns the effective RPS for a rule, if it has a valid limit.
287    fn rule_rps(&self, rule: &ThroughputRule) -> Option<u32> {
288        let pct_limit = rule.pct.and_then(|p| {
289            self.config
290                .global_rps
291                .map(|g| ((g as f64) * (p as f64 / 100.0)) as u32)
292        });
293
294        match (rule.rps, pct_limit) {
295            (Some(r), Some(p)) => Some(r.min(p)),
296            (Some(r), None) => Some(r),
297            (None, Some(p)) => Some(p),
298            (None, None) => None,
299        }
300    }
301}
302
303/// A token bucket rate limiter.
304///
305/// Tokens refill at a constant rate up to capacity. Each request consumes one token.
306/// When no tokens are available, requests are rejected.
307///
308/// This implementation is not thread-safe on its own. Wrap in a `Mutex` for concurrent access.
309#[derive(Debug)]
310struct TokenBucket {
311    refill_rate: f64,
312    capacity: f64,
313    tokens: f64,
314    last_update: Instant,
315}
316
317impl TokenBucket {
318    /// Creates a new, full token bucket with the specified rate limit and burst capacity.
319    ///
320    /// - `rps`: tokens refilled per second (sustained rate limit)
321    /// - `burst`: initial tokens and burst allowance above sustained rate
322    pub fn new(rps: u32, burst: u32) -> Self {
323        Self {
324            refill_rate: rps as f64,
325            capacity: (rps + burst) as f64,
326            tokens: (rps + burst) as f64,
327            last_update: Instant::now(),
328        }
329    }
330
331    /// Attempts to acquire a token from the bucket.
332    ///
333    /// Returns `true` if a token was acquired, `false` if no tokens available.
334    pub fn try_acquire(&mut self) -> bool {
335        let now = Instant::now();
336        let refill = now.duration_since(self.last_update).as_secs_f64() * self.refill_rate;
337        let refilled = (self.tokens + refill).min(self.capacity);
338
339        // Only apply refill if we'd gain at least 1 whole token
340        if refilled.floor() > self.tokens.floor() {
341            self.last_update = now;
342            self.tokens = refilled;
343        }
344
345        // Try to consume one token
346        if self.tokens >= 1.0 {
347            self.tokens -= 1.0;
348            true
349        } else {
350            false
351        }
352    }
353}
354
355/// A wrapper around a `PayloadStream` that measures bandwidth usage.
356///
357/// This behaves exactly as a `PayloadStream`, except that every time an item is polled,
358/// the accumulator is incremented by the size of the returned `Bytes` chunk.
359pub(crate) struct MeteredPayloadStream {
360    inner: PayloadStream,
361    accumulator: Arc<AtomicU64>,
362}
363
364impl MeteredPayloadStream {
365    pub fn from(inner: PayloadStream, accumulator: Arc<AtomicU64>) -> Self {
366        Self { inner, accumulator }
367    }
368}
369
370impl std::fmt::Debug for MeteredPayloadStream {
371    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
372        f.debug_struct("MeteredPayloadStream")
373            .field("accumulator", &self.accumulator)
374            .finish()
375    }
376}
377
378impl Stream for MeteredPayloadStream {
379    type Item = std::io::Result<Bytes>;
380
381    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
382        let this = self.get_mut();
383        let res = this.inner.as_mut().poll_next(cx);
384        if let Poll::Ready(Some(Ok(ref bytes))) = res {
385            this.accumulator
386                .fetch_add(bytes.len() as u64, std::sync::atomic::Ordering::Relaxed);
387        }
388        res
389    }
390}