objectstore_server/
rate_limits.rs

1//! Admission-based rate limiting for throughput and bandwidth.
2//!
3//! This module provides [`RateLimiter`], which enforces configurable limits at three
4//! levels of granularity — global, per-usecase, and per-scope — for both request
5//! throughput (requests/s) and upload/download bandwidth (bytes/s).
6//!
7//! Throughput is enforced using token buckets; bandwidth is estimated with an
8//! exponentially weighted moving average (EWMA) updated by a background task every
9//! 50 ms. All rate-limit checks are synchronous and non-blocking.
10
11use std::pin::Pin;
12use std::sync::Arc;
13use std::sync::Mutex;
14use std::sync::atomic::AtomicU64;
15use std::task::{Context, Poll};
16use std::time::{Duration, Instant};
17
18use bytes::Bytes;
19use futures_util::Stream;
20use objectstore_service::PayloadStream;
21use objectstore_service::id::ObjectContext;
22use objectstore_types::scope::Scopes;
23use serde::{Deserialize, Serialize};
24
25/// Rate limits for objectstore.
26#[derive(Clone, Debug, Default, Deserialize, Serialize, PartialEq)]
27pub struct RateLimits {
28    /// Limits the number of requests per second per service instance.
29    pub throughput: ThroughputLimits,
30    /// Limits the concurrent bandwidth per service instance.
31    pub bandwidth: BandwidthLimits,
32}
33
34/// Request throughput limits applied at global, per-usecase, and per-scope granularity.
35///
36/// All limits are optional. When a limit is `None`, that level of limiting is not enforced.
37/// Per-usecase and per-scope limits are expressed as a percentage of the global limit.
38#[derive(Clone, Debug, Default, Deserialize, Serialize, PartialEq)]
39pub struct ThroughputLimits {
40    /// The overall maximum number of requests per second per service instance.
41    ///
42    /// Defaults to `None`, meaning no global rate limit is enforced.
43    pub global_rps: Option<u32>,
44
45    /// The maximum burst for each rate limit.
46    ///
47    /// Defaults to `0`, meaning no bursting is allowed. If set to a value greater than `0`,
48    /// short spikes above the rate limit are allowed up to the burst size.
49    pub burst: u32,
50
51    /// The maximum percentage of the global rate limit that can be used by any usecase.
52    ///
53    /// Value from `0` to `100`. Defaults to `None`, meaning no per-usecase limit is enforced.
54    pub usecase_pct: Option<u8>,
55
56    /// The maximum percentage of the global rate limit that can be used by any scope.
57    ///
58    /// This treats each full scope separately and applies across all use cases:
59    ///  - Two requests with exact same scopes count against the same limit regardless of use case.
60    ///  - Two requests that share the same top scope but differ in inner scopes count separately.
61    ///
62    /// Value from `0` to `100`. Defaults to `None`, meaning no per-scope limit is enforced.
63    pub scope_pct: Option<u8>,
64
65    /// Overrides for specific usecases and scopes.
66    pub rules: Vec<ThroughputRule>,
67}
68
69/// An override rule that applies a specific throughput limit to matching request contexts.
70///
71/// A rule matches when all specified fields match the request context. Fields not set match
72/// any value. When multiple rules match, each is enforced independently via its own token
73/// bucket — all matching rules must admit the request.
74#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
75pub struct ThroughputRule {
76    /// Optional usecase to match.
77    ///
78    /// If `None`, matches any usecase.
79    pub usecase: Option<String>,
80
81    /// Scopes to match.
82    ///
83    /// If empty, matches any scopes. Additional scopes in the context are ignored, so a rule
84    /// matches if all of the specified scopes are present in the request with matching values.
85    pub scopes: Vec<(String, String)>,
86
87    /// The rate limit to apply when this rule matches.
88    ///
89    /// If both a rate and pct are specified, the more restrictive limit applies.
90    /// Should be greater than `0`. To block traffic entirely, use killswitches instead.
91    pub rps: Option<u32>,
92
93    /// The percentage of the global rate limit to apply when this rule matches.
94    ///
95    /// If both a rate and pct are specified, the more restrictive limit applies.
96    /// Should be greater than `0`. To block traffic entirely, use killswitches instead.
97    pub pct: Option<u8>,
98}
99
100impl ThroughputRule {
101    /// Returns `true` if this rule matches the given context.
102    pub fn matches(&self, context: &ObjectContext) -> bool {
103        if let Some(ref rule_usecase) = self.usecase
104            && rule_usecase != &context.usecase
105        {
106            return false;
107        }
108
109        for (scope_name, scope_value) in &self.scopes {
110            match context.scopes.get_value(scope_name) {
111                Some(value) if value == scope_value => (),
112                _ => return false,
113            }
114        }
115
116        true
117    }
118}
119
120/// Bandwidth limits applied at global, per-usecase, and per-scope granularity.
121///
122/// Bandwidth is measured as bytes transferred per second (upload + download combined)
123/// and estimated using an EWMA updated every 50 ms. All limits are optional.
124#[derive(Clone, Debug, Default, Deserialize, Serialize, PartialEq)]
125pub struct BandwidthLimits {
126    /// The overall maximum bandwidth (in bytes per second) per service instance.
127    ///
128    /// Defaults to `None`, meaning no global bandwidth limit is enforced.
129    pub global_bps: Option<u64>,
130
131    /// The maximum percentage of the global bandwidth limit that can be used by any usecase.
132    ///
133    /// Value from `0` to `100`. Defaults to `None`, meaning no per-usecase bandwidth limit is enforced.
134    pub usecase_pct: Option<u8>,
135
136    /// The maximum percentage of the global bandwidth limit that can be used by any scope.
137    ///
138    /// Value from `0` to `100`. Defaults to `None`, meaning no per-scope bandwidth limit is enforced.
139    pub scope_pct: Option<u8>,
140}
141
142/// Combined rate limiter that enforces both bandwidth and throughput limits.
143///
144/// Checks are synchronous and non-blocking. Bandwidth is checked before
145/// throughput so that rejected requests are never counted toward the admitted
146/// throughput EWMA. See [`check`](RateLimiter::check) for details.
147///
148/// Call [`start`](RateLimiter::start) after construction to launch the background
149/// estimation tasks. Without it, bandwidth EWMAs remain at zero and bandwidth
150/// limits are never triggered.
151#[derive(Debug)]
152pub struct RateLimiter {
153    bandwidth: BandwidthRateLimiter,
154    throughput: ThroughputRateLimiter,
155}
156
157impl RateLimiter {
158    /// Creates a new rate limiter from the given configuration.
159    ///
160    /// Background estimation tasks are not started until [`start`](RateLimiter::start) is called.
161    pub fn new(config: RateLimits) -> Self {
162        Self {
163            bandwidth: BandwidthRateLimiter::new(config.bandwidth),
164            throughput: ThroughputRateLimiter::new(config.throughput),
165        }
166    }
167
168    /// Starts background tasks for rate limit estimation and monitoring.
169    ///
170    /// Must be called from within a Tokio runtime.
171    pub fn start(&self) {
172        self.bandwidth.start();
173        self.throughput.start();
174    }
175
176    /// Checks if the given context is within the rate limits.
177    ///
178    /// Returns `true` if the context is within the rate limits, `false` otherwise.
179    pub fn check(&self, context: &ObjectContext) -> bool {
180        // Bandwidth is checked first because it is a pure read (no token consumption).
181        // Throughput increments the EWMA accumulator only on success, so checking it
182        // second ensures rejected requests are never counted toward admitted traffic.
183        self.bandwidth.check(context) && self.throughput.check(context)
184    }
185
186    /// Returns all bandwidth accumulators (global + per-usecase + per-scope) for the given context.
187    ///
188    /// Creates entries in the per-usecase/per-scope maps if they don't exist yet.
189    pub fn bytes_accumulators(&self, context: &ObjectContext) -> Vec<Arc<AtomicU64>> {
190        self.bandwidth.accumulators(context)
191    }
192
193    /// Records bandwidth usage across all accumulators for the given context.
194    ///
195    /// This is used for cases where bytes are known upfront (e.g. batch INSERT) rather than
196    /// streamed through a `MeteredPayloadStream`.
197    pub fn record_bandwidth(&self, context: &ObjectContext, bytes: u64) {
198        for acc in self.bandwidth.accumulators(context) {
199            acc.fetch_add(bytes, std::sync::atomic::Ordering::Relaxed);
200        }
201    }
202
203    /// Returns the current global bandwidth EWMA in bytes per second.
204    pub fn bandwidth_ewma(&self) -> u64 {
205        self.bandwidth
206            .global
207            .estimate
208            .load(std::sync::atomic::Ordering::Relaxed)
209    }
210
211    /// Returns the configured global bandwidth limit in bytes/s, if set.
212    pub fn bandwidth_limit(&self) -> Option<u64> {
213        self.bandwidth.config.global_bps
214    }
215
216    /// Returns the current estimated throughput in requests per second.
217    pub fn throughput_rps(&self) -> u64 {
218        self.throughput
219            .global_estimator
220            .estimate
221            .load(std::sync::atomic::Ordering::Relaxed)
222    }
223
224    /// Returns the configured global throughput limit in requests/s, if set.
225    pub fn throughput_limit(&self) -> Option<u32> {
226        self.throughput.config.global_rps
227    }
228
229    /// Returns total bytes transferred since startup.
230    pub fn bandwidth_total_bytes(&self) -> u64 {
231        self.bandwidth
232            .total_bytes
233            .load(std::sync::atomic::Ordering::Relaxed)
234    }
235
236    /// Returns total admitted requests since startup.
237    pub fn throughput_total_admitted(&self) -> u64 {
238        self.throughput
239            .total_admitted
240            .load(std::sync::atomic::Ordering::Relaxed)
241    }
242}
243
244/// Shared EWMA estimator state.
245///
246/// The accumulator is incremented as events occur, and the estimate is updated
247/// periodically by a background task using an exponentially weighted moving average.
248#[derive(Debug)]
249struct EwmaEstimator {
250    accumulator: Arc<AtomicU64>,
251    estimate: Arc<AtomicU64>,
252}
253
254impl EwmaEstimator {
255    fn new() -> Self {
256        Self {
257            accumulator: Arc::new(AtomicU64::new(0)),
258            estimate: Arc::new(AtomicU64::new(0)),
259        }
260    }
261
262    /// Updates the EWMA from the accumulator.
263    ///
264    /// Swaps the accumulator to zero, converts the count to a per-second rate using
265    /// `to_rate` (i.e., `1.0 / tick_duration.as_secs_f64()`), then applies the
266    /// EWMA smoothing and stores the floored result in `estimate`.
267    fn update_ewma(&self, ewma: &mut f64, to_rate: f64) {
268        const ALPHA: f64 = 0.2;
269        let current = self
270            .accumulator
271            .swap(0, std::sync::atomic::Ordering::Relaxed);
272        let rate = (current as f64) * to_rate;
273        *ewma = ALPHA * rate + (1.0 - ALPHA) * *ewma;
274        self.estimate
275            .store(ewma.floor() as u64, std::sync::atomic::Ordering::Relaxed);
276    }
277}
278
279#[derive(Debug)]
280struct BandwidthRateLimiter {
281    config: BandwidthLimits,
282    /// Global accumulator/estimator pair.
283    global: Arc<EwmaEstimator>,
284    /// Cumulative bytes transferred since startup. Never reset.
285    total_bytes: Arc<AtomicU64>,
286    // NB: These maps grow unbounded but we accept this as we expect an overall limited
287    // number of usecases and scopes. We emit gauge metrics to monitor their size.
288    usecases: Arc<papaya::HashMap<String, Arc<EwmaEstimator>>>,
289    scopes: Arc<papaya::HashMap<Scopes, Arc<EwmaEstimator>>>,
290}
291
292impl BandwidthRateLimiter {
293    fn new(config: BandwidthLimits) -> Self {
294        Self {
295            config,
296            global: Arc::new(EwmaEstimator::new()),
297            total_bytes: Arc::new(AtomicU64::new(0)),
298            usecases: Arc::new(papaya::HashMap::new()),
299            scopes: Arc::new(papaya::HashMap::new()),
300        }
301    }
302
303    fn start(&self) {
304        let global = Arc::clone(&self.global);
305        let usecases = Arc::clone(&self.usecases);
306        let scopes = Arc::clone(&self.scopes);
307        let global_limit = self.config.global_bps;
308        // NB: This task has no shutdown mechanism — the rate limiter is only created once.
309        // The task is aborted when the Tokio runtime is dropped on process exit.
310        tokio::task::spawn(async move {
311            Self::estimator(global, usecases, scopes, global_limit).await;
312        });
313    }
314
315    /// Estimates the current bandwidth utilization using an exponentially weighted moving average.
316    ///
317    /// Iterates over the global estimator as well as all per-usecase and per-scope estimators
318    /// on each tick, updating their EWMAs.
319    async fn estimator(
320        global: Arc<EwmaEstimator>,
321        usecases: Arc<papaya::HashMap<String, Arc<EwmaEstimator>>>,
322        scopes: Arc<papaya::HashMap<Scopes, Arc<EwmaEstimator>>>,
323        global_limit: Option<u64>,
324    ) {
325        const TICK: Duration = Duration::from_millis(50); // Recompute EWMA on every TICK
326
327        let mut interval = tokio::time::interval(TICK);
328        let to_bps = 1.0 / TICK.as_secs_f64(); // Conversion factor from bytes to bps
329        let mut global_ewma: f64 = 0.0;
330        // Shadow EWMAs for per-usecase/per-scope entries, keyed the same way as the maps.
331        let mut usecase_ewmas: std::collections::HashMap<String, f64> =
332            std::collections::HashMap::new();
333        let mut scope_ewmas: std::collections::HashMap<Scopes, f64> =
334            std::collections::HashMap::new();
335
336        loop {
337            interval.tick().await;
338
339            // Global
340            global.update_ewma(&mut global_ewma, to_bps);
341            objectstore_metrics::gauge!("server.bandwidth.ewma"@b: global_ewma.floor() as u64);
342            if let Some(limit) = global_limit {
343                objectstore_metrics::gauge!("server.bandwidth.limit"@b: limit);
344            }
345
346            // Per-usecase
347            {
348                let guard = usecases.pin();
349                for (key, estimator) in guard.iter() {
350                    let ewma = usecase_ewmas.entry(key.clone()).or_insert(0.0);
351                    estimator.update_ewma(ewma, to_bps);
352                }
353            }
354
355            // Per-scope
356            {
357                let guard = scopes.pin();
358                for (key, estimator) in guard.iter() {
359                    let ewma = scope_ewmas.entry(key.clone()).or_insert(0.0);
360                    estimator.update_ewma(ewma, to_bps);
361                }
362            }
363
364            objectstore_metrics::gauge!("server.rate_limiter.bandwidth.scope_map_size": scopes.len());
365            objectstore_metrics::gauge!("server.rate_limiter.bandwidth.usecase_map_size": usecases.len());
366        }
367    }
368
369    fn check(&self, context: &ObjectContext) -> bool {
370        let Some(global_bps) = self.config.global_bps else {
371            return true;
372        };
373
374        // Global check
375        if self
376            .global
377            .estimate
378            .load(std::sync::atomic::Ordering::Relaxed)
379            > global_bps
380        {
381            return false;
382        }
383
384        // Per-usecase check
385        if let Some(usecase_bps) = self.usecase_bps() {
386            let guard = self.usecases.pin();
387            if let Some(estimator) = guard.get(&context.usecase)
388                && estimator
389                    .estimate
390                    .load(std::sync::atomic::Ordering::Relaxed)
391                    > usecase_bps
392            {
393                return false;
394            }
395        }
396
397        // Per-scope check
398        if let Some(scope_bps) = self.scope_bps() {
399            let guard = self.scopes.pin();
400            if let Some(estimator) = guard.get(&context.scopes)
401                && estimator
402                    .estimate
403                    .load(std::sync::atomic::Ordering::Relaxed)
404                    > scope_bps
405            {
406                return false;
407            }
408        }
409
410        true
411    }
412
413    /// Returns all accumulators (global + per-usecase + per-scope) for the given context.
414    ///
415    /// Creates entries in the per-usecase/per-scope maps if they don't exist yet.
416    /// Always includes `total_bytes` (cumulative, never reset) as the first entry.
417    fn accumulators(&self, context: &ObjectContext) -> Vec<Arc<AtomicU64>> {
418        let mut accs = vec![
419            Arc::clone(&self.total_bytes),
420            Arc::clone(&self.global.accumulator),
421        ];
422
423        if self.usecase_bps().is_some() {
424            let guard = self.usecases.pin();
425            let estimator = guard
426                .get_or_insert_with(context.usecase.clone(), || Arc::new(EwmaEstimator::new()));
427            accs.push(Arc::clone(&estimator.accumulator));
428        }
429
430        if self.scope_bps().is_some() {
431            let guard = self.scopes.pin();
432            let estimator =
433                guard.get_or_insert_with(context.scopes.clone(), || Arc::new(EwmaEstimator::new()));
434            accs.push(Arc::clone(&estimator.accumulator));
435        }
436
437        accs
438    }
439
440    /// Returns the effective BPS for per-usecase limiting, if configured.
441    fn usecase_bps(&self) -> Option<u64> {
442        let global_bps = self.config.global_bps?;
443        let pct = self.config.usecase_pct?;
444        Some(((global_bps as f64) * (pct as f64 / 100.0)) as u64)
445    }
446
447    /// Returns the effective BPS for per-scope limiting, if configured.
448    fn scope_bps(&self) -> Option<u64> {
449        let global_bps = self.config.global_bps?;
450        let pct = self.config.scope_pct?;
451        Some(((global_bps as f64) * (pct as f64 / 100.0)) as u64)
452    }
453}
454
455#[derive(Debug)]
456struct ThroughputRateLimiter {
457    config: ThroughputLimits,
458    global: Option<Mutex<TokenBucket>>,
459    /// Global EWMA estimator for admitted request rate.
460    global_estimator: Arc<EwmaEstimator>,
461    /// Cumulative admitted requests since startup. Never reset.
462    total_admitted: Arc<AtomicU64>,
463    // NB: These maps grow unbounded but we accept this as we expect an overall limited
464    // number of usecases and scopes. We emit gauge metrics to monitor their size.
465    usecases: Arc<papaya::HashMap<String, Mutex<TokenBucket>>>,
466    scopes: Arc<papaya::HashMap<Scopes, Mutex<TokenBucket>>>,
467    rules: papaya::HashMap<usize, Mutex<TokenBucket>>,
468}
469
470impl ThroughputRateLimiter {
471    fn new(config: ThroughputLimits) -> Self {
472        let global = config
473            .global_rps
474            .map(|rps| Mutex::new(TokenBucket::new(rps, config.burst)));
475
476        Self {
477            config,
478            global,
479            global_estimator: Arc::new(EwmaEstimator::new()),
480            total_admitted: Arc::new(AtomicU64::new(0)),
481            usecases: Arc::new(papaya::HashMap::new()),
482            scopes: Arc::new(papaya::HashMap::new()),
483            rules: papaya::HashMap::new(),
484        }
485    }
486
487    fn start(&self) {
488        let usecases = Arc::clone(&self.usecases);
489        let scopes = Arc::clone(&self.scopes);
490        let global_estimator = Arc::clone(&self.global_estimator);
491        let global_limit = self.config.global_rps;
492        // NB: This task has no shutdown mechanism — the rate limiter is only created once.
493        // The task is aborted when the Tokio runtime is dropped on process exit.
494        tokio::task::spawn(async move {
495            const TICK: Duration = Duration::from_millis(50);
496            let mut interval = tokio::time::interval(TICK);
497            let to_rps = 1.0 / TICK.as_secs_f64();
498            let mut global_ewma: f64 = 0.0;
499            loop {
500                interval.tick().await;
501                global_estimator.update_ewma(&mut global_ewma, to_rps);
502                objectstore_metrics::gauge!("server.throughput.ewma": global_ewma.floor() as u64);
503                if let Some(limit) = global_limit {
504                    objectstore_metrics::gauge!("server.rate_limiter.throughput.limit": u64::from(limit));
505                }
506                objectstore_metrics::gauge!("server.rate_limiter.throughput.scope_map_size": scopes.len());
507                objectstore_metrics::gauge!("server.rate_limiter.throughput.usecase_map_size": usecases.len());
508            }
509        });
510    }
511
512    fn check(&self, context: &ObjectContext) -> bool {
513        // NB: We intentionally use unwrap and crash the server if the mutexes are poisoned.
514
515        // Global check
516        if let Some(ref global) = self.global
517            && !global.lock().unwrap().try_acquire()
518        {
519            return false;
520        }
521
522        // Usecase check - only if both global_rps and usecase_pct are set
523        if let Some(usecase_rps) = self.usecase_rps() {
524            let guard = self.usecases.pin();
525            let bucket = guard
526                .get_or_insert_with(context.usecase.clone(), || self.create_bucket(usecase_rps));
527            if !bucket.lock().unwrap().try_acquire() {
528                return false;
529            }
530        }
531
532        // Scope check - only if both global_rps and scope_pct are set
533        if let Some(scope_rps) = self.scope_rps() {
534            let guard = self.scopes.pin();
535            let bucket =
536                guard.get_or_insert_with(context.scopes.clone(), || self.create_bucket(scope_rps));
537            if !bucket.lock().unwrap().try_acquire() {
538                return false;
539            }
540        }
541
542        // Rule checks - each matching rule has its own dedicated bucket
543        for (idx, rule) in self.config.rules.iter().enumerate() {
544            if !rule.matches(context) {
545                continue;
546            }
547            let Some(rule_rps) = self.rule_rps(rule) else {
548                continue;
549            };
550            let guard = self.rules.pin();
551            let bucket = guard.get_or_insert_with(idx, || self.create_bucket(rule_rps));
552            if !bucket.lock().unwrap().try_acquire() {
553                return false;
554            }
555        }
556
557        // Count this admitted request in the EWMA accumulator and the cumulative counter.
558        // NB: u64 wrapping is not a practical concern — at 1M rps it takes ~585k years.
559        // Prometheus irate() also handles counter resets gracefully should it ever occur.
560        self.global_estimator
561            .accumulator
562            .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
563        self.total_admitted
564            .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
565
566        true
567    }
568
569    fn create_bucket(&self, rps: u32) -> Mutex<TokenBucket> {
570        Mutex::new(TokenBucket::new(rps, self.config.burst))
571    }
572
573    /// Returns the effective RPS for per-usecase limiting, if configured.
574    fn usecase_rps(&self) -> Option<u32> {
575        let global_rps = self.config.global_rps?;
576        let pct = self.config.usecase_pct?;
577        Some(((global_rps as f64) * (pct as f64 / 100.0)) as u32)
578    }
579
580    /// Returns the effective RPS for per-scope limiting, if configured.
581    fn scope_rps(&self) -> Option<u32> {
582        let global_rps = self.config.global_rps?;
583        let pct = self.config.scope_pct?;
584        Some(((global_rps as f64) * (pct as f64 / 100.0)) as u32)
585    }
586
587    /// Returns the effective RPS for a rule, if it has a valid limit.
588    fn rule_rps(&self, rule: &ThroughputRule) -> Option<u32> {
589        let pct_limit = rule.pct.and_then(|p| {
590            self.config
591                .global_rps
592                .map(|g| ((g as f64) * (p as f64 / 100.0)) as u32)
593        });
594
595        match (rule.rps, pct_limit) {
596            (Some(r), Some(p)) => Some(r.min(p)),
597            (Some(r), None) => Some(r),
598            (None, Some(p)) => Some(p),
599            (None, None) => None,
600        }
601    }
602}
603
604/// A token bucket rate limiter.
605///
606/// Tokens refill at a constant rate up to capacity. Each request consumes one token.
607/// When no tokens are available, requests are rejected.
608///
609/// This implementation is not thread-safe on its own. Wrap in a `Mutex` for concurrent access.
610#[derive(Debug)]
611struct TokenBucket {
612    refill_rate: f64,
613    capacity: f64,
614    tokens: f64,
615    last_update: Instant,
616}
617
618impl TokenBucket {
619    /// Creates a new, full token bucket with the specified rate limit and burst capacity.
620    ///
621    /// - `rps`: tokens refilled per second (sustained rate limit)
622    /// - `burst`: initial tokens and burst allowance above sustained rate
623    pub fn new(rps: u32, burst: u32) -> Self {
624        Self {
625            refill_rate: rps as f64,
626            capacity: (rps + burst) as f64,
627            tokens: (rps + burst) as f64,
628            last_update: Instant::now(),
629        }
630    }
631
632    /// Attempts to acquire a token from the bucket.
633    ///
634    /// Returns `true` if a token was acquired, `false` if no tokens available.
635    pub fn try_acquire(&mut self) -> bool {
636        let now = Instant::now();
637        let refill = now.duration_since(self.last_update).as_secs_f64() * self.refill_rate;
638        let refilled = (self.tokens + refill).min(self.capacity);
639
640        // Only apply refill if we'd gain at least 1 whole token
641        if refilled.floor() > self.tokens.floor() {
642            self.last_update = now;
643            self.tokens = refilled;
644        }
645
646        // Try to consume one token
647        if self.tokens >= 1.0 {
648            self.tokens -= 1.0;
649            true
650        } else {
651            false
652        }
653    }
654}
655
656/// A wrapper around a `PayloadStream` that measures bandwidth usage.
657///
658/// This behaves exactly as a `PayloadStream`, except that every time an item is polled,
659/// all accumulators are incremented by the size of the returned `Bytes` chunk.
660pub(crate) struct MeteredPayloadStream {
661    inner: PayloadStream,
662    accumulators: Vec<Arc<AtomicU64>>,
663}
664
665impl MeteredPayloadStream {
666    pub fn new(inner: PayloadStream, accumulators: Vec<Arc<AtomicU64>>) -> Self {
667        Self {
668            inner,
669            accumulators,
670        }
671    }
672}
673
674impl std::fmt::Debug for MeteredPayloadStream {
675    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
676        f.debug_struct("MeteredPayloadStream")
677            .field("accumulators", &self.accumulators)
678            .finish()
679    }
680}
681
682impl Stream for MeteredPayloadStream {
683    type Item = std::io::Result<Bytes>;
684
685    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
686        let this = self.get_mut();
687        let res = this.inner.as_mut().poll_next(cx);
688        if let Poll::Ready(Some(Ok(ref bytes))) = res {
689            let len = bytes.len() as u64;
690            for acc in &this.accumulators {
691                acc.fetch_add(len, std::sync::atomic::Ordering::Relaxed);
692            }
693        }
694        res
695    }
696}
697
698#[cfg(test)]
699mod tests {
700    use objectstore_service::id::ObjectContext;
701    use objectstore_types::scope::{Scope, Scopes};
702
703    use super::*;
704
705    fn make_context() -> ObjectContext {
706        ObjectContext {
707            usecase: "testing".into(),
708            scopes: Scopes::from_iter([Scope::create("org", "1").unwrap()]),
709        }
710    }
711
712    #[test]
713    fn ewma_estimator_update_applies_alpha() {
714        let estimator = EwmaEstimator::new();
715        const TICK: f64 = 0.05; // 50ms
716        let to_rate = 1.0 / TICK;
717        let mut ewma: f64 = 0.0;
718
719        // Simulate 10 events in one 50ms tick → 200 /s raw rate.
720        // After one step: 0.2 * 200 + 0.8 * 0 = 40.
721        estimator
722            .accumulator
723            .store(10, std::sync::atomic::Ordering::Relaxed);
724        estimator.update_ewma(&mut ewma, to_rate);
725        assert_eq!(
726            estimator
727                .estimate
728                .load(std::sync::atomic::Ordering::Relaxed),
729            40
730        );
731
732        // Accumulator must have been zeroed.
733        assert_eq!(
734            estimator
735                .accumulator
736                .load(std::sync::atomic::Ordering::Relaxed),
737            0
738        );
739    }
740
741    #[test]
742    fn throughput_check_increments_accumulator() {
743        let limiter = ThroughputRateLimiter::new(ThroughputLimits {
744            global_rps: Some(1000),
745            ..Default::default()
746        });
747
748        assert_eq!(
749            limiter
750                .global_estimator
751                .accumulator
752                .load(std::sync::atomic::Ordering::Relaxed),
753            0
754        );
755
756        let context = make_context();
757        assert!(limiter.check(&context));
758        assert!(limiter.check(&context));
759
760        assert_eq!(
761            limiter
762                .global_estimator
763                .accumulator
764                .load(std::sync::atomic::Ordering::Relaxed),
765            2
766        );
767    }
768
769    #[test]
770    fn throughput_rejected_does_not_increment_accumulator() {
771        let limiter = ThroughputRateLimiter::new(ThroughputLimits {
772            global_rps: Some(1),
773            burst: 0,
774            ..Default::default()
775        });
776
777        let context = make_context();
778        // First call admitted (consumes the one token), second rejected.
779        assert!(limiter.check(&context));
780        assert!(!limiter.check(&context));
781
782        assert_eq!(
783            limiter
784                .global_estimator
785                .accumulator
786                .load(std::sync::atomic::Ordering::Relaxed),
787            1 // only the admitted request
788        );
789    }
790
791    #[test]
792    fn bandwidth_rejection_does_not_increment_throughput_accumulator() {
793        // global_bps of 1 means the estimate (0 initially) is not > 1, so the first
794        // call passes the bandwidth check. Use 0 to guarantee an immediate reject.
795        // BandwidthRateLimiter::check rejects when estimate > global_bps, so set
796        // global_bps = 0 to make the bandwidth check always reject.
797        let limiter = RateLimiter::new(RateLimits {
798            throughput: ThroughputLimits {
799                global_rps: Some(1000),
800                ..Default::default()
801            },
802            bandwidth: BandwidthLimits {
803                global_bps: Some(0),
804                ..Default::default()
805            },
806        });
807
808        // Prime the bandwidth EWMA so it exceeds the limit.
809        limiter
810            .bandwidth
811            .global
812            .estimate
813            .store(1, std::sync::atomic::Ordering::Relaxed);
814
815        let context = make_context();
816        assert!(!limiter.check(&context));
817
818        // The throughput accumulator must still be 0.
819        assert_eq!(
820            limiter
821                .throughput
822                .global_estimator
823                .accumulator
824                .load(std::sync::atomic::Ordering::Relaxed),
825            0
826        );
827    }
828
829    #[test]
830    fn rate_limiter_accessors_with_config() {
831        let rate_limiter = RateLimiter::new(RateLimits {
832            throughput: ThroughputLimits {
833                global_rps: Some(500),
834                ..Default::default()
835            },
836            bandwidth: BandwidthLimits {
837                global_bps: Some(1_000_000),
838                ..Default::default()
839            },
840        });
841
842        assert_eq!(rate_limiter.bandwidth_limit(), Some(1_000_000));
843        assert_eq!(rate_limiter.throughput_limit(), Some(500));
844        // Estimates start at 0 before any background tick.
845        assert_eq!(rate_limiter.bandwidth_ewma(), 0);
846        assert_eq!(rate_limiter.throughput_rps(), 0);
847    }
848
849    #[test]
850    fn rate_limiter_accessors_no_limits() {
851        let rate_limiter = RateLimiter::new(RateLimits::default());
852
853        assert_eq!(rate_limiter.bandwidth_limit(), None);
854        assert_eq!(rate_limiter.throughput_limit(), None);
855        assert_eq!(rate_limiter.bandwidth_ewma(), 0);
856        assert_eq!(rate_limiter.throughput_rps(), 0);
857    }
858}