1use std::fmt::{self, Debug};
2use std::sync::Arc;
3
4use relay_base_schema::metrics::MetricNamespace;
5use relay_base_schema::organization::OrganizationId;
6use relay_common::time::UnixTimestamp;
7use relay_log::protocol::value;
8use relay_redis::redis::{self, FromRedisValue, Script};
9use relay_redis::{AsyncRedisClient, RedisError, RedisScripts};
10use thiserror::Error;
11
12use crate::cache::OpportunisticQuotaCache;
13use crate::global::GlobalLimiter;
14use crate::quota::{ItemScoping, Quota, QuotaScope};
15use crate::rate_limit::{RateLimit, RateLimits, RetryAfter};
16use crate::statsd::{QuotaCounters, QuotaTimers};
17use crate::{REJECT_ALL_SECS, cache};
18
19const GRACE: u64 = 60;
23
24#[derive(Debug, Error)]
26pub enum RateLimitingError {
27 #[error("failed to communicate with redis")]
29 Redis(
30 #[from]
31 #[source]
32 RedisError,
33 ),
34
35 #[error("failed to check global rate limits")]
37 UnreachableGlobalRateLimits,
38}
39
40fn get_refunded_quota_key(counter_key: &str) -> String {
45 format!("r:{counter_key}")
46}
47
48struct OptionalDisplay<T>(Option<T>);
50
51impl<T> fmt::Display for OptionalDisplay<T>
52where
53 T: fmt::Display,
54{
55 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
56 match self.0 {
57 Some(ref value) => write!(f, "{value}"),
58 None => Ok(()),
59 }
60 }
61}
62
63#[derive(Debug, Clone)]
65pub struct OwnedRedisQuota {
66 quota: Quota,
68 scoping: ItemScoping,
70 prefix: Arc<str>,
72 window: u64,
74 quantity: u64,
76 timestamp: UnixTimestamp,
78}
79
80impl OwnedRedisQuota {
81 pub fn build_ref(&self) -> RedisQuota<'_> {
83 RedisQuota {
84 quota: &self.quota,
85 scoping: self.scoping,
86 prefix: Arc::clone(&self.prefix),
87 window: self.window,
88 quantity: self.quantity,
89 timestamp: self.timestamp,
90 }
91 }
92}
93
94#[derive(Debug, Clone, Eq, PartialEq)]
96pub struct RedisQuota<'a> {
97 quota: &'a Quota,
99 scoping: ItemScoping,
101 prefix: Arc<str>,
103 window: u64,
105 quantity: u64,
107 timestamp: UnixTimestamp,
109}
110
111impl<'a> RedisQuota<'a> {
112 pub fn new(
118 quota: &'a Quota,
119 quantity: u64,
120 scoping: ItemScoping,
121 timestamp: UnixTimestamp,
122 ) -> Option<Self> {
123 let prefix = quota.id.clone()?;
125 let window = quota.window?;
126
127 Some(Self {
128 quota,
129 scoping,
130 prefix,
131 quantity,
132 window,
133 timestamp,
134 })
135 }
136
137 pub fn build_owned(&self) -> OwnedRedisQuota {
140 OwnedRedisQuota {
141 quota: self.quota.clone(),
142 scoping: self.scoping,
143 prefix: Arc::clone(&self.prefix),
144 window: self.window,
145 quantity: self.quantity,
146 timestamp: self.timestamp,
147 }
148 }
149
150 pub fn window(&self) -> u64 {
152 self.window
153 }
154
155 pub fn prefix(&self) -> &str {
157 &self.prefix
158 }
159
160 pub fn quantity(&self) -> u64 {
162 self.quantity
163 }
164
165 pub fn limit(&self) -> i64 {
170 self.limit
171 .and_then(|limit| limit.try_into().ok())
173 .unwrap_or(-1)
174 }
175
176 fn shift(&self) -> u64 {
177 if self.quota.scope == QuotaScope::Global {
178 0
179 } else {
180 self.scoping.organization_id.value() % self.window
181 }
182 }
183
184 pub fn slot(&self) -> u64 {
188 (self.timestamp.as_secs() - self.shift()) / self.window
189 }
190
191 pub fn expiry(&self) -> UnixTimestamp {
193 let next_slot = self.slot() + 1;
194 let next_start = next_slot * self.window + self.shift();
195 UnixTimestamp::from_secs(next_start)
196 }
197
198 pub fn key_expiry(&self) -> u64 {
202 self.expiry().as_secs() + GRACE
203 }
204
205 pub fn key(&self) -> QuotaCacheKey {
211 let subscope = match self.quota.scope {
214 QuotaScope::Global => None,
215 QuotaScope::Organization => None,
216 scope => self.scoping.scope_id(scope),
217 };
218
219 QuotaCacheKey {
220 id: Arc::clone(&self.prefix),
221 org: self.scoping.organization_id,
222 subscope,
223 namespace: self.namespace,
224 slot: self.slot(),
225 }
226 }
227
228 fn for_cache(&self) -> cache::Quota<QuotaCacheKey> {
230 cache::Quota {
231 limit: self.limit(),
232 window: self.window,
233 key: self.key(),
234 expiry: UnixTimestamp::from_secs(self.key_expiry()),
235 }
236 }
237}
238
239impl std::ops::Deref for RedisQuota<'_> {
240 type Target = Quota;
241
242 fn deref(&self) -> &Self::Target {
243 self.quota
244 }
245}
246
247#[derive(Debug, Clone, Hash, PartialEq, Eq, PartialOrd, Ord)]
253pub struct QuotaCacheKey {
254 id: Arc<str>,
255 org: OrganizationId,
256 subscope: Option<u64>,
257 namespace: Option<MetricNamespace>,
258 slot: u64,
259}
260
261impl fmt::Display for QuotaCacheKey {
262 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
263 write!(
264 f,
265 "quota:{id}{{{org}}}{subscope}{namespace}:{slot}",
266 id = self.id,
267 org = self.org,
268 subscope = OptionalDisplay(self.subscope),
269 namespace = OptionalDisplay(self.namespace),
270 slot = self.slot,
271 )
272 }
273}
274
275#[derive(Clone)]
286pub struct RedisRateLimiter<T> {
287 client: AsyncRedisClient,
288 cache: Option<Arc<OpportunisticQuotaCache<QuotaCacheKey>>>,
289 script: &'static Script,
290 max_limit: Option<u64>,
291 global_limiter: T,
292}
293
294impl<T: GlobalLimiter> RedisRateLimiter<T> {
295 pub fn new(client: AsyncRedisClient, global_limiter: T) -> Self {
297 RedisRateLimiter {
298 client,
299 cache: None,
300 script: RedisScripts::load_is_rate_limited(),
301 max_limit: None,
302 global_limiter,
303 }
304 }
305
306 pub fn max_limit(mut self, max_limit: Option<u64>) -> Self {
311 self.max_limit = max_limit;
312 self
313 }
314
315 pub fn cache(mut self, cache_ratio: Option<f32>, max: Option<f32>) -> Self {
323 self.cache = cache_ratio
324 .map(OpportunisticQuotaCache::new)
325 .map(|c| c.with_max(max))
326 .map(Arc::new);
327
328 self
329 }
330
331 pub async fn is_rate_limited<'a>(
348 &self,
349 quotas: impl IntoIterator<Item = &'a Quota>,
350 item_scoping: ItemScoping,
351 quantity: usize,
352 over_accept_once: bool,
353 ) -> Result<RateLimits, RateLimitingError> {
354 let timestamp = UnixTimestamp::now();
355 let mut invocation = self.script.prepare_invoke();
356 let mut tracked_quotas = Vec::new();
357 let mut rate_limits = RateLimits::new();
358
359 let mut global_quotas = vec![];
360
361 let quantity = u64::try_from(quantity).unwrap_or(u64::MAX);
362
363 for quota in quotas {
364 if !quota.matches(item_scoping) {
365 } else if quota.limit == Some(0) {
367 let retry_after = self.retry_after(REJECT_ALL_SECS);
371 rate_limits.add(RateLimit::from_quota(quota, *item_scoping, retry_after));
372 } else if let Some(mut quota) =
373 RedisQuota::new(quota, quantity, item_scoping, timestamp)
374 {
375 if quota.scope == QuotaScope::Global {
376 global_quotas.push(quota);
377 } else {
378 if let Some(cache) = &self.cache {
379 quota.quantity = match cache.check_quota(quota.for_cache(), quantity) {
380 cache::Action::Accept => continue,
381 cache::Action::Check(quantity) => quantity,
382 };
383 }
384
385 let redis_key = quota.key().to_string();
386 let refund_key = get_refunded_quota_key(&redis_key);
388
389 invocation.key(redis_key);
390 invocation.key(refund_key);
391
392 invocation.arg(quota.limit());
393 invocation.arg(quota.key_expiry());
394 invocation.arg(quota.quantity);
395 invocation.arg(over_accept_once);
396
397 tracked_quotas.push(quota);
398 }
399 } else {
400 relay_log::with_scope(
403 |scope| scope.set_extra("quota", value::to_value(quota).unwrap()),
404 || relay_log::warn!("skipping unsupported quota"),
405 )
406 }
407 }
408
409 let rate_limited_global_quotas = self
416 .global_limiter
417 .check_global_rate_limits(&global_quotas)
418 .await?;
419
420 for quota in rate_limited_global_quotas {
421 let retry_after = self.retry_after((quota.expiry() - timestamp).as_secs());
422 rate_limits.add(RateLimit::from_quota(quota, *item_scoping, retry_after));
423 }
424
425 if tracked_quotas.is_empty() || rate_limits.is_limited() {
428 return Ok(rate_limits);
429 }
430
431 let mut connection = self.client.get_connection().await?;
435 let result: ScriptResult = invocation
436 .invoke_async(&mut connection)
437 .await
438 .map_err(RedisError::Redis)?;
439
440 for (quota, state) in tracked_quotas.iter().zip(result.0) {
441 if state.is_rejected {
442 let cache_error = {
445 let remaining = quota.limit().saturating_sub(state.consumed).max(0) as u64;
446 let cache_quantity = quota.quantity.saturating_sub(quantity);
447
448 cache_quantity.saturating_sub(remaining)
449 };
450 relay_statsd::metric!(
451 counter(QuotaCounters::CacheError) += cache_error,
452 category = item_scoping.category.name(),
453 );
454
455 let retry_after = self.retry_after((quota.expiry() - timestamp).as_secs());
456 rate_limits.add(RateLimit::from_quota(quota, *item_scoping, retry_after));
457 } else if let Some(cache) = &self.cache {
458 cache.set_quota(quota.for_cache(), state.consumed);
461 }
462 }
463
464 if let Some(cache) = &self.cache {
465 let vacuum_start = std::time::Instant::now();
466 if cache.try_vacuum(timestamp) {
467 relay_statsd::metric!(
468 timer(QuotaTimers::CacheVacuumDuration) = vacuum_start.elapsed()
469 );
470 }
471 }
472
473 Ok(rate_limits)
474 }
475
476 fn retry_after(&self, mut seconds: u64) -> RetryAfter {
480 if let Some(max_limit) = self.max_limit {
481 seconds = std::cmp::min(seconds, max_limit);
482 }
483
484 RetryAfter::from_secs(seconds)
485 }
486}
487
488#[derive(Debug)]
490struct ScriptResult(Vec<QuotaState>);
491
492impl FromRedisValue for ScriptResult {
493 fn from_redis_value(v: &redis::Value) -> redis::RedisResult<Self> {
494 let Some(seq) = v.as_sequence() else {
495 return Err(redis::RedisError::from((
496 redis::ErrorKind::TypeError,
497 "Expected a sequence from the rate limiting script",
498 format!("{v:?}"),
499 )));
500 };
501
502 let (chunks, rem) = seq.as_chunks();
503 if !rem.is_empty() {
504 return Err(redis::RedisError::from((
505 redis::ErrorKind::TypeError,
506 "Expected an even number of values from the rate limiting script",
507 format!("{v:?}"),
508 )));
509 }
510
511 let mut quotas = Vec::with_capacity(chunks.len());
512 for [is_rejected, consumed] in chunks {
513 quotas.push(QuotaState {
514 is_rejected: bool::from_redis_value(is_rejected)?,
515 consumed: i64::from_redis_value(consumed)?,
516 });
517 }
518
519 Ok(Self(quotas))
520 }
521}
522
523#[derive(Debug)]
525struct QuotaState {
526 is_rejected: bool,
528 consumed: i64,
530}
531
532#[cfg(test)]
533mod tests {
534 use std::time::{SystemTime, UNIX_EPOCH};
535
536 use super::*;
537 use crate::quota::{DataCategories, DataCategory, ReasonCode, Scoping};
538 use crate::rate_limit::RateLimitScope;
539 use crate::{GlobalRateLimiter, MetricNamespaceScoping};
540 use relay_base_schema::metrics::MetricNamespace;
541 use relay_base_schema::organization::OrganizationId;
542 use relay_base_schema::project::{ProjectId, ProjectKey};
543 use relay_redis::RedisConfigOptions;
544 use relay_redis::redis::AsyncCommands;
545 use smallvec::smallvec;
546 use tokio::sync::Mutex;
547
548 struct MockGlobalLimiter {
549 client: AsyncRedisClient,
550 global_rate_limiter: Mutex<GlobalRateLimiter>,
551 }
552
553 impl GlobalLimiter for MockGlobalLimiter {
554 async fn check_global_rate_limits<'a>(
555 &self,
556 global_quotas: &'a [RedisQuota<'a>],
557 ) -> Result<Vec<&'a RedisQuota<'a>>, RateLimitingError> {
558 self.global_rate_limiter
559 .lock()
560 .await
561 .filter_rate_limited(&self.client, global_quotas)
562 .await
563 }
564 }
565
566 fn build_rate_limiter() -> RedisRateLimiter<MockGlobalLimiter> {
567 let url = std::env::var("RELAY_REDIS_URL")
568 .unwrap_or_else(|_| "redis://127.0.0.1:6379".to_owned());
569 let client =
570 AsyncRedisClient::single("test", &url, &RedisConfigOptions::default()).unwrap();
571
572 let global_limiter = MockGlobalLimiter {
573 client: client.clone(),
574 global_rate_limiter: Mutex::new(GlobalRateLimiter::default()),
575 };
576
577 RedisRateLimiter {
578 client,
579 cache: None,
580 script: RedisScripts::load_is_rate_limited(),
581 max_limit: None,
582 global_limiter,
583 }
584 }
585
586 #[tokio::test]
587 async fn test_zero_size_quotas() {
588 let quotas = &[
589 Quota {
590 id: None,
591 categories: DataCategories::new(),
592 scope: QuotaScope::Organization,
593 scope_id: None,
594 limit: Some(0),
595 window: None,
596 reason_code: Some(ReasonCode::new("get_lost")),
597 namespace: None,
598 },
599 Quota {
600 id: Some("42".into()),
601 categories: DataCategories::new(),
602 scope: QuotaScope::Organization,
603 scope_id: None,
604 limit: None,
605 window: Some(42),
606 reason_code: Some(ReasonCode::new("unlimited")),
607 namespace: None,
608 },
609 ];
610
611 let scoping = ItemScoping {
612 category: DataCategory::Error,
613 scoping: Scoping {
614 organization_id: OrganizationId::new(42),
615 project_id: ProjectId::new(43),
616 project_key: ProjectKey::parse("a94ae32be2584e0bbd7a4cbb95971fee").unwrap(),
617 key_id: Some(44),
618 },
619 namespace: MetricNamespaceScoping::None,
620 };
621
622 let rate_limits: Vec<RateLimit> = build_rate_limiter()
623 .is_rate_limited(quotas, scoping, 1, false)
624 .await
625 .expect("rate limiting failed")
626 .into_iter()
627 .collect();
628
629 assert_eq!(
630 rate_limits,
631 vec![RateLimit {
632 categories: DataCategories::new(),
633 scope: RateLimitScope::Organization(OrganizationId::new(42)),
634 reason_code: Some(ReasonCode::new("get_lost")),
635 retry_after: rate_limits[0].retry_after,
636 namespaces: smallvec![],
637 }]
638 );
639 }
640
641 #[tokio::test]
643 async fn test_non_global_namespace_quota() {
644 let quota_limit = 5;
645 let get_quota = |namespace: Option<MetricNamespace>| -> Quota {
646 Quota {
647 id: Some(format!("test_simple_quota_{}", uuid::Uuid::new_v4()).into()),
648 categories: DataCategories::new(),
649 scope: QuotaScope::Organization,
650 scope_id: None,
651 limit: Some(quota_limit),
652 window: Some(600),
653 reason_code: Some(ReasonCode::new(format!("ns: {namespace:?}"))),
654 namespace,
655 }
656 };
657
658 let quotas = &[get_quota(None)];
659 let quota_with_namespace = &[get_quota(Some(MetricNamespace::Transactions))];
660
661 let scoping = ItemScoping {
662 category: DataCategory::Error,
663 scoping: Scoping {
664 organization_id: OrganizationId::new(42),
665 project_id: ProjectId::new(43),
666 project_key: ProjectKey::parse("a94ae32be2584e0bbd7a4cbb95971fee").unwrap(),
667 key_id: Some(44),
668 },
669 namespace: MetricNamespaceScoping::Some(MetricNamespace::Transactions),
670 };
671
672 let rate_limiter = build_rate_limiter();
673
674 for i in 0..10 {
676 let rate_limits: Vec<RateLimit> = rate_limiter
677 .is_rate_limited(quotas, scoping, 1, false)
678 .await
679 .expect("rate limiting failed")
680 .into_iter()
681 .collect();
682
683 if i < quota_limit {
684 assert_eq!(rate_limits, vec![]);
685 } else {
686 assert_eq!(
687 rate_limits[0].reason_code,
688 Some(ReasonCode::new("ns: None"))
689 );
690 }
691 }
692
693 for i in 0..10 {
695 let rate_limits: Vec<RateLimit> = rate_limiter
696 .is_rate_limited(quota_with_namespace, scoping, 1, false)
697 .await
698 .expect("rate limiting failed")
699 .into_iter()
700 .collect();
701
702 if i < quota_limit {
703 assert_eq!(rate_limits, vec![]);
704 } else {
705 assert_eq!(
706 rate_limits[0].reason_code,
707 Some(ReasonCode::new("ns: Some(Transactions)"))
708 );
709 }
710 }
711 }
712
713 #[tokio::test]
714 async fn test_simple_quota() {
715 let quotas = &[Quota {
716 id: Some(format!("test_simple_quota_{}", uuid::Uuid::new_v4()).into()),
717 categories: DataCategories::new(),
718 scope: QuotaScope::Organization,
719 scope_id: None,
720 limit: Some(5),
721 window: Some(60),
722 reason_code: Some(ReasonCode::new("get_lost")),
723 namespace: None,
724 }];
725
726 let scoping = ItemScoping {
727 category: DataCategory::Error,
728 scoping: Scoping {
729 organization_id: OrganizationId::new(42),
730 project_id: ProjectId::new(43),
731 project_key: ProjectKey::parse("a94ae32be2584e0bbd7a4cbb95971fee").unwrap(),
732 key_id: Some(44),
733 },
734 namespace: MetricNamespaceScoping::None,
735 };
736
737 let rate_limiter = build_rate_limiter();
738
739 for i in 0..10 {
740 let rate_limits: Vec<RateLimit> = rate_limiter
741 .is_rate_limited(quotas, scoping, 1, false)
742 .await
743 .expect("rate limiting failed")
744 .into_iter()
745 .collect();
746
747 if i >= 5 {
748 assert_eq!(
749 rate_limits,
750 vec![RateLimit {
751 categories: DataCategories::new(),
752 scope: RateLimitScope::Organization(OrganizationId::new(42)),
753 reason_code: Some(ReasonCode::new("get_lost")),
754 retry_after: rate_limits[0].retry_after,
755 namespaces: smallvec![],
756 }]
757 );
758 } else {
759 assert_eq!(rate_limits, vec![]);
760 }
761 }
762 }
763
764 #[tokio::test]
765 async fn test_simple_global_quota() {
766 let quotas = &[Quota {
767 id: Some(format!("test_simple_global_quota_{}", uuid::Uuid::new_v4()).into()),
768 categories: DataCategories::new(),
769 scope: QuotaScope::Global,
770 scope_id: None,
771 limit: Some(5),
772 window: Some(60),
773 reason_code: Some(ReasonCode::new("get_lost")),
774 namespace: None,
775 }];
776
777 let scoping = ItemScoping {
778 category: DataCategory::Error,
779 scoping: Scoping {
780 organization_id: OrganizationId::new(42),
781 project_id: ProjectId::new(43),
782 project_key: ProjectKey::parse("a94ae32be2584e0bbd7a4cbb95971fee").unwrap(),
783 key_id: Some(44),
784 },
785 namespace: MetricNamespaceScoping::None,
786 };
787
788 let rate_limiter = build_rate_limiter();
789
790 for i in 0..10 {
791 let rate_limits: Vec<RateLimit> = rate_limiter
792 .is_rate_limited(quotas, scoping, 1, false)
793 .await
794 .expect("rate limiting failed")
795 .into_iter()
796 .collect();
797
798 if i >= 5 {
799 assert_eq!(
800 rate_limits,
801 vec![RateLimit {
802 categories: DataCategories::new(),
803 scope: RateLimitScope::Global,
804 reason_code: Some(ReasonCode::new("get_lost")),
805 retry_after: rate_limits[0].retry_after,
806 namespaces: smallvec![],
807 }]
808 );
809 } else {
810 assert_eq!(rate_limits, vec![]);
811 }
812 }
813 }
814
815 #[tokio::test]
816 async fn test_quantity_0() {
817 let quotas = &[Quota {
818 id: Some(format!("test_quantity_0_{}", uuid::Uuid::new_v4()).into()),
819 categories: DataCategories::new(),
820 scope: QuotaScope::Organization,
821 scope_id: None,
822 limit: Some(1),
823 window: Some(60),
824 reason_code: Some(ReasonCode::new("get_lost")),
825 namespace: None,
826 }];
827
828 let scoping = ItemScoping {
829 category: DataCategory::Error,
830 scoping: Scoping {
831 organization_id: OrganizationId::new(42),
832 project_id: ProjectId::new(43),
833 project_key: ProjectKey::parse("a94ae32be2584e0bbd7a4cbb95971fee").unwrap(),
834 key_id: Some(44),
835 },
836 namespace: MetricNamespaceScoping::None,
837 };
838
839 let rate_limiter = build_rate_limiter();
840
841 assert!(
843 !rate_limiter
844 .is_rate_limited(quotas, scoping, 1, false)
845 .await
846 .unwrap()
847 .is_limited()
848 );
849
850 assert!(
852 rate_limiter
853 .is_rate_limited(quotas, scoping, 1, false)
854 .await
855 .unwrap()
856 .is_limited()
857 );
858
859 assert!(
861 rate_limiter
862 .is_rate_limited(quotas, scoping, 0, false)
863 .await
864 .unwrap()
865 .is_limited()
866 );
867
868 assert!(
870 rate_limiter
871 .is_rate_limited(quotas, scoping, 1, false)
872 .await
873 .unwrap()
874 .is_limited()
875 );
876 }
877
878 #[tokio::test]
879 async fn test_quota_go_over() {
880 let quotas = &[Quota {
881 id: Some(format!("test_quota_go_over{}", uuid::Uuid::new_v4()).into()),
882 categories: DataCategories::new(),
883 scope: QuotaScope::Organization,
884 scope_id: None,
885 limit: Some(2),
886 window: Some(60),
887 reason_code: Some(ReasonCode::new("get_lost")),
888 namespace: None,
889 }];
890
891 let scoping = ItemScoping {
892 category: DataCategory::Error,
893 scoping: Scoping {
894 organization_id: OrganizationId::new(42),
895 project_id: ProjectId::new(43),
896 project_key: ProjectKey::parse("a94ae32be2584e0bbd7a4cbb95971fee").unwrap(),
897 key_id: Some(44),
898 },
899 namespace: MetricNamespaceScoping::None,
900 };
901
902 let rate_limiter = build_rate_limiter();
903
904 let is_limited = rate_limiter
906 .is_rate_limited(quotas, scoping, 1, true)
907 .await
908 .unwrap()
909 .is_limited();
910 assert!(!is_limited);
911
912 let is_limited = rate_limiter
914 .is_rate_limited(quotas, scoping, 2, true)
915 .await
916 .unwrap()
917 .is_limited();
918 assert!(!is_limited);
919
920 let is_limited = rate_limiter
922 .is_rate_limited(quotas, scoping, 0, true)
923 .await
924 .unwrap()
925 .is_limited();
926 assert!(is_limited);
927
928 let is_limited = rate_limiter
930 .is_rate_limited(quotas, scoping, 1, true)
931 .await
932 .unwrap()
933 .is_limited();
934 assert!(is_limited);
935 }
936
937 #[tokio::test]
938 async fn test_bails_immediately_without_any_quota() {
939 let scoping = ItemScoping {
940 category: DataCategory::Error,
941 scoping: Scoping {
942 organization_id: OrganizationId::new(42),
943 project_id: ProjectId::new(43),
944 project_key: ProjectKey::parse("a94ae32be2584e0bbd7a4cbb95971fee").unwrap(),
945 key_id: Some(44),
946 },
947 namespace: MetricNamespaceScoping::None,
948 };
949
950 let rate_limits: Vec<RateLimit> = build_rate_limiter()
951 .is_rate_limited(&[], scoping, 1, false)
952 .await
953 .expect("rate limiting failed")
954 .into_iter()
955 .collect();
956
957 assert_eq!(rate_limits, vec![]);
958 }
959
960 #[tokio::test]
961 async fn test_limited_with_unlimited_quota() {
962 let quotas = &[
963 Quota {
964 id: Some("q0".into()),
965 categories: DataCategories::new(),
966 scope: QuotaScope::Organization,
967 scope_id: None,
968 limit: None,
969 window: Some(1),
970 reason_code: Some(ReasonCode::new("project_quota0")),
971 namespace: None,
972 },
973 Quota {
974 id: Some("q1".into()),
975 categories: DataCategories::new(),
976 scope: QuotaScope::Organization,
977 scope_id: None,
978 limit: Some(1),
979 window: Some(1),
980 reason_code: Some(ReasonCode::new("project_quota1")),
981 namespace: None,
982 },
983 ];
984
985 let scoping = ItemScoping {
986 category: DataCategory::Error,
987 scoping: Scoping {
988 organization_id: OrganizationId::new(42),
989 project_id: ProjectId::new(43),
990 project_key: ProjectKey::parse("a94ae32be2584e0bbd7a4cbb95971fee").unwrap(),
991 key_id: Some(44),
992 },
993 namespace: MetricNamespaceScoping::None,
994 };
995
996 let rate_limiter = build_rate_limiter();
997
998 for i in 0..1 {
999 let rate_limits: Vec<RateLimit> = rate_limiter
1000 .is_rate_limited(quotas, scoping, 1, false)
1001 .await
1002 .expect("rate limiting failed")
1003 .into_iter()
1004 .collect();
1005
1006 if i == 0 {
1007 assert_eq!(rate_limits, &[]);
1008 } else {
1009 assert_eq!(
1010 rate_limits,
1011 vec![RateLimit {
1012 categories: DataCategories::new(),
1013 scope: RateLimitScope::Organization(OrganizationId::new(42)),
1014 reason_code: Some(ReasonCode::new("project_quota1")),
1015 retry_after: rate_limits[0].retry_after,
1016 namespaces: smallvec![],
1017 }]
1018 );
1019 }
1020 }
1021 }
1022
1023 #[tokio::test]
1024 async fn test_quota_with_quantity() {
1025 let quotas = &[Quota {
1026 id: Some(format!("test_quantity_quota_{}", uuid::Uuid::new_v4()).into()),
1027 categories: DataCategories::new(),
1028 scope: QuotaScope::Organization,
1029 scope_id: None,
1030 limit: Some(500),
1031 window: Some(60),
1032 reason_code: Some(ReasonCode::new("get_lost")),
1033 namespace: None,
1034 }];
1035
1036 let scoping = ItemScoping {
1037 category: DataCategory::Error,
1038 scoping: Scoping {
1039 organization_id: OrganizationId::new(42),
1040 project_id: ProjectId::new(43),
1041 project_key: ProjectKey::parse("a94ae32be2584e0bbd7a4cbb95971fee").unwrap(),
1042 key_id: Some(44),
1043 },
1044 namespace: MetricNamespaceScoping::None,
1045 };
1046
1047 let rate_limiter = build_rate_limiter();
1048
1049 for i in 0..10 {
1050 let rate_limits: Vec<RateLimit> = rate_limiter
1051 .is_rate_limited(quotas, scoping, 100, false)
1052 .await
1053 .expect("rate limiting failed")
1054 .into_iter()
1055 .collect();
1056
1057 if i >= 5 {
1058 assert_eq!(
1059 rate_limits,
1060 vec![RateLimit {
1061 categories: DataCategories::new(),
1062 scope: RateLimitScope::Organization(OrganizationId::new(42)),
1063 reason_code: Some(ReasonCode::new("get_lost")),
1064 retry_after: rate_limits[0].retry_after,
1065 namespaces: smallvec![],
1066 }]
1067 );
1068 } else {
1069 assert_eq!(rate_limits, vec![]);
1070 }
1071 }
1072 }
1073
1074 #[tokio::test]
1075 async fn test_get_redis_key_scoped() {
1076 let quota = Quota {
1077 id: Some("foo".into()),
1078 categories: DataCategories::new(),
1079 scope: QuotaScope::Project,
1080 scope_id: Some("42".into()),
1081 window: Some(2),
1082 limit: Some(0),
1083 reason_code: None,
1084 namespace: None,
1085 };
1086
1087 let scoping = ItemScoping {
1088 category: DataCategory::Error,
1089 scoping: Scoping {
1090 organization_id: OrganizationId::new(69420),
1091 project_id: ProjectId::new(42),
1092 project_key: ProjectKey::parse("a94ae32be2584e0bbd7a4cbb95971fee").unwrap(),
1093 key_id: Some(4711),
1094 },
1095 namespace: MetricNamespaceScoping::None,
1096 };
1097
1098 let timestamp = UnixTimestamp::from_secs(123_123_123);
1099 let redis_quota = RedisQuota::new("a, 0, scoping, timestamp).unwrap();
1100 assert_eq!(redis_quota.key().to_string(), "quota:foo{69420}42:61561561");
1101 }
1102
1103 #[tokio::test]
1104 async fn test_get_redis_key_unscoped() {
1105 let quota = Quota {
1106 id: Some("foo".into()),
1107 categories: DataCategories::new(),
1108 scope: QuotaScope::Organization,
1109 scope_id: None,
1110 window: Some(10),
1111 limit: Some(0),
1112 reason_code: None,
1113 namespace: None,
1114 };
1115
1116 let scoping = ItemScoping {
1117 category: DataCategory::Error,
1118 scoping: Scoping {
1119 organization_id: OrganizationId::new(69420),
1120 project_id: ProjectId::new(42),
1121 project_key: ProjectKey::parse("a94ae32be2584e0bbd7a4cbb95971fee").unwrap(),
1122 key_id: Some(4711),
1123 },
1124 namespace: MetricNamespaceScoping::None,
1125 };
1126
1127 let timestamp = UnixTimestamp::from_secs(234_531);
1128 let redis_quota = RedisQuota::new("a, 0, scoping, timestamp).unwrap();
1129 assert_eq!(redis_quota.key().to_string(), "quota:foo{69420}:23453");
1130 }
1131
1132 #[tokio::test]
1133 async fn test_large_redis_limit_large() {
1134 let quota = Quota {
1135 id: Some("foo".into()),
1136 categories: DataCategories::new(),
1137 scope: QuotaScope::Organization,
1138 scope_id: None,
1139 window: Some(10),
1140 limit: Some(9223372036854775808), reason_code: None,
1142 namespace: None,
1143 };
1144
1145 let scoping = ItemScoping {
1146 category: DataCategory::Error,
1147 scoping: Scoping {
1148 organization_id: OrganizationId::new(69420),
1149 project_id: ProjectId::new(42),
1150 project_key: ProjectKey::parse("a94ae32be2584e0bbd7a4cbb95971fee").unwrap(),
1151 key_id: Some(4711),
1152 },
1153 namespace: MetricNamespaceScoping::None,
1154 };
1155
1156 let timestamp = UnixTimestamp::from_secs(234_531);
1157 let redis_quota = RedisQuota::new("a, 0, scoping, timestamp).unwrap();
1158 assert_eq!(redis_quota.limit(), -1);
1159 }
1160
1161 #[tokio::test]
1162 async fn test_is_rate_limited_script() {
1163 let now = SystemTime::now()
1164 .duration_since(UNIX_EPOCH)
1165 .map(|duration| duration.as_secs())
1166 .unwrap();
1167
1168 let rate_limiter = build_rate_limiter();
1169 let mut conn = rate_limiter.client.get_connection().await.unwrap();
1170
1171 let foo = format!("foo___{now}");
1173 let r_foo = format!("r:foo___{now}");
1174 let bar = format!("bar___{now}");
1175 let r_bar = format!("r:bar___{now}");
1176 let apple = format!("apple___{now}");
1177 let orange = format!("orange___{now}");
1178 let baz = format!("baz___{now}");
1179
1180 let script = RedisScripts::load_is_rate_limited();
1181
1182 macro_rules! assert_invocation {
1183 ($invocation:expr, $($tt:tt)*) => {{
1184 let result = $invocation
1185 .invoke_async::<ScriptResult>(&mut conn)
1186 .await
1187 .unwrap();
1188
1189 insta::assert_debug_snapshot!(result, $($tt)*);
1190 }};
1191 }
1192
1193 let mut invocation = script.prepare_invoke();
1194 invocation
1195 .key(&foo) .key(&r_foo) .key(&bar) .key(&r_bar) .arg(1) .arg(now + 60) .arg(1) .arg(false) .arg(2) .arg(now + 120) .arg(1) .arg(false); let mut invocation2 = script.prepare_invoke();
1211 invocation2
1212 .key(&bar) .key(&r_bar) .arg(2) .arg(now + 120) .arg(1) .arg(false); assert_invocation!(invocation, @r"
1221 ScriptResult(
1222 [
1223 QuotaState {
1224 is_rejected: false,
1225 consumed: 1,
1226 },
1227 QuotaState {
1228 is_rejected: false,
1229 consumed: 1,
1230 },
1231 ],
1232 )
1233 "
1234 );
1235
1236 assert_invocation!(invocation, @r"
1239 ScriptResult(
1240 [
1241 QuotaState {
1242 is_rejected: true,
1243 consumed: 1,
1244 },
1245 QuotaState {
1246 is_rejected: false,
1247 consumed: 1,
1248 },
1249 ],
1250 )
1251 "
1252 );
1253
1254 assert_invocation!(invocation, @r"
1256 ScriptResult(
1257 [
1258 QuotaState {
1259 is_rejected: true,
1260 consumed: 1,
1261 },
1262 QuotaState {
1263 is_rejected: false,
1264 consumed: 1,
1265 },
1266 ],
1267 )
1268 "
1269 );
1270
1271 assert_invocation!(invocation2, @r"
1274 ScriptResult(
1275 [
1276 QuotaState {
1277 is_rejected: false,
1278 consumed: 2,
1279 },
1280 ],
1281 )
1282 "
1283 );
1284
1285 assert_invocation!(invocation2, @r"
1287 ScriptResult(
1288 [
1289 QuotaState {
1290 is_rejected: true,
1291 consumed: 2,
1292 },
1293 ],
1294 )
1295 "
1296 );
1297
1298 assert_invocation!(invocation, @r"
1300 ScriptResult(
1301 [
1302 QuotaState {
1303 is_rejected: true,
1304 consumed: 1,
1305 },
1306 QuotaState {
1307 is_rejected: true,
1308 consumed: 2,
1309 },
1310 ],
1311 )
1312 "
1313 );
1314
1315 assert_eq!(conn.get::<_, String>(&foo).await.unwrap(), "1");
1316 let ttl: u64 = conn.ttl(&foo).await.unwrap();
1317 assert!(ttl >= 59);
1318 assert!(ttl <= 60);
1319
1320 assert_eq!(conn.get::<_, String>(&bar).await.unwrap(), "2");
1321 let ttl: u64 = conn.ttl(&bar).await.unwrap();
1322 assert!(ttl >= 119);
1323 assert!(ttl <= 120);
1324
1325 let () = conn.get(r_foo).await.unwrap();
1327 let () = conn.get(r_bar).await.unwrap();
1328
1329 let () = conn.set(&apple, 5).await.unwrap();
1331
1332 let mut invocation = script.prepare_invoke();
1333 invocation
1334 .key(&orange) .key(&baz) .arg(1) .arg(now + 60) .arg(1) .arg(false);
1340
1341 assert_invocation!(invocation, @r"
1343 ScriptResult(
1344 [
1345 QuotaState {
1346 is_rejected: false,
1347 consumed: 1,
1348 },
1349 ],
1350 )
1351 "
1352 );
1353
1354 assert_invocation!(invocation, @r"
1356 ScriptResult(
1357 [
1358 QuotaState {
1359 is_rejected: true,
1360 consumed: 1,
1361 },
1362 ],
1363 )
1364 "
1365 );
1366
1367 assert_invocation!(invocation, @r"
1369 ScriptResult(
1370 [
1371 QuotaState {
1372 is_rejected: true,
1373 consumed: 1,
1374 },
1375 ],
1376 )
1377 "
1378 );
1379
1380 let mut invocation = script.prepare_invoke();
1381 invocation
1382 .key(&orange) .key(&apple) .arg(1) .arg(now + 60) .arg(1) .arg(false);
1388
1389 assert_invocation!(invocation, @r"
1391 ScriptResult(
1392 [
1393 QuotaState {
1394 is_rejected: false,
1395 consumed: -3,
1396 },
1397 ],
1398 )
1399 "
1400 );
1401 }
1402
1403 #[tokio::test]
1405 async fn test_quota_with_cache() {
1406 let quotas = &[Quota {
1407 id: Some(format!("test_simple_quota_{}", uuid::Uuid::new_v4()).into()),
1408 categories: DataCategories::new(),
1409 scope: QuotaScope::Organization,
1410 scope_id: None,
1411 limit: Some(50),
1412 window: Some(60),
1413 reason_code: Some(ReasonCode::new("get_lost")),
1414 namespace: None,
1415 }];
1416
1417 let scoping = ItemScoping {
1418 category: DataCategory::Error,
1419 scoping: Scoping {
1420 organization_id: OrganizationId::new(42),
1421 project_id: ProjectId::new(43),
1422 project_key: ProjectKey::parse("a94ae32be2584e0bbd7a4cbb95971fee").unwrap(),
1423 key_id: Some(44),
1424 },
1425 namespace: MetricNamespaceScoping::None,
1426 };
1427
1428 let rate_limiter = build_rate_limiter().cache(Some(0.1), Some(0.9));
1431
1432 for _ in 0..50 {
1433 let rate_limits = rate_limiter
1434 .is_rate_limited(quotas, scoping, 1, false)
1435 .await
1436 .unwrap();
1437
1438 assert!(rate_limits.is_empty());
1439 }
1440
1441 let rate_limits: Vec<RateLimit> = rate_limiter
1442 .is_rate_limited(quotas, scoping, 1, false)
1443 .await
1444 .expect("rate limiting failed")
1445 .into_iter()
1446 .collect();
1447
1448 assert_eq!(
1449 rate_limits,
1450 vec![RateLimit {
1451 categories: DataCategories::new(),
1452 scope: RateLimitScope::Organization(OrganizationId::new(42)),
1453 reason_code: Some(ReasonCode::new("get_lost")),
1454 retry_after: rate_limits[0].retry_after,
1455 namespaces: smallvec![],
1456 }]
1457 );
1458 }
1459
1460 #[tokio::test]
1461 async fn test_quota_with_cache_slightly_over_account() {
1462 let window = 60;
1463 let limit = 50 * window;
1464
1465 let quotas = &[Quota {
1466 id: Some(format!("test_simple_quota_{}", uuid::Uuid::new_v4()).into()),
1467 categories: DataCategories::new(),
1468 scope: QuotaScope::Organization,
1469 scope_id: None,
1470 limit: Some(limit),
1471 window: Some(window),
1472 reason_code: Some(ReasonCode::new("get_lost")),
1473 namespace: None,
1474 }];
1475
1476 let scoping = ItemScoping {
1477 category: DataCategory::Error,
1478 scoping: Scoping {
1479 organization_id: OrganizationId::new(42),
1480 project_id: ProjectId::new(43),
1481 project_key: ProjectKey::parse("a94ae32be2584e0bbd7a4cbb95971fee").unwrap(),
1482 key_id: Some(44),
1483 },
1484 namespace: MetricNamespaceScoping::None,
1485 };
1486
1487 let rate_limiter1 = build_rate_limiter().cache(Some(0.1), None);
1489 let rate_limiter2 = build_rate_limiter().cache(Some(0.1), None);
1490
1491 let rate_limits = rate_limiter1
1493 .is_rate_limited(quotas, scoping, 1, false)
1494 .await
1495 .unwrap();
1496 assert!(rate_limits.is_empty());
1497 let rate_limits = rate_limiter1
1499 .is_rate_limited(quotas, scoping, 3, false)
1500 .await
1501 .unwrap();
1502 assert!(rate_limits.is_empty());
1503
1504 let rate_limits = rate_limiter2
1506 .is_rate_limited(quotas, scoping, limit as usize - 1, false)
1507 .await
1508 .unwrap();
1509 assert!(rate_limits.is_empty());
1510
1511 let rate_limits = rate_limiter1
1513 .is_rate_limited(quotas, scoping, 1, false)
1514 .await
1515 .unwrap();
1516 assert!(rate_limits.is_empty());
1517
1518 let rate_limits: Vec<RateLimit> = rate_limiter1
1520 .is_rate_limited(quotas, scoping, 1, false)
1521 .await
1522 .unwrap()
1523 .into_iter()
1524 .collect();
1525
1526 assert_eq!(
1527 rate_limits,
1528 vec![RateLimit {
1529 categories: DataCategories::new(),
1530 scope: RateLimitScope::Organization(OrganizationId::new(42)),
1531 reason_code: Some(ReasonCode::new("get_lost")),
1532 retry_after: rate_limits[0].retry_after,
1533 namespaces: smallvec![],
1534 }]
1535 );
1536 }
1537}