1use std::pin::Pin;
12use std::sync::Arc;
13use std::sync::Mutex;
14use std::sync::atomic::AtomicU64;
15use std::task::{Context, Poll};
16use std::time::{Duration, Instant};
17
18use bytes::Bytes;
19use futures_util::Stream;
20use objectstore_service::PayloadStream;
21use objectstore_service::id::ObjectContext;
22use objectstore_types::scope::Scopes;
23use serde::{Deserialize, Serialize};
24
25#[derive(Clone, Debug, Default, Deserialize, Serialize, PartialEq)]
27pub struct RateLimits {
28 pub throughput: ThroughputLimits,
30 pub bandwidth: BandwidthLimits,
32}
33
34#[derive(Clone, Debug, Default, Deserialize, Serialize, PartialEq)]
39pub struct ThroughputLimits {
40 pub global_rps: Option<u32>,
44
45 pub burst: u32,
50
51 pub usecase_pct: Option<u8>,
55
56 pub scope_pct: Option<u8>,
64
65 pub rules: Vec<ThroughputRule>,
67}
68
69#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
75pub struct ThroughputRule {
76 pub usecase: Option<String>,
80
81 pub scopes: Vec<(String, String)>,
86
87 pub rps: Option<u32>,
92
93 pub pct: Option<u8>,
98}
99
100impl ThroughputRule {
101 pub fn matches(&self, context: &ObjectContext) -> bool {
103 if let Some(ref rule_usecase) = self.usecase
104 && rule_usecase != &context.usecase
105 {
106 return false;
107 }
108
109 for (scope_name, scope_value) in &self.scopes {
110 match context.scopes.get_value(scope_name) {
111 Some(value) if value == scope_value => (),
112 _ => return false,
113 }
114 }
115
116 true
117 }
118}
119
120#[derive(Clone, Debug, Default, Deserialize, Serialize, PartialEq)]
125pub struct BandwidthLimits {
126 pub global_bps: Option<u64>,
130
131 pub usecase_pct: Option<u8>,
135
136 pub scope_pct: Option<u8>,
140}
141
142#[derive(Debug)]
152pub struct RateLimiter {
153 bandwidth: BandwidthRateLimiter,
154 throughput: ThroughputRateLimiter,
155}
156
157impl RateLimiter {
158 pub fn new(config: RateLimits) -> Self {
162 Self {
163 bandwidth: BandwidthRateLimiter::new(config.bandwidth),
164 throughput: ThroughputRateLimiter::new(config.throughput),
165 }
166 }
167
168 pub fn start(&self) {
172 self.bandwidth.start();
173 self.throughput.start();
174 }
175
176 pub fn check(&self, context: &ObjectContext) -> bool {
180 self.bandwidth.check(context) && self.throughput.check(context)
184 }
185
186 pub fn bytes_accumulators(&self, context: &ObjectContext) -> Vec<Arc<AtomicU64>> {
190 self.bandwidth.accumulators(context)
191 }
192
193 pub fn record_bandwidth(&self, context: &ObjectContext, bytes: u64) {
198 for acc in self.bandwidth.accumulators(context) {
199 acc.fetch_add(bytes, std::sync::atomic::Ordering::Relaxed);
200 }
201 }
202
203 pub fn bandwidth_ewma(&self) -> u64 {
205 self.bandwidth
206 .global
207 .estimate
208 .load(std::sync::atomic::Ordering::Relaxed)
209 }
210
211 pub fn bandwidth_limit(&self) -> Option<u64> {
213 self.bandwidth.config.global_bps
214 }
215
216 pub fn throughput_rps(&self) -> u64 {
218 self.throughput
219 .global_estimator
220 .estimate
221 .load(std::sync::atomic::Ordering::Relaxed)
222 }
223
224 pub fn throughput_limit(&self) -> Option<u32> {
226 self.throughput.config.global_rps
227 }
228
229 pub fn bandwidth_total_bytes(&self) -> u64 {
231 self.bandwidth
232 .total_bytes
233 .load(std::sync::atomic::Ordering::Relaxed)
234 }
235
236 pub fn throughput_total_admitted(&self) -> u64 {
238 self.throughput
239 .total_admitted
240 .load(std::sync::atomic::Ordering::Relaxed)
241 }
242}
243
244#[derive(Debug)]
249struct EwmaEstimator {
250 accumulator: Arc<AtomicU64>,
251 estimate: Arc<AtomicU64>,
252}
253
254impl EwmaEstimator {
255 fn new() -> Self {
256 Self {
257 accumulator: Arc::new(AtomicU64::new(0)),
258 estimate: Arc::new(AtomicU64::new(0)),
259 }
260 }
261
262 fn update_ewma(&self, ewma: &mut f64, to_rate: f64) {
268 const ALPHA: f64 = 0.2;
269 let current = self
270 .accumulator
271 .swap(0, std::sync::atomic::Ordering::Relaxed);
272 let rate = (current as f64) * to_rate;
273 *ewma = ALPHA * rate + (1.0 - ALPHA) * *ewma;
274 self.estimate
275 .store(ewma.floor() as u64, std::sync::atomic::Ordering::Relaxed);
276 }
277}
278
279#[derive(Debug)]
280struct BandwidthRateLimiter {
281 config: BandwidthLimits,
282 global: Arc<EwmaEstimator>,
284 total_bytes: Arc<AtomicU64>,
286 usecases: Arc<papaya::HashMap<String, Arc<EwmaEstimator>>>,
289 scopes: Arc<papaya::HashMap<Scopes, Arc<EwmaEstimator>>>,
290}
291
292impl BandwidthRateLimiter {
293 fn new(config: BandwidthLimits) -> Self {
294 Self {
295 config,
296 global: Arc::new(EwmaEstimator::new()),
297 total_bytes: Arc::new(AtomicU64::new(0)),
298 usecases: Arc::new(papaya::HashMap::new()),
299 scopes: Arc::new(papaya::HashMap::new()),
300 }
301 }
302
303 fn start(&self) {
304 let global = Arc::clone(&self.global);
305 let usecases = Arc::clone(&self.usecases);
306 let scopes = Arc::clone(&self.scopes);
307 let global_limit = self.config.global_bps;
308 tokio::task::spawn(async move {
311 Self::estimator(global, usecases, scopes, global_limit).await;
312 });
313 }
314
315 async fn estimator(
320 global: Arc<EwmaEstimator>,
321 usecases: Arc<papaya::HashMap<String, Arc<EwmaEstimator>>>,
322 scopes: Arc<papaya::HashMap<Scopes, Arc<EwmaEstimator>>>,
323 global_limit: Option<u64>,
324 ) {
325 const TICK: Duration = Duration::from_millis(50); let mut interval = tokio::time::interval(TICK);
328 let to_bps = 1.0 / TICK.as_secs_f64(); let mut global_ewma: f64 = 0.0;
330 let mut usecase_ewmas: std::collections::HashMap<String, f64> =
332 std::collections::HashMap::new();
333 let mut scope_ewmas: std::collections::HashMap<Scopes, f64> =
334 std::collections::HashMap::new();
335
336 loop {
337 interval.tick().await;
338
339 global.update_ewma(&mut global_ewma, to_bps);
341 objectstore_metrics::gauge!("server.bandwidth.ewma"@b: global_ewma.floor() as u64);
342 if let Some(limit) = global_limit {
343 objectstore_metrics::gauge!("server.bandwidth.limit"@b: limit);
344 }
345
346 {
348 let guard = usecases.pin();
349 for (key, estimator) in guard.iter() {
350 let ewma = usecase_ewmas.entry(key.clone()).or_insert(0.0);
351 estimator.update_ewma(ewma, to_bps);
352 }
353 }
354
355 {
357 let guard = scopes.pin();
358 for (key, estimator) in guard.iter() {
359 let ewma = scope_ewmas.entry(key.clone()).or_insert(0.0);
360 estimator.update_ewma(ewma, to_bps);
361 }
362 }
363
364 objectstore_metrics::gauge!("server.rate_limiter.bandwidth.scope_map_size": scopes.len());
365 objectstore_metrics::gauge!("server.rate_limiter.bandwidth.usecase_map_size": usecases.len());
366 }
367 }
368
369 fn check(&self, context: &ObjectContext) -> bool {
370 let Some(global_bps) = self.config.global_bps else {
371 return true;
372 };
373
374 if self
376 .global
377 .estimate
378 .load(std::sync::atomic::Ordering::Relaxed)
379 > global_bps
380 {
381 return false;
382 }
383
384 if let Some(usecase_bps) = self.usecase_bps() {
386 let guard = self.usecases.pin();
387 if let Some(estimator) = guard.get(&context.usecase)
388 && estimator
389 .estimate
390 .load(std::sync::atomic::Ordering::Relaxed)
391 > usecase_bps
392 {
393 return false;
394 }
395 }
396
397 if let Some(scope_bps) = self.scope_bps() {
399 let guard = self.scopes.pin();
400 if let Some(estimator) = guard.get(&context.scopes)
401 && estimator
402 .estimate
403 .load(std::sync::atomic::Ordering::Relaxed)
404 > scope_bps
405 {
406 return false;
407 }
408 }
409
410 true
411 }
412
413 fn accumulators(&self, context: &ObjectContext) -> Vec<Arc<AtomicU64>> {
418 let mut accs = vec![
419 Arc::clone(&self.total_bytes),
420 Arc::clone(&self.global.accumulator),
421 ];
422
423 if self.usecase_bps().is_some() {
424 let guard = self.usecases.pin();
425 let estimator = guard
426 .get_or_insert_with(context.usecase.clone(), || Arc::new(EwmaEstimator::new()));
427 accs.push(Arc::clone(&estimator.accumulator));
428 }
429
430 if self.scope_bps().is_some() {
431 let guard = self.scopes.pin();
432 let estimator =
433 guard.get_or_insert_with(context.scopes.clone(), || Arc::new(EwmaEstimator::new()));
434 accs.push(Arc::clone(&estimator.accumulator));
435 }
436
437 accs
438 }
439
440 fn usecase_bps(&self) -> Option<u64> {
442 let global_bps = self.config.global_bps?;
443 let pct = self.config.usecase_pct?;
444 Some(((global_bps as f64) * (pct as f64 / 100.0)) as u64)
445 }
446
447 fn scope_bps(&self) -> Option<u64> {
449 let global_bps = self.config.global_bps?;
450 let pct = self.config.scope_pct?;
451 Some(((global_bps as f64) * (pct as f64 / 100.0)) as u64)
452 }
453}
454
455#[derive(Debug)]
456struct ThroughputRateLimiter {
457 config: ThroughputLimits,
458 global: Option<Mutex<TokenBucket>>,
459 global_estimator: Arc<EwmaEstimator>,
461 total_admitted: Arc<AtomicU64>,
463 usecases: Arc<papaya::HashMap<String, Mutex<TokenBucket>>>,
466 scopes: Arc<papaya::HashMap<Scopes, Mutex<TokenBucket>>>,
467 rules: papaya::HashMap<usize, Mutex<TokenBucket>>,
468}
469
470impl ThroughputRateLimiter {
471 fn new(config: ThroughputLimits) -> Self {
472 let global = config
473 .global_rps
474 .map(|rps| Mutex::new(TokenBucket::new(rps, config.burst)));
475
476 Self {
477 config,
478 global,
479 global_estimator: Arc::new(EwmaEstimator::new()),
480 total_admitted: Arc::new(AtomicU64::new(0)),
481 usecases: Arc::new(papaya::HashMap::new()),
482 scopes: Arc::new(papaya::HashMap::new()),
483 rules: papaya::HashMap::new(),
484 }
485 }
486
487 fn start(&self) {
488 let usecases = Arc::clone(&self.usecases);
489 let scopes = Arc::clone(&self.scopes);
490 let global_estimator = Arc::clone(&self.global_estimator);
491 let global_limit = self.config.global_rps;
492 tokio::task::spawn(async move {
495 const TICK: Duration = Duration::from_millis(50);
496 let mut interval = tokio::time::interval(TICK);
497 let to_rps = 1.0 / TICK.as_secs_f64();
498 let mut global_ewma: f64 = 0.0;
499 loop {
500 interval.tick().await;
501 global_estimator.update_ewma(&mut global_ewma, to_rps);
502 objectstore_metrics::gauge!("server.throughput.ewma": global_ewma.floor() as u64);
503 if let Some(limit) = global_limit {
504 objectstore_metrics::gauge!("server.rate_limiter.throughput.limit": u64::from(limit));
505 }
506 objectstore_metrics::gauge!("server.rate_limiter.throughput.scope_map_size": scopes.len());
507 objectstore_metrics::gauge!("server.rate_limiter.throughput.usecase_map_size": usecases.len());
508 }
509 });
510 }
511
512 fn check(&self, context: &ObjectContext) -> bool {
513 if let Some(ref global) = self.global
517 && !global.lock().unwrap().try_acquire()
518 {
519 return false;
520 }
521
522 if let Some(usecase_rps) = self.usecase_rps() {
524 let guard = self.usecases.pin();
525 let bucket = guard
526 .get_or_insert_with(context.usecase.clone(), || self.create_bucket(usecase_rps));
527 if !bucket.lock().unwrap().try_acquire() {
528 return false;
529 }
530 }
531
532 if let Some(scope_rps) = self.scope_rps() {
534 let guard = self.scopes.pin();
535 let bucket =
536 guard.get_or_insert_with(context.scopes.clone(), || self.create_bucket(scope_rps));
537 if !bucket.lock().unwrap().try_acquire() {
538 return false;
539 }
540 }
541
542 for (idx, rule) in self.config.rules.iter().enumerate() {
544 if !rule.matches(context) {
545 continue;
546 }
547 let Some(rule_rps) = self.rule_rps(rule) else {
548 continue;
549 };
550 let guard = self.rules.pin();
551 let bucket = guard.get_or_insert_with(idx, || self.create_bucket(rule_rps));
552 if !bucket.lock().unwrap().try_acquire() {
553 return false;
554 }
555 }
556
557 self.global_estimator
561 .accumulator
562 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
563 self.total_admitted
564 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
565
566 true
567 }
568
569 fn create_bucket(&self, rps: u32) -> Mutex<TokenBucket> {
570 Mutex::new(TokenBucket::new(rps, self.config.burst))
571 }
572
573 fn usecase_rps(&self) -> Option<u32> {
575 let global_rps = self.config.global_rps?;
576 let pct = self.config.usecase_pct?;
577 Some(((global_rps as f64) * (pct as f64 / 100.0)) as u32)
578 }
579
580 fn scope_rps(&self) -> Option<u32> {
582 let global_rps = self.config.global_rps?;
583 let pct = self.config.scope_pct?;
584 Some(((global_rps as f64) * (pct as f64 / 100.0)) as u32)
585 }
586
587 fn rule_rps(&self, rule: &ThroughputRule) -> Option<u32> {
589 let pct_limit = rule.pct.and_then(|p| {
590 self.config
591 .global_rps
592 .map(|g| ((g as f64) * (p as f64 / 100.0)) as u32)
593 });
594
595 match (rule.rps, pct_limit) {
596 (Some(r), Some(p)) => Some(r.min(p)),
597 (Some(r), None) => Some(r),
598 (None, Some(p)) => Some(p),
599 (None, None) => None,
600 }
601 }
602}
603
604#[derive(Debug)]
611struct TokenBucket {
612 refill_rate: f64,
613 capacity: f64,
614 tokens: f64,
615 last_update: Instant,
616}
617
618impl TokenBucket {
619 pub fn new(rps: u32, burst: u32) -> Self {
624 Self {
625 refill_rate: rps as f64,
626 capacity: (rps + burst) as f64,
627 tokens: (rps + burst) as f64,
628 last_update: Instant::now(),
629 }
630 }
631
632 pub fn try_acquire(&mut self) -> bool {
636 let now = Instant::now();
637 let refill = now.duration_since(self.last_update).as_secs_f64() * self.refill_rate;
638 let refilled = (self.tokens + refill).min(self.capacity);
639
640 if refilled.floor() > self.tokens.floor() {
642 self.last_update = now;
643 self.tokens = refilled;
644 }
645
646 if self.tokens >= 1.0 {
648 self.tokens -= 1.0;
649 true
650 } else {
651 false
652 }
653 }
654}
655
656pub(crate) struct MeteredPayloadStream {
661 inner: PayloadStream,
662 accumulators: Vec<Arc<AtomicU64>>,
663}
664
665impl MeteredPayloadStream {
666 pub fn new(inner: PayloadStream, accumulators: Vec<Arc<AtomicU64>>) -> Self {
667 Self {
668 inner,
669 accumulators,
670 }
671 }
672}
673
674impl std::fmt::Debug for MeteredPayloadStream {
675 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
676 f.debug_struct("MeteredPayloadStream")
677 .field("accumulators", &self.accumulators)
678 .finish()
679 }
680}
681
682impl Stream for MeteredPayloadStream {
683 type Item = std::io::Result<Bytes>;
684
685 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
686 let this = self.get_mut();
687 let res = this.inner.as_mut().poll_next(cx);
688 if let Poll::Ready(Some(Ok(ref bytes))) = res {
689 let len = bytes.len() as u64;
690 for acc in &this.accumulators {
691 acc.fetch_add(len, std::sync::atomic::Ordering::Relaxed);
692 }
693 }
694 res
695 }
696}
697
698#[cfg(test)]
699mod tests {
700 use objectstore_service::id::ObjectContext;
701 use objectstore_types::scope::{Scope, Scopes};
702
703 use super::*;
704
705 fn make_context() -> ObjectContext {
706 ObjectContext {
707 usecase: "testing".into(),
708 scopes: Scopes::from_iter([Scope::create("org", "1").unwrap()]),
709 }
710 }
711
712 #[test]
713 fn ewma_estimator_update_applies_alpha() {
714 let estimator = EwmaEstimator::new();
715 const TICK: f64 = 0.05; let to_rate = 1.0 / TICK;
717 let mut ewma: f64 = 0.0;
718
719 estimator
722 .accumulator
723 .store(10, std::sync::atomic::Ordering::Relaxed);
724 estimator.update_ewma(&mut ewma, to_rate);
725 assert_eq!(
726 estimator
727 .estimate
728 .load(std::sync::atomic::Ordering::Relaxed),
729 40
730 );
731
732 assert_eq!(
734 estimator
735 .accumulator
736 .load(std::sync::atomic::Ordering::Relaxed),
737 0
738 );
739 }
740
741 #[test]
742 fn throughput_check_increments_accumulator() {
743 let limiter = ThroughputRateLimiter::new(ThroughputLimits {
744 global_rps: Some(1000),
745 ..Default::default()
746 });
747
748 assert_eq!(
749 limiter
750 .global_estimator
751 .accumulator
752 .load(std::sync::atomic::Ordering::Relaxed),
753 0
754 );
755
756 let context = make_context();
757 assert!(limiter.check(&context));
758 assert!(limiter.check(&context));
759
760 assert_eq!(
761 limiter
762 .global_estimator
763 .accumulator
764 .load(std::sync::atomic::Ordering::Relaxed),
765 2
766 );
767 }
768
769 #[test]
770 fn throughput_rejected_does_not_increment_accumulator() {
771 let limiter = ThroughputRateLimiter::new(ThroughputLimits {
772 global_rps: Some(1),
773 burst: 0,
774 ..Default::default()
775 });
776
777 let context = make_context();
778 assert!(limiter.check(&context));
780 assert!(!limiter.check(&context));
781
782 assert_eq!(
783 limiter
784 .global_estimator
785 .accumulator
786 .load(std::sync::atomic::Ordering::Relaxed),
787 1 );
789 }
790
791 #[test]
792 fn bandwidth_rejection_does_not_increment_throughput_accumulator() {
793 let limiter = RateLimiter::new(RateLimits {
798 throughput: ThroughputLimits {
799 global_rps: Some(1000),
800 ..Default::default()
801 },
802 bandwidth: BandwidthLimits {
803 global_bps: Some(0),
804 ..Default::default()
805 },
806 });
807
808 limiter
810 .bandwidth
811 .global
812 .estimate
813 .store(1, std::sync::atomic::Ordering::Relaxed);
814
815 let context = make_context();
816 assert!(!limiter.check(&context));
817
818 assert_eq!(
820 limiter
821 .throughput
822 .global_estimator
823 .accumulator
824 .load(std::sync::atomic::Ordering::Relaxed),
825 0
826 );
827 }
828
829 #[test]
830 fn rate_limiter_accessors_with_config() {
831 let rate_limiter = RateLimiter::new(RateLimits {
832 throughput: ThroughputLimits {
833 global_rps: Some(500),
834 ..Default::default()
835 },
836 bandwidth: BandwidthLimits {
837 global_bps: Some(1_000_000),
838 ..Default::default()
839 },
840 });
841
842 assert_eq!(rate_limiter.bandwidth_limit(), Some(1_000_000));
843 assert_eq!(rate_limiter.throughput_limit(), Some(500));
844 assert_eq!(rate_limiter.bandwidth_ewma(), 0);
846 assert_eq!(rate_limiter.throughput_rps(), 0);
847 }
848
849 #[test]
850 fn rate_limiter_accessors_no_limits() {
851 let rate_limiter = RateLimiter::new(RateLimits::default());
852
853 assert_eq!(rate_limiter.bandwidth_limit(), None);
854 assert_eq!(rate_limiter.throughput_limit(), None);
855 assert_eq!(rate_limiter.bandwidth_ewma(), 0);
856 assert_eq!(rate_limiter.throughput_rps(), 0);
857 }
858}