1use std::fmt;
12use std::pin::Pin;
13use std::sync::Arc;
14use std::sync::{Mutex, atomic::AtomicU64};
15use std::task::{Context, Poll};
16use std::time::{Duration, Instant};
17
18use bytes::Bytes;
19use futures_util::Stream;
20use objectstore_service::id::ObjectContext;
21use objectstore_types::scope::Scopes;
22use serde::{Deserialize, Serialize};
23
24#[derive(Debug, Clone, Copy, PartialEq, Eq)]
26enum RateLimitRejection {
27 BandwidthGlobal,
29 BandwidthUsecase,
31 BandwidthScope,
33 ThroughputGlobal,
35 ThroughputUsecase,
37 ThroughputScope,
39 ThroughputRule,
41}
42
43impl RateLimitRejection {
44 pub fn as_str(self) -> &'static str {
46 match self {
47 RateLimitRejection::BandwidthGlobal => "bandwidth_global",
48 RateLimitRejection::BandwidthUsecase => "bandwidth_usecase",
49 RateLimitRejection::BandwidthScope => "bandwidth_scope",
50 RateLimitRejection::ThroughputGlobal => "throughput_global",
51 RateLimitRejection::ThroughputUsecase => "throughput_usecase",
52 RateLimitRejection::ThroughputScope => "throughput_scope",
53 RateLimitRejection::ThroughputRule => "throughput_rule",
54 }
55 }
56}
57
58impl fmt::Display for RateLimitRejection {
59 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
60 f.write_str(self.as_str())
61 }
62}
63
64#[derive(Clone, Debug, Default, Deserialize, Serialize, PartialEq)]
66pub struct RateLimits {
67 pub throughput: ThroughputLimits,
69 pub bandwidth: BandwidthLimits,
71}
72
73#[derive(Clone, Debug, Default, Deserialize, Serialize, PartialEq)]
78pub struct ThroughputLimits {
79 pub global_rps: Option<u32>,
83
84 pub burst: u32,
89
90 pub usecase_pct: Option<u8>,
94
95 pub scope_pct: Option<u8>,
103
104 pub rules: Vec<ThroughputRule>,
106}
107
108#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
114pub struct ThroughputRule {
115 pub usecase: Option<String>,
119
120 pub scopes: Vec<(String, String)>,
125
126 pub rps: Option<u32>,
131
132 pub pct: Option<u8>,
137}
138
139impl ThroughputRule {
140 pub fn matches(&self, context: &ObjectContext) -> bool {
142 if let Some(ref rule_usecase) = self.usecase
143 && rule_usecase != &context.usecase
144 {
145 return false;
146 }
147
148 for (scope_name, scope_value) in &self.scopes {
149 match context.scopes.get_value(scope_name) {
150 Some(value) if value == scope_value => (),
151 _ => return false,
152 }
153 }
154
155 true
156 }
157}
158
159#[derive(Clone, Debug, Default, Deserialize, Serialize, PartialEq)]
164pub struct BandwidthLimits {
165 pub global_bps: Option<u64>,
169
170 pub usecase_pct: Option<u8>,
174
175 pub scope_pct: Option<u8>,
179}
180
181#[derive(Debug)]
191pub struct RateLimiter {
192 bandwidth: BandwidthRateLimiter,
193 throughput: ThroughputRateLimiter,
194}
195
196impl RateLimiter {
197 pub fn new(config: RateLimits) -> Self {
201 Self {
202 bandwidth: BandwidthRateLimiter::new(config.bandwidth),
203 throughput: ThroughputRateLimiter::new(config.throughput),
204 }
205 }
206
207 pub fn start(&self) {
211 self.bandwidth.start();
212 self.throughput.start();
213 }
214
215 pub fn check(&self, context: &ObjectContext) -> bool {
221 let rejection = self
225 .bandwidth
226 .check(context)
227 .or_else(|| self.throughput.check(context));
228
229 let Some(rejection) = rejection else {
230 return true;
231 };
232
233 objectstore_metrics::count!("server.request.rate_limited", reason = rejection.as_str());
234 objectstore_log::warn!(
235 reason = rejection.as_str(),
236 "Request rejected: rate limit exceeded"
237 );
238 false
239 }
240
241 pub fn bytes_accumulators(&self, context: &ObjectContext) -> Vec<Arc<AtomicU64>> {
245 self.bandwidth.accumulators(context)
246 }
247
248 pub fn record_bandwidth(&self, context: &ObjectContext, bytes: u64) {
253 for acc in self.bandwidth.accumulators(context) {
254 acc.fetch_add(bytes, std::sync::atomic::Ordering::Relaxed);
255 }
256 }
257
258 pub fn bandwidth_ewma(&self) -> u64 {
260 self.bandwidth
261 .global
262 .estimate
263 .load(std::sync::atomic::Ordering::Relaxed)
264 }
265
266 pub fn bandwidth_limit(&self) -> Option<u64> {
268 self.bandwidth.config.global_bps
269 }
270
271 pub fn throughput_rps(&self) -> u64 {
273 self.throughput
274 .global_estimator
275 .estimate
276 .load(std::sync::atomic::Ordering::Relaxed)
277 }
278
279 pub fn throughput_limit(&self) -> Option<u32> {
281 self.throughput.config.global_rps
282 }
283
284 pub fn bandwidth_total_bytes(&self) -> u64 {
286 self.bandwidth
287 .total_bytes
288 .load(std::sync::atomic::Ordering::Relaxed)
289 }
290
291 pub fn throughput_total_admitted(&self) -> u64 {
293 self.throughput
294 .total_admitted
295 .load(std::sync::atomic::Ordering::Relaxed)
296 }
297}
298
299#[derive(Debug)]
304struct EwmaEstimator {
305 accumulator: Arc<AtomicU64>,
306 estimate: Arc<AtomicU64>,
307}
308
309impl EwmaEstimator {
310 fn new() -> Self {
311 Self {
312 accumulator: Arc::new(AtomicU64::new(0)),
313 estimate: Arc::new(AtomicU64::new(0)),
314 }
315 }
316
317 fn update_ewma(&self, ewma: &mut f64, to_rate: f64) {
323 const ALPHA: f64 = 0.2;
324 let current = self
325 .accumulator
326 .swap(0, std::sync::atomic::Ordering::Relaxed);
327 let rate = (current as f64) * to_rate;
328 *ewma = ALPHA * rate + (1.0 - ALPHA) * *ewma;
329 self.estimate
330 .store(ewma.floor() as u64, std::sync::atomic::Ordering::Relaxed);
331 }
332}
333
334#[derive(Debug)]
335struct BandwidthRateLimiter {
336 config: BandwidthLimits,
337 global: Arc<EwmaEstimator>,
339 total_bytes: Arc<AtomicU64>,
341 usecases: Arc<papaya::HashMap<String, Arc<EwmaEstimator>>>,
344 scopes: Arc<papaya::HashMap<Scopes, Arc<EwmaEstimator>>>,
345}
346
347impl BandwidthRateLimiter {
348 fn new(config: BandwidthLimits) -> Self {
349 Self {
350 config,
351 global: Arc::new(EwmaEstimator::new()),
352 total_bytes: Arc::new(AtomicU64::new(0)),
353 usecases: Arc::new(papaya::HashMap::new()),
354 scopes: Arc::new(papaya::HashMap::new()),
355 }
356 }
357
358 fn start(&self) {
359 let global = Arc::clone(&self.global);
360 let usecases = Arc::clone(&self.usecases);
361 let scopes = Arc::clone(&self.scopes);
362 let global_limit = self.config.global_bps;
363 tokio::task::spawn(async move {
366 Self::estimator(global, usecases, scopes, global_limit).await;
367 });
368 }
369
370 async fn estimator(
375 global: Arc<EwmaEstimator>,
376 usecases: Arc<papaya::HashMap<String, Arc<EwmaEstimator>>>,
377 scopes: Arc<papaya::HashMap<Scopes, Arc<EwmaEstimator>>>,
378 global_limit: Option<u64>,
379 ) {
380 const TICK: Duration = Duration::from_millis(50); let mut interval = tokio::time::interval(TICK);
383 let to_bps = 1.0 / TICK.as_secs_f64(); let mut global_ewma: f64 = 0.0;
385 let mut usecase_ewmas: std::collections::HashMap<String, f64> =
387 std::collections::HashMap::new();
388 let mut scope_ewmas: std::collections::HashMap<Scopes, f64> =
389 std::collections::HashMap::new();
390
391 loop {
392 interval.tick().await;
393
394 global.update_ewma(&mut global_ewma, to_bps);
396 objectstore_metrics::gauge!("server.bandwidth.ewma" = global_ewma.floor() as u64);
397 if let Some(limit) = global_limit {
398 objectstore_metrics::gauge!("server.bandwidth.limit" = limit);
399 }
400
401 {
403 let guard = usecases.pin();
404 for (key, estimator) in guard.iter() {
405 let ewma = usecase_ewmas.entry(key.clone()).or_insert(0.0);
406 estimator.update_ewma(ewma, to_bps);
407 }
408 }
409
410 {
412 let guard = scopes.pin();
413 for (key, estimator) in guard.iter() {
414 let ewma = scope_ewmas.entry(key.clone()).or_insert(0.0);
415 estimator.update_ewma(ewma, to_bps);
416 }
417 }
418
419 objectstore_metrics::gauge!(
420 "server.rate_limiter.bandwidth.scope_map_size" = scopes.len()
421 );
422 objectstore_metrics::gauge!(
423 "server.rate_limiter.bandwidth.usecase_map_size" = usecases.len()
424 );
425 }
426 }
427
428 fn check(&self, context: &ObjectContext) -> Option<RateLimitRejection> {
429 let global_bps = self.config.global_bps?;
430
431 if self
433 .global
434 .estimate
435 .load(std::sync::atomic::Ordering::Relaxed)
436 > global_bps
437 {
438 return Some(RateLimitRejection::BandwidthGlobal);
439 }
440
441 if let Some(usecase_bps) = self.usecase_bps() {
443 let guard = self.usecases.pin();
444 if let Some(estimator) = guard.get(&context.usecase)
445 && estimator
446 .estimate
447 .load(std::sync::atomic::Ordering::Relaxed)
448 > usecase_bps
449 {
450 return Some(RateLimitRejection::BandwidthUsecase);
451 }
452 }
453
454 if let Some(scope_bps) = self.scope_bps() {
456 let guard = self.scopes.pin();
457 if let Some(estimator) = guard.get(&context.scopes)
458 && estimator
459 .estimate
460 .load(std::sync::atomic::Ordering::Relaxed)
461 > scope_bps
462 {
463 return Some(RateLimitRejection::BandwidthScope);
464 }
465 }
466
467 None
468 }
469
470 fn accumulators(&self, context: &ObjectContext) -> Vec<Arc<AtomicU64>> {
475 let mut accs = vec![
476 Arc::clone(&self.total_bytes),
477 Arc::clone(&self.global.accumulator),
478 ];
479
480 if self.usecase_bps().is_some() {
481 let guard = self.usecases.pin();
482 let estimator = guard
483 .get_or_insert_with(context.usecase.clone(), || Arc::new(EwmaEstimator::new()));
484 accs.push(Arc::clone(&estimator.accumulator));
485 }
486
487 if self.scope_bps().is_some() {
488 let guard = self.scopes.pin();
489 let estimator =
490 guard.get_or_insert_with(context.scopes.clone(), || Arc::new(EwmaEstimator::new()));
491 accs.push(Arc::clone(&estimator.accumulator));
492 }
493
494 accs
495 }
496
497 fn usecase_bps(&self) -> Option<u64> {
499 let global_bps = self.config.global_bps?;
500 let pct = self.config.usecase_pct?;
501 Some(((global_bps as f64) * (pct as f64 / 100.0)) as u64)
502 }
503
504 fn scope_bps(&self) -> Option<u64> {
506 let global_bps = self.config.global_bps?;
507 let pct = self.config.scope_pct?;
508 Some(((global_bps as f64) * (pct as f64 / 100.0)) as u64)
509 }
510}
511
512#[derive(Debug)]
513struct ThroughputRateLimiter {
514 config: ThroughputLimits,
515 global: Option<Mutex<TokenBucket>>,
516 global_estimator: Arc<EwmaEstimator>,
518 total_admitted: Arc<AtomicU64>,
520 usecases: Arc<papaya::HashMap<String, Mutex<TokenBucket>>>,
523 scopes: Arc<papaya::HashMap<Scopes, Mutex<TokenBucket>>>,
524 rules: papaya::HashMap<usize, Mutex<TokenBucket>>,
525}
526
527impl ThroughputRateLimiter {
528 fn new(config: ThroughputLimits) -> Self {
529 let global = config
530 .global_rps
531 .map(|rps| Mutex::new(TokenBucket::new(rps, config.burst)));
532
533 Self {
534 config,
535 global,
536 global_estimator: Arc::new(EwmaEstimator::new()),
537 total_admitted: Arc::new(AtomicU64::new(0)),
538 usecases: Arc::new(papaya::HashMap::new()),
539 scopes: Arc::new(papaya::HashMap::new()),
540 rules: papaya::HashMap::new(),
541 }
542 }
543
544 fn start(&self) {
545 let usecases = Arc::clone(&self.usecases);
546 let scopes = Arc::clone(&self.scopes);
547 let global_estimator = Arc::clone(&self.global_estimator);
548 let global_limit = self.config.global_rps;
549 tokio::task::spawn(async move {
552 const TICK: Duration = Duration::from_millis(50);
553 let mut interval = tokio::time::interval(TICK);
554 let to_rps = 1.0 / TICK.as_secs_f64();
555 let mut global_ewma: f64 = 0.0;
556 loop {
557 interval.tick().await;
558 global_estimator.update_ewma(&mut global_ewma, to_rps);
559 objectstore_metrics::gauge!("server.throughput.ewma" = global_ewma.floor() as u64);
560 if let Some(limit) = global_limit {
561 objectstore_metrics::gauge!(
562 "server.rate_limiter.throughput.limit" = u64::from(limit)
563 );
564 }
565 objectstore_metrics::gauge!(
566 "server.rate_limiter.throughput.scope_map_size" = scopes.len()
567 );
568 objectstore_metrics::gauge!(
569 "server.rate_limiter.throughput.usecase_map_size" = usecases.len()
570 );
571 }
572 });
573 }
574
575 fn check(&self, context: &ObjectContext) -> Option<RateLimitRejection> {
576 if let Some(ref global) = self.global {
580 let acquired = global.lock().unwrap().try_acquire();
581 if !acquired {
582 return Some(RateLimitRejection::ThroughputGlobal);
583 }
584 }
585
586 if let Some(usecase_rps) = self.usecase_rps() {
588 let guard = self.usecases.pin();
589 let bucket = guard
590 .get_or_insert_with(context.usecase.clone(), || self.create_bucket(usecase_rps));
591 if !bucket.lock().unwrap().try_acquire() {
592 return Some(RateLimitRejection::ThroughputUsecase);
593 }
594 }
595
596 if let Some(scope_rps) = self.scope_rps() {
598 let guard = self.scopes.pin();
599 let bucket =
600 guard.get_or_insert_with(context.scopes.clone(), || self.create_bucket(scope_rps));
601 if !bucket.lock().unwrap().try_acquire() {
602 return Some(RateLimitRejection::ThroughputScope);
603 }
604 }
605
606 for (idx, rule) in self.config.rules.iter().enumerate() {
608 if !rule.matches(context) {
609 continue;
610 }
611 let Some(rule_rps) = self.rule_rps(rule) else {
612 continue;
613 };
614 let guard = self.rules.pin();
615 let bucket = guard.get_or_insert_with(idx, || self.create_bucket(rule_rps));
616 if !bucket.lock().unwrap().try_acquire() {
617 return Some(RateLimitRejection::ThroughputRule);
618 }
619 }
620
621 self.global_estimator
625 .accumulator
626 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
627 self.total_admitted
628 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
629
630 None
631 }
632
633 fn create_bucket(&self, rps: u32) -> Mutex<TokenBucket> {
634 Mutex::new(TokenBucket::new(rps, self.config.burst))
635 }
636
637 fn usecase_rps(&self) -> Option<u32> {
639 let global_rps = self.config.global_rps?;
640 let pct = self.config.usecase_pct?;
641 Some(((global_rps as f64) * (pct as f64 / 100.0)) as u32)
642 }
643
644 fn scope_rps(&self) -> Option<u32> {
646 let global_rps = self.config.global_rps?;
647 let pct = self.config.scope_pct?;
648 Some(((global_rps as f64) * (pct as f64 / 100.0)) as u32)
649 }
650
651 fn rule_rps(&self, rule: &ThroughputRule) -> Option<u32> {
653 let pct_limit = rule.pct.and_then(|p| {
654 self.config
655 .global_rps
656 .map(|g| ((g as f64) * (p as f64 / 100.0)) as u32)
657 });
658
659 match (rule.rps, pct_limit) {
660 (Some(r), Some(p)) => Some(r.min(p)),
661 (Some(r), None) => Some(r),
662 (None, Some(p)) => Some(p),
663 (None, None) => None,
664 }
665 }
666}
667
668#[derive(Debug)]
675struct TokenBucket {
676 refill_rate: f64,
677 capacity: f64,
678 tokens: f64,
679 last_update: Instant,
680}
681
682impl TokenBucket {
683 pub fn new(rps: u32, burst: u32) -> Self {
688 Self {
689 refill_rate: rps as f64,
690 capacity: (rps + burst) as f64,
691 tokens: (rps + burst) as f64,
692 last_update: Instant::now(),
693 }
694 }
695
696 pub fn try_acquire(&mut self) -> bool {
700 let now = Instant::now();
701 let refill = now.duration_since(self.last_update).as_secs_f64() * self.refill_rate;
702 let refilled = (self.tokens + refill).min(self.capacity);
703
704 if refilled.floor() > self.tokens.floor() {
706 self.last_update = now;
707 self.tokens = refilled;
708 }
709
710 if self.tokens >= 1.0 {
712 self.tokens -= 1.0;
713 true
714 } else {
715 false
716 }
717 }
718}
719
720pub(crate) struct MeteredPayloadStream<S> {
725 inner: S,
726 accumulators: Vec<Arc<AtomicU64>>,
727}
728
729impl<S> MeteredPayloadStream<S> {
730 pub fn new(inner: S, accumulators: Vec<Arc<AtomicU64>>) -> Self {
731 Self {
732 inner,
733 accumulators,
734 }
735 }
736}
737
738impl<S> fmt::Debug for MeteredPayloadStream<S> {
739 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
740 f.debug_struct("MeteredPayloadStream")
741 .field("accumulators", &self.accumulators)
742 .finish()
743 }
744}
745
746impl<S, E> Stream for MeteredPayloadStream<S>
747where
748 S: Stream<Item = Result<Bytes, E>> + Unpin,
749{
750 type Item = Result<Bytes, E>;
751
752 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
753 let this = self.get_mut();
754 let res = Pin::new(&mut this.inner).poll_next(cx);
755 if let Poll::Ready(Some(Ok(ref bytes))) = res {
756 let len = bytes.len() as u64;
757 for acc in &this.accumulators {
758 acc.fetch_add(len, std::sync::atomic::Ordering::Relaxed);
759 }
760 }
761 res
762 }
763}
764
765#[cfg(test)]
766mod tests {
767 use objectstore_service::id::ObjectContext;
768 use objectstore_types::scope::{Scope, Scopes};
769
770 use super::*;
771
772 fn make_context() -> ObjectContext {
773 ObjectContext {
774 usecase: "testing".into(),
775 scopes: Scopes::from_iter([Scope::create("org", "1").unwrap()]),
776 }
777 }
778
779 #[test]
780 fn ewma_estimator_update_applies_alpha() {
781 let estimator = EwmaEstimator::new();
782 const TICK: f64 = 0.05; let to_rate = 1.0 / TICK;
784 let mut ewma: f64 = 0.0;
785
786 estimator
789 .accumulator
790 .store(10, std::sync::atomic::Ordering::Relaxed);
791 estimator.update_ewma(&mut ewma, to_rate);
792 assert_eq!(
793 estimator
794 .estimate
795 .load(std::sync::atomic::Ordering::Relaxed),
796 40
797 );
798
799 assert_eq!(
801 estimator
802 .accumulator
803 .load(std::sync::atomic::Ordering::Relaxed),
804 0
805 );
806 }
807
808 #[test]
809 fn throughput_check_increments_accumulator() {
810 let limiter = ThroughputRateLimiter::new(ThroughputLimits {
811 global_rps: Some(1000),
812 ..Default::default()
813 });
814
815 assert_eq!(
816 limiter
817 .global_estimator
818 .accumulator
819 .load(std::sync::atomic::Ordering::Relaxed),
820 0
821 );
822
823 let context = make_context();
824 assert!(limiter.check(&context).is_none());
825 assert!(limiter.check(&context).is_none());
826
827 assert_eq!(
828 limiter
829 .global_estimator
830 .accumulator
831 .load(std::sync::atomic::Ordering::Relaxed),
832 2
833 );
834 }
835
836 #[test]
837 fn throughput_rejected_does_not_increment_accumulator() {
838 let limiter = ThroughputRateLimiter::new(ThroughputLimits {
839 global_rps: Some(1),
840 burst: 0,
841 ..Default::default()
842 });
843
844 let context = make_context();
845 assert!(limiter.check(&context).is_none());
847 assert!(limiter.check(&context).is_some());
848
849 assert_eq!(
850 limiter
851 .global_estimator
852 .accumulator
853 .load(std::sync::atomic::Ordering::Relaxed),
854 1 );
856 }
857
858 #[test]
859 fn bandwidth_rejection_does_not_increment_throughput_accumulator() {
860 let limiter = RateLimiter::new(RateLimits {
865 throughput: ThroughputLimits {
866 global_rps: Some(1000),
867 ..Default::default()
868 },
869 bandwidth: BandwidthLimits {
870 global_bps: Some(0),
871 ..Default::default()
872 },
873 });
874
875 limiter
877 .bandwidth
878 .global
879 .estimate
880 .store(1, std::sync::atomic::Ordering::Relaxed);
881
882 let context = make_context();
883 assert!(!limiter.check(&context));
884
885 assert_eq!(
887 limiter
888 .throughput
889 .global_estimator
890 .accumulator
891 .load(std::sync::atomic::Ordering::Relaxed),
892 0
893 );
894 }
895
896 #[test]
897 fn rate_limiter_accessors_with_config() {
898 let rate_limiter = RateLimiter::new(RateLimits {
899 throughput: ThroughputLimits {
900 global_rps: Some(500),
901 ..Default::default()
902 },
903 bandwidth: BandwidthLimits {
904 global_bps: Some(1_000_000),
905 ..Default::default()
906 },
907 });
908
909 assert_eq!(rate_limiter.bandwidth_limit(), Some(1_000_000));
910 assert_eq!(rate_limiter.throughput_limit(), Some(500));
911 assert_eq!(rate_limiter.bandwidth_ewma(), 0);
913 assert_eq!(rate_limiter.throughput_rps(), 0);
914 }
915
916 #[test]
917 fn rate_limiter_accessors_no_limits() {
918 let rate_limiter = RateLimiter::new(RateLimits::default());
919
920 assert_eq!(rate_limiter.bandwidth_limit(), None);
921 assert_eq!(rate_limiter.throughput_limit(), None);
922 assert_eq!(rate_limiter.bandwidth_ewma(), 0);
923 assert_eq!(rate_limiter.throughput_rps(), 0);
924 }
925}