Skip to main content

objectstore_server/
rate_limits.rs

1//! Admission-based rate limiting for throughput and bandwidth.
2//!
3//! This module provides [`RateLimiter`], which enforces configurable limits at three
4//! levels of granularity — global, per-usecase, and per-scope — for both request
5//! throughput (requests/s) and upload/download bandwidth (bytes/s).
6//!
7//! Throughput is enforced using token buckets; bandwidth is estimated with an
8//! exponentially weighted moving average (EWMA) updated by a background task every
9//! 50 ms. All rate-limit checks are synchronous and non-blocking.
10
11use std::fmt;
12use std::pin::Pin;
13use std::sync::Arc;
14use std::sync::{Mutex, atomic::AtomicU64};
15use std::task::{Context, Poll};
16use std::time::{Duration, Instant};
17
18use bytes::Bytes;
19use futures_util::Stream;
20use objectstore_service::id::ObjectContext;
21use objectstore_types::scope::Scopes;
22use serde::{Deserialize, Serialize};
23
24/// Identifies which rate limit triggered a rejection.
25#[derive(Debug, Clone, Copy, PartialEq, Eq)]
26enum RateLimitRejection {
27    /// Global bandwidth limit exceeded.
28    BandwidthGlobal,
29    /// Per-usecase bandwidth limit exceeded.
30    BandwidthUsecase,
31    /// Per-scope bandwidth limit exceeded.
32    BandwidthScope,
33    /// Global throughput limit exceeded.
34    ThroughputGlobal,
35    /// Per-usecase throughput limit exceeded.
36    ThroughputUsecase,
37    /// Per-scope throughput limit exceeded.
38    ThroughputScope,
39    /// Per-rule throughput limit exceeded.
40    ThroughputRule,
41}
42
43impl RateLimitRejection {
44    /// Returns a static string identifier suitable for use as a metric tag.
45    pub fn as_str(self) -> &'static str {
46        match self {
47            RateLimitRejection::BandwidthGlobal => "bandwidth_global",
48            RateLimitRejection::BandwidthUsecase => "bandwidth_usecase",
49            RateLimitRejection::BandwidthScope => "bandwidth_scope",
50            RateLimitRejection::ThroughputGlobal => "throughput_global",
51            RateLimitRejection::ThroughputUsecase => "throughput_usecase",
52            RateLimitRejection::ThroughputScope => "throughput_scope",
53            RateLimitRejection::ThroughputRule => "throughput_rule",
54        }
55    }
56}
57
58impl fmt::Display for RateLimitRejection {
59    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
60        f.write_str(self.as_str())
61    }
62}
63
64/// Rate limits for objectstore.
65#[derive(Clone, Debug, Default, Deserialize, Serialize, PartialEq)]
66pub struct RateLimits {
67    /// Limits the number of requests per second per service instance.
68    pub throughput: ThroughputLimits,
69    /// Limits the concurrent bandwidth per service instance.
70    pub bandwidth: BandwidthLimits,
71}
72
73/// Request throughput limits applied at global, per-usecase, and per-scope granularity.
74///
75/// All limits are optional. When a limit is `None`, that level of limiting is not enforced.
76/// Per-usecase and per-scope limits are expressed as a percentage of the global limit.
77#[derive(Clone, Debug, Default, Deserialize, Serialize, PartialEq)]
78pub struct ThroughputLimits {
79    /// The overall maximum number of requests per second per service instance.
80    ///
81    /// Defaults to `None`, meaning no global rate limit is enforced.
82    pub global_rps: Option<u32>,
83
84    /// The maximum burst for each rate limit.
85    ///
86    /// Defaults to `0`, meaning no bursting is allowed. If set to a value greater than `0`,
87    /// short spikes above the rate limit are allowed up to the burst size.
88    pub burst: u32,
89
90    /// The maximum percentage of the global rate limit that can be used by any usecase.
91    ///
92    /// Value from `0` to `100`. Defaults to `None`, meaning no per-usecase limit is enforced.
93    pub usecase_pct: Option<u8>,
94
95    /// The maximum percentage of the global rate limit that can be used by any scope.
96    ///
97    /// This treats each full scope separately and applies across all use cases:
98    ///  - Two requests with exact same scopes count against the same limit regardless of use case.
99    ///  - Two requests that share the same top scope but differ in inner scopes count separately.
100    ///
101    /// Value from `0` to `100`. Defaults to `None`, meaning no per-scope limit is enforced.
102    pub scope_pct: Option<u8>,
103
104    /// Overrides for specific usecases and scopes.
105    pub rules: Vec<ThroughputRule>,
106}
107
108/// An override rule that applies a specific throughput limit to matching request contexts.
109///
110/// A rule matches when all specified fields match the request context. Fields not set match
111/// any value. When multiple rules match, each is enforced independently via its own token
112/// bucket — all matching rules must admit the request.
113#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
114pub struct ThroughputRule {
115    /// Optional usecase to match.
116    ///
117    /// If `None`, matches any usecase.
118    pub usecase: Option<String>,
119
120    /// Scopes to match.
121    ///
122    /// If empty, matches any scopes. Additional scopes in the context are ignored, so a rule
123    /// matches if all of the specified scopes are present in the request with matching values.
124    pub scopes: Vec<(String, String)>,
125
126    /// The rate limit to apply when this rule matches.
127    ///
128    /// If both a rate and pct are specified, the more restrictive limit applies.
129    /// Should be greater than `0`. To block traffic entirely, use killswitches instead.
130    pub rps: Option<u32>,
131
132    /// The percentage of the global rate limit to apply when this rule matches.
133    ///
134    /// If both a rate and pct are specified, the more restrictive limit applies.
135    /// Should be greater than `0`. To block traffic entirely, use killswitches instead.
136    pub pct: Option<u8>,
137}
138
139impl ThroughputRule {
140    /// Returns `true` if this rule matches the given context.
141    pub fn matches(&self, context: &ObjectContext) -> bool {
142        if let Some(ref rule_usecase) = self.usecase
143            && rule_usecase != &context.usecase
144        {
145            return false;
146        }
147
148        for (scope_name, scope_value) in &self.scopes {
149            match context.scopes.get_value(scope_name) {
150                Some(value) if value == scope_value => (),
151                _ => return false,
152            }
153        }
154
155        true
156    }
157}
158
159/// Bandwidth limits applied at global, per-usecase, and per-scope granularity.
160///
161/// Bandwidth is measured as bytes transferred per second (upload + download combined)
162/// and estimated using an EWMA updated every 50 ms. All limits are optional.
163#[derive(Clone, Debug, Default, Deserialize, Serialize, PartialEq)]
164pub struct BandwidthLimits {
165    /// The overall maximum bandwidth (in bytes per second) per service instance.
166    ///
167    /// Defaults to `None`, meaning no global bandwidth limit is enforced.
168    pub global_bps: Option<u64>,
169
170    /// The maximum percentage of the global bandwidth limit that can be used by any usecase.
171    ///
172    /// Value from `0` to `100`. Defaults to `None`, meaning no per-usecase bandwidth limit is enforced.
173    pub usecase_pct: Option<u8>,
174
175    /// The maximum percentage of the global bandwidth limit that can be used by any scope.
176    ///
177    /// Value from `0` to `100`. Defaults to `None`, meaning no per-scope bandwidth limit is enforced.
178    pub scope_pct: Option<u8>,
179}
180
181/// Combined rate limiter that enforces both bandwidth and throughput limits.
182///
183/// Checks are synchronous and non-blocking. Bandwidth is checked before
184/// throughput so that rejected requests are never counted toward the admitted
185/// throughput EWMA. See [`check`](RateLimiter::check) for details.
186///
187/// Call [`start`](RateLimiter::start) after construction to launch the background
188/// estimation tasks. Without it, bandwidth EWMAs remain at zero and bandwidth
189/// limits are never triggered.
190#[derive(Debug)]
191pub struct RateLimiter {
192    bandwidth: BandwidthRateLimiter,
193    throughput: ThroughputRateLimiter,
194}
195
196impl RateLimiter {
197    /// Creates a new rate limiter from the given configuration.
198    ///
199    /// Background estimation tasks are not started until [`start`](RateLimiter::start) is called.
200    pub fn new(config: RateLimits) -> Self {
201        Self {
202            bandwidth: BandwidthRateLimiter::new(config.bandwidth),
203            throughput: ThroughputRateLimiter::new(config.throughput),
204        }
205    }
206
207    /// Starts background tasks for rate limit estimation and monitoring.
208    ///
209    /// Must be called from within a Tokio runtime.
210    pub fn start(&self) {
211        self.bandwidth.start();
212        self.throughput.start();
213    }
214
215    /// Checks if the given context is within the rate limits.
216    ///
217    /// Returns `true` if the request is admitted, `false` if it was rejected. On rejection, emits a
218    /// `server.request.rate_limited` metric counter and a `warn!` log. Bandwidth is checked before
219    /// throughput so that rejected requests are never counted toward admitted traffic.
220    pub fn check(&self, context: &ObjectContext) -> bool {
221        // Bandwidth is checked first because it is a pure read (no token consumption).
222        // Throughput increments the EWMA accumulator only on success, so checking it
223        // second ensures rejected requests are never counted toward admitted traffic.
224        let rejection = self
225            .bandwidth
226            .check(context)
227            .or_else(|| self.throughput.check(context));
228
229        let Some(rejection) = rejection else {
230            return true;
231        };
232
233        objectstore_metrics::count!("server.request.rate_limited", reason = rejection.as_str());
234        objectstore_log::warn!(
235            reason = rejection.as_str(),
236            "Request rejected: rate limit exceeded"
237        );
238        false
239    }
240
241    /// Returns all bandwidth accumulators (global + per-usecase + per-scope) for the given context.
242    ///
243    /// Creates entries in the per-usecase/per-scope maps if they don't exist yet.
244    pub fn bytes_accumulators(&self, context: &ObjectContext) -> Vec<Arc<AtomicU64>> {
245        self.bandwidth.accumulators(context)
246    }
247
248    /// Records bandwidth usage across all accumulators for the given context.
249    ///
250    /// This is used for cases where bytes are known upfront (e.g. batch INSERT) rather than
251    /// streamed through a `MeteredPayloadStream`.
252    pub fn record_bandwidth(&self, context: &ObjectContext, bytes: u64) {
253        for acc in self.bandwidth.accumulators(context) {
254            acc.fetch_add(bytes, std::sync::atomic::Ordering::Relaxed);
255        }
256    }
257
258    /// Returns the current global bandwidth EWMA in bytes per second.
259    pub fn bandwidth_ewma(&self) -> u64 {
260        self.bandwidth
261            .global
262            .estimate
263            .load(std::sync::atomic::Ordering::Relaxed)
264    }
265
266    /// Returns the configured global bandwidth limit in bytes/s, if set.
267    pub fn bandwidth_limit(&self) -> Option<u64> {
268        self.bandwidth.config.global_bps
269    }
270
271    /// Returns the current estimated throughput in requests per second.
272    pub fn throughput_rps(&self) -> u64 {
273        self.throughput
274            .global_estimator
275            .estimate
276            .load(std::sync::atomic::Ordering::Relaxed)
277    }
278
279    /// Returns the configured global throughput limit in requests/s, if set.
280    pub fn throughput_limit(&self) -> Option<u32> {
281        self.throughput.config.global_rps
282    }
283
284    /// Returns total bytes transferred since startup.
285    pub fn bandwidth_total_bytes(&self) -> u64 {
286        self.bandwidth
287            .total_bytes
288            .load(std::sync::atomic::Ordering::Relaxed)
289    }
290
291    /// Returns total admitted requests since startup.
292    pub fn throughput_total_admitted(&self) -> u64 {
293        self.throughput
294            .total_admitted
295            .load(std::sync::atomic::Ordering::Relaxed)
296    }
297}
298
299/// Shared EWMA estimator state.
300///
301/// The accumulator is incremented as events occur, and the estimate is updated
302/// periodically by a background task using an exponentially weighted moving average.
303#[derive(Debug)]
304struct EwmaEstimator {
305    accumulator: Arc<AtomicU64>,
306    estimate: Arc<AtomicU64>,
307}
308
309impl EwmaEstimator {
310    fn new() -> Self {
311        Self {
312            accumulator: Arc::new(AtomicU64::new(0)),
313            estimate: Arc::new(AtomicU64::new(0)),
314        }
315    }
316
317    /// Updates the EWMA from the accumulator.
318    ///
319    /// Swaps the accumulator to zero, converts the count to a per-second rate using
320    /// `to_rate` (i.e., `1.0 / tick_duration.as_secs_f64()`), then applies the
321    /// EWMA smoothing and stores the floored result in `estimate`.
322    fn update_ewma(&self, ewma: &mut f64, to_rate: f64) {
323        const ALPHA: f64 = 0.2;
324        let current = self
325            .accumulator
326            .swap(0, std::sync::atomic::Ordering::Relaxed);
327        let rate = (current as f64) * to_rate;
328        *ewma = ALPHA * rate + (1.0 - ALPHA) * *ewma;
329        self.estimate
330            .store(ewma.floor() as u64, std::sync::atomic::Ordering::Relaxed);
331    }
332}
333
334#[derive(Debug)]
335struct BandwidthRateLimiter {
336    config: BandwidthLimits,
337    /// Global accumulator/estimator pair.
338    global: Arc<EwmaEstimator>,
339    /// Cumulative bytes transferred since startup. Never reset.
340    total_bytes: Arc<AtomicU64>,
341    // NB: These maps grow unbounded but we accept this as we expect an overall limited
342    // number of usecases and scopes. We emit gauge metrics to monitor their size.
343    usecases: Arc<papaya::HashMap<String, Arc<EwmaEstimator>>>,
344    scopes: Arc<papaya::HashMap<Scopes, Arc<EwmaEstimator>>>,
345}
346
347impl BandwidthRateLimiter {
348    fn new(config: BandwidthLimits) -> Self {
349        Self {
350            config,
351            global: Arc::new(EwmaEstimator::new()),
352            total_bytes: Arc::new(AtomicU64::new(0)),
353            usecases: Arc::new(papaya::HashMap::new()),
354            scopes: Arc::new(papaya::HashMap::new()),
355        }
356    }
357
358    fn start(&self) {
359        let global = Arc::clone(&self.global);
360        let usecases = Arc::clone(&self.usecases);
361        let scopes = Arc::clone(&self.scopes);
362        let global_limit = self.config.global_bps;
363        // NB: This task has no shutdown mechanism — the rate limiter is only created once.
364        // The task is aborted when the Tokio runtime is dropped on process exit.
365        tokio::task::spawn(async move {
366            Self::estimator(global, usecases, scopes, global_limit).await;
367        });
368    }
369
370    /// Estimates the current bandwidth utilization using an exponentially weighted moving average.
371    ///
372    /// Iterates over the global estimator as well as all per-usecase and per-scope estimators
373    /// on each tick, updating their EWMAs.
374    async fn estimator(
375        global: Arc<EwmaEstimator>,
376        usecases: Arc<papaya::HashMap<String, Arc<EwmaEstimator>>>,
377        scopes: Arc<papaya::HashMap<Scopes, Arc<EwmaEstimator>>>,
378        global_limit: Option<u64>,
379    ) {
380        const TICK: Duration = Duration::from_millis(50); // Recompute EWMA on every TICK
381
382        let mut interval = tokio::time::interval(TICK);
383        interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
384        // The first tick of a tokio interval fires immediately. Consume it so the
385        // first real iteration has a full ~50ms of elapsed time.
386        interval.tick().await;
387        let mut last = Instant::now();
388        let mut global_ewma: f64 = 0.0;
389        // Shadow EWMAs for per-usecase/per-scope entries, keyed the same way as the maps.
390        let mut usecase_ewmas: std::collections::HashMap<String, f64> =
391            std::collections::HashMap::new();
392        let mut scope_ewmas: std::collections::HashMap<Scopes, f64> =
393            std::collections::HashMap::new();
394
395        loop {
396            interval.tick().await;
397
398            let now = Instant::now();
399            let to_bps = 1.0 / now.duration_since(last).as_secs_f64();
400            last = now;
401
402            // Global
403            global.update_ewma(&mut global_ewma, to_bps);
404            objectstore_metrics::gauge!("server.bandwidth.ewma" = global_ewma.floor() as u64);
405            if let Some(limit) = global_limit {
406                objectstore_metrics::gauge!("server.bandwidth.limit" = limit);
407            }
408
409            // Per-usecase
410            {
411                let guard = usecases.pin();
412                for (key, estimator) in guard.iter() {
413                    let ewma = usecase_ewmas.entry(key.clone()).or_insert(0.0);
414                    estimator.update_ewma(ewma, to_bps);
415                }
416            }
417
418            // Per-scope
419            {
420                let guard = scopes.pin();
421                for (key, estimator) in guard.iter() {
422                    let ewma = scope_ewmas.entry(key.clone()).or_insert(0.0);
423                    estimator.update_ewma(ewma, to_bps);
424                }
425            }
426
427            objectstore_metrics::gauge!(
428                "server.rate_limiter.bandwidth.scope_map_size" = scopes.len()
429            );
430            objectstore_metrics::gauge!(
431                "server.rate_limiter.bandwidth.usecase_map_size" = usecases.len()
432            );
433        }
434    }
435
436    fn check(&self, context: &ObjectContext) -> Option<RateLimitRejection> {
437        let global_bps = self.config.global_bps?;
438
439        // Global check
440        if self
441            .global
442            .estimate
443            .load(std::sync::atomic::Ordering::Relaxed)
444            > global_bps
445        {
446            return Some(RateLimitRejection::BandwidthGlobal);
447        }
448
449        // Per-usecase check
450        if let Some(usecase_bps) = self.usecase_bps() {
451            let guard = self.usecases.pin();
452            if let Some(estimator) = guard.get(&context.usecase)
453                && estimator
454                    .estimate
455                    .load(std::sync::atomic::Ordering::Relaxed)
456                    > usecase_bps
457            {
458                return Some(RateLimitRejection::BandwidthUsecase);
459            }
460        }
461
462        // Per-scope check
463        if let Some(scope_bps) = self.scope_bps() {
464            let guard = self.scopes.pin();
465            if let Some(estimator) = guard.get(&context.scopes)
466                && estimator
467                    .estimate
468                    .load(std::sync::atomic::Ordering::Relaxed)
469                    > scope_bps
470            {
471                return Some(RateLimitRejection::BandwidthScope);
472            }
473        }
474
475        None
476    }
477
478    /// Returns all accumulators (global + per-usecase + per-scope) for the given context.
479    ///
480    /// Creates entries in the per-usecase/per-scope maps if they don't exist yet.
481    /// Always includes `total_bytes` (cumulative, never reset) as the first entry.
482    fn accumulators(&self, context: &ObjectContext) -> Vec<Arc<AtomicU64>> {
483        let mut accs = vec![
484            Arc::clone(&self.total_bytes),
485            Arc::clone(&self.global.accumulator),
486        ];
487
488        if self.usecase_bps().is_some() {
489            let guard = self.usecases.pin();
490            let estimator = guard
491                .get_or_insert_with(context.usecase.clone(), || Arc::new(EwmaEstimator::new()));
492            accs.push(Arc::clone(&estimator.accumulator));
493        }
494
495        if self.scope_bps().is_some() {
496            let guard = self.scopes.pin();
497            let estimator =
498                guard.get_or_insert_with(context.scopes.clone(), || Arc::new(EwmaEstimator::new()));
499            accs.push(Arc::clone(&estimator.accumulator));
500        }
501
502        accs
503    }
504
505    /// Returns the effective BPS for per-usecase limiting, if configured.
506    fn usecase_bps(&self) -> Option<u64> {
507        let global_bps = self.config.global_bps?;
508        let pct = self.config.usecase_pct?;
509        Some(((global_bps as f64) * (pct as f64 / 100.0)) as u64)
510    }
511
512    /// Returns the effective BPS for per-scope limiting, if configured.
513    fn scope_bps(&self) -> Option<u64> {
514        let global_bps = self.config.global_bps?;
515        let pct = self.config.scope_pct?;
516        Some(((global_bps as f64) * (pct as f64 / 100.0)) as u64)
517    }
518}
519
520#[derive(Debug)]
521struct ThroughputRateLimiter {
522    config: ThroughputLimits,
523    global: Option<Mutex<TokenBucket>>,
524    /// Global EWMA estimator for admitted request rate.
525    global_estimator: Arc<EwmaEstimator>,
526    /// Cumulative admitted requests since startup. Never reset.
527    total_admitted: Arc<AtomicU64>,
528    // NB: These maps grow unbounded but we accept this as we expect an overall limited
529    // number of usecases and scopes. We emit gauge metrics to monitor their size.
530    usecases: Arc<papaya::HashMap<String, Mutex<TokenBucket>>>,
531    scopes: Arc<papaya::HashMap<Scopes, Mutex<TokenBucket>>>,
532    rules: papaya::HashMap<usize, Mutex<TokenBucket>>,
533}
534
535impl ThroughputRateLimiter {
536    fn new(config: ThroughputLimits) -> Self {
537        let global = config
538            .global_rps
539            .map(|rps| Mutex::new(TokenBucket::new(rps, config.burst)));
540
541        Self {
542            config,
543            global,
544            global_estimator: Arc::new(EwmaEstimator::new()),
545            total_admitted: Arc::new(AtomicU64::new(0)),
546            usecases: Arc::new(papaya::HashMap::new()),
547            scopes: Arc::new(papaya::HashMap::new()),
548            rules: papaya::HashMap::new(),
549        }
550    }
551
552    fn start(&self) {
553        let usecases = Arc::clone(&self.usecases);
554        let scopes = Arc::clone(&self.scopes);
555        let global_estimator = Arc::clone(&self.global_estimator);
556        let global_limit = self.config.global_rps;
557        // NB: This task has no shutdown mechanism — the rate limiter is only created once.
558        // The task is aborted when the Tokio runtime is dropped on process exit.
559        tokio::task::spawn(async move {
560            const TICK: Duration = Duration::from_millis(50);
561            let mut interval = tokio::time::interval(TICK);
562            interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
563            interval.tick().await;
564            let mut last = Instant::now();
565            let mut global_ewma: f64 = 0.0;
566            loop {
567                interval.tick().await;
568                let now = Instant::now();
569                let to_rps = 1.0 / now.duration_since(last).as_secs_f64();
570                last = now;
571                global_estimator.update_ewma(&mut global_ewma, to_rps);
572                objectstore_metrics::gauge!("server.throughput.ewma" = global_ewma.floor() as u64);
573                if let Some(limit) = global_limit {
574                    objectstore_metrics::gauge!(
575                        "server.rate_limiter.throughput.limit" = u64::from(limit)
576                    );
577                }
578                objectstore_metrics::gauge!(
579                    "server.rate_limiter.throughput.scope_map_size" = scopes.len()
580                );
581                objectstore_metrics::gauge!(
582                    "server.rate_limiter.throughput.usecase_map_size" = usecases.len()
583                );
584            }
585        });
586    }
587
588    fn check(&self, context: &ObjectContext) -> Option<RateLimitRejection> {
589        // NB: We intentionally use unwrap and crash the server if the mutexes are poisoned.
590
591        // Global check
592        if let Some(ref global) = self.global {
593            let acquired = global.lock().unwrap().try_acquire();
594            if !acquired {
595                return Some(RateLimitRejection::ThroughputGlobal);
596            }
597        }
598
599        // Usecase check - only if both global_rps and usecase_pct are set
600        if let Some(usecase_rps) = self.usecase_rps() {
601            let guard = self.usecases.pin();
602            let bucket = guard
603                .get_or_insert_with(context.usecase.clone(), || self.create_bucket(usecase_rps));
604            if !bucket.lock().unwrap().try_acquire() {
605                return Some(RateLimitRejection::ThroughputUsecase);
606            }
607        }
608
609        // Scope check - only if both global_rps and scope_pct are set
610        if let Some(scope_rps) = self.scope_rps() {
611            let guard = self.scopes.pin();
612            let bucket =
613                guard.get_or_insert_with(context.scopes.clone(), || self.create_bucket(scope_rps));
614            if !bucket.lock().unwrap().try_acquire() {
615                return Some(RateLimitRejection::ThroughputScope);
616            }
617        }
618
619        // Rule checks - each matching rule has its own dedicated bucket
620        for (idx, rule) in self.config.rules.iter().enumerate() {
621            if !rule.matches(context) {
622                continue;
623            }
624            let Some(rule_rps) = self.rule_rps(rule) else {
625                continue;
626            };
627            let guard = self.rules.pin();
628            let bucket = guard.get_or_insert_with(idx, || self.create_bucket(rule_rps));
629            if !bucket.lock().unwrap().try_acquire() {
630                return Some(RateLimitRejection::ThroughputRule);
631            }
632        }
633
634        // Count this admitted request in the EWMA accumulator and the cumulative counter.
635        // NB: u64 wrapping is not a practical concern — at 1M rps it takes ~585k years.
636        // Prometheus irate() also handles counter resets gracefully should it ever occur.
637        self.global_estimator
638            .accumulator
639            .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
640        self.total_admitted
641            .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
642
643        None
644    }
645
646    fn create_bucket(&self, rps: u32) -> Mutex<TokenBucket> {
647        Mutex::new(TokenBucket::new(rps, self.config.burst))
648    }
649
650    /// Returns the effective RPS for per-usecase limiting, if configured.
651    fn usecase_rps(&self) -> Option<u32> {
652        let global_rps = self.config.global_rps?;
653        let pct = self.config.usecase_pct?;
654        Some(((global_rps as f64) * (pct as f64 / 100.0)) as u32)
655    }
656
657    /// Returns the effective RPS for per-scope limiting, if configured.
658    fn scope_rps(&self) -> Option<u32> {
659        let global_rps = self.config.global_rps?;
660        let pct = self.config.scope_pct?;
661        Some(((global_rps as f64) * (pct as f64 / 100.0)) as u32)
662    }
663
664    /// Returns the effective RPS for a rule, if it has a valid limit.
665    fn rule_rps(&self, rule: &ThroughputRule) -> Option<u32> {
666        let pct_limit = rule.pct.and_then(|p| {
667            self.config
668                .global_rps
669                .map(|g| ((g as f64) * (p as f64 / 100.0)) as u32)
670        });
671
672        match (rule.rps, pct_limit) {
673            (Some(r), Some(p)) => Some(r.min(p)),
674            (Some(r), None) => Some(r),
675            (None, Some(p)) => Some(p),
676            (None, None) => None,
677        }
678    }
679}
680
681/// A token bucket rate limiter.
682///
683/// Tokens refill at a constant rate up to capacity. Each request consumes one token.
684/// When no tokens are available, requests are rejected.
685///
686/// This implementation is not thread-safe on its own. Wrap in a `Mutex` for concurrent access.
687#[derive(Debug)]
688struct TokenBucket {
689    refill_rate: f64,
690    capacity: f64,
691    tokens: f64,
692    last_update: Instant,
693}
694
695impl TokenBucket {
696    /// Creates a new, full token bucket with the specified rate limit and burst capacity.
697    ///
698    /// - `rps`: tokens refilled per second (sustained rate limit)
699    /// - `burst`: initial tokens and burst allowance above sustained rate
700    pub fn new(rps: u32, burst: u32) -> Self {
701        Self {
702            refill_rate: rps as f64,
703            capacity: (rps + burst) as f64,
704            tokens: (rps + burst) as f64,
705            last_update: Instant::now(),
706        }
707    }
708
709    /// Attempts to acquire a token from the bucket.
710    ///
711    /// Returns `true` if a token was acquired, `false` if no tokens available.
712    pub fn try_acquire(&mut self) -> bool {
713        let now = Instant::now();
714        let refill = now.duration_since(self.last_update).as_secs_f64() * self.refill_rate;
715        let refilled = (self.tokens + refill).min(self.capacity);
716
717        // Only apply refill if we'd gain at least 1 whole token
718        if refilled.floor() > self.tokens.floor() {
719            self.last_update = now;
720            self.tokens = refilled;
721        }
722
723        // Try to consume one token
724        if self.tokens >= 1.0 {
725            self.tokens -= 1.0;
726            true
727        } else {
728            false
729        }
730    }
731}
732
733/// A wrapper around a byte stream that measures bandwidth usage.
734///
735/// Every time a chunk is polled successfully, all accumulators are incremented
736/// by its size. Generic over both the stream type `S` and its error type.
737pub(crate) struct MeteredPayloadStream<S> {
738    inner: S,
739    accumulators: Vec<Arc<AtomicU64>>,
740}
741
742impl<S> MeteredPayloadStream<S> {
743    pub fn new(inner: S, accumulators: Vec<Arc<AtomicU64>>) -> Self {
744        Self {
745            inner,
746            accumulators,
747        }
748    }
749}
750
751impl<S> fmt::Debug for MeteredPayloadStream<S> {
752    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
753        f.debug_struct("MeteredPayloadStream")
754            .field("accumulators", &self.accumulators)
755            .finish()
756    }
757}
758
759impl<S, E> Stream for MeteredPayloadStream<S>
760where
761    S: Stream<Item = Result<Bytes, E>> + Unpin,
762{
763    type Item = Result<Bytes, E>;
764
765    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
766        let this = self.get_mut();
767        let res = Pin::new(&mut this.inner).poll_next(cx);
768        if let Poll::Ready(Some(Ok(ref bytes))) = res {
769            let len = bytes.len() as u64;
770            for acc in &this.accumulators {
771                acc.fetch_add(len, std::sync::atomic::Ordering::Relaxed);
772            }
773        }
774        res
775    }
776}
777
778#[cfg(test)]
779mod tests {
780    use objectstore_service::id::ObjectContext;
781    use objectstore_types::scope::{Scope, Scopes};
782
783    use super::*;
784
785    fn make_context() -> ObjectContext {
786        ObjectContext {
787            usecase: "testing".into(),
788            scopes: Scopes::from_iter([Scope::create("org", "1").unwrap()]),
789        }
790    }
791
792    #[test]
793    fn ewma_estimator_update_applies_alpha() {
794        let estimator = EwmaEstimator::new();
795        const TICK: f64 = 0.05; // 50ms
796        let to_rate = 1.0 / TICK;
797        let mut ewma: f64 = 0.0;
798
799        // Simulate 10 events in one 50ms tick → 200 /s raw rate.
800        // After one step: 0.2 * 200 + 0.8 * 0 = 40.
801        estimator
802            .accumulator
803            .store(10, std::sync::atomic::Ordering::Relaxed);
804        estimator.update_ewma(&mut ewma, to_rate);
805        assert_eq!(
806            estimator
807                .estimate
808                .load(std::sync::atomic::Ordering::Relaxed),
809            40
810        );
811
812        // Accumulator must have been zeroed.
813        assert_eq!(
814            estimator
815                .accumulator
816                .load(std::sync::atomic::Ordering::Relaxed),
817            0
818        );
819    }
820
821    #[test]
822    fn throughput_check_increments_accumulator() {
823        let limiter = ThroughputRateLimiter::new(ThroughputLimits {
824            global_rps: Some(1000),
825            ..Default::default()
826        });
827
828        assert_eq!(
829            limiter
830                .global_estimator
831                .accumulator
832                .load(std::sync::atomic::Ordering::Relaxed),
833            0
834        );
835
836        let context = make_context();
837        assert!(limiter.check(&context).is_none());
838        assert!(limiter.check(&context).is_none());
839
840        assert_eq!(
841            limiter
842                .global_estimator
843                .accumulator
844                .load(std::sync::atomic::Ordering::Relaxed),
845            2
846        );
847    }
848
849    #[test]
850    fn throughput_rejected_does_not_increment_accumulator() {
851        let limiter = ThroughputRateLimiter::new(ThroughputLimits {
852            global_rps: Some(1),
853            burst: 0,
854            ..Default::default()
855        });
856
857        let context = make_context();
858        // First call admitted (consumes the one token), second rejected.
859        assert!(limiter.check(&context).is_none());
860        assert!(limiter.check(&context).is_some());
861
862        assert_eq!(
863            limiter
864                .global_estimator
865                .accumulator
866                .load(std::sync::atomic::Ordering::Relaxed),
867            1 // only the admitted request
868        );
869    }
870
871    #[test]
872    fn bandwidth_rejection_does_not_increment_throughput_accumulator() {
873        // global_bps of 1 means the estimate (0 initially) is not > 1, so the first
874        // call passes the bandwidth check. Use 0 to guarantee an immediate reject.
875        // BandwidthRateLimiter::check rejects when estimate > global_bps, so set
876        // global_bps = 0 to make the bandwidth check always reject.
877        let limiter = RateLimiter::new(RateLimits {
878            throughput: ThroughputLimits {
879                global_rps: Some(1000),
880                ..Default::default()
881            },
882            bandwidth: BandwidthLimits {
883                global_bps: Some(0),
884                ..Default::default()
885            },
886        });
887
888        // Prime the bandwidth EWMA so it exceeds the limit.
889        limiter
890            .bandwidth
891            .global
892            .estimate
893            .store(1, std::sync::atomic::Ordering::Relaxed);
894
895        let context = make_context();
896        assert!(!limiter.check(&context));
897
898        // The throughput accumulator must still be 0.
899        assert_eq!(
900            limiter
901                .throughput
902                .global_estimator
903                .accumulator
904                .load(std::sync::atomic::Ordering::Relaxed),
905            0
906        );
907    }
908
909    #[test]
910    fn rate_limiter_accessors_with_config() {
911        let rate_limiter = RateLimiter::new(RateLimits {
912            throughput: ThroughputLimits {
913                global_rps: Some(500),
914                ..Default::default()
915            },
916            bandwidth: BandwidthLimits {
917                global_bps: Some(1_000_000),
918                ..Default::default()
919            },
920        });
921
922        assert_eq!(rate_limiter.bandwidth_limit(), Some(1_000_000));
923        assert_eq!(rate_limiter.throughput_limit(), Some(500));
924        // Estimates start at 0 before any background tick.
925        assert_eq!(rate_limiter.bandwidth_ewma(), 0);
926        assert_eq!(rate_limiter.throughput_rps(), 0);
927    }
928
929    #[test]
930    fn rate_limiter_accessors_no_limits() {
931        let rate_limiter = RateLimiter::new(RateLimits::default());
932
933        assert_eq!(rate_limiter.bandwidth_limit(), None);
934        assert_eq!(rate_limiter.throughput_limit(), None);
935        assert_eq!(rate_limiter.bandwidth_ewma(), 0);
936        assert_eq!(rate_limiter.throughput_rps(), 0);
937    }
938}