1use std::fmt;
12use std::pin::Pin;
13use std::sync::Arc;
14use std::sync::{Mutex, atomic::AtomicU64};
15use std::task::{Context, Poll};
16use std::time::{Duration, Instant};
17
18use bytes::Bytes;
19use futures_util::Stream;
20use objectstore_service::id::ObjectContext;
21use objectstore_types::scope::Scopes;
22use serde::{Deserialize, Serialize};
23
24#[derive(Debug, Clone, Copy, PartialEq, Eq)]
26enum RateLimitRejection {
27 BandwidthGlobal,
29 BandwidthUsecase,
31 BandwidthScope,
33 ThroughputGlobal,
35 ThroughputUsecase,
37 ThroughputScope,
39 ThroughputRule,
41}
42
43impl RateLimitRejection {
44 pub fn as_str(self) -> &'static str {
46 match self {
47 RateLimitRejection::BandwidthGlobal => "bandwidth_global",
48 RateLimitRejection::BandwidthUsecase => "bandwidth_usecase",
49 RateLimitRejection::BandwidthScope => "bandwidth_scope",
50 RateLimitRejection::ThroughputGlobal => "throughput_global",
51 RateLimitRejection::ThroughputUsecase => "throughput_usecase",
52 RateLimitRejection::ThroughputScope => "throughput_scope",
53 RateLimitRejection::ThroughputRule => "throughput_rule",
54 }
55 }
56}
57
58impl fmt::Display for RateLimitRejection {
59 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
60 f.write_str(self.as_str())
61 }
62}
63
64#[derive(Clone, Debug, Default, Deserialize, Serialize, PartialEq)]
66pub struct RateLimits {
67 pub throughput: ThroughputLimits,
69 pub bandwidth: BandwidthLimits,
71}
72
73#[derive(Clone, Debug, Default, Deserialize, Serialize, PartialEq)]
78pub struct ThroughputLimits {
79 pub global_rps: Option<u32>,
83
84 pub burst: u32,
89
90 pub usecase_pct: Option<u8>,
94
95 pub scope_pct: Option<u8>,
103
104 pub rules: Vec<ThroughputRule>,
106}
107
108#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
114pub struct ThroughputRule {
115 pub usecase: Option<String>,
119
120 pub scopes: Vec<(String, String)>,
125
126 pub rps: Option<u32>,
131
132 pub pct: Option<u8>,
137}
138
139impl ThroughputRule {
140 pub fn matches(&self, context: &ObjectContext) -> bool {
142 if let Some(ref rule_usecase) = self.usecase
143 && rule_usecase != &context.usecase
144 {
145 return false;
146 }
147
148 for (scope_name, scope_value) in &self.scopes {
149 match context.scopes.get_value(scope_name) {
150 Some(value) if value == scope_value => (),
151 _ => return false,
152 }
153 }
154
155 true
156 }
157}
158
159#[derive(Clone, Debug, Default, Deserialize, Serialize, PartialEq)]
164pub struct BandwidthLimits {
165 pub global_bps: Option<u64>,
169
170 pub usecase_pct: Option<u8>,
174
175 pub scope_pct: Option<u8>,
179}
180
181#[derive(Debug)]
191pub struct RateLimiter {
192 bandwidth: BandwidthRateLimiter,
193 throughput: ThroughputRateLimiter,
194}
195
196impl RateLimiter {
197 pub fn new(config: RateLimits) -> Self {
201 Self {
202 bandwidth: BandwidthRateLimiter::new(config.bandwidth),
203 throughput: ThroughputRateLimiter::new(config.throughput),
204 }
205 }
206
207 pub fn start(&self) {
211 self.bandwidth.start();
212 self.throughput.start();
213 }
214
215 pub fn check(&self, context: &ObjectContext, key: Option<&str>) -> bool {
221 let rejection = self
225 .bandwidth
226 .check(context)
227 .or_else(|| self.throughput.check(context));
228
229 let Some(rejection) = rejection else {
230 return true;
231 };
232
233 objectstore_metrics::count!(
234 "server.request.rate_limited",
235 reason = rejection.as_str(),
236 usecase = context.usecase.clone()
237 );
238 objectstore_log::warn!(
239 reason = rejection.as_str(),
240 usecase = &context.usecase,
241 scopes = %context.scopes.as_api_path(),
242 key,
243 "Request rejected: rate limit exceeded"
244 );
245 false
246 }
247
248 pub fn bytes_accumulators(&self, context: &ObjectContext) -> Vec<Arc<AtomicU64>> {
252 self.bandwidth.accumulators(context)
253 }
254
255 pub fn record_bandwidth(&self, context: &ObjectContext, bytes: u64) {
260 for acc in self.bandwidth.accumulators(context) {
261 acc.fetch_add(bytes, std::sync::atomic::Ordering::Relaxed);
262 }
263 }
264
265 pub fn bandwidth_ewma(&self) -> u64 {
267 self.bandwidth
268 .global
269 .estimate
270 .load(std::sync::atomic::Ordering::Relaxed)
271 }
272
273 pub fn bandwidth_limit(&self) -> Option<u64> {
275 self.bandwidth.config.global_bps
276 }
277
278 pub fn throughput_rps(&self) -> u64 {
280 self.throughput
281 .global_estimator
282 .estimate
283 .load(std::sync::atomic::Ordering::Relaxed)
284 }
285
286 pub fn throughput_limit(&self) -> Option<u32> {
288 self.throughput.config.global_rps
289 }
290
291 pub fn bandwidth_total_bytes(&self) -> u64 {
293 self.bandwidth
294 .total_bytes
295 .load(std::sync::atomic::Ordering::Relaxed)
296 }
297
298 pub fn throughput_total_admitted(&self) -> u64 {
300 self.throughput
301 .total_admitted
302 .load(std::sync::atomic::Ordering::Relaxed)
303 }
304}
305
306#[derive(Debug)]
311struct EwmaEstimator {
312 accumulator: Arc<AtomicU64>,
313 estimate: Arc<AtomicU64>,
314}
315
316impl EwmaEstimator {
317 fn new() -> Self {
318 Self {
319 accumulator: Arc::new(AtomicU64::new(0)),
320 estimate: Arc::new(AtomicU64::new(0)),
321 }
322 }
323
324 fn update_ewma(&self, ewma: &mut f64, to_rate: f64) {
330 const ALPHA: f64 = 0.2;
331 let current = self
332 .accumulator
333 .swap(0, std::sync::atomic::Ordering::Relaxed);
334 let rate = (current as f64) * to_rate;
335 *ewma = ALPHA * rate + (1.0 - ALPHA) * *ewma;
336 self.estimate
337 .store(ewma.floor() as u64, std::sync::atomic::Ordering::Relaxed);
338 }
339}
340
341#[derive(Debug)]
342struct BandwidthRateLimiter {
343 config: BandwidthLimits,
344 global: Arc<EwmaEstimator>,
346 total_bytes: Arc<AtomicU64>,
348 usecases: Arc<papaya::HashMap<String, Arc<EwmaEstimator>>>,
351 scopes: Arc<papaya::HashMap<Scopes, Arc<EwmaEstimator>>>,
352}
353
354impl BandwidthRateLimiter {
355 fn new(config: BandwidthLimits) -> Self {
356 Self {
357 config,
358 global: Arc::new(EwmaEstimator::new()),
359 total_bytes: Arc::new(AtomicU64::new(0)),
360 usecases: Arc::new(papaya::HashMap::new()),
361 scopes: Arc::new(papaya::HashMap::new()),
362 }
363 }
364
365 fn start(&self) {
366 let global = Arc::clone(&self.global);
367 let usecases = Arc::clone(&self.usecases);
368 let scopes = Arc::clone(&self.scopes);
369 let global_limit = self.config.global_bps;
370 tokio::task::spawn(async move {
373 Self::estimator(global, usecases, scopes, global_limit).await;
374 });
375 }
376
377 async fn estimator(
382 global: Arc<EwmaEstimator>,
383 usecases: Arc<papaya::HashMap<String, Arc<EwmaEstimator>>>,
384 scopes: Arc<papaya::HashMap<Scopes, Arc<EwmaEstimator>>>,
385 global_limit: Option<u64>,
386 ) {
387 const TICK: Duration = Duration::from_millis(50); let mut interval = tokio::time::interval(TICK);
390 interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
391 interval.tick().await;
394 let mut last = Instant::now();
395 let mut global_ewma: f64 = 0.0;
396 let mut usecase_ewmas: std::collections::HashMap<String, f64> =
398 std::collections::HashMap::new();
399 let mut scope_ewmas: std::collections::HashMap<Scopes, f64> =
400 std::collections::HashMap::new();
401
402 loop {
403 interval.tick().await;
404
405 let now = Instant::now();
406 let to_bps = 1.0 / now.duration_since(last).as_secs_f64();
407 last = now;
408
409 global.update_ewma(&mut global_ewma, to_bps);
411 objectstore_metrics::gauge!("server.bandwidth.ewma" = global_ewma.floor() as u64);
412 if let Some(limit) = global_limit {
413 objectstore_metrics::gauge!("server.bandwidth.limit" = limit);
414 }
415
416 {
418 let guard = usecases.pin();
419 for (key, estimator) in guard.iter() {
420 let ewma = usecase_ewmas.entry(key.clone()).or_insert(0.0);
421 estimator.update_ewma(ewma, to_bps);
422 }
423 }
424
425 {
427 let guard = scopes.pin();
428 for (key, estimator) in guard.iter() {
429 let ewma = scope_ewmas.entry(key.clone()).or_insert(0.0);
430 estimator.update_ewma(ewma, to_bps);
431 }
432 }
433
434 objectstore_metrics::gauge!(
435 "server.rate_limiter.bandwidth.scope_map_size" = scopes.len()
436 );
437 objectstore_metrics::gauge!(
438 "server.rate_limiter.bandwidth.usecase_map_size" = usecases.len()
439 );
440 }
441 }
442
443 fn check(&self, context: &ObjectContext) -> Option<RateLimitRejection> {
444 let global_bps = self.config.global_bps?;
445
446 if self
448 .global
449 .estimate
450 .load(std::sync::atomic::Ordering::Relaxed)
451 > global_bps
452 {
453 return Some(RateLimitRejection::BandwidthGlobal);
454 }
455
456 if let Some(usecase_bps) = self.usecase_bps() {
458 let guard = self.usecases.pin();
459 if let Some(estimator) = guard.get(&context.usecase)
460 && estimator
461 .estimate
462 .load(std::sync::atomic::Ordering::Relaxed)
463 > usecase_bps
464 {
465 return Some(RateLimitRejection::BandwidthUsecase);
466 }
467 }
468
469 if let Some(scope_bps) = self.scope_bps() {
471 let guard = self.scopes.pin();
472 if let Some(estimator) = guard.get(&context.scopes)
473 && estimator
474 .estimate
475 .load(std::sync::atomic::Ordering::Relaxed)
476 > scope_bps
477 {
478 return Some(RateLimitRejection::BandwidthScope);
479 }
480 }
481
482 None
483 }
484
485 fn accumulators(&self, context: &ObjectContext) -> Vec<Arc<AtomicU64>> {
490 let mut accs = vec![
491 Arc::clone(&self.total_bytes),
492 Arc::clone(&self.global.accumulator),
493 ];
494
495 if self.usecase_bps().is_some() {
496 let guard = self.usecases.pin();
497 let estimator = guard
498 .get_or_insert_with(context.usecase.clone(), || Arc::new(EwmaEstimator::new()));
499 accs.push(Arc::clone(&estimator.accumulator));
500 }
501
502 if self.scope_bps().is_some() {
503 let guard = self.scopes.pin();
504 let estimator =
505 guard.get_or_insert_with(context.scopes.clone(), || Arc::new(EwmaEstimator::new()));
506 accs.push(Arc::clone(&estimator.accumulator));
507 }
508
509 accs
510 }
511
512 fn usecase_bps(&self) -> Option<u64> {
514 let global_bps = self.config.global_bps?;
515 let pct = self.config.usecase_pct?;
516 Some(((global_bps as f64) * (pct as f64 / 100.0)) as u64)
517 }
518
519 fn scope_bps(&self) -> Option<u64> {
521 let global_bps = self.config.global_bps?;
522 let pct = self.config.scope_pct?;
523 Some(((global_bps as f64) * (pct as f64 / 100.0)) as u64)
524 }
525}
526
527#[derive(Debug)]
528struct ThroughputRateLimiter {
529 config: ThroughputLimits,
530 global: Option<Mutex<TokenBucket>>,
531 global_estimator: Arc<EwmaEstimator>,
533 total_admitted: Arc<AtomicU64>,
535 usecases: Arc<papaya::HashMap<String, Mutex<TokenBucket>>>,
538 scopes: Arc<papaya::HashMap<Scopes, Mutex<TokenBucket>>>,
539 rules: papaya::HashMap<usize, Mutex<TokenBucket>>,
540}
541
542impl ThroughputRateLimiter {
543 fn new(config: ThroughputLimits) -> Self {
544 let global = config
545 .global_rps
546 .map(|rps| Mutex::new(TokenBucket::new(rps, config.burst)));
547
548 Self {
549 config,
550 global,
551 global_estimator: Arc::new(EwmaEstimator::new()),
552 total_admitted: Arc::new(AtomicU64::new(0)),
553 usecases: Arc::new(papaya::HashMap::new()),
554 scopes: Arc::new(papaya::HashMap::new()),
555 rules: papaya::HashMap::new(),
556 }
557 }
558
559 fn start(&self) {
560 let usecases = Arc::clone(&self.usecases);
561 let scopes = Arc::clone(&self.scopes);
562 let global_estimator = Arc::clone(&self.global_estimator);
563 let global_limit = self.config.global_rps;
564 tokio::task::spawn(async move {
567 const TICK: Duration = Duration::from_millis(50);
568 let mut interval = tokio::time::interval(TICK);
569 interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
570 interval.tick().await;
571 let mut last = Instant::now();
572 let mut global_ewma: f64 = 0.0;
573 loop {
574 interval.tick().await;
575 let now = Instant::now();
576 let to_rps = 1.0 / now.duration_since(last).as_secs_f64();
577 last = now;
578 global_estimator.update_ewma(&mut global_ewma, to_rps);
579 objectstore_metrics::gauge!("server.throughput.ewma" = global_ewma.floor() as u64);
580 if let Some(limit) = global_limit {
581 objectstore_metrics::gauge!(
582 "server.rate_limiter.throughput.limit" = u64::from(limit)
583 );
584 }
585 objectstore_metrics::gauge!(
586 "server.rate_limiter.throughput.scope_map_size" = scopes.len()
587 );
588 objectstore_metrics::gauge!(
589 "server.rate_limiter.throughput.usecase_map_size" = usecases.len()
590 );
591 }
592 });
593 }
594
595 fn check(&self, context: &ObjectContext) -> Option<RateLimitRejection> {
596 if let Some(ref global) = self.global {
600 let acquired = global.lock().unwrap().try_acquire();
601 if !acquired {
602 return Some(RateLimitRejection::ThroughputGlobal);
603 }
604 }
605
606 if let Some(usecase_rps) = self.usecase_rps() {
608 let guard = self.usecases.pin();
609 let bucket = guard
610 .get_or_insert_with(context.usecase.clone(), || self.create_bucket(usecase_rps));
611 if !bucket.lock().unwrap().try_acquire() {
612 return Some(RateLimitRejection::ThroughputUsecase);
613 }
614 }
615
616 if let Some(scope_rps) = self.scope_rps() {
618 let guard = self.scopes.pin();
619 let bucket =
620 guard.get_or_insert_with(context.scopes.clone(), || self.create_bucket(scope_rps));
621 if !bucket.lock().unwrap().try_acquire() {
622 return Some(RateLimitRejection::ThroughputScope);
623 }
624 }
625
626 for (idx, rule) in self.config.rules.iter().enumerate() {
628 if !rule.matches(context) {
629 continue;
630 }
631 let Some(rule_rps) = self.rule_rps(rule) else {
632 continue;
633 };
634 let guard = self.rules.pin();
635 let bucket = guard.get_or_insert_with(idx, || self.create_bucket(rule_rps));
636 if !bucket.lock().unwrap().try_acquire() {
637 return Some(RateLimitRejection::ThroughputRule);
638 }
639 }
640
641 self.global_estimator
645 .accumulator
646 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
647 self.total_admitted
648 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
649
650 None
651 }
652
653 fn create_bucket(&self, rps: u32) -> Mutex<TokenBucket> {
654 Mutex::new(TokenBucket::new(rps, self.config.burst))
655 }
656
657 fn usecase_rps(&self) -> Option<u32> {
659 let global_rps = self.config.global_rps?;
660 let pct = self.config.usecase_pct?;
661 Some(((global_rps as f64) * (pct as f64 / 100.0)) as u32)
662 }
663
664 fn scope_rps(&self) -> Option<u32> {
666 let global_rps = self.config.global_rps?;
667 let pct = self.config.scope_pct?;
668 Some(((global_rps as f64) * (pct as f64 / 100.0)) as u32)
669 }
670
671 fn rule_rps(&self, rule: &ThroughputRule) -> Option<u32> {
673 let pct_limit = rule.pct.and_then(|p| {
674 self.config
675 .global_rps
676 .map(|g| ((g as f64) * (p as f64 / 100.0)) as u32)
677 });
678
679 match (rule.rps, pct_limit) {
680 (Some(r), Some(p)) => Some(r.min(p)),
681 (Some(r), None) => Some(r),
682 (None, Some(p)) => Some(p),
683 (None, None) => None,
684 }
685 }
686}
687
688#[derive(Debug)]
695struct TokenBucket {
696 refill_rate: f64,
697 capacity: f64,
698 tokens: f64,
699 last_update: Instant,
700}
701
702impl TokenBucket {
703 pub fn new(rps: u32, burst: u32) -> Self {
708 Self {
709 refill_rate: rps as f64,
710 capacity: (rps + burst) as f64,
711 tokens: (rps + burst) as f64,
712 last_update: Instant::now(),
713 }
714 }
715
716 pub fn try_acquire(&mut self) -> bool {
720 let now = Instant::now();
721 let refill = now.duration_since(self.last_update).as_secs_f64() * self.refill_rate;
722 let refilled = (self.tokens + refill).min(self.capacity);
723
724 if refilled.floor() > self.tokens.floor() {
726 self.last_update = now;
727 self.tokens = refilled;
728 }
729
730 if self.tokens >= 1.0 {
732 self.tokens -= 1.0;
733 true
734 } else {
735 false
736 }
737 }
738}
739
740pub(crate) struct MeteredPayloadStream<S> {
745 inner: S,
746 accumulators: Vec<Arc<AtomicU64>>,
747}
748
749impl<S> MeteredPayloadStream<S> {
750 pub fn new(inner: S, accumulators: Vec<Arc<AtomicU64>>) -> Self {
751 Self {
752 inner,
753 accumulators,
754 }
755 }
756}
757
758impl<S> fmt::Debug for MeteredPayloadStream<S> {
759 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
760 f.debug_struct("MeteredPayloadStream")
761 .field("accumulators", &self.accumulators)
762 .finish()
763 }
764}
765
766impl<S, E> Stream for MeteredPayloadStream<S>
767where
768 S: Stream<Item = Result<Bytes, E>> + Unpin,
769{
770 type Item = Result<Bytes, E>;
771
772 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
773 let this = self.get_mut();
774 let res = Pin::new(&mut this.inner).poll_next(cx);
775 if let Poll::Ready(Some(Ok(ref bytes))) = res {
776 let len = bytes.len() as u64;
777 for acc in &this.accumulators {
778 acc.fetch_add(len, std::sync::atomic::Ordering::Relaxed);
779 }
780 }
781 res
782 }
783}
784
785#[cfg(test)]
786mod tests {
787 use objectstore_service::id::ObjectContext;
788 use objectstore_types::scope::{Scope, Scopes};
789
790 use super::*;
791
792 fn make_context() -> ObjectContext {
793 ObjectContext {
794 usecase: "testing".into(),
795 scopes: Scopes::from_iter([Scope::create("org", "1").unwrap()]),
796 }
797 }
798
799 #[test]
800 fn ewma_estimator_update_applies_alpha() {
801 let estimator = EwmaEstimator::new();
802 const TICK: f64 = 0.05; let to_rate = 1.0 / TICK;
804 let mut ewma: f64 = 0.0;
805
806 estimator
809 .accumulator
810 .store(10, std::sync::atomic::Ordering::Relaxed);
811 estimator.update_ewma(&mut ewma, to_rate);
812 assert_eq!(
813 estimator
814 .estimate
815 .load(std::sync::atomic::Ordering::Relaxed),
816 40
817 );
818
819 assert_eq!(
821 estimator
822 .accumulator
823 .load(std::sync::atomic::Ordering::Relaxed),
824 0
825 );
826 }
827
828 #[test]
829 fn throughput_check_increments_accumulator() {
830 let limiter = ThroughputRateLimiter::new(ThroughputLimits {
831 global_rps: Some(1000),
832 ..Default::default()
833 });
834
835 assert_eq!(
836 limiter
837 .global_estimator
838 .accumulator
839 .load(std::sync::atomic::Ordering::Relaxed),
840 0
841 );
842
843 let context = make_context();
844 assert!(limiter.check(&context).is_none());
845 assert!(limiter.check(&context).is_none());
846
847 assert_eq!(
848 limiter
849 .global_estimator
850 .accumulator
851 .load(std::sync::atomic::Ordering::Relaxed),
852 2
853 );
854 }
855
856 #[test]
857 fn throughput_rejected_does_not_increment_accumulator() {
858 let limiter = ThroughputRateLimiter::new(ThroughputLimits {
859 global_rps: Some(1),
860 burst: 0,
861 ..Default::default()
862 });
863
864 let context = make_context();
865 assert!(limiter.check(&context).is_none());
867 assert!(limiter.check(&context).is_some());
868
869 assert_eq!(
870 limiter
871 .global_estimator
872 .accumulator
873 .load(std::sync::atomic::Ordering::Relaxed),
874 1 );
876 }
877
878 #[test]
879 fn bandwidth_rejection_does_not_increment_throughput_accumulator() {
880 let limiter = RateLimiter::new(RateLimits {
885 throughput: ThroughputLimits {
886 global_rps: Some(1000),
887 ..Default::default()
888 },
889 bandwidth: BandwidthLimits {
890 global_bps: Some(0),
891 ..Default::default()
892 },
893 });
894
895 limiter
897 .bandwidth
898 .global
899 .estimate
900 .store(1, std::sync::atomic::Ordering::Relaxed);
901
902 let context = make_context();
903 assert!(!limiter.check(&context, None));
904
905 assert_eq!(
907 limiter
908 .throughput
909 .global_estimator
910 .accumulator
911 .load(std::sync::atomic::Ordering::Relaxed),
912 0
913 );
914 }
915
916 #[test]
917 fn rate_limiter_accessors_with_config() {
918 let rate_limiter = RateLimiter::new(RateLimits {
919 throughput: ThroughputLimits {
920 global_rps: Some(500),
921 ..Default::default()
922 },
923 bandwidth: BandwidthLimits {
924 global_bps: Some(1_000_000),
925 ..Default::default()
926 },
927 });
928
929 assert_eq!(rate_limiter.bandwidth_limit(), Some(1_000_000));
930 assert_eq!(rate_limiter.throughput_limit(), Some(500));
931 assert_eq!(rate_limiter.bandwidth_ewma(), 0);
933 assert_eq!(rate_limiter.throughput_rps(), 0);
934 }
935
936 #[test]
937 fn rate_limiter_accessors_no_limits() {
938 let rate_limiter = RateLimiter::new(RateLimits::default());
939
940 assert_eq!(rate_limiter.bandwidth_limit(), None);
941 assert_eq!(rate_limiter.throughput_limit(), None);
942 assert_eq!(rate_limiter.bandwidth_ewma(), 0);
943 assert_eq!(rate_limiter.throughput_rps(), 0);
944 }
945}