objectstore_server/
rate_limits.rs

1//! Admission-based rate limiting for throughput and bandwidth.
2//!
3//! This module provides [`RateLimiter`], which enforces configurable limits at three
4//! levels of granularity — global, per-usecase, and per-scope — for both request
5//! throughput (requests/s) and upload/download bandwidth (bytes/s).
6//!
7//! Throughput is enforced using token buckets; bandwidth is estimated with an
8//! exponentially weighted moving average (EWMA) updated by a background task every
9//! 50 ms. All rate-limit checks are synchronous and non-blocking.
10
11use std::fmt;
12use std::pin::Pin;
13use std::sync::Arc;
14use std::sync::{Mutex, atomic::AtomicU64};
15use std::task::{Context, Poll};
16use std::time::{Duration, Instant};
17
18use bytes::Bytes;
19use futures_util::Stream;
20use objectstore_service::id::ObjectContext;
21use objectstore_types::scope::Scopes;
22use serde::{Deserialize, Serialize};
23
24/// Identifies which rate limit triggered a rejection.
25#[derive(Debug, Clone, Copy, PartialEq, Eq)]
26enum RateLimitRejection {
27    /// Global bandwidth limit exceeded.
28    BandwidthGlobal,
29    /// Per-usecase bandwidth limit exceeded.
30    BandwidthUsecase,
31    /// Per-scope bandwidth limit exceeded.
32    BandwidthScope,
33    /// Global throughput limit exceeded.
34    ThroughputGlobal,
35    /// Per-usecase throughput limit exceeded.
36    ThroughputUsecase,
37    /// Per-scope throughput limit exceeded.
38    ThroughputScope,
39    /// Per-rule throughput limit exceeded.
40    ThroughputRule,
41}
42
43impl RateLimitRejection {
44    /// Returns a static string identifier suitable for use as a metric tag.
45    pub fn as_str(self) -> &'static str {
46        match self {
47            RateLimitRejection::BandwidthGlobal => "bandwidth_global",
48            RateLimitRejection::BandwidthUsecase => "bandwidth_usecase",
49            RateLimitRejection::BandwidthScope => "bandwidth_scope",
50            RateLimitRejection::ThroughputGlobal => "throughput_global",
51            RateLimitRejection::ThroughputUsecase => "throughput_usecase",
52            RateLimitRejection::ThroughputScope => "throughput_scope",
53            RateLimitRejection::ThroughputRule => "throughput_rule",
54        }
55    }
56}
57
58impl fmt::Display for RateLimitRejection {
59    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
60        f.write_str(self.as_str())
61    }
62}
63
64/// Rate limits for objectstore.
65#[derive(Clone, Debug, Default, Deserialize, Serialize, PartialEq)]
66pub struct RateLimits {
67    /// Limits the number of requests per second per service instance.
68    pub throughput: ThroughputLimits,
69    /// Limits the concurrent bandwidth per service instance.
70    pub bandwidth: BandwidthLimits,
71}
72
73/// Request throughput limits applied at global, per-usecase, and per-scope granularity.
74///
75/// All limits are optional. When a limit is `None`, that level of limiting is not enforced.
76/// Per-usecase and per-scope limits are expressed as a percentage of the global limit.
77#[derive(Clone, Debug, Default, Deserialize, Serialize, PartialEq)]
78pub struct ThroughputLimits {
79    /// The overall maximum number of requests per second per service instance.
80    ///
81    /// Defaults to `None`, meaning no global rate limit is enforced.
82    pub global_rps: Option<u32>,
83
84    /// The maximum burst for each rate limit.
85    ///
86    /// Defaults to `0`, meaning no bursting is allowed. If set to a value greater than `0`,
87    /// short spikes above the rate limit are allowed up to the burst size.
88    pub burst: u32,
89
90    /// The maximum percentage of the global rate limit that can be used by any usecase.
91    ///
92    /// Value from `0` to `100`. Defaults to `None`, meaning no per-usecase limit is enforced.
93    pub usecase_pct: Option<u8>,
94
95    /// The maximum percentage of the global rate limit that can be used by any scope.
96    ///
97    /// This treats each full scope separately and applies across all use cases:
98    ///  - Two requests with exact same scopes count against the same limit regardless of use case.
99    ///  - Two requests that share the same top scope but differ in inner scopes count separately.
100    ///
101    /// Value from `0` to `100`. Defaults to `None`, meaning no per-scope limit is enforced.
102    pub scope_pct: Option<u8>,
103
104    /// Overrides for specific usecases and scopes.
105    pub rules: Vec<ThroughputRule>,
106}
107
108/// An override rule that applies a specific throughput limit to matching request contexts.
109///
110/// A rule matches when all specified fields match the request context. Fields not set match
111/// any value. When multiple rules match, each is enforced independently via its own token
112/// bucket — all matching rules must admit the request.
113#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
114pub struct ThroughputRule {
115    /// Optional usecase to match.
116    ///
117    /// If `None`, matches any usecase.
118    pub usecase: Option<String>,
119
120    /// Scopes to match.
121    ///
122    /// If empty, matches any scopes. Additional scopes in the context are ignored, so a rule
123    /// matches if all of the specified scopes are present in the request with matching values.
124    pub scopes: Vec<(String, String)>,
125
126    /// The rate limit to apply when this rule matches.
127    ///
128    /// If both a rate and pct are specified, the more restrictive limit applies.
129    /// Should be greater than `0`. To block traffic entirely, use killswitches instead.
130    pub rps: Option<u32>,
131
132    /// The percentage of the global rate limit to apply when this rule matches.
133    ///
134    /// If both a rate and pct are specified, the more restrictive limit applies.
135    /// Should be greater than `0`. To block traffic entirely, use killswitches instead.
136    pub pct: Option<u8>,
137}
138
139impl ThroughputRule {
140    /// Returns `true` if this rule matches the given context.
141    pub fn matches(&self, context: &ObjectContext) -> bool {
142        if let Some(ref rule_usecase) = self.usecase
143            && rule_usecase != &context.usecase
144        {
145            return false;
146        }
147
148        for (scope_name, scope_value) in &self.scopes {
149            match context.scopes.get_value(scope_name) {
150                Some(value) if value == scope_value => (),
151                _ => return false,
152            }
153        }
154
155        true
156    }
157}
158
159/// Bandwidth limits applied at global, per-usecase, and per-scope granularity.
160///
161/// Bandwidth is measured as bytes transferred per second (upload + download combined)
162/// and estimated using an EWMA updated every 50 ms. All limits are optional.
163#[derive(Clone, Debug, Default, Deserialize, Serialize, PartialEq)]
164pub struct BandwidthLimits {
165    /// The overall maximum bandwidth (in bytes per second) per service instance.
166    ///
167    /// Defaults to `None`, meaning no global bandwidth limit is enforced.
168    pub global_bps: Option<u64>,
169
170    /// The maximum percentage of the global bandwidth limit that can be used by any usecase.
171    ///
172    /// Value from `0` to `100`. Defaults to `None`, meaning no per-usecase bandwidth limit is enforced.
173    pub usecase_pct: Option<u8>,
174
175    /// The maximum percentage of the global bandwidth limit that can be used by any scope.
176    ///
177    /// Value from `0` to `100`. Defaults to `None`, meaning no per-scope bandwidth limit is enforced.
178    pub scope_pct: Option<u8>,
179}
180
181/// Combined rate limiter that enforces both bandwidth and throughput limits.
182///
183/// Checks are synchronous and non-blocking. Bandwidth is checked before
184/// throughput so that rejected requests are never counted toward the admitted
185/// throughput EWMA. See [`check`](RateLimiter::check) for details.
186///
187/// Call [`start`](RateLimiter::start) after construction to launch the background
188/// estimation tasks. Without it, bandwidth EWMAs remain at zero and bandwidth
189/// limits are never triggered.
190#[derive(Debug)]
191pub struct RateLimiter {
192    bandwidth: BandwidthRateLimiter,
193    throughput: ThroughputRateLimiter,
194}
195
196impl RateLimiter {
197    /// Creates a new rate limiter from the given configuration.
198    ///
199    /// Background estimation tasks are not started until [`start`](RateLimiter::start) is called.
200    pub fn new(config: RateLimits) -> Self {
201        Self {
202            bandwidth: BandwidthRateLimiter::new(config.bandwidth),
203            throughput: ThroughputRateLimiter::new(config.throughput),
204        }
205    }
206
207    /// Starts background tasks for rate limit estimation and monitoring.
208    ///
209    /// Must be called from within a Tokio runtime.
210    pub fn start(&self) {
211        self.bandwidth.start();
212        self.throughput.start();
213    }
214
215    /// Checks if the given context is within the rate limits.
216    ///
217    /// Returns `true` if the request is admitted, `false` if it was rejected. On rejection, emits a
218    /// `server.request.rate_limited` metric counter and a `warn!` log. Bandwidth is checked before
219    /// throughput so that rejected requests are never counted toward admitted traffic.
220    pub fn check(&self, context: &ObjectContext) -> bool {
221        // Bandwidth is checked first because it is a pure read (no token consumption).
222        // Throughput increments the EWMA accumulator only on success, so checking it
223        // second ensures rejected requests are never counted toward admitted traffic.
224        let rejection = self
225            .bandwidth
226            .check(context)
227            .or_else(|| self.throughput.check(context));
228
229        let Some(rejection) = rejection else {
230            return true;
231        };
232
233        objectstore_metrics::count!("server.request.rate_limited", reason = rejection.as_str());
234        objectstore_log::warn!(
235            reason = rejection.as_str(),
236            "Request rejected: rate limit exceeded"
237        );
238        false
239    }
240
241    /// Returns all bandwidth accumulators (global + per-usecase + per-scope) for the given context.
242    ///
243    /// Creates entries in the per-usecase/per-scope maps if they don't exist yet.
244    pub fn bytes_accumulators(&self, context: &ObjectContext) -> Vec<Arc<AtomicU64>> {
245        self.bandwidth.accumulators(context)
246    }
247
248    /// Records bandwidth usage across all accumulators for the given context.
249    ///
250    /// This is used for cases where bytes are known upfront (e.g. batch INSERT) rather than
251    /// streamed through a `MeteredPayloadStream`.
252    pub fn record_bandwidth(&self, context: &ObjectContext, bytes: u64) {
253        for acc in self.bandwidth.accumulators(context) {
254            acc.fetch_add(bytes, std::sync::atomic::Ordering::Relaxed);
255        }
256    }
257
258    /// Returns the current global bandwidth EWMA in bytes per second.
259    pub fn bandwidth_ewma(&self) -> u64 {
260        self.bandwidth
261            .global
262            .estimate
263            .load(std::sync::atomic::Ordering::Relaxed)
264    }
265
266    /// Returns the configured global bandwidth limit in bytes/s, if set.
267    pub fn bandwidth_limit(&self) -> Option<u64> {
268        self.bandwidth.config.global_bps
269    }
270
271    /// Returns the current estimated throughput in requests per second.
272    pub fn throughput_rps(&self) -> u64 {
273        self.throughput
274            .global_estimator
275            .estimate
276            .load(std::sync::atomic::Ordering::Relaxed)
277    }
278
279    /// Returns the configured global throughput limit in requests/s, if set.
280    pub fn throughput_limit(&self) -> Option<u32> {
281        self.throughput.config.global_rps
282    }
283
284    /// Returns total bytes transferred since startup.
285    pub fn bandwidth_total_bytes(&self) -> u64 {
286        self.bandwidth
287            .total_bytes
288            .load(std::sync::atomic::Ordering::Relaxed)
289    }
290
291    /// Returns total admitted requests since startup.
292    pub fn throughput_total_admitted(&self) -> u64 {
293        self.throughput
294            .total_admitted
295            .load(std::sync::atomic::Ordering::Relaxed)
296    }
297}
298
299/// Shared EWMA estimator state.
300///
301/// The accumulator is incremented as events occur, and the estimate is updated
302/// periodically by a background task using an exponentially weighted moving average.
303#[derive(Debug)]
304struct EwmaEstimator {
305    accumulator: Arc<AtomicU64>,
306    estimate: Arc<AtomicU64>,
307}
308
309impl EwmaEstimator {
310    fn new() -> Self {
311        Self {
312            accumulator: Arc::new(AtomicU64::new(0)),
313            estimate: Arc::new(AtomicU64::new(0)),
314        }
315    }
316
317    /// Updates the EWMA from the accumulator.
318    ///
319    /// Swaps the accumulator to zero, converts the count to a per-second rate using
320    /// `to_rate` (i.e., `1.0 / tick_duration.as_secs_f64()`), then applies the
321    /// EWMA smoothing and stores the floored result in `estimate`.
322    fn update_ewma(&self, ewma: &mut f64, to_rate: f64) {
323        const ALPHA: f64 = 0.2;
324        let current = self
325            .accumulator
326            .swap(0, std::sync::atomic::Ordering::Relaxed);
327        let rate = (current as f64) * to_rate;
328        *ewma = ALPHA * rate + (1.0 - ALPHA) * *ewma;
329        self.estimate
330            .store(ewma.floor() as u64, std::sync::atomic::Ordering::Relaxed);
331    }
332}
333
334#[derive(Debug)]
335struct BandwidthRateLimiter {
336    config: BandwidthLimits,
337    /// Global accumulator/estimator pair.
338    global: Arc<EwmaEstimator>,
339    /// Cumulative bytes transferred since startup. Never reset.
340    total_bytes: Arc<AtomicU64>,
341    // NB: These maps grow unbounded but we accept this as we expect an overall limited
342    // number of usecases and scopes. We emit gauge metrics to monitor their size.
343    usecases: Arc<papaya::HashMap<String, Arc<EwmaEstimator>>>,
344    scopes: Arc<papaya::HashMap<Scopes, Arc<EwmaEstimator>>>,
345}
346
347impl BandwidthRateLimiter {
348    fn new(config: BandwidthLimits) -> Self {
349        Self {
350            config,
351            global: Arc::new(EwmaEstimator::new()),
352            total_bytes: Arc::new(AtomicU64::new(0)),
353            usecases: Arc::new(papaya::HashMap::new()),
354            scopes: Arc::new(papaya::HashMap::new()),
355        }
356    }
357
358    fn start(&self) {
359        let global = Arc::clone(&self.global);
360        let usecases = Arc::clone(&self.usecases);
361        let scopes = Arc::clone(&self.scopes);
362        let global_limit = self.config.global_bps;
363        // NB: This task has no shutdown mechanism — the rate limiter is only created once.
364        // The task is aborted when the Tokio runtime is dropped on process exit.
365        tokio::task::spawn(async move {
366            Self::estimator(global, usecases, scopes, global_limit).await;
367        });
368    }
369
370    /// Estimates the current bandwidth utilization using an exponentially weighted moving average.
371    ///
372    /// Iterates over the global estimator as well as all per-usecase and per-scope estimators
373    /// on each tick, updating their EWMAs.
374    async fn estimator(
375        global: Arc<EwmaEstimator>,
376        usecases: Arc<papaya::HashMap<String, Arc<EwmaEstimator>>>,
377        scopes: Arc<papaya::HashMap<Scopes, Arc<EwmaEstimator>>>,
378        global_limit: Option<u64>,
379    ) {
380        const TICK: Duration = Duration::from_millis(50); // Recompute EWMA on every TICK
381
382        let mut interval = tokio::time::interval(TICK);
383        let to_bps = 1.0 / TICK.as_secs_f64(); // Conversion factor from bytes to bps
384        let mut global_ewma: f64 = 0.0;
385        // Shadow EWMAs for per-usecase/per-scope entries, keyed the same way as the maps.
386        let mut usecase_ewmas: std::collections::HashMap<String, f64> =
387            std::collections::HashMap::new();
388        let mut scope_ewmas: std::collections::HashMap<Scopes, f64> =
389            std::collections::HashMap::new();
390
391        loop {
392            interval.tick().await;
393
394            // Global
395            global.update_ewma(&mut global_ewma, to_bps);
396            objectstore_metrics::gauge!("server.bandwidth.ewma" = global_ewma.floor() as u64);
397            if let Some(limit) = global_limit {
398                objectstore_metrics::gauge!("server.bandwidth.limit" = limit);
399            }
400
401            // Per-usecase
402            {
403                let guard = usecases.pin();
404                for (key, estimator) in guard.iter() {
405                    let ewma = usecase_ewmas.entry(key.clone()).or_insert(0.0);
406                    estimator.update_ewma(ewma, to_bps);
407                }
408            }
409
410            // Per-scope
411            {
412                let guard = scopes.pin();
413                for (key, estimator) in guard.iter() {
414                    let ewma = scope_ewmas.entry(key.clone()).or_insert(0.0);
415                    estimator.update_ewma(ewma, to_bps);
416                }
417            }
418
419            objectstore_metrics::gauge!(
420                "server.rate_limiter.bandwidth.scope_map_size" = scopes.len()
421            );
422            objectstore_metrics::gauge!(
423                "server.rate_limiter.bandwidth.usecase_map_size" = usecases.len()
424            );
425        }
426    }
427
428    fn check(&self, context: &ObjectContext) -> Option<RateLimitRejection> {
429        let global_bps = self.config.global_bps?;
430
431        // Global check
432        if self
433            .global
434            .estimate
435            .load(std::sync::atomic::Ordering::Relaxed)
436            > global_bps
437        {
438            return Some(RateLimitRejection::BandwidthGlobal);
439        }
440
441        // Per-usecase check
442        if let Some(usecase_bps) = self.usecase_bps() {
443            let guard = self.usecases.pin();
444            if let Some(estimator) = guard.get(&context.usecase)
445                && estimator
446                    .estimate
447                    .load(std::sync::atomic::Ordering::Relaxed)
448                    > usecase_bps
449            {
450                return Some(RateLimitRejection::BandwidthUsecase);
451            }
452        }
453
454        // Per-scope check
455        if let Some(scope_bps) = self.scope_bps() {
456            let guard = self.scopes.pin();
457            if let Some(estimator) = guard.get(&context.scopes)
458                && estimator
459                    .estimate
460                    .load(std::sync::atomic::Ordering::Relaxed)
461                    > scope_bps
462            {
463                return Some(RateLimitRejection::BandwidthScope);
464            }
465        }
466
467        None
468    }
469
470    /// Returns all accumulators (global + per-usecase + per-scope) for the given context.
471    ///
472    /// Creates entries in the per-usecase/per-scope maps if they don't exist yet.
473    /// Always includes `total_bytes` (cumulative, never reset) as the first entry.
474    fn accumulators(&self, context: &ObjectContext) -> Vec<Arc<AtomicU64>> {
475        let mut accs = vec![
476            Arc::clone(&self.total_bytes),
477            Arc::clone(&self.global.accumulator),
478        ];
479
480        if self.usecase_bps().is_some() {
481            let guard = self.usecases.pin();
482            let estimator = guard
483                .get_or_insert_with(context.usecase.clone(), || Arc::new(EwmaEstimator::new()));
484            accs.push(Arc::clone(&estimator.accumulator));
485        }
486
487        if self.scope_bps().is_some() {
488            let guard = self.scopes.pin();
489            let estimator =
490                guard.get_or_insert_with(context.scopes.clone(), || Arc::new(EwmaEstimator::new()));
491            accs.push(Arc::clone(&estimator.accumulator));
492        }
493
494        accs
495    }
496
497    /// Returns the effective BPS for per-usecase limiting, if configured.
498    fn usecase_bps(&self) -> Option<u64> {
499        let global_bps = self.config.global_bps?;
500        let pct = self.config.usecase_pct?;
501        Some(((global_bps as f64) * (pct as f64 / 100.0)) as u64)
502    }
503
504    /// Returns the effective BPS for per-scope limiting, if configured.
505    fn scope_bps(&self) -> Option<u64> {
506        let global_bps = self.config.global_bps?;
507        let pct = self.config.scope_pct?;
508        Some(((global_bps as f64) * (pct as f64 / 100.0)) as u64)
509    }
510}
511
512#[derive(Debug)]
513struct ThroughputRateLimiter {
514    config: ThroughputLimits,
515    global: Option<Mutex<TokenBucket>>,
516    /// Global EWMA estimator for admitted request rate.
517    global_estimator: Arc<EwmaEstimator>,
518    /// Cumulative admitted requests since startup. Never reset.
519    total_admitted: Arc<AtomicU64>,
520    // NB: These maps grow unbounded but we accept this as we expect an overall limited
521    // number of usecases and scopes. We emit gauge metrics to monitor their size.
522    usecases: Arc<papaya::HashMap<String, Mutex<TokenBucket>>>,
523    scopes: Arc<papaya::HashMap<Scopes, Mutex<TokenBucket>>>,
524    rules: papaya::HashMap<usize, Mutex<TokenBucket>>,
525}
526
527impl ThroughputRateLimiter {
528    fn new(config: ThroughputLimits) -> Self {
529        let global = config
530            .global_rps
531            .map(|rps| Mutex::new(TokenBucket::new(rps, config.burst)));
532
533        Self {
534            config,
535            global,
536            global_estimator: Arc::new(EwmaEstimator::new()),
537            total_admitted: Arc::new(AtomicU64::new(0)),
538            usecases: Arc::new(papaya::HashMap::new()),
539            scopes: Arc::new(papaya::HashMap::new()),
540            rules: papaya::HashMap::new(),
541        }
542    }
543
544    fn start(&self) {
545        let usecases = Arc::clone(&self.usecases);
546        let scopes = Arc::clone(&self.scopes);
547        let global_estimator = Arc::clone(&self.global_estimator);
548        let global_limit = self.config.global_rps;
549        // NB: This task has no shutdown mechanism — the rate limiter is only created once.
550        // The task is aborted when the Tokio runtime is dropped on process exit.
551        tokio::task::spawn(async move {
552            const TICK: Duration = Duration::from_millis(50);
553            let mut interval = tokio::time::interval(TICK);
554            let to_rps = 1.0 / TICK.as_secs_f64();
555            let mut global_ewma: f64 = 0.0;
556            loop {
557                interval.tick().await;
558                global_estimator.update_ewma(&mut global_ewma, to_rps);
559                objectstore_metrics::gauge!("server.throughput.ewma" = global_ewma.floor() as u64);
560                if let Some(limit) = global_limit {
561                    objectstore_metrics::gauge!(
562                        "server.rate_limiter.throughput.limit" = u64::from(limit)
563                    );
564                }
565                objectstore_metrics::gauge!(
566                    "server.rate_limiter.throughput.scope_map_size" = scopes.len()
567                );
568                objectstore_metrics::gauge!(
569                    "server.rate_limiter.throughput.usecase_map_size" = usecases.len()
570                );
571            }
572        });
573    }
574
575    fn check(&self, context: &ObjectContext) -> Option<RateLimitRejection> {
576        // NB: We intentionally use unwrap and crash the server if the mutexes are poisoned.
577
578        // Global check
579        if let Some(ref global) = self.global {
580            let acquired = global.lock().unwrap().try_acquire();
581            if !acquired {
582                return Some(RateLimitRejection::ThroughputGlobal);
583            }
584        }
585
586        // Usecase check - only if both global_rps and usecase_pct are set
587        if let Some(usecase_rps) = self.usecase_rps() {
588            let guard = self.usecases.pin();
589            let bucket = guard
590                .get_or_insert_with(context.usecase.clone(), || self.create_bucket(usecase_rps));
591            if !bucket.lock().unwrap().try_acquire() {
592                return Some(RateLimitRejection::ThroughputUsecase);
593            }
594        }
595
596        // Scope check - only if both global_rps and scope_pct are set
597        if let Some(scope_rps) = self.scope_rps() {
598            let guard = self.scopes.pin();
599            let bucket =
600                guard.get_or_insert_with(context.scopes.clone(), || self.create_bucket(scope_rps));
601            if !bucket.lock().unwrap().try_acquire() {
602                return Some(RateLimitRejection::ThroughputScope);
603            }
604        }
605
606        // Rule checks - each matching rule has its own dedicated bucket
607        for (idx, rule) in self.config.rules.iter().enumerate() {
608            if !rule.matches(context) {
609                continue;
610            }
611            let Some(rule_rps) = self.rule_rps(rule) else {
612                continue;
613            };
614            let guard = self.rules.pin();
615            let bucket = guard.get_or_insert_with(idx, || self.create_bucket(rule_rps));
616            if !bucket.lock().unwrap().try_acquire() {
617                return Some(RateLimitRejection::ThroughputRule);
618            }
619        }
620
621        // Count this admitted request in the EWMA accumulator and the cumulative counter.
622        // NB: u64 wrapping is not a practical concern — at 1M rps it takes ~585k years.
623        // Prometheus irate() also handles counter resets gracefully should it ever occur.
624        self.global_estimator
625            .accumulator
626            .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
627        self.total_admitted
628            .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
629
630        None
631    }
632
633    fn create_bucket(&self, rps: u32) -> Mutex<TokenBucket> {
634        Mutex::new(TokenBucket::new(rps, self.config.burst))
635    }
636
637    /// Returns the effective RPS for per-usecase limiting, if configured.
638    fn usecase_rps(&self) -> Option<u32> {
639        let global_rps = self.config.global_rps?;
640        let pct = self.config.usecase_pct?;
641        Some(((global_rps as f64) * (pct as f64 / 100.0)) as u32)
642    }
643
644    /// Returns the effective RPS for per-scope limiting, if configured.
645    fn scope_rps(&self) -> Option<u32> {
646        let global_rps = self.config.global_rps?;
647        let pct = self.config.scope_pct?;
648        Some(((global_rps as f64) * (pct as f64 / 100.0)) as u32)
649    }
650
651    /// Returns the effective RPS for a rule, if it has a valid limit.
652    fn rule_rps(&self, rule: &ThroughputRule) -> Option<u32> {
653        let pct_limit = rule.pct.and_then(|p| {
654            self.config
655                .global_rps
656                .map(|g| ((g as f64) * (p as f64 / 100.0)) as u32)
657        });
658
659        match (rule.rps, pct_limit) {
660            (Some(r), Some(p)) => Some(r.min(p)),
661            (Some(r), None) => Some(r),
662            (None, Some(p)) => Some(p),
663            (None, None) => None,
664        }
665    }
666}
667
668/// A token bucket rate limiter.
669///
670/// Tokens refill at a constant rate up to capacity. Each request consumes one token.
671/// When no tokens are available, requests are rejected.
672///
673/// This implementation is not thread-safe on its own. Wrap in a `Mutex` for concurrent access.
674#[derive(Debug)]
675struct TokenBucket {
676    refill_rate: f64,
677    capacity: f64,
678    tokens: f64,
679    last_update: Instant,
680}
681
682impl TokenBucket {
683    /// Creates a new, full token bucket with the specified rate limit and burst capacity.
684    ///
685    /// - `rps`: tokens refilled per second (sustained rate limit)
686    /// - `burst`: initial tokens and burst allowance above sustained rate
687    pub fn new(rps: u32, burst: u32) -> Self {
688        Self {
689            refill_rate: rps as f64,
690            capacity: (rps + burst) as f64,
691            tokens: (rps + burst) as f64,
692            last_update: Instant::now(),
693        }
694    }
695
696    /// Attempts to acquire a token from the bucket.
697    ///
698    /// Returns `true` if a token was acquired, `false` if no tokens available.
699    pub fn try_acquire(&mut self) -> bool {
700        let now = Instant::now();
701        let refill = now.duration_since(self.last_update).as_secs_f64() * self.refill_rate;
702        let refilled = (self.tokens + refill).min(self.capacity);
703
704        // Only apply refill if we'd gain at least 1 whole token
705        if refilled.floor() > self.tokens.floor() {
706            self.last_update = now;
707            self.tokens = refilled;
708        }
709
710        // Try to consume one token
711        if self.tokens >= 1.0 {
712            self.tokens -= 1.0;
713            true
714        } else {
715            false
716        }
717    }
718}
719
720/// A wrapper around a byte stream that measures bandwidth usage.
721///
722/// Every time a chunk is polled successfully, all accumulators are incremented
723/// by its size. Generic over both the stream type `S` and its error type.
724pub(crate) struct MeteredPayloadStream<S> {
725    inner: S,
726    accumulators: Vec<Arc<AtomicU64>>,
727}
728
729impl<S> MeteredPayloadStream<S> {
730    pub fn new(inner: S, accumulators: Vec<Arc<AtomicU64>>) -> Self {
731        Self {
732            inner,
733            accumulators,
734        }
735    }
736}
737
738impl<S> fmt::Debug for MeteredPayloadStream<S> {
739    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
740        f.debug_struct("MeteredPayloadStream")
741            .field("accumulators", &self.accumulators)
742            .finish()
743    }
744}
745
746impl<S, E> Stream for MeteredPayloadStream<S>
747where
748    S: Stream<Item = Result<Bytes, E>> + Unpin,
749{
750    type Item = Result<Bytes, E>;
751
752    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
753        let this = self.get_mut();
754        let res = Pin::new(&mut this.inner).poll_next(cx);
755        if let Poll::Ready(Some(Ok(ref bytes))) = res {
756            let len = bytes.len() as u64;
757            for acc in &this.accumulators {
758                acc.fetch_add(len, std::sync::atomic::Ordering::Relaxed);
759            }
760        }
761        res
762    }
763}
764
765#[cfg(test)]
766mod tests {
767    use objectstore_service::id::ObjectContext;
768    use objectstore_types::scope::{Scope, Scopes};
769
770    use super::*;
771
772    fn make_context() -> ObjectContext {
773        ObjectContext {
774            usecase: "testing".into(),
775            scopes: Scopes::from_iter([Scope::create("org", "1").unwrap()]),
776        }
777    }
778
779    #[test]
780    fn ewma_estimator_update_applies_alpha() {
781        let estimator = EwmaEstimator::new();
782        const TICK: f64 = 0.05; // 50ms
783        let to_rate = 1.0 / TICK;
784        let mut ewma: f64 = 0.0;
785
786        // Simulate 10 events in one 50ms tick → 200 /s raw rate.
787        // After one step: 0.2 * 200 + 0.8 * 0 = 40.
788        estimator
789            .accumulator
790            .store(10, std::sync::atomic::Ordering::Relaxed);
791        estimator.update_ewma(&mut ewma, to_rate);
792        assert_eq!(
793            estimator
794                .estimate
795                .load(std::sync::atomic::Ordering::Relaxed),
796            40
797        );
798
799        // Accumulator must have been zeroed.
800        assert_eq!(
801            estimator
802                .accumulator
803                .load(std::sync::atomic::Ordering::Relaxed),
804            0
805        );
806    }
807
808    #[test]
809    fn throughput_check_increments_accumulator() {
810        let limiter = ThroughputRateLimiter::new(ThroughputLimits {
811            global_rps: Some(1000),
812            ..Default::default()
813        });
814
815        assert_eq!(
816            limiter
817                .global_estimator
818                .accumulator
819                .load(std::sync::atomic::Ordering::Relaxed),
820            0
821        );
822
823        let context = make_context();
824        assert!(limiter.check(&context).is_none());
825        assert!(limiter.check(&context).is_none());
826
827        assert_eq!(
828            limiter
829                .global_estimator
830                .accumulator
831                .load(std::sync::atomic::Ordering::Relaxed),
832            2
833        );
834    }
835
836    #[test]
837    fn throughput_rejected_does_not_increment_accumulator() {
838        let limiter = ThroughputRateLimiter::new(ThroughputLimits {
839            global_rps: Some(1),
840            burst: 0,
841            ..Default::default()
842        });
843
844        let context = make_context();
845        // First call admitted (consumes the one token), second rejected.
846        assert!(limiter.check(&context).is_none());
847        assert!(limiter.check(&context).is_some());
848
849        assert_eq!(
850            limiter
851                .global_estimator
852                .accumulator
853                .load(std::sync::atomic::Ordering::Relaxed),
854            1 // only the admitted request
855        );
856    }
857
858    #[test]
859    fn bandwidth_rejection_does_not_increment_throughput_accumulator() {
860        // global_bps of 1 means the estimate (0 initially) is not > 1, so the first
861        // call passes the bandwidth check. Use 0 to guarantee an immediate reject.
862        // BandwidthRateLimiter::check rejects when estimate > global_bps, so set
863        // global_bps = 0 to make the bandwidth check always reject.
864        let limiter = RateLimiter::new(RateLimits {
865            throughput: ThroughputLimits {
866                global_rps: Some(1000),
867                ..Default::default()
868            },
869            bandwidth: BandwidthLimits {
870                global_bps: Some(0),
871                ..Default::default()
872            },
873        });
874
875        // Prime the bandwidth EWMA so it exceeds the limit.
876        limiter
877            .bandwidth
878            .global
879            .estimate
880            .store(1, std::sync::atomic::Ordering::Relaxed);
881
882        let context = make_context();
883        assert!(!limiter.check(&context));
884
885        // The throughput accumulator must still be 0.
886        assert_eq!(
887            limiter
888                .throughput
889                .global_estimator
890                .accumulator
891                .load(std::sync::atomic::Ordering::Relaxed),
892            0
893        );
894    }
895
896    #[test]
897    fn rate_limiter_accessors_with_config() {
898        let rate_limiter = RateLimiter::new(RateLimits {
899            throughput: ThroughputLimits {
900                global_rps: Some(500),
901                ..Default::default()
902            },
903            bandwidth: BandwidthLimits {
904                global_bps: Some(1_000_000),
905                ..Default::default()
906            },
907        });
908
909        assert_eq!(rate_limiter.bandwidth_limit(), Some(1_000_000));
910        assert_eq!(rate_limiter.throughput_limit(), Some(500));
911        // Estimates start at 0 before any background tick.
912        assert_eq!(rate_limiter.bandwidth_ewma(), 0);
913        assert_eq!(rate_limiter.throughput_rps(), 0);
914    }
915
916    #[test]
917    fn rate_limiter_accessors_no_limits() {
918        let rate_limiter = RateLimiter::new(RateLimits::default());
919
920        assert_eq!(rate_limiter.bandwidth_limit(), None);
921        assert_eq!(rate_limiter.throughput_limit(), None);
922        assert_eq!(rate_limiter.bandwidth_ewma(), 0);
923        assert_eq!(rate_limiter.throughput_rps(), 0);
924    }
925}