1use std::fmt::{self, Debug};
2use std::sync::Arc;
3
4use relay_base_schema::metrics::MetricNamespace;
5use relay_base_schema::organization::OrganizationId;
6use relay_common::time::UnixTimestamp;
7use relay_log::protocol::value;
8use relay_redis::redis::{self, FromRedisValue, Script};
9use relay_redis::{AsyncRedisClient, RedisError, RedisScripts};
10use thiserror::Error;
11
12use crate::cache::OpportunisticQuotaCache;
13use crate::global::GlobalLimiter;
14use crate::quota::{ItemScoping, Quota, QuotaScope};
15use crate::rate_limit::{RateLimit, RateLimits, RetryAfter};
16use crate::statsd::{QuotaCounters, QuotaTimers};
17use crate::{REJECT_ALL_SECS, cache};
18
19const GRACE: u64 = 60;
23
24#[derive(Debug, Error)]
26pub enum RateLimitingError {
27 #[error("failed to communicate with redis")]
29 Redis(
30 #[from]
31 #[source]
32 RedisError,
33 ),
34
35 #[error("failed to check global rate limits")]
37 UnreachableGlobalRateLimits,
38}
39
40fn get_refunded_quota_key(counter_key: &str) -> String {
45 format!("r:{counter_key}")
46}
47
48struct OptionalDisplay<T>(Option<T>);
50
51impl<T> fmt::Display for OptionalDisplay<T>
52where
53 T: fmt::Display,
54{
55 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
56 match self.0 {
57 Some(ref value) => write!(f, "{value}"),
58 None => Ok(()),
59 }
60 }
61}
62
63#[derive(Debug, Clone)]
65pub struct OwnedRedisQuota {
66 quota: Quota,
68 scoping: ItemScoping,
70 prefix: Arc<str>,
72 window: u64,
74 quantity: u64,
76 timestamp: UnixTimestamp,
78}
79
80impl OwnedRedisQuota {
81 pub fn build_ref(&self) -> RedisQuota<'_> {
83 RedisQuota {
84 quota: &self.quota,
85 scoping: self.scoping,
86 prefix: Arc::clone(&self.prefix),
87 window: self.window,
88 quantity: self.quantity,
89 timestamp: self.timestamp,
90 }
91 }
92}
93
94#[derive(Debug, Clone, Eq, PartialEq)]
96pub struct RedisQuota<'a> {
97 quota: &'a Quota,
99 scoping: ItemScoping,
101 prefix: Arc<str>,
103 window: u64,
105 quantity: u64,
107 timestamp: UnixTimestamp,
109}
110
111impl<'a> RedisQuota<'a> {
112 pub fn new(
118 quota: &'a Quota,
119 quantity: u64,
120 scoping: ItemScoping,
121 timestamp: UnixTimestamp,
122 ) -> Option<Self> {
123 let prefix = quota.id.clone()?;
125 let window = quota.window?;
126
127 Some(Self {
128 quota,
129 scoping,
130 prefix,
131 quantity,
132 window,
133 timestamp,
134 })
135 }
136
137 pub fn build_owned(&self) -> OwnedRedisQuota {
140 OwnedRedisQuota {
141 quota: self.quota.clone(),
142 scoping: self.scoping,
143 prefix: Arc::clone(&self.prefix),
144 window: self.window,
145 quantity: self.quantity,
146 timestamp: self.timestamp,
147 }
148 }
149
150 pub fn window(&self) -> u64 {
152 self.window
153 }
154
155 pub fn prefix(&self) -> &str {
157 &self.prefix
158 }
159
160 pub fn quantity(&self) -> u64 {
162 self.quantity
163 }
164
165 pub fn limit(&self) -> i64 {
170 self.limit
171 .and_then(|limit| limit.try_into().ok())
173 .unwrap_or(-1)
174 }
175
176 fn shift(&self) -> u64 {
177 if self.quota.scope == QuotaScope::Global {
178 0
179 } else {
180 self.scoping.organization_id.value() % self.window
181 }
182 }
183
184 pub fn slot(&self) -> u64 {
188 (self.timestamp.as_secs() - self.shift()) / self.window
189 }
190
191 pub fn expiry(&self) -> UnixTimestamp {
193 let next_slot = self.slot() + 1;
194 let next_start = next_slot * self.window + self.shift();
195 UnixTimestamp::from_secs(next_start)
196 }
197
198 pub fn key_expiry(&self) -> u64 {
202 self.expiry().as_secs() + GRACE
203 }
204
205 pub fn key(&self) -> QuotaCacheKey {
211 let subscope = match self.quota.scope {
214 QuotaScope::Global => None,
215 QuotaScope::Organization => None,
216 scope => self.scoping.scope_id(scope),
217 };
218
219 QuotaCacheKey {
220 id: Arc::clone(&self.prefix),
221 org: self.scoping.organization_id,
222 subscope,
223 namespace: self.namespace,
224 slot: self.slot(),
225 }
226 }
227
228 fn for_cache(&self) -> cache::Quota<QuotaCacheKey> {
230 cache::Quota {
231 limit: self.limit(),
232 window: self.window,
233 key: self.key(),
234 expiry: UnixTimestamp::from_secs(self.key_expiry()),
235 }
236 }
237}
238
239impl std::ops::Deref for RedisQuota<'_> {
240 type Target = Quota;
241
242 fn deref(&self) -> &Self::Target {
243 self.quota
244 }
245}
246
247#[derive(Debug, Clone, Hash, PartialEq, Eq, PartialOrd, Ord)]
253pub struct QuotaCacheKey {
254 id: Arc<str>,
255 org: OrganizationId,
256 subscope: Option<u64>,
257 namespace: Option<MetricNamespace>,
258 slot: u64,
259}
260
261impl fmt::Display for QuotaCacheKey {
262 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
263 write!(
264 f,
265 "quota:{id}{{{org}}}{subscope}{namespace}:{slot}",
266 id = self.id,
267 org = self.org,
268 subscope = OptionalDisplay(self.subscope),
269 namespace = OptionalDisplay(self.namespace),
270 slot = self.slot,
271 )
272 }
273}
274
275#[derive(Clone)]
286pub struct RedisRateLimiter<T> {
287 client: AsyncRedisClient,
288 cache: Option<Arc<OpportunisticQuotaCache<QuotaCacheKey>>>,
289 script: &'static Script,
290 max_limit: Option<u64>,
291 global_limiter: T,
292}
293
294impl<T: GlobalLimiter> RedisRateLimiter<T> {
295 pub fn new(client: AsyncRedisClient, global_limiter: T) -> Self {
297 RedisRateLimiter {
298 client,
299 cache: None,
300 script: RedisScripts::load_is_rate_limited(),
301 max_limit: None,
302 global_limiter,
303 }
304 }
305
306 pub fn max_limit(mut self, max_limit: Option<u64>) -> Self {
311 self.max_limit = max_limit;
312 self
313 }
314
315 pub fn cache(mut self, cache_ratio: Option<f32>, max: Option<f32>) -> Self {
323 self.cache = cache_ratio
324 .map(OpportunisticQuotaCache::new)
325 .map(|c| c.with_max(max))
326 .map(Arc::new);
327
328 self
329 }
330
331 pub async fn is_rate_limited<'a>(
348 &self,
349 quotas: impl IntoIterator<Item = &'a Quota>,
350 item_scoping: ItemScoping,
351 quantity: usize,
352 over_accept_once: bool,
353 ) -> Result<RateLimits, RateLimitingError> {
354 let timestamp = UnixTimestamp::now();
355 let mut invocation = self.script.prepare_invoke();
356 let mut tracked_quotas = Vec::new();
357 let mut rate_limits = RateLimits::new();
358
359 let mut global_quotas = vec![];
360
361 let quantity = u64::try_from(quantity).unwrap_or(u64::MAX);
362
363 for quota in quotas {
364 if !quota.matches(item_scoping) {
365 } else if quota.limit == Some(0) {
367 let retry_after = self.retry_after(REJECT_ALL_SECS);
371 rate_limits.add(RateLimit::from_quota(quota, *item_scoping, retry_after));
372 } else if let Some(mut quota) =
373 RedisQuota::new(quota, quantity, item_scoping, timestamp)
374 {
375 if quota.scope == QuotaScope::Global {
376 global_quotas.push(quota);
377 } else {
378 if let Some(cache) = &self.cache {
379 quota.quantity = match cache.check_quota(quota.for_cache(), quantity) {
380 cache::Action::Accept => continue,
381 cache::Action::Check(quantity) => quantity,
382 };
383 }
384
385 let redis_key = quota.key().to_string();
386 let refund_key = get_refunded_quota_key(&redis_key);
388
389 invocation.key(redis_key);
390 invocation.key(refund_key);
391
392 invocation.arg(quota.limit());
393 invocation.arg(quota.key_expiry());
394 invocation.arg(quota.quantity);
395 invocation.arg(over_accept_once);
396
397 tracked_quotas.push(quota);
398 }
399 } else {
400 relay_log::with_scope(
403 |scope| scope.set_extra("quota", value::to_value(quota).unwrap()),
404 || relay_log::warn!("skipping unsupported quota"),
405 )
406 }
407 }
408
409 if !global_quotas.is_empty() {
410 let rate_limited_global_quotas = self
417 .global_limiter
418 .check_global_rate_limits(&global_quotas)
419 .await?;
420
421 for quota in rate_limited_global_quotas {
422 let retry_after = self.retry_after((quota.expiry() - timestamp).as_secs());
423 rate_limits.add(RateLimit::from_quota(quota, *item_scoping, retry_after));
424 }
425 }
426
427 if tracked_quotas.is_empty() || rate_limits.is_limited() {
430 return Ok(rate_limits);
431 }
432
433 let mut connection = self.client.get_connection().await?;
437 let result: ScriptResult = invocation
438 .invoke_async(&mut connection)
439 .await
440 .map_err(RedisError::Redis)?;
441
442 for (quota, state) in tracked_quotas.iter().zip(result.0) {
443 if state.is_rejected {
444 let cache_error = {
447 let remaining = quota.limit().saturating_sub(state.consumed).max(0) as u64;
448 let cache_quantity = quota.quantity.saturating_sub(quantity);
449
450 cache_quantity.saturating_sub(remaining)
451 };
452 relay_statsd::metric!(
453 counter(QuotaCounters::CacheError) += cache_error,
454 category = item_scoping.category.name(),
455 );
456
457 let retry_after = self.retry_after((quota.expiry() - timestamp).as_secs());
458 rate_limits.add(RateLimit::from_quota(quota, *item_scoping, retry_after));
459 } else if let Some(cache) = &self.cache {
460 cache.set_quota(quota.for_cache(), state.consumed);
463 }
464 }
465
466 if let Some(cache) = &self.cache {
467 let vacuum_start = std::time::Instant::now();
468 if cache.try_vacuum(timestamp) {
469 relay_statsd::metric!(
470 timer(QuotaTimers::CacheVacuumDuration) = vacuum_start.elapsed()
471 );
472 }
473 }
474
475 Ok(rate_limits)
476 }
477
478 fn retry_after(&self, mut seconds: u64) -> RetryAfter {
482 if let Some(max_limit) = self.max_limit {
483 seconds = std::cmp::min(seconds, max_limit);
484 }
485
486 RetryAfter::from_secs(seconds)
487 }
488}
489
490#[derive(Debug)]
492struct ScriptResult(Vec<QuotaState>);
493
494impl FromRedisValue for ScriptResult {
495 fn from_redis_value(v: &redis::Value) -> redis::RedisResult<Self> {
496 let Some(seq) = v.as_sequence() else {
497 return Err(redis::RedisError::from((
498 redis::ErrorKind::TypeError,
499 "Expected a sequence from the rate limiting script",
500 format!("{v:?}"),
501 )));
502 };
503
504 let (chunks, rem) = seq.as_chunks();
505 if !rem.is_empty() {
506 return Err(redis::RedisError::from((
507 redis::ErrorKind::TypeError,
508 "Expected an even number of values from the rate limiting script",
509 format!("{v:?}"),
510 )));
511 }
512
513 let mut quotas = Vec::with_capacity(chunks.len());
514 for [is_rejected, consumed] in chunks {
515 quotas.push(QuotaState {
516 is_rejected: bool::from_redis_value(is_rejected)?,
517 consumed: i64::from_redis_value(consumed)?,
518 });
519 }
520
521 Ok(Self(quotas))
522 }
523}
524
525#[derive(Debug)]
527struct QuotaState {
528 is_rejected: bool,
530 consumed: i64,
532}
533
534#[cfg(test)]
535mod tests {
536 use std::time::{SystemTime, UNIX_EPOCH};
537
538 use super::*;
539 use crate::quota::{DataCategories, DataCategory, ReasonCode, Scoping};
540 use crate::rate_limit::RateLimitScope;
541 use crate::{GlobalRateLimiter, MetricNamespaceScoping};
542 use relay_base_schema::metrics::MetricNamespace;
543 use relay_base_schema::organization::OrganizationId;
544 use relay_base_schema::project::{ProjectId, ProjectKey};
545 use relay_redis::RedisConfigOptions;
546 use relay_redis::redis::AsyncCommands;
547 use smallvec::smallvec;
548 use tokio::sync::Mutex;
549
550 struct MockGlobalLimiter {
551 client: AsyncRedisClient,
552 global_rate_limiter: Mutex<GlobalRateLimiter>,
553 }
554
555 impl GlobalLimiter for MockGlobalLimiter {
556 async fn check_global_rate_limits<'a>(
557 &self,
558 global_quotas: &'a [RedisQuota<'a>],
559 ) -> Result<Vec<&'a RedisQuota<'a>>, RateLimitingError> {
560 self.global_rate_limiter
561 .lock()
562 .await
563 .filter_rate_limited(&self.client, global_quotas)
564 .await
565 }
566 }
567
568 fn build_rate_limiter() -> RedisRateLimiter<MockGlobalLimiter> {
569 let url = std::env::var("RELAY_REDIS_URL")
570 .unwrap_or_else(|_| "redis://127.0.0.1:6379".to_owned());
571 let client =
572 AsyncRedisClient::single("test", &url, &RedisConfigOptions::default()).unwrap();
573
574 let global_limiter = MockGlobalLimiter {
575 client: client.clone(),
576 global_rate_limiter: Mutex::new(GlobalRateLimiter::default()),
577 };
578
579 RedisRateLimiter {
580 client,
581 cache: None,
582 script: RedisScripts::load_is_rate_limited(),
583 max_limit: None,
584 global_limiter,
585 }
586 }
587
588 #[tokio::test]
589 async fn test_zero_size_quotas() {
590 let quotas = &[
591 Quota {
592 id: None,
593 categories: DataCategories::new(),
594 scope: QuotaScope::Organization,
595 scope_id: None,
596 limit: Some(0),
597 window: None,
598 reason_code: Some(ReasonCode::new("get_lost")),
599 namespace: None,
600 },
601 Quota {
602 id: Some("42".into()),
603 categories: DataCategories::new(),
604 scope: QuotaScope::Organization,
605 scope_id: None,
606 limit: None,
607 window: Some(42),
608 reason_code: Some(ReasonCode::new("unlimited")),
609 namespace: None,
610 },
611 ];
612
613 let scoping = ItemScoping {
614 category: DataCategory::Error,
615 scoping: Scoping {
616 organization_id: OrganizationId::new(42),
617 project_id: ProjectId::new(43),
618 project_key: ProjectKey::parse("a94ae32be2584e0bbd7a4cbb95971fee").unwrap(),
619 key_id: Some(44),
620 },
621 namespace: MetricNamespaceScoping::None,
622 };
623
624 let rate_limits: Vec<RateLimit> = build_rate_limiter()
625 .is_rate_limited(quotas, scoping, 1, false)
626 .await
627 .expect("rate limiting failed")
628 .into_iter()
629 .collect();
630
631 assert_eq!(
632 rate_limits,
633 vec![RateLimit {
634 categories: DataCategories::new(),
635 scope: RateLimitScope::Organization(OrganizationId::new(42)),
636 reason_code: Some(ReasonCode::new("get_lost")),
637 retry_after: rate_limits[0].retry_after,
638 namespaces: smallvec![],
639 }]
640 );
641 }
642
643 #[tokio::test]
645 async fn test_non_global_namespace_quota() {
646 let quota_limit = 5;
647 let get_quota = |namespace: Option<MetricNamespace>| -> Quota {
648 Quota {
649 id: Some(format!("test_simple_quota_{}", uuid::Uuid::new_v4()).into()),
650 categories: DataCategories::new(),
651 scope: QuotaScope::Organization,
652 scope_id: None,
653 limit: Some(quota_limit),
654 window: Some(600),
655 reason_code: Some(ReasonCode::new(format!("ns: {namespace:?}"))),
656 namespace,
657 }
658 };
659
660 let quotas = &[get_quota(None)];
661 let quota_with_namespace = &[get_quota(Some(MetricNamespace::Transactions))];
662
663 let scoping = ItemScoping {
664 category: DataCategory::Error,
665 scoping: Scoping {
666 organization_id: OrganizationId::new(42),
667 project_id: ProjectId::new(43),
668 project_key: ProjectKey::parse("a94ae32be2584e0bbd7a4cbb95971fee").unwrap(),
669 key_id: Some(44),
670 },
671 namespace: MetricNamespaceScoping::Some(MetricNamespace::Transactions),
672 };
673
674 let rate_limiter = build_rate_limiter();
675
676 for i in 0..10 {
678 let rate_limits: Vec<RateLimit> = rate_limiter
679 .is_rate_limited(quotas, scoping, 1, false)
680 .await
681 .expect("rate limiting failed")
682 .into_iter()
683 .collect();
684
685 if i < quota_limit {
686 assert_eq!(rate_limits, vec![]);
687 } else {
688 assert_eq!(
689 rate_limits[0].reason_code,
690 Some(ReasonCode::new("ns: None"))
691 );
692 }
693 }
694
695 for i in 0..10 {
697 let rate_limits: Vec<RateLimit> = rate_limiter
698 .is_rate_limited(quota_with_namespace, scoping, 1, false)
699 .await
700 .expect("rate limiting failed")
701 .into_iter()
702 .collect();
703
704 if i < quota_limit {
705 assert_eq!(rate_limits, vec![]);
706 } else {
707 assert_eq!(
708 rate_limits[0].reason_code,
709 Some(ReasonCode::new("ns: Some(Transactions)"))
710 );
711 }
712 }
713 }
714
715 #[tokio::test]
716 async fn test_simple_quota() {
717 let quotas = &[Quota {
718 id: Some(format!("test_simple_quota_{}", uuid::Uuid::new_v4()).into()),
719 categories: DataCategories::new(),
720 scope: QuotaScope::Organization,
721 scope_id: None,
722 limit: Some(5),
723 window: Some(60),
724 reason_code: Some(ReasonCode::new("get_lost")),
725 namespace: None,
726 }];
727
728 let scoping = ItemScoping {
729 category: DataCategory::Error,
730 scoping: Scoping {
731 organization_id: OrganizationId::new(42),
732 project_id: ProjectId::new(43),
733 project_key: ProjectKey::parse("a94ae32be2584e0bbd7a4cbb95971fee").unwrap(),
734 key_id: Some(44),
735 },
736 namespace: MetricNamespaceScoping::None,
737 };
738
739 let rate_limiter = build_rate_limiter();
740
741 for i in 0..10 {
742 let rate_limits: Vec<RateLimit> = rate_limiter
743 .is_rate_limited(quotas, scoping, 1, false)
744 .await
745 .expect("rate limiting failed")
746 .into_iter()
747 .collect();
748
749 if i >= 5 {
750 assert_eq!(
751 rate_limits,
752 vec![RateLimit {
753 categories: DataCategories::new(),
754 scope: RateLimitScope::Organization(OrganizationId::new(42)),
755 reason_code: Some(ReasonCode::new("get_lost")),
756 retry_after: rate_limits[0].retry_after,
757 namespaces: smallvec![],
758 }]
759 );
760 } else {
761 assert_eq!(rate_limits, vec![]);
762 }
763 }
764 }
765
766 #[tokio::test]
767 async fn test_simple_global_quota() {
768 let quotas = &[Quota {
769 id: Some(format!("test_simple_global_quota_{}", uuid::Uuid::new_v4()).into()),
770 categories: DataCategories::new(),
771 scope: QuotaScope::Global,
772 scope_id: None,
773 limit: Some(5),
774 window: Some(60),
775 reason_code: Some(ReasonCode::new("get_lost")),
776 namespace: None,
777 }];
778
779 let scoping = ItemScoping {
780 category: DataCategory::Error,
781 scoping: Scoping {
782 organization_id: OrganizationId::new(42),
783 project_id: ProjectId::new(43),
784 project_key: ProjectKey::parse("a94ae32be2584e0bbd7a4cbb95971fee").unwrap(),
785 key_id: Some(44),
786 },
787 namespace: MetricNamespaceScoping::None,
788 };
789
790 let rate_limiter = build_rate_limiter();
791
792 for i in 0..10 {
793 let rate_limits: Vec<RateLimit> = rate_limiter
794 .is_rate_limited(quotas, scoping, 1, false)
795 .await
796 .expect("rate limiting failed")
797 .into_iter()
798 .collect();
799
800 if i >= 5 {
801 assert_eq!(
802 rate_limits,
803 vec![RateLimit {
804 categories: DataCategories::new(),
805 scope: RateLimitScope::Global,
806 reason_code: Some(ReasonCode::new("get_lost")),
807 retry_after: rate_limits[0].retry_after,
808 namespaces: smallvec![],
809 }]
810 );
811 } else {
812 assert_eq!(rate_limits, vec![]);
813 }
814 }
815 }
816
817 #[tokio::test]
818 async fn test_quantity_0() {
819 let quotas = &[Quota {
820 id: Some(format!("test_quantity_0_{}", uuid::Uuid::new_v4()).into()),
821 categories: DataCategories::new(),
822 scope: QuotaScope::Organization,
823 scope_id: None,
824 limit: Some(1),
825 window: Some(60),
826 reason_code: Some(ReasonCode::new("get_lost")),
827 namespace: None,
828 }];
829
830 let scoping = ItemScoping {
831 category: DataCategory::Error,
832 scoping: Scoping {
833 organization_id: OrganizationId::new(42),
834 project_id: ProjectId::new(43),
835 project_key: ProjectKey::parse("a94ae32be2584e0bbd7a4cbb95971fee").unwrap(),
836 key_id: Some(44),
837 },
838 namespace: MetricNamespaceScoping::None,
839 };
840
841 let rate_limiter = build_rate_limiter();
842
843 assert!(
845 !rate_limiter
846 .is_rate_limited(quotas, scoping, 1, false)
847 .await
848 .unwrap()
849 .is_limited()
850 );
851
852 assert!(
854 rate_limiter
855 .is_rate_limited(quotas, scoping, 1, false)
856 .await
857 .unwrap()
858 .is_limited()
859 );
860
861 assert!(
863 rate_limiter
864 .is_rate_limited(quotas, scoping, 0, false)
865 .await
866 .unwrap()
867 .is_limited()
868 );
869
870 assert!(
872 rate_limiter
873 .is_rate_limited(quotas, scoping, 1, false)
874 .await
875 .unwrap()
876 .is_limited()
877 );
878 }
879
880 #[tokio::test]
881 async fn test_quota_go_over() {
882 let quotas = &[Quota {
883 id: Some(format!("test_quota_go_over{}", uuid::Uuid::new_v4()).into()),
884 categories: DataCategories::new(),
885 scope: QuotaScope::Organization,
886 scope_id: None,
887 limit: Some(2),
888 window: Some(60),
889 reason_code: Some(ReasonCode::new("get_lost")),
890 namespace: None,
891 }];
892
893 let scoping = ItemScoping {
894 category: DataCategory::Error,
895 scoping: Scoping {
896 organization_id: OrganizationId::new(42),
897 project_id: ProjectId::new(43),
898 project_key: ProjectKey::parse("a94ae32be2584e0bbd7a4cbb95971fee").unwrap(),
899 key_id: Some(44),
900 },
901 namespace: MetricNamespaceScoping::None,
902 };
903
904 let rate_limiter = build_rate_limiter();
905
906 let is_limited = rate_limiter
908 .is_rate_limited(quotas, scoping, 1, true)
909 .await
910 .unwrap()
911 .is_limited();
912 assert!(!is_limited);
913
914 let is_limited = rate_limiter
916 .is_rate_limited(quotas, scoping, 2, true)
917 .await
918 .unwrap()
919 .is_limited();
920 assert!(!is_limited);
921
922 let is_limited = rate_limiter
924 .is_rate_limited(quotas, scoping, 0, true)
925 .await
926 .unwrap()
927 .is_limited();
928 assert!(is_limited);
929
930 let is_limited = rate_limiter
932 .is_rate_limited(quotas, scoping, 1, true)
933 .await
934 .unwrap()
935 .is_limited();
936 assert!(is_limited);
937 }
938
939 #[tokio::test]
940 async fn test_bails_immediately_without_any_quota() {
941 let scoping = ItemScoping {
942 category: DataCategory::Error,
943 scoping: Scoping {
944 organization_id: OrganizationId::new(42),
945 project_id: ProjectId::new(43),
946 project_key: ProjectKey::parse("a94ae32be2584e0bbd7a4cbb95971fee").unwrap(),
947 key_id: Some(44),
948 },
949 namespace: MetricNamespaceScoping::None,
950 };
951
952 let rate_limits: Vec<RateLimit> = build_rate_limiter()
953 .is_rate_limited(&[], scoping, 1, false)
954 .await
955 .expect("rate limiting failed")
956 .into_iter()
957 .collect();
958
959 assert_eq!(rate_limits, vec![]);
960 }
961
962 #[tokio::test]
963 async fn test_limited_with_unlimited_quota() {
964 let quotas = &[
965 Quota {
966 id: Some("q0".into()),
967 categories: DataCategories::new(),
968 scope: QuotaScope::Organization,
969 scope_id: None,
970 limit: None,
971 window: Some(1),
972 reason_code: Some(ReasonCode::new("project_quota0")),
973 namespace: None,
974 },
975 Quota {
976 id: Some("q1".into()),
977 categories: DataCategories::new(),
978 scope: QuotaScope::Organization,
979 scope_id: None,
980 limit: Some(1),
981 window: Some(1),
982 reason_code: Some(ReasonCode::new("project_quota1")),
983 namespace: None,
984 },
985 ];
986
987 let scoping = ItemScoping {
988 category: DataCategory::Error,
989 scoping: Scoping {
990 organization_id: OrganizationId::new(42),
991 project_id: ProjectId::new(43),
992 project_key: ProjectKey::parse("a94ae32be2584e0bbd7a4cbb95971fee").unwrap(),
993 key_id: Some(44),
994 },
995 namespace: MetricNamespaceScoping::None,
996 };
997
998 let rate_limiter = build_rate_limiter();
999
1000 for i in 0..1 {
1001 let rate_limits: Vec<RateLimit> = rate_limiter
1002 .is_rate_limited(quotas, scoping, 1, false)
1003 .await
1004 .expect("rate limiting failed")
1005 .into_iter()
1006 .collect();
1007
1008 if i == 0 {
1009 assert_eq!(rate_limits, &[]);
1010 } else {
1011 assert_eq!(
1012 rate_limits,
1013 vec![RateLimit {
1014 categories: DataCategories::new(),
1015 scope: RateLimitScope::Organization(OrganizationId::new(42)),
1016 reason_code: Some(ReasonCode::new("project_quota1")),
1017 retry_after: rate_limits[0].retry_after,
1018 namespaces: smallvec![],
1019 }]
1020 );
1021 }
1022 }
1023 }
1024
1025 #[tokio::test]
1026 async fn test_quota_with_quantity() {
1027 let quotas = &[Quota {
1028 id: Some(format!("test_quantity_quota_{}", uuid::Uuid::new_v4()).into()),
1029 categories: DataCategories::new(),
1030 scope: QuotaScope::Organization,
1031 scope_id: None,
1032 limit: Some(500),
1033 window: Some(60),
1034 reason_code: Some(ReasonCode::new("get_lost")),
1035 namespace: None,
1036 }];
1037
1038 let scoping = ItemScoping {
1039 category: DataCategory::Error,
1040 scoping: Scoping {
1041 organization_id: OrganizationId::new(42),
1042 project_id: ProjectId::new(43),
1043 project_key: ProjectKey::parse("a94ae32be2584e0bbd7a4cbb95971fee").unwrap(),
1044 key_id: Some(44),
1045 },
1046 namespace: MetricNamespaceScoping::None,
1047 };
1048
1049 let rate_limiter = build_rate_limiter();
1050
1051 for i in 0..10 {
1052 let rate_limits: Vec<RateLimit> = rate_limiter
1053 .is_rate_limited(quotas, scoping, 100, false)
1054 .await
1055 .expect("rate limiting failed")
1056 .into_iter()
1057 .collect();
1058
1059 if i >= 5 {
1060 assert_eq!(
1061 rate_limits,
1062 vec![RateLimit {
1063 categories: DataCategories::new(),
1064 scope: RateLimitScope::Organization(OrganizationId::new(42)),
1065 reason_code: Some(ReasonCode::new("get_lost")),
1066 retry_after: rate_limits[0].retry_after,
1067 namespaces: smallvec![],
1068 }]
1069 );
1070 } else {
1071 assert_eq!(rate_limits, vec![]);
1072 }
1073 }
1074 }
1075
1076 #[tokio::test]
1077 async fn test_get_redis_key_scoped() {
1078 let quota = Quota {
1079 id: Some("foo".into()),
1080 categories: DataCategories::new(),
1081 scope: QuotaScope::Project,
1082 scope_id: Some("42".into()),
1083 window: Some(2),
1084 limit: Some(0),
1085 reason_code: None,
1086 namespace: None,
1087 };
1088
1089 let scoping = ItemScoping {
1090 category: DataCategory::Error,
1091 scoping: Scoping {
1092 organization_id: OrganizationId::new(69420),
1093 project_id: ProjectId::new(42),
1094 project_key: ProjectKey::parse("a94ae32be2584e0bbd7a4cbb95971fee").unwrap(),
1095 key_id: Some(4711),
1096 },
1097 namespace: MetricNamespaceScoping::None,
1098 };
1099
1100 let timestamp = UnixTimestamp::from_secs(123_123_123);
1101 let redis_quota = RedisQuota::new("a, 0, scoping, timestamp).unwrap();
1102 assert_eq!(redis_quota.key().to_string(), "quota:foo{69420}42:61561561");
1103 }
1104
1105 #[tokio::test]
1106 async fn test_get_redis_key_unscoped() {
1107 let quota = Quota {
1108 id: Some("foo".into()),
1109 categories: DataCategories::new(),
1110 scope: QuotaScope::Organization,
1111 scope_id: None,
1112 window: Some(10),
1113 limit: Some(0),
1114 reason_code: None,
1115 namespace: None,
1116 };
1117
1118 let scoping = ItemScoping {
1119 category: DataCategory::Error,
1120 scoping: Scoping {
1121 organization_id: OrganizationId::new(69420),
1122 project_id: ProjectId::new(42),
1123 project_key: ProjectKey::parse("a94ae32be2584e0bbd7a4cbb95971fee").unwrap(),
1124 key_id: Some(4711),
1125 },
1126 namespace: MetricNamespaceScoping::None,
1127 };
1128
1129 let timestamp = UnixTimestamp::from_secs(234_531);
1130 let redis_quota = RedisQuota::new("a, 0, scoping, timestamp).unwrap();
1131 assert_eq!(redis_quota.key().to_string(), "quota:foo{69420}:23453");
1132 }
1133
1134 #[tokio::test]
1135 async fn test_large_redis_limit_large() {
1136 let quota = Quota {
1137 id: Some("foo".into()),
1138 categories: DataCategories::new(),
1139 scope: QuotaScope::Organization,
1140 scope_id: None,
1141 window: Some(10),
1142 limit: Some(9223372036854775808), reason_code: None,
1144 namespace: None,
1145 };
1146
1147 let scoping = ItemScoping {
1148 category: DataCategory::Error,
1149 scoping: Scoping {
1150 organization_id: OrganizationId::new(69420),
1151 project_id: ProjectId::new(42),
1152 project_key: ProjectKey::parse("a94ae32be2584e0bbd7a4cbb95971fee").unwrap(),
1153 key_id: Some(4711),
1154 },
1155 namespace: MetricNamespaceScoping::None,
1156 };
1157
1158 let timestamp = UnixTimestamp::from_secs(234_531);
1159 let redis_quota = RedisQuota::new("a, 0, scoping, timestamp).unwrap();
1160 assert_eq!(redis_quota.limit(), -1);
1161 }
1162
1163 #[tokio::test]
1164 async fn test_is_rate_limited_script() {
1165 let now = SystemTime::now()
1166 .duration_since(UNIX_EPOCH)
1167 .map(|duration| duration.as_secs())
1168 .unwrap();
1169
1170 let rate_limiter = build_rate_limiter();
1171 let mut conn = rate_limiter.client.get_connection().await.unwrap();
1172
1173 let foo = format!("foo___{now}");
1175 let r_foo = format!("r:foo___{now}");
1176 let bar = format!("bar___{now}");
1177 let r_bar = format!("r:bar___{now}");
1178 let apple = format!("apple___{now}");
1179 let orange = format!("orange___{now}");
1180 let baz = format!("baz___{now}");
1181
1182 let script = RedisScripts::load_is_rate_limited();
1183
1184 macro_rules! assert_invocation {
1185 ($invocation:expr, $($tt:tt)*) => {{
1186 let result = $invocation
1187 .invoke_async::<ScriptResult>(&mut conn)
1188 .await
1189 .unwrap();
1190
1191 insta::assert_debug_snapshot!(result, $($tt)*);
1192 }};
1193 }
1194
1195 let mut invocation = script.prepare_invoke();
1196 invocation
1197 .key(&foo) .key(&r_foo) .key(&bar) .key(&r_bar) .arg(1) .arg(now + 60) .arg(1) .arg(false) .arg(2) .arg(now + 120) .arg(1) .arg(false); let mut invocation2 = script.prepare_invoke();
1213 invocation2
1214 .key(&bar) .key(&r_bar) .arg(2) .arg(now + 120) .arg(1) .arg(false); assert_invocation!(invocation, @r"
1223 ScriptResult(
1224 [
1225 QuotaState {
1226 is_rejected: false,
1227 consumed: 1,
1228 },
1229 QuotaState {
1230 is_rejected: false,
1231 consumed: 1,
1232 },
1233 ],
1234 )
1235 "
1236 );
1237
1238 assert_invocation!(invocation, @r"
1241 ScriptResult(
1242 [
1243 QuotaState {
1244 is_rejected: true,
1245 consumed: 1,
1246 },
1247 QuotaState {
1248 is_rejected: false,
1249 consumed: 1,
1250 },
1251 ],
1252 )
1253 "
1254 );
1255
1256 assert_invocation!(invocation, @r"
1258 ScriptResult(
1259 [
1260 QuotaState {
1261 is_rejected: true,
1262 consumed: 1,
1263 },
1264 QuotaState {
1265 is_rejected: false,
1266 consumed: 1,
1267 },
1268 ],
1269 )
1270 "
1271 );
1272
1273 assert_invocation!(invocation2, @r"
1276 ScriptResult(
1277 [
1278 QuotaState {
1279 is_rejected: false,
1280 consumed: 2,
1281 },
1282 ],
1283 )
1284 "
1285 );
1286
1287 assert_invocation!(invocation2, @r"
1289 ScriptResult(
1290 [
1291 QuotaState {
1292 is_rejected: true,
1293 consumed: 2,
1294 },
1295 ],
1296 )
1297 "
1298 );
1299
1300 assert_invocation!(invocation, @r"
1302 ScriptResult(
1303 [
1304 QuotaState {
1305 is_rejected: true,
1306 consumed: 1,
1307 },
1308 QuotaState {
1309 is_rejected: true,
1310 consumed: 2,
1311 },
1312 ],
1313 )
1314 "
1315 );
1316
1317 assert_eq!(conn.get::<_, String>(&foo).await.unwrap(), "1");
1318 let ttl: u64 = conn.ttl(&foo).await.unwrap();
1319 assert!(ttl >= 59);
1320 assert!(ttl <= 60);
1321
1322 assert_eq!(conn.get::<_, String>(&bar).await.unwrap(), "2");
1323 let ttl: u64 = conn.ttl(&bar).await.unwrap();
1324 assert!(ttl >= 119);
1325 assert!(ttl <= 120);
1326
1327 let () = conn.get(r_foo).await.unwrap();
1329 let () = conn.get(r_bar).await.unwrap();
1330
1331 let () = conn.set(&apple, 5).await.unwrap();
1333
1334 let mut invocation = script.prepare_invoke();
1335 invocation
1336 .key(&orange) .key(&baz) .arg(1) .arg(now + 60) .arg(1) .arg(false);
1342
1343 assert_invocation!(invocation, @r"
1345 ScriptResult(
1346 [
1347 QuotaState {
1348 is_rejected: false,
1349 consumed: 1,
1350 },
1351 ],
1352 )
1353 "
1354 );
1355
1356 assert_invocation!(invocation, @r"
1358 ScriptResult(
1359 [
1360 QuotaState {
1361 is_rejected: true,
1362 consumed: 1,
1363 },
1364 ],
1365 )
1366 "
1367 );
1368
1369 assert_invocation!(invocation, @r"
1371 ScriptResult(
1372 [
1373 QuotaState {
1374 is_rejected: true,
1375 consumed: 1,
1376 },
1377 ],
1378 )
1379 "
1380 );
1381
1382 let mut invocation = script.prepare_invoke();
1383 invocation
1384 .key(&orange) .key(&apple) .arg(1) .arg(now + 60) .arg(1) .arg(false);
1390
1391 assert_invocation!(invocation, @r"
1393 ScriptResult(
1394 [
1395 QuotaState {
1396 is_rejected: false,
1397 consumed: -3,
1398 },
1399 ],
1400 )
1401 "
1402 );
1403 }
1404
1405 #[tokio::test]
1407 async fn test_quota_with_cache() {
1408 let quotas = &[Quota {
1409 id: Some(format!("test_simple_quota_{}", uuid::Uuid::new_v4()).into()),
1410 categories: DataCategories::new(),
1411 scope: QuotaScope::Organization,
1412 scope_id: None,
1413 limit: Some(50),
1414 window: Some(60),
1415 reason_code: Some(ReasonCode::new("get_lost")),
1416 namespace: None,
1417 }];
1418
1419 let scoping = ItemScoping {
1420 category: DataCategory::Error,
1421 scoping: Scoping {
1422 organization_id: OrganizationId::new(42),
1423 project_id: ProjectId::new(43),
1424 project_key: ProjectKey::parse("a94ae32be2584e0bbd7a4cbb95971fee").unwrap(),
1425 key_id: Some(44),
1426 },
1427 namespace: MetricNamespaceScoping::None,
1428 };
1429
1430 let rate_limiter = build_rate_limiter().cache(Some(0.1), Some(0.9));
1433
1434 for _ in 0..50 {
1435 let rate_limits = rate_limiter
1436 .is_rate_limited(quotas, scoping, 1, false)
1437 .await
1438 .unwrap();
1439
1440 assert!(rate_limits.is_empty());
1441 }
1442
1443 let rate_limits: Vec<RateLimit> = rate_limiter
1444 .is_rate_limited(quotas, scoping, 1, false)
1445 .await
1446 .expect("rate limiting failed")
1447 .into_iter()
1448 .collect();
1449
1450 assert_eq!(
1451 rate_limits,
1452 vec![RateLimit {
1453 categories: DataCategories::new(),
1454 scope: RateLimitScope::Organization(OrganizationId::new(42)),
1455 reason_code: Some(ReasonCode::new("get_lost")),
1456 retry_after: rate_limits[0].retry_after,
1457 namespaces: smallvec![],
1458 }]
1459 );
1460 }
1461
1462 #[tokio::test]
1463 async fn test_quota_with_cache_slightly_over_account() {
1464 let window = 60;
1465 let limit = 50 * window;
1466
1467 let quotas = &[Quota {
1468 id: Some(format!("test_simple_quota_{}", uuid::Uuid::new_v4()).into()),
1469 categories: DataCategories::new(),
1470 scope: QuotaScope::Organization,
1471 scope_id: None,
1472 limit: Some(limit),
1473 window: Some(window),
1474 reason_code: Some(ReasonCode::new("get_lost")),
1475 namespace: None,
1476 }];
1477
1478 let scoping = ItemScoping {
1479 category: DataCategory::Error,
1480 scoping: Scoping {
1481 organization_id: OrganizationId::new(42),
1482 project_id: ProjectId::new(43),
1483 project_key: ProjectKey::parse("a94ae32be2584e0bbd7a4cbb95971fee").unwrap(),
1484 key_id: Some(44),
1485 },
1486 namespace: MetricNamespaceScoping::None,
1487 };
1488
1489 let rate_limiter1 = build_rate_limiter().cache(Some(0.1), None);
1491 let rate_limiter2 = build_rate_limiter().cache(Some(0.1), None);
1492
1493 let rate_limits = rate_limiter1
1495 .is_rate_limited(quotas, scoping, 1, false)
1496 .await
1497 .unwrap();
1498 assert!(rate_limits.is_empty());
1499 let rate_limits = rate_limiter1
1501 .is_rate_limited(quotas, scoping, 3, false)
1502 .await
1503 .unwrap();
1504 assert!(rate_limits.is_empty());
1505
1506 let rate_limits = rate_limiter2
1508 .is_rate_limited(quotas, scoping, limit as usize - 1, false)
1509 .await
1510 .unwrap();
1511 assert!(rate_limits.is_empty());
1512
1513 let rate_limits = rate_limiter1
1515 .is_rate_limited(quotas, scoping, 1, false)
1516 .await
1517 .unwrap();
1518 assert!(rate_limits.is_empty());
1519
1520 let rate_limits: Vec<RateLimit> = rate_limiter1
1522 .is_rate_limited(quotas, scoping, 1, false)
1523 .await
1524 .unwrap()
1525 .into_iter()
1526 .collect();
1527
1528 assert_eq!(
1529 rate_limits,
1530 vec![RateLimit {
1531 categories: DataCategories::new(),
1532 scope: RateLimitScope::Organization(OrganizationId::new(42)),
1533 reason_code: Some(ReasonCode::new("get_lost")),
1534 retry_after: rate_limits[0].retry_after,
1535 namespaces: smallvec![],
1536 }]
1537 );
1538 }
1539}