Skip to main content

objectstore_server/
rate_limits.rs

1//! Admission-based rate limiting for throughput and bandwidth.
2//!
3//! This module provides [`RateLimiter`], which enforces configurable limits at three
4//! levels of granularity — global, per-usecase, and per-scope — for both request
5//! throughput (requests/s) and upload/download bandwidth (bytes/s).
6//!
7//! Throughput is enforced using token buckets; bandwidth is estimated with an
8//! exponentially weighted moving average (EWMA) updated by a background task every
9//! 50 ms. All rate-limit checks are synchronous and non-blocking.
10
11use std::fmt;
12use std::pin::Pin;
13use std::sync::Arc;
14use std::sync::{Mutex, atomic::AtomicU64};
15use std::task::{Context, Poll};
16use std::time::{Duration, Instant};
17
18use bytes::Bytes;
19use futures_util::Stream;
20use objectstore_service::id::ObjectContext;
21use objectstore_types::scope::Scopes;
22use serde::{Deserialize, Serialize};
23
24/// Identifies which rate limit triggered a rejection.
25#[derive(Debug, Clone, Copy, PartialEq, Eq)]
26enum RateLimitRejection {
27    /// Global bandwidth limit exceeded.
28    BandwidthGlobal,
29    /// Per-usecase bandwidth limit exceeded.
30    BandwidthUsecase,
31    /// Per-scope bandwidth limit exceeded.
32    BandwidthScope,
33    /// Global throughput limit exceeded.
34    ThroughputGlobal,
35    /// Per-usecase throughput limit exceeded.
36    ThroughputUsecase,
37    /// Per-scope throughput limit exceeded.
38    ThroughputScope,
39    /// Per-rule throughput limit exceeded.
40    ThroughputRule,
41}
42
43impl RateLimitRejection {
44    /// Returns a static string identifier suitable for use as a metric tag.
45    pub fn as_str(self) -> &'static str {
46        match self {
47            RateLimitRejection::BandwidthGlobal => "bandwidth_global",
48            RateLimitRejection::BandwidthUsecase => "bandwidth_usecase",
49            RateLimitRejection::BandwidthScope => "bandwidth_scope",
50            RateLimitRejection::ThroughputGlobal => "throughput_global",
51            RateLimitRejection::ThroughputUsecase => "throughput_usecase",
52            RateLimitRejection::ThroughputScope => "throughput_scope",
53            RateLimitRejection::ThroughputRule => "throughput_rule",
54        }
55    }
56}
57
58impl fmt::Display for RateLimitRejection {
59    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
60        f.write_str(self.as_str())
61    }
62}
63
64/// Rate limits for objectstore.
65#[derive(Clone, Debug, Default, Deserialize, Serialize, PartialEq)]
66pub struct RateLimits {
67    /// Limits the number of requests per second per service instance.
68    pub throughput: ThroughputLimits,
69    /// Limits the concurrent bandwidth per service instance.
70    pub bandwidth: BandwidthLimits,
71}
72
73/// Request throughput limits applied at global, per-usecase, and per-scope granularity.
74///
75/// All limits are optional. When a limit is `None`, that level of limiting is not enforced.
76/// Per-usecase and per-scope limits are expressed as a percentage of the global limit.
77#[derive(Clone, Debug, Default, Deserialize, Serialize, PartialEq)]
78pub struct ThroughputLimits {
79    /// The overall maximum number of requests per second per service instance.
80    ///
81    /// Defaults to `None`, meaning no global rate limit is enforced.
82    pub global_rps: Option<u32>,
83
84    /// The maximum burst for each rate limit.
85    ///
86    /// Defaults to `0`, meaning no bursting is allowed. If set to a value greater than `0`,
87    /// short spikes above the rate limit are allowed up to the burst size.
88    pub burst: u32,
89
90    /// The maximum percentage of the global rate limit that can be used by any usecase.
91    ///
92    /// Value from `0` to `100`. Defaults to `None`, meaning no per-usecase limit is enforced.
93    pub usecase_pct: Option<u8>,
94
95    /// The maximum percentage of the global rate limit that can be used by any scope.
96    ///
97    /// This treats each full scope separately and applies across all use cases:
98    ///  - Two requests with exact same scopes count against the same limit regardless of use case.
99    ///  - Two requests that share the same top scope but differ in inner scopes count separately.
100    ///
101    /// Value from `0` to `100`. Defaults to `None`, meaning no per-scope limit is enforced.
102    pub scope_pct: Option<u8>,
103
104    /// Overrides for specific usecases and scopes.
105    pub rules: Vec<ThroughputRule>,
106}
107
108/// An override rule that applies a specific throughput limit to matching request contexts.
109///
110/// A rule matches when all specified fields match the request context. Fields not set match
111/// any value. When multiple rules match, each is enforced independently via its own token
112/// bucket — all matching rules must admit the request.
113#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
114pub struct ThroughputRule {
115    /// Optional usecase to match.
116    ///
117    /// If `None`, matches any usecase.
118    pub usecase: Option<String>,
119
120    /// Scopes to match.
121    ///
122    /// If empty, matches any scopes. Additional scopes in the context are ignored, so a rule
123    /// matches if all of the specified scopes are present in the request with matching values.
124    pub scopes: Vec<(String, String)>,
125
126    /// The rate limit to apply when this rule matches.
127    ///
128    /// If both a rate and pct are specified, the more restrictive limit applies.
129    /// Should be greater than `0`. To block traffic entirely, use killswitches instead.
130    pub rps: Option<u32>,
131
132    /// The percentage of the global rate limit to apply when this rule matches.
133    ///
134    /// If both a rate and pct are specified, the more restrictive limit applies.
135    /// Should be greater than `0`. To block traffic entirely, use killswitches instead.
136    pub pct: Option<u8>,
137}
138
139impl ThroughputRule {
140    /// Returns `true` if this rule matches the given context.
141    pub fn matches(&self, context: &ObjectContext) -> bool {
142        if let Some(ref rule_usecase) = self.usecase
143            && rule_usecase != &context.usecase
144        {
145            return false;
146        }
147
148        for (scope_name, scope_value) in &self.scopes {
149            match context.scopes.get_value(scope_name) {
150                Some(value) if value == scope_value => (),
151                _ => return false,
152            }
153        }
154
155        true
156    }
157}
158
159/// Bandwidth limits applied at global, per-usecase, and per-scope granularity.
160///
161/// Bandwidth is measured as bytes transferred per second (upload + download combined)
162/// and estimated using an EWMA updated every 50 ms. All limits are optional.
163#[derive(Clone, Debug, Default, Deserialize, Serialize, PartialEq)]
164pub struct BandwidthLimits {
165    /// The overall maximum bandwidth (in bytes per second) per service instance.
166    ///
167    /// Defaults to `None`, meaning no global bandwidth limit is enforced.
168    pub global_bps: Option<u64>,
169
170    /// The maximum percentage of the global bandwidth limit that can be used by any usecase.
171    ///
172    /// Value from `0` to `100`. Defaults to `None`, meaning no per-usecase bandwidth limit is enforced.
173    pub usecase_pct: Option<u8>,
174
175    /// The maximum percentage of the global bandwidth limit that can be used by any scope.
176    ///
177    /// Value from `0` to `100`. Defaults to `None`, meaning no per-scope bandwidth limit is enforced.
178    pub scope_pct: Option<u8>,
179}
180
181/// Combined rate limiter that enforces both bandwidth and throughput limits.
182///
183/// Checks are synchronous and non-blocking. Bandwidth is checked before
184/// throughput so that rejected requests are never counted toward the admitted
185/// throughput EWMA. See [`check`](RateLimiter::check) for details.
186///
187/// Call [`start`](RateLimiter::start) after construction to launch the background
188/// estimation tasks. Without it, bandwidth EWMAs remain at zero and bandwidth
189/// limits are never triggered.
190#[derive(Debug)]
191pub struct RateLimiter {
192    bandwidth: BandwidthRateLimiter,
193    throughput: ThroughputRateLimiter,
194}
195
196impl RateLimiter {
197    /// Creates a new rate limiter from the given configuration.
198    ///
199    /// Background estimation tasks are not started until [`start`](RateLimiter::start) is called.
200    pub fn new(config: RateLimits) -> Self {
201        Self {
202            bandwidth: BandwidthRateLimiter::new(config.bandwidth),
203            throughput: ThroughputRateLimiter::new(config.throughput),
204        }
205    }
206
207    /// Starts background tasks for rate limit estimation and monitoring.
208    ///
209    /// Must be called from within a Tokio runtime.
210    pub fn start(&self) {
211        self.bandwidth.start();
212        self.throughput.start();
213    }
214
215    /// Checks if the given context is within the rate limits.
216    ///
217    /// Returns `true` if the request is admitted, `false` if it was rejected. On rejection, emits a
218    /// `server.request.rate_limited` metric counter and a `warn!` log. Bandwidth is checked before
219    /// throughput so that rejected requests are never counted toward admitted traffic.
220    pub fn check(&self, context: &ObjectContext, key: Option<&str>) -> bool {
221        // Bandwidth is checked first because it is a pure read (no token consumption).
222        // Throughput increments the EWMA accumulator only on success, so checking it
223        // second ensures rejected requests are never counted toward admitted traffic.
224        let rejection = self
225            .bandwidth
226            .check(context)
227            .or_else(|| self.throughput.check(context));
228
229        let Some(rejection) = rejection else {
230            return true;
231        };
232
233        objectstore_metrics::count!(
234            "server.request.rate_limited",
235            reason = rejection.as_str(),
236            usecase = context.usecase.clone()
237        );
238        objectstore_log::warn!(
239            reason = rejection.as_str(),
240            usecase = &context.usecase,
241            scopes = %context.scopes.as_api_path(),
242            key,
243            "Request rejected: rate limit exceeded"
244        );
245        false
246    }
247
248    /// Returns all bandwidth accumulators (global + per-usecase + per-scope) for the given context.
249    ///
250    /// Creates entries in the per-usecase/per-scope maps if they don't exist yet.
251    pub fn bytes_accumulators(&self, context: &ObjectContext) -> Vec<Arc<AtomicU64>> {
252        self.bandwidth.accumulators(context)
253    }
254
255    /// Records bandwidth usage across all accumulators for the given context.
256    ///
257    /// This is used for cases where bytes are known upfront (e.g. batch INSERT) rather than
258    /// streamed through a `MeteredPayloadStream`.
259    pub fn record_bandwidth(&self, context: &ObjectContext, bytes: u64) {
260        for acc in self.bandwidth.accumulators(context) {
261            acc.fetch_add(bytes, std::sync::atomic::Ordering::Relaxed);
262        }
263    }
264
265    /// Returns the current global bandwidth EWMA in bytes per second.
266    pub fn bandwidth_ewma(&self) -> u64 {
267        self.bandwidth
268            .global
269            .estimate
270            .load(std::sync::atomic::Ordering::Relaxed)
271    }
272
273    /// Returns the configured global bandwidth limit in bytes/s, if set.
274    pub fn bandwidth_limit(&self) -> Option<u64> {
275        self.bandwidth.config.global_bps
276    }
277
278    /// Returns the current estimated throughput in requests per second.
279    pub fn throughput_rps(&self) -> u64 {
280        self.throughput
281            .global_estimator
282            .estimate
283            .load(std::sync::atomic::Ordering::Relaxed)
284    }
285
286    /// Returns the configured global throughput limit in requests/s, if set.
287    pub fn throughput_limit(&self) -> Option<u32> {
288        self.throughput.config.global_rps
289    }
290
291    /// Returns total bytes transferred since startup.
292    pub fn bandwidth_total_bytes(&self) -> u64 {
293        self.bandwidth
294            .total_bytes
295            .load(std::sync::atomic::Ordering::Relaxed)
296    }
297
298    /// Returns total admitted requests since startup.
299    pub fn throughput_total_admitted(&self) -> u64 {
300        self.throughput
301            .total_admitted
302            .load(std::sync::atomic::Ordering::Relaxed)
303    }
304}
305
306/// Shared EWMA estimator state.
307///
308/// The accumulator is incremented as events occur, and the estimate is updated
309/// periodically by a background task using an exponentially weighted moving average.
310#[derive(Debug)]
311struct EwmaEstimator {
312    accumulator: Arc<AtomicU64>,
313    estimate: Arc<AtomicU64>,
314}
315
316impl EwmaEstimator {
317    fn new() -> Self {
318        Self {
319            accumulator: Arc::new(AtomicU64::new(0)),
320            estimate: Arc::new(AtomicU64::new(0)),
321        }
322    }
323
324    /// Updates the EWMA from the accumulator.
325    ///
326    /// Swaps the accumulator to zero, converts the count to a per-second rate using
327    /// `to_rate` (i.e., `1.0 / tick_duration.as_secs_f64()`), then applies the
328    /// EWMA smoothing and stores the floored result in `estimate`.
329    fn update_ewma(&self, ewma: &mut f64, to_rate: f64) {
330        const ALPHA: f64 = 0.2;
331        let current = self
332            .accumulator
333            .swap(0, std::sync::atomic::Ordering::Relaxed);
334        let rate = (current as f64) * to_rate;
335        *ewma = ALPHA * rate + (1.0 - ALPHA) * *ewma;
336        self.estimate
337            .store(ewma.floor() as u64, std::sync::atomic::Ordering::Relaxed);
338    }
339}
340
341#[derive(Debug)]
342struct BandwidthRateLimiter {
343    config: BandwidthLimits,
344    /// Global accumulator/estimator pair.
345    global: Arc<EwmaEstimator>,
346    /// Cumulative bytes transferred since startup. Never reset.
347    total_bytes: Arc<AtomicU64>,
348    // NB: These maps grow unbounded but we accept this as we expect an overall limited
349    // number of usecases and scopes. We emit gauge metrics to monitor their size.
350    usecases: Arc<papaya::HashMap<String, Arc<EwmaEstimator>>>,
351    scopes: Arc<papaya::HashMap<Scopes, Arc<EwmaEstimator>>>,
352}
353
354impl BandwidthRateLimiter {
355    fn new(config: BandwidthLimits) -> Self {
356        Self {
357            config,
358            global: Arc::new(EwmaEstimator::new()),
359            total_bytes: Arc::new(AtomicU64::new(0)),
360            usecases: Arc::new(papaya::HashMap::new()),
361            scopes: Arc::new(papaya::HashMap::new()),
362        }
363    }
364
365    fn start(&self) {
366        let global = Arc::clone(&self.global);
367        let usecases = Arc::clone(&self.usecases);
368        let scopes = Arc::clone(&self.scopes);
369        let global_limit = self.config.global_bps;
370        // NB: This task has no shutdown mechanism — the rate limiter is only created once.
371        // The task is aborted when the Tokio runtime is dropped on process exit.
372        tokio::task::spawn(async move {
373            Self::estimator(global, usecases, scopes, global_limit).await;
374        });
375    }
376
377    /// Estimates the current bandwidth utilization using an exponentially weighted moving average.
378    ///
379    /// Iterates over the global estimator as well as all per-usecase and per-scope estimators
380    /// on each tick, updating their EWMAs.
381    async fn estimator(
382        global: Arc<EwmaEstimator>,
383        usecases: Arc<papaya::HashMap<String, Arc<EwmaEstimator>>>,
384        scopes: Arc<papaya::HashMap<Scopes, Arc<EwmaEstimator>>>,
385        global_limit: Option<u64>,
386    ) {
387        const TICK: Duration = Duration::from_millis(50); // Recompute EWMA on every TICK
388
389        let mut interval = tokio::time::interval(TICK);
390        interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
391        // The first tick of a tokio interval fires immediately. Consume it so the
392        // first real iteration has a full ~50ms of elapsed time.
393        interval.tick().await;
394        let mut last = Instant::now();
395        let mut global_ewma: f64 = 0.0;
396        // Shadow EWMAs for per-usecase/per-scope entries, keyed the same way as the maps.
397        let mut usecase_ewmas: std::collections::HashMap<String, f64> =
398            std::collections::HashMap::new();
399        let mut scope_ewmas: std::collections::HashMap<Scopes, f64> =
400            std::collections::HashMap::new();
401
402        loop {
403            interval.tick().await;
404
405            let now = Instant::now();
406            let to_bps = 1.0 / now.duration_since(last).as_secs_f64();
407            last = now;
408
409            // Global
410            global.update_ewma(&mut global_ewma, to_bps);
411            objectstore_metrics::gauge!("server.bandwidth.ewma" = global_ewma.floor() as u64);
412            if let Some(limit) = global_limit {
413                objectstore_metrics::gauge!("server.bandwidth.limit" = limit);
414            }
415
416            // Per-usecase
417            {
418                let guard = usecases.pin();
419                for (key, estimator) in guard.iter() {
420                    let ewma = usecase_ewmas.entry(key.clone()).or_insert(0.0);
421                    estimator.update_ewma(ewma, to_bps);
422                }
423            }
424
425            // Per-scope
426            {
427                let guard = scopes.pin();
428                for (key, estimator) in guard.iter() {
429                    let ewma = scope_ewmas.entry(key.clone()).or_insert(0.0);
430                    estimator.update_ewma(ewma, to_bps);
431                }
432            }
433
434            objectstore_metrics::gauge!(
435                "server.rate_limiter.bandwidth.scope_map_size" = scopes.len()
436            );
437            objectstore_metrics::gauge!(
438                "server.rate_limiter.bandwidth.usecase_map_size" = usecases.len()
439            );
440        }
441    }
442
443    fn check(&self, context: &ObjectContext) -> Option<RateLimitRejection> {
444        let global_bps = self.config.global_bps?;
445
446        // Global check
447        if self
448            .global
449            .estimate
450            .load(std::sync::atomic::Ordering::Relaxed)
451            > global_bps
452        {
453            return Some(RateLimitRejection::BandwidthGlobal);
454        }
455
456        // Per-usecase check
457        if let Some(usecase_bps) = self.usecase_bps() {
458            let guard = self.usecases.pin();
459            if let Some(estimator) = guard.get(&context.usecase)
460                && estimator
461                    .estimate
462                    .load(std::sync::atomic::Ordering::Relaxed)
463                    > usecase_bps
464            {
465                return Some(RateLimitRejection::BandwidthUsecase);
466            }
467        }
468
469        // Per-scope check
470        if let Some(scope_bps) = self.scope_bps() {
471            let guard = self.scopes.pin();
472            if let Some(estimator) = guard.get(&context.scopes)
473                && estimator
474                    .estimate
475                    .load(std::sync::atomic::Ordering::Relaxed)
476                    > scope_bps
477            {
478                return Some(RateLimitRejection::BandwidthScope);
479            }
480        }
481
482        None
483    }
484
485    /// Returns all accumulators (global + per-usecase + per-scope) for the given context.
486    ///
487    /// Creates entries in the per-usecase/per-scope maps if they don't exist yet.
488    /// Always includes `total_bytes` (cumulative, never reset) as the first entry.
489    fn accumulators(&self, context: &ObjectContext) -> Vec<Arc<AtomicU64>> {
490        let mut accs = vec![
491            Arc::clone(&self.total_bytes),
492            Arc::clone(&self.global.accumulator),
493        ];
494
495        if self.usecase_bps().is_some() {
496            let guard = self.usecases.pin();
497            let estimator = guard
498                .get_or_insert_with(context.usecase.clone(), || Arc::new(EwmaEstimator::new()));
499            accs.push(Arc::clone(&estimator.accumulator));
500        }
501
502        if self.scope_bps().is_some() {
503            let guard = self.scopes.pin();
504            let estimator =
505                guard.get_or_insert_with(context.scopes.clone(), || Arc::new(EwmaEstimator::new()));
506            accs.push(Arc::clone(&estimator.accumulator));
507        }
508
509        accs
510    }
511
512    /// Returns the effective BPS for per-usecase limiting, if configured.
513    fn usecase_bps(&self) -> Option<u64> {
514        let global_bps = self.config.global_bps?;
515        let pct = self.config.usecase_pct?;
516        Some(((global_bps as f64) * (pct as f64 / 100.0)) as u64)
517    }
518
519    /// Returns the effective BPS for per-scope limiting, if configured.
520    fn scope_bps(&self) -> Option<u64> {
521        let global_bps = self.config.global_bps?;
522        let pct = self.config.scope_pct?;
523        Some(((global_bps as f64) * (pct as f64 / 100.0)) as u64)
524    }
525}
526
527#[derive(Debug)]
528struct ThroughputRateLimiter {
529    config: ThroughputLimits,
530    global: Option<Mutex<TokenBucket>>,
531    /// Global EWMA estimator for admitted request rate.
532    global_estimator: Arc<EwmaEstimator>,
533    /// Cumulative admitted requests since startup. Never reset.
534    total_admitted: Arc<AtomicU64>,
535    // NB: These maps grow unbounded but we accept this as we expect an overall limited
536    // number of usecases and scopes. We emit gauge metrics to monitor their size.
537    usecases: Arc<papaya::HashMap<String, Mutex<TokenBucket>>>,
538    scopes: Arc<papaya::HashMap<Scopes, Mutex<TokenBucket>>>,
539    rules: papaya::HashMap<usize, Mutex<TokenBucket>>,
540}
541
542impl ThroughputRateLimiter {
543    fn new(config: ThroughputLimits) -> Self {
544        let global = config
545            .global_rps
546            .map(|rps| Mutex::new(TokenBucket::new(rps, config.burst)));
547
548        Self {
549            config,
550            global,
551            global_estimator: Arc::new(EwmaEstimator::new()),
552            total_admitted: Arc::new(AtomicU64::new(0)),
553            usecases: Arc::new(papaya::HashMap::new()),
554            scopes: Arc::new(papaya::HashMap::new()),
555            rules: papaya::HashMap::new(),
556        }
557    }
558
559    fn start(&self) {
560        let usecases = Arc::clone(&self.usecases);
561        let scopes = Arc::clone(&self.scopes);
562        let global_estimator = Arc::clone(&self.global_estimator);
563        let global_limit = self.config.global_rps;
564        // NB: This task has no shutdown mechanism — the rate limiter is only created once.
565        // The task is aborted when the Tokio runtime is dropped on process exit.
566        tokio::task::spawn(async move {
567            const TICK: Duration = Duration::from_millis(50);
568            let mut interval = tokio::time::interval(TICK);
569            interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
570            interval.tick().await;
571            let mut last = Instant::now();
572            let mut global_ewma: f64 = 0.0;
573            loop {
574                interval.tick().await;
575                let now = Instant::now();
576                let to_rps = 1.0 / now.duration_since(last).as_secs_f64();
577                last = now;
578                global_estimator.update_ewma(&mut global_ewma, to_rps);
579                objectstore_metrics::gauge!("server.throughput.ewma" = global_ewma.floor() as u64);
580                if let Some(limit) = global_limit {
581                    objectstore_metrics::gauge!(
582                        "server.rate_limiter.throughput.limit" = u64::from(limit)
583                    );
584                }
585                objectstore_metrics::gauge!(
586                    "server.rate_limiter.throughput.scope_map_size" = scopes.len()
587                );
588                objectstore_metrics::gauge!(
589                    "server.rate_limiter.throughput.usecase_map_size" = usecases.len()
590                );
591            }
592        });
593    }
594
595    fn check(&self, context: &ObjectContext) -> Option<RateLimitRejection> {
596        // NB: We intentionally use unwrap and crash the server if the mutexes are poisoned.
597
598        // Global check
599        if let Some(ref global) = self.global {
600            let acquired = global.lock().unwrap().try_acquire();
601            if !acquired {
602                return Some(RateLimitRejection::ThroughputGlobal);
603            }
604        }
605
606        // Usecase check - only if both global_rps and usecase_pct are set
607        if let Some(usecase_rps) = self.usecase_rps() {
608            let guard = self.usecases.pin();
609            let bucket = guard
610                .get_or_insert_with(context.usecase.clone(), || self.create_bucket(usecase_rps));
611            if !bucket.lock().unwrap().try_acquire() {
612                return Some(RateLimitRejection::ThroughputUsecase);
613            }
614        }
615
616        // Scope check - only if both global_rps and scope_pct are set
617        if let Some(scope_rps) = self.scope_rps() {
618            let guard = self.scopes.pin();
619            let bucket =
620                guard.get_or_insert_with(context.scopes.clone(), || self.create_bucket(scope_rps));
621            if !bucket.lock().unwrap().try_acquire() {
622                return Some(RateLimitRejection::ThroughputScope);
623            }
624        }
625
626        // Rule checks - each matching rule has its own dedicated bucket
627        for (idx, rule) in self.config.rules.iter().enumerate() {
628            if !rule.matches(context) {
629                continue;
630            }
631            let Some(rule_rps) = self.rule_rps(rule) else {
632                continue;
633            };
634            let guard = self.rules.pin();
635            let bucket = guard.get_or_insert_with(idx, || self.create_bucket(rule_rps));
636            if !bucket.lock().unwrap().try_acquire() {
637                return Some(RateLimitRejection::ThroughputRule);
638            }
639        }
640
641        // Count this admitted request in the EWMA accumulator and the cumulative counter.
642        // NB: u64 wrapping is not a practical concern — at 1M rps it takes ~585k years.
643        // Prometheus irate() also handles counter resets gracefully should it ever occur.
644        self.global_estimator
645            .accumulator
646            .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
647        self.total_admitted
648            .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
649
650        None
651    }
652
653    fn create_bucket(&self, rps: u32) -> Mutex<TokenBucket> {
654        Mutex::new(TokenBucket::new(rps, self.config.burst))
655    }
656
657    /// Returns the effective RPS for per-usecase limiting, if configured.
658    fn usecase_rps(&self) -> Option<u32> {
659        let global_rps = self.config.global_rps?;
660        let pct = self.config.usecase_pct?;
661        Some(((global_rps as f64) * (pct as f64 / 100.0)) as u32)
662    }
663
664    /// Returns the effective RPS for per-scope limiting, if configured.
665    fn scope_rps(&self) -> Option<u32> {
666        let global_rps = self.config.global_rps?;
667        let pct = self.config.scope_pct?;
668        Some(((global_rps as f64) * (pct as f64 / 100.0)) as u32)
669    }
670
671    /// Returns the effective RPS for a rule, if it has a valid limit.
672    fn rule_rps(&self, rule: &ThroughputRule) -> Option<u32> {
673        let pct_limit = rule.pct.and_then(|p| {
674            self.config
675                .global_rps
676                .map(|g| ((g as f64) * (p as f64 / 100.0)) as u32)
677        });
678
679        match (rule.rps, pct_limit) {
680            (Some(r), Some(p)) => Some(r.min(p)),
681            (Some(r), None) => Some(r),
682            (None, Some(p)) => Some(p),
683            (None, None) => None,
684        }
685    }
686}
687
688/// A token bucket rate limiter.
689///
690/// Tokens refill at a constant rate up to capacity. Each request consumes one token.
691/// When no tokens are available, requests are rejected.
692///
693/// This implementation is not thread-safe on its own. Wrap in a `Mutex` for concurrent access.
694#[derive(Debug)]
695struct TokenBucket {
696    refill_rate: f64,
697    capacity: f64,
698    tokens: f64,
699    last_update: Instant,
700}
701
702impl TokenBucket {
703    /// Creates a new, full token bucket with the specified rate limit and burst capacity.
704    ///
705    /// - `rps`: tokens refilled per second (sustained rate limit)
706    /// - `burst`: initial tokens and burst allowance above sustained rate
707    pub fn new(rps: u32, burst: u32) -> Self {
708        Self {
709            refill_rate: rps as f64,
710            capacity: (rps + burst) as f64,
711            tokens: (rps + burst) as f64,
712            last_update: Instant::now(),
713        }
714    }
715
716    /// Attempts to acquire a token from the bucket.
717    ///
718    /// Returns `true` if a token was acquired, `false` if no tokens available.
719    pub fn try_acquire(&mut self) -> bool {
720        let now = Instant::now();
721        let refill = now.duration_since(self.last_update).as_secs_f64() * self.refill_rate;
722        let refilled = (self.tokens + refill).min(self.capacity);
723
724        // Only apply refill if we'd gain at least 1 whole token
725        if refilled.floor() > self.tokens.floor() {
726            self.last_update = now;
727            self.tokens = refilled;
728        }
729
730        // Try to consume one token
731        if self.tokens >= 1.0 {
732            self.tokens -= 1.0;
733            true
734        } else {
735            false
736        }
737    }
738}
739
740/// A wrapper around a byte stream that measures bandwidth usage.
741///
742/// Every time a chunk is polled successfully, all accumulators are incremented
743/// by its size. Generic over both the stream type `S` and its error type.
744pub(crate) struct MeteredPayloadStream<S> {
745    inner: S,
746    accumulators: Vec<Arc<AtomicU64>>,
747}
748
749impl<S> MeteredPayloadStream<S> {
750    pub fn new(inner: S, accumulators: Vec<Arc<AtomicU64>>) -> Self {
751        Self {
752            inner,
753            accumulators,
754        }
755    }
756}
757
758impl<S> fmt::Debug for MeteredPayloadStream<S> {
759    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
760        f.debug_struct("MeteredPayloadStream")
761            .field("accumulators", &self.accumulators)
762            .finish()
763    }
764}
765
766impl<S, E> Stream for MeteredPayloadStream<S>
767where
768    S: Stream<Item = Result<Bytes, E>> + Unpin,
769{
770    type Item = Result<Bytes, E>;
771
772    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
773        let this = self.get_mut();
774        let res = Pin::new(&mut this.inner).poll_next(cx);
775        if let Poll::Ready(Some(Ok(ref bytes))) = res {
776            let len = bytes.len() as u64;
777            for acc in &this.accumulators {
778                acc.fetch_add(len, std::sync::atomic::Ordering::Relaxed);
779            }
780        }
781        res
782    }
783}
784
785#[cfg(test)]
786mod tests {
787    use objectstore_service::id::ObjectContext;
788    use objectstore_types::scope::{Scope, Scopes};
789
790    use super::*;
791
792    fn make_context() -> ObjectContext {
793        ObjectContext {
794            usecase: "testing".into(),
795            scopes: Scopes::from_iter([Scope::create("org", "1").unwrap()]),
796        }
797    }
798
799    #[test]
800    fn ewma_estimator_update_applies_alpha() {
801        let estimator = EwmaEstimator::new();
802        const TICK: f64 = 0.05; // 50ms
803        let to_rate = 1.0 / TICK;
804        let mut ewma: f64 = 0.0;
805
806        // Simulate 10 events in one 50ms tick → 200 /s raw rate.
807        // After one step: 0.2 * 200 + 0.8 * 0 = 40.
808        estimator
809            .accumulator
810            .store(10, std::sync::atomic::Ordering::Relaxed);
811        estimator.update_ewma(&mut ewma, to_rate);
812        assert_eq!(
813            estimator
814                .estimate
815                .load(std::sync::atomic::Ordering::Relaxed),
816            40
817        );
818
819        // Accumulator must have been zeroed.
820        assert_eq!(
821            estimator
822                .accumulator
823                .load(std::sync::atomic::Ordering::Relaxed),
824            0
825        );
826    }
827
828    #[test]
829    fn throughput_check_increments_accumulator() {
830        let limiter = ThroughputRateLimiter::new(ThroughputLimits {
831            global_rps: Some(1000),
832            ..Default::default()
833        });
834
835        assert_eq!(
836            limiter
837                .global_estimator
838                .accumulator
839                .load(std::sync::atomic::Ordering::Relaxed),
840            0
841        );
842
843        let context = make_context();
844        assert!(limiter.check(&context).is_none());
845        assert!(limiter.check(&context).is_none());
846
847        assert_eq!(
848            limiter
849                .global_estimator
850                .accumulator
851                .load(std::sync::atomic::Ordering::Relaxed),
852            2
853        );
854    }
855
856    #[test]
857    fn throughput_rejected_does_not_increment_accumulator() {
858        let limiter = ThroughputRateLimiter::new(ThroughputLimits {
859            global_rps: Some(1),
860            burst: 0,
861            ..Default::default()
862        });
863
864        let context = make_context();
865        // First call admitted (consumes the one token), second rejected.
866        assert!(limiter.check(&context).is_none());
867        assert!(limiter.check(&context).is_some());
868
869        assert_eq!(
870            limiter
871                .global_estimator
872                .accumulator
873                .load(std::sync::atomic::Ordering::Relaxed),
874            1 // only the admitted request
875        );
876    }
877
878    #[test]
879    fn bandwidth_rejection_does_not_increment_throughput_accumulator() {
880        // global_bps of 1 means the estimate (0 initially) is not > 1, so the first
881        // call passes the bandwidth check. Use 0 to guarantee an immediate reject.
882        // BandwidthRateLimiter::check rejects when estimate > global_bps, so set
883        // global_bps = 0 to make the bandwidth check always reject.
884        let limiter = RateLimiter::new(RateLimits {
885            throughput: ThroughputLimits {
886                global_rps: Some(1000),
887                ..Default::default()
888            },
889            bandwidth: BandwidthLimits {
890                global_bps: Some(0),
891                ..Default::default()
892            },
893        });
894
895        // Prime the bandwidth EWMA so it exceeds the limit.
896        limiter
897            .bandwidth
898            .global
899            .estimate
900            .store(1, std::sync::atomic::Ordering::Relaxed);
901
902        let context = make_context();
903        assert!(!limiter.check(&context, None));
904
905        // The throughput accumulator must still be 0.
906        assert_eq!(
907            limiter
908                .throughput
909                .global_estimator
910                .accumulator
911                .load(std::sync::atomic::Ordering::Relaxed),
912            0
913        );
914    }
915
916    #[test]
917    fn rate_limiter_accessors_with_config() {
918        let rate_limiter = RateLimiter::new(RateLimits {
919            throughput: ThroughputLimits {
920                global_rps: Some(500),
921                ..Default::default()
922            },
923            bandwidth: BandwidthLimits {
924                global_bps: Some(1_000_000),
925                ..Default::default()
926            },
927        });
928
929        assert_eq!(rate_limiter.bandwidth_limit(), Some(1_000_000));
930        assert_eq!(rate_limiter.throughput_limit(), Some(500));
931        // Estimates start at 0 before any background tick.
932        assert_eq!(rate_limiter.bandwidth_ewma(), 0);
933        assert_eq!(rate_limiter.throughput_rps(), 0);
934    }
935
936    #[test]
937    fn rate_limiter_accessors_no_limits() {
938        let rate_limiter = RateLimiter::new(RateLimits::default());
939
940        assert_eq!(rate_limiter.bandwidth_limit(), None);
941        assert_eq!(rate_limiter.throughput_limit(), None);
942        assert_eq!(rate_limiter.bandwidth_ewma(), 0);
943        assert_eq!(rate_limiter.throughput_rps(), 0);
944    }
945}