1use std::fmt;
12use std::pin::Pin;
13use std::sync::Arc;
14use std::sync::{Mutex, atomic::AtomicU64};
15use std::task::{Context, Poll};
16use std::time::{Duration, Instant};
17
18use bytes::Bytes;
19use futures_util::Stream;
20use objectstore_service::id::ObjectContext;
21use objectstore_types::scope::Scopes;
22use serde::{Deserialize, Serialize};
23
24#[derive(Debug, Clone, Copy, PartialEq, Eq)]
26enum RateLimitRejection {
27 BandwidthGlobal,
29 BandwidthUsecase,
31 BandwidthScope,
33 ThroughputGlobal,
35 ThroughputUsecase,
37 ThroughputScope,
39 ThroughputRule,
41}
42
43impl RateLimitRejection {
44 pub fn as_str(self) -> &'static str {
46 match self {
47 RateLimitRejection::BandwidthGlobal => "bandwidth_global",
48 RateLimitRejection::BandwidthUsecase => "bandwidth_usecase",
49 RateLimitRejection::BandwidthScope => "bandwidth_scope",
50 RateLimitRejection::ThroughputGlobal => "throughput_global",
51 RateLimitRejection::ThroughputUsecase => "throughput_usecase",
52 RateLimitRejection::ThroughputScope => "throughput_scope",
53 RateLimitRejection::ThroughputRule => "throughput_rule",
54 }
55 }
56}
57
58impl fmt::Display for RateLimitRejection {
59 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
60 f.write_str(self.as_str())
61 }
62}
63
64#[derive(Clone, Debug, Default, Deserialize, Serialize, PartialEq)]
66pub struct RateLimits {
67 pub throughput: ThroughputLimits,
69 pub bandwidth: BandwidthLimits,
71}
72
73#[derive(Clone, Debug, Default, Deserialize, Serialize, PartialEq)]
78pub struct ThroughputLimits {
79 pub global_rps: Option<u32>,
83
84 pub burst: u32,
89
90 pub usecase_pct: Option<u8>,
94
95 pub scope_pct: Option<u8>,
103
104 pub rules: Vec<ThroughputRule>,
106}
107
108#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
114pub struct ThroughputRule {
115 pub usecase: Option<String>,
119
120 pub scopes: Vec<(String, String)>,
125
126 pub rps: Option<u32>,
131
132 pub pct: Option<u8>,
137}
138
139impl ThroughputRule {
140 pub fn matches(&self, context: &ObjectContext) -> bool {
142 if let Some(ref rule_usecase) = self.usecase
143 && rule_usecase != &context.usecase
144 {
145 return false;
146 }
147
148 for (scope_name, scope_value) in &self.scopes {
149 match context.scopes.get_value(scope_name) {
150 Some(value) if value == scope_value => (),
151 _ => return false,
152 }
153 }
154
155 true
156 }
157}
158
159#[derive(Clone, Debug, Default, Deserialize, Serialize, PartialEq)]
164pub struct BandwidthLimits {
165 pub global_bps: Option<u64>,
169
170 pub usecase_pct: Option<u8>,
174
175 pub scope_pct: Option<u8>,
179}
180
181#[derive(Debug)]
191pub struct RateLimiter {
192 bandwidth: BandwidthRateLimiter,
193 throughput: ThroughputRateLimiter,
194}
195
196impl RateLimiter {
197 pub fn new(config: RateLimits) -> Self {
201 Self {
202 bandwidth: BandwidthRateLimiter::new(config.bandwidth),
203 throughput: ThroughputRateLimiter::new(config.throughput),
204 }
205 }
206
207 pub fn start(&self) {
211 self.bandwidth.start();
212 self.throughput.start();
213 }
214
215 pub fn check(&self, context: &ObjectContext) -> bool {
221 let rejection = self
225 .bandwidth
226 .check(context)
227 .or_else(|| self.throughput.check(context));
228
229 let Some(rejection) = rejection else {
230 return true;
231 };
232
233 objectstore_metrics::count!("server.request.rate_limited", reason = rejection.as_str());
234 objectstore_log::warn!(
235 reason = rejection.as_str(),
236 "Request rejected: rate limit exceeded"
237 );
238 false
239 }
240
241 pub fn bytes_accumulators(&self, context: &ObjectContext) -> Vec<Arc<AtomicU64>> {
245 self.bandwidth.accumulators(context)
246 }
247
248 pub fn record_bandwidth(&self, context: &ObjectContext, bytes: u64) {
253 for acc in self.bandwidth.accumulators(context) {
254 acc.fetch_add(bytes, std::sync::atomic::Ordering::Relaxed);
255 }
256 }
257
258 pub fn bandwidth_ewma(&self) -> u64 {
260 self.bandwidth
261 .global
262 .estimate
263 .load(std::sync::atomic::Ordering::Relaxed)
264 }
265
266 pub fn bandwidth_limit(&self) -> Option<u64> {
268 self.bandwidth.config.global_bps
269 }
270
271 pub fn throughput_rps(&self) -> u64 {
273 self.throughput
274 .global_estimator
275 .estimate
276 .load(std::sync::atomic::Ordering::Relaxed)
277 }
278
279 pub fn throughput_limit(&self) -> Option<u32> {
281 self.throughput.config.global_rps
282 }
283
284 pub fn bandwidth_total_bytes(&self) -> u64 {
286 self.bandwidth
287 .total_bytes
288 .load(std::sync::atomic::Ordering::Relaxed)
289 }
290
291 pub fn throughput_total_admitted(&self) -> u64 {
293 self.throughput
294 .total_admitted
295 .load(std::sync::atomic::Ordering::Relaxed)
296 }
297}
298
299#[derive(Debug)]
304struct EwmaEstimator {
305 accumulator: Arc<AtomicU64>,
306 estimate: Arc<AtomicU64>,
307}
308
309impl EwmaEstimator {
310 fn new() -> Self {
311 Self {
312 accumulator: Arc::new(AtomicU64::new(0)),
313 estimate: Arc::new(AtomicU64::new(0)),
314 }
315 }
316
317 fn update_ewma(&self, ewma: &mut f64, to_rate: f64) {
323 const ALPHA: f64 = 0.2;
324 let current = self
325 .accumulator
326 .swap(0, std::sync::atomic::Ordering::Relaxed);
327 let rate = (current as f64) * to_rate;
328 *ewma = ALPHA * rate + (1.0 - ALPHA) * *ewma;
329 self.estimate
330 .store(ewma.floor() as u64, std::sync::atomic::Ordering::Relaxed);
331 }
332}
333
334#[derive(Debug)]
335struct BandwidthRateLimiter {
336 config: BandwidthLimits,
337 global: Arc<EwmaEstimator>,
339 total_bytes: Arc<AtomicU64>,
341 usecases: Arc<papaya::HashMap<String, Arc<EwmaEstimator>>>,
344 scopes: Arc<papaya::HashMap<Scopes, Arc<EwmaEstimator>>>,
345}
346
347impl BandwidthRateLimiter {
348 fn new(config: BandwidthLimits) -> Self {
349 Self {
350 config,
351 global: Arc::new(EwmaEstimator::new()),
352 total_bytes: Arc::new(AtomicU64::new(0)),
353 usecases: Arc::new(papaya::HashMap::new()),
354 scopes: Arc::new(papaya::HashMap::new()),
355 }
356 }
357
358 fn start(&self) {
359 let global = Arc::clone(&self.global);
360 let usecases = Arc::clone(&self.usecases);
361 let scopes = Arc::clone(&self.scopes);
362 let global_limit = self.config.global_bps;
363 tokio::task::spawn(async move {
366 Self::estimator(global, usecases, scopes, global_limit).await;
367 });
368 }
369
370 async fn estimator(
375 global: Arc<EwmaEstimator>,
376 usecases: Arc<papaya::HashMap<String, Arc<EwmaEstimator>>>,
377 scopes: Arc<papaya::HashMap<Scopes, Arc<EwmaEstimator>>>,
378 global_limit: Option<u64>,
379 ) {
380 const TICK: Duration = Duration::from_millis(50); let mut interval = tokio::time::interval(TICK);
383 interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
384 interval.tick().await;
387 let mut last = Instant::now();
388 let mut global_ewma: f64 = 0.0;
389 let mut usecase_ewmas: std::collections::HashMap<String, f64> =
391 std::collections::HashMap::new();
392 let mut scope_ewmas: std::collections::HashMap<Scopes, f64> =
393 std::collections::HashMap::new();
394
395 loop {
396 interval.tick().await;
397
398 let now = Instant::now();
399 let to_bps = 1.0 / now.duration_since(last).as_secs_f64();
400 last = now;
401
402 global.update_ewma(&mut global_ewma, to_bps);
404 objectstore_metrics::gauge!("server.bandwidth.ewma" = global_ewma.floor() as u64);
405 if let Some(limit) = global_limit {
406 objectstore_metrics::gauge!("server.bandwidth.limit" = limit);
407 }
408
409 {
411 let guard = usecases.pin();
412 for (key, estimator) in guard.iter() {
413 let ewma = usecase_ewmas.entry(key.clone()).or_insert(0.0);
414 estimator.update_ewma(ewma, to_bps);
415 }
416 }
417
418 {
420 let guard = scopes.pin();
421 for (key, estimator) in guard.iter() {
422 let ewma = scope_ewmas.entry(key.clone()).or_insert(0.0);
423 estimator.update_ewma(ewma, to_bps);
424 }
425 }
426
427 objectstore_metrics::gauge!(
428 "server.rate_limiter.bandwidth.scope_map_size" = scopes.len()
429 );
430 objectstore_metrics::gauge!(
431 "server.rate_limiter.bandwidth.usecase_map_size" = usecases.len()
432 );
433 }
434 }
435
436 fn check(&self, context: &ObjectContext) -> Option<RateLimitRejection> {
437 let global_bps = self.config.global_bps?;
438
439 if self
441 .global
442 .estimate
443 .load(std::sync::atomic::Ordering::Relaxed)
444 > global_bps
445 {
446 return Some(RateLimitRejection::BandwidthGlobal);
447 }
448
449 if let Some(usecase_bps) = self.usecase_bps() {
451 let guard = self.usecases.pin();
452 if let Some(estimator) = guard.get(&context.usecase)
453 && estimator
454 .estimate
455 .load(std::sync::atomic::Ordering::Relaxed)
456 > usecase_bps
457 {
458 return Some(RateLimitRejection::BandwidthUsecase);
459 }
460 }
461
462 if let Some(scope_bps) = self.scope_bps() {
464 let guard = self.scopes.pin();
465 if let Some(estimator) = guard.get(&context.scopes)
466 && estimator
467 .estimate
468 .load(std::sync::atomic::Ordering::Relaxed)
469 > scope_bps
470 {
471 return Some(RateLimitRejection::BandwidthScope);
472 }
473 }
474
475 None
476 }
477
478 fn accumulators(&self, context: &ObjectContext) -> Vec<Arc<AtomicU64>> {
483 let mut accs = vec![
484 Arc::clone(&self.total_bytes),
485 Arc::clone(&self.global.accumulator),
486 ];
487
488 if self.usecase_bps().is_some() {
489 let guard = self.usecases.pin();
490 let estimator = guard
491 .get_or_insert_with(context.usecase.clone(), || Arc::new(EwmaEstimator::new()));
492 accs.push(Arc::clone(&estimator.accumulator));
493 }
494
495 if self.scope_bps().is_some() {
496 let guard = self.scopes.pin();
497 let estimator =
498 guard.get_or_insert_with(context.scopes.clone(), || Arc::new(EwmaEstimator::new()));
499 accs.push(Arc::clone(&estimator.accumulator));
500 }
501
502 accs
503 }
504
505 fn usecase_bps(&self) -> Option<u64> {
507 let global_bps = self.config.global_bps?;
508 let pct = self.config.usecase_pct?;
509 Some(((global_bps as f64) * (pct as f64 / 100.0)) as u64)
510 }
511
512 fn scope_bps(&self) -> Option<u64> {
514 let global_bps = self.config.global_bps?;
515 let pct = self.config.scope_pct?;
516 Some(((global_bps as f64) * (pct as f64 / 100.0)) as u64)
517 }
518}
519
520#[derive(Debug)]
521struct ThroughputRateLimiter {
522 config: ThroughputLimits,
523 global: Option<Mutex<TokenBucket>>,
524 global_estimator: Arc<EwmaEstimator>,
526 total_admitted: Arc<AtomicU64>,
528 usecases: Arc<papaya::HashMap<String, Mutex<TokenBucket>>>,
531 scopes: Arc<papaya::HashMap<Scopes, Mutex<TokenBucket>>>,
532 rules: papaya::HashMap<usize, Mutex<TokenBucket>>,
533}
534
535impl ThroughputRateLimiter {
536 fn new(config: ThroughputLimits) -> Self {
537 let global = config
538 .global_rps
539 .map(|rps| Mutex::new(TokenBucket::new(rps, config.burst)));
540
541 Self {
542 config,
543 global,
544 global_estimator: Arc::new(EwmaEstimator::new()),
545 total_admitted: Arc::new(AtomicU64::new(0)),
546 usecases: Arc::new(papaya::HashMap::new()),
547 scopes: Arc::new(papaya::HashMap::new()),
548 rules: papaya::HashMap::new(),
549 }
550 }
551
552 fn start(&self) {
553 let usecases = Arc::clone(&self.usecases);
554 let scopes = Arc::clone(&self.scopes);
555 let global_estimator = Arc::clone(&self.global_estimator);
556 let global_limit = self.config.global_rps;
557 tokio::task::spawn(async move {
560 const TICK: Duration = Duration::from_millis(50);
561 let mut interval = tokio::time::interval(TICK);
562 interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
563 interval.tick().await;
564 let mut last = Instant::now();
565 let mut global_ewma: f64 = 0.0;
566 loop {
567 interval.tick().await;
568 let now = Instant::now();
569 let to_rps = 1.0 / now.duration_since(last).as_secs_f64();
570 last = now;
571 global_estimator.update_ewma(&mut global_ewma, to_rps);
572 objectstore_metrics::gauge!("server.throughput.ewma" = global_ewma.floor() as u64);
573 if let Some(limit) = global_limit {
574 objectstore_metrics::gauge!(
575 "server.rate_limiter.throughput.limit" = u64::from(limit)
576 );
577 }
578 objectstore_metrics::gauge!(
579 "server.rate_limiter.throughput.scope_map_size" = scopes.len()
580 );
581 objectstore_metrics::gauge!(
582 "server.rate_limiter.throughput.usecase_map_size" = usecases.len()
583 );
584 }
585 });
586 }
587
588 fn check(&self, context: &ObjectContext) -> Option<RateLimitRejection> {
589 if let Some(ref global) = self.global {
593 let acquired = global.lock().unwrap().try_acquire();
594 if !acquired {
595 return Some(RateLimitRejection::ThroughputGlobal);
596 }
597 }
598
599 if let Some(usecase_rps) = self.usecase_rps() {
601 let guard = self.usecases.pin();
602 let bucket = guard
603 .get_or_insert_with(context.usecase.clone(), || self.create_bucket(usecase_rps));
604 if !bucket.lock().unwrap().try_acquire() {
605 return Some(RateLimitRejection::ThroughputUsecase);
606 }
607 }
608
609 if let Some(scope_rps) = self.scope_rps() {
611 let guard = self.scopes.pin();
612 let bucket =
613 guard.get_or_insert_with(context.scopes.clone(), || self.create_bucket(scope_rps));
614 if !bucket.lock().unwrap().try_acquire() {
615 return Some(RateLimitRejection::ThroughputScope);
616 }
617 }
618
619 for (idx, rule) in self.config.rules.iter().enumerate() {
621 if !rule.matches(context) {
622 continue;
623 }
624 let Some(rule_rps) = self.rule_rps(rule) else {
625 continue;
626 };
627 let guard = self.rules.pin();
628 let bucket = guard.get_or_insert_with(idx, || self.create_bucket(rule_rps));
629 if !bucket.lock().unwrap().try_acquire() {
630 return Some(RateLimitRejection::ThroughputRule);
631 }
632 }
633
634 self.global_estimator
638 .accumulator
639 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
640 self.total_admitted
641 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
642
643 None
644 }
645
646 fn create_bucket(&self, rps: u32) -> Mutex<TokenBucket> {
647 Mutex::new(TokenBucket::new(rps, self.config.burst))
648 }
649
650 fn usecase_rps(&self) -> Option<u32> {
652 let global_rps = self.config.global_rps?;
653 let pct = self.config.usecase_pct?;
654 Some(((global_rps as f64) * (pct as f64 / 100.0)) as u32)
655 }
656
657 fn scope_rps(&self) -> Option<u32> {
659 let global_rps = self.config.global_rps?;
660 let pct = self.config.scope_pct?;
661 Some(((global_rps as f64) * (pct as f64 / 100.0)) as u32)
662 }
663
664 fn rule_rps(&self, rule: &ThroughputRule) -> Option<u32> {
666 let pct_limit = rule.pct.and_then(|p| {
667 self.config
668 .global_rps
669 .map(|g| ((g as f64) * (p as f64 / 100.0)) as u32)
670 });
671
672 match (rule.rps, pct_limit) {
673 (Some(r), Some(p)) => Some(r.min(p)),
674 (Some(r), None) => Some(r),
675 (None, Some(p)) => Some(p),
676 (None, None) => None,
677 }
678 }
679}
680
681#[derive(Debug)]
688struct TokenBucket {
689 refill_rate: f64,
690 capacity: f64,
691 tokens: f64,
692 last_update: Instant,
693}
694
695impl TokenBucket {
696 pub fn new(rps: u32, burst: u32) -> Self {
701 Self {
702 refill_rate: rps as f64,
703 capacity: (rps + burst) as f64,
704 tokens: (rps + burst) as f64,
705 last_update: Instant::now(),
706 }
707 }
708
709 pub fn try_acquire(&mut self) -> bool {
713 let now = Instant::now();
714 let refill = now.duration_since(self.last_update).as_secs_f64() * self.refill_rate;
715 let refilled = (self.tokens + refill).min(self.capacity);
716
717 if refilled.floor() > self.tokens.floor() {
719 self.last_update = now;
720 self.tokens = refilled;
721 }
722
723 if self.tokens >= 1.0 {
725 self.tokens -= 1.0;
726 true
727 } else {
728 false
729 }
730 }
731}
732
733pub(crate) struct MeteredPayloadStream<S> {
738 inner: S,
739 accumulators: Vec<Arc<AtomicU64>>,
740}
741
742impl<S> MeteredPayloadStream<S> {
743 pub fn new(inner: S, accumulators: Vec<Arc<AtomicU64>>) -> Self {
744 Self {
745 inner,
746 accumulators,
747 }
748 }
749}
750
751impl<S> fmt::Debug for MeteredPayloadStream<S> {
752 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
753 f.debug_struct("MeteredPayloadStream")
754 .field("accumulators", &self.accumulators)
755 .finish()
756 }
757}
758
759impl<S, E> Stream for MeteredPayloadStream<S>
760where
761 S: Stream<Item = Result<Bytes, E>> + Unpin,
762{
763 type Item = Result<Bytes, E>;
764
765 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
766 let this = self.get_mut();
767 let res = Pin::new(&mut this.inner).poll_next(cx);
768 if let Poll::Ready(Some(Ok(ref bytes))) = res {
769 let len = bytes.len() as u64;
770 for acc in &this.accumulators {
771 acc.fetch_add(len, std::sync::atomic::Ordering::Relaxed);
772 }
773 }
774 res
775 }
776}
777
778#[cfg(test)]
779mod tests {
780 use objectstore_service::id::ObjectContext;
781 use objectstore_types::scope::{Scope, Scopes};
782
783 use super::*;
784
785 fn make_context() -> ObjectContext {
786 ObjectContext {
787 usecase: "testing".into(),
788 scopes: Scopes::from_iter([Scope::create("org", "1").unwrap()]),
789 }
790 }
791
792 #[test]
793 fn ewma_estimator_update_applies_alpha() {
794 let estimator = EwmaEstimator::new();
795 const TICK: f64 = 0.05; let to_rate = 1.0 / TICK;
797 let mut ewma: f64 = 0.0;
798
799 estimator
802 .accumulator
803 .store(10, std::sync::atomic::Ordering::Relaxed);
804 estimator.update_ewma(&mut ewma, to_rate);
805 assert_eq!(
806 estimator
807 .estimate
808 .load(std::sync::atomic::Ordering::Relaxed),
809 40
810 );
811
812 assert_eq!(
814 estimator
815 .accumulator
816 .load(std::sync::atomic::Ordering::Relaxed),
817 0
818 );
819 }
820
821 #[test]
822 fn throughput_check_increments_accumulator() {
823 let limiter = ThroughputRateLimiter::new(ThroughputLimits {
824 global_rps: Some(1000),
825 ..Default::default()
826 });
827
828 assert_eq!(
829 limiter
830 .global_estimator
831 .accumulator
832 .load(std::sync::atomic::Ordering::Relaxed),
833 0
834 );
835
836 let context = make_context();
837 assert!(limiter.check(&context).is_none());
838 assert!(limiter.check(&context).is_none());
839
840 assert_eq!(
841 limiter
842 .global_estimator
843 .accumulator
844 .load(std::sync::atomic::Ordering::Relaxed),
845 2
846 );
847 }
848
849 #[test]
850 fn throughput_rejected_does_not_increment_accumulator() {
851 let limiter = ThroughputRateLimiter::new(ThroughputLimits {
852 global_rps: Some(1),
853 burst: 0,
854 ..Default::default()
855 });
856
857 let context = make_context();
858 assert!(limiter.check(&context).is_none());
860 assert!(limiter.check(&context).is_some());
861
862 assert_eq!(
863 limiter
864 .global_estimator
865 .accumulator
866 .load(std::sync::atomic::Ordering::Relaxed),
867 1 );
869 }
870
871 #[test]
872 fn bandwidth_rejection_does_not_increment_throughput_accumulator() {
873 let limiter = RateLimiter::new(RateLimits {
878 throughput: ThroughputLimits {
879 global_rps: Some(1000),
880 ..Default::default()
881 },
882 bandwidth: BandwidthLimits {
883 global_bps: Some(0),
884 ..Default::default()
885 },
886 });
887
888 limiter
890 .bandwidth
891 .global
892 .estimate
893 .store(1, std::sync::atomic::Ordering::Relaxed);
894
895 let context = make_context();
896 assert!(!limiter.check(&context));
897
898 assert_eq!(
900 limiter
901 .throughput
902 .global_estimator
903 .accumulator
904 .load(std::sync::atomic::Ordering::Relaxed),
905 0
906 );
907 }
908
909 #[test]
910 fn rate_limiter_accessors_with_config() {
911 let rate_limiter = RateLimiter::new(RateLimits {
912 throughput: ThroughputLimits {
913 global_rps: Some(500),
914 ..Default::default()
915 },
916 bandwidth: BandwidthLimits {
917 global_bps: Some(1_000_000),
918 ..Default::default()
919 },
920 });
921
922 assert_eq!(rate_limiter.bandwidth_limit(), Some(1_000_000));
923 assert_eq!(rate_limiter.throughput_limit(), Some(500));
924 assert_eq!(rate_limiter.bandwidth_ewma(), 0);
926 assert_eq!(rate_limiter.throughput_rps(), 0);
927 }
928
929 #[test]
930 fn rate_limiter_accessors_no_limits() {
931 let rate_limiter = RateLimiter::new(RateLimits::default());
932
933 assert_eq!(rate_limiter.bandwidth_limit(), None);
934 assert_eq!(rate_limiter.throughput_limit(), None);
935 assert_eq!(rate_limiter.bandwidth_ewma(), 0);
936 assert_eq!(rate_limiter.throughput_rps(), 0);
937 }
938}