relay_quotas/
redis.rs

1use std::fmt::{self, Debug};
2
3use relay_common::time::UnixTimestamp;
4use relay_log::protocol::value;
5use relay_redis::redis::Script;
6use relay_redis::{AsyncRedisClient, RedisError, RedisScripts};
7use thiserror::Error;
8
9use crate::REJECT_ALL_SECS;
10use crate::global::GlobalLimiter;
11use crate::quota::{ItemScoping, Quota, QuotaScope};
12use crate::rate_limit::{RateLimit, RateLimits, RetryAfter};
13
14/// The `grace` period allows accommodating for clock drift in TTL
15/// calculation since the clock on the Redis instance used to store quota
16/// metrics may not be in sync with the computer running this code.
17const GRACE: u64 = 60;
18
19/// An error returned by [`RedisRateLimiter`].
20#[derive(Debug, Error)]
21pub enum RateLimitingError {
22    /// Failed to communicate with Redis.
23    #[error("failed to communicate with redis")]
24    Redis(
25        #[from]
26        #[source]
27        RedisError,
28    ),
29
30    /// Failed to check global rate limits via the service.
31    #[error("failed to check global rate limits")]
32    UnreachableGlobalRateLimits,
33}
34
35/// Creates a refund key for a given counter key.
36///
37/// Refund keys are used to track credits that should be applied to a quota,
38/// allowing for more flexible quota management.
39fn get_refunded_quota_key(counter_key: &str) -> String {
40    format!("r:{counter_key}")
41}
42
43/// A transparent wrapper around an Option that only displays `Some`.
44struct OptionalDisplay<T>(Option<T>);
45
46impl<T> fmt::Display for OptionalDisplay<T>
47where
48    T: fmt::Display,
49{
50    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
51        match self.0 {
52            Some(ref value) => write!(f, "{value}"),
53            None => Ok(()),
54        }
55    }
56}
57
58/// Owned version of [`RedisQuota`].
59#[derive(Debug, Clone)]
60pub struct OwnedRedisQuota {
61    /// The original quota.
62    quota: Quota,
63    /// Scopes of the item being tracked.
64    scoping: ItemScoping,
65    /// The Redis key prefix mapped from the quota id.
66    prefix: String,
67    /// The redis window in seconds mapped from the quota.
68    window: u64,
69    /// The ingestion timestamp determining the rate limiting bucket.
70    timestamp: UnixTimestamp,
71}
72
73impl OwnedRedisQuota {
74    /// Returns an instance of [`RedisQuota`] which borrows from this [`OwnedRedisQuota`].
75    pub fn build_ref(&self) -> RedisQuota {
76        RedisQuota {
77            quota: &self.quota,
78            scoping: self.scoping,
79            prefix: &self.prefix,
80            window: self.window,
81            timestamp: self.timestamp,
82        }
83    }
84}
85
86/// Reference to information required for tracking quotas in Redis.
87#[derive(Debug, Clone, Eq, PartialEq)]
88pub struct RedisQuota<'a> {
89    /// The original quota.
90    quota: &'a Quota,
91    /// Scopes of the item being tracked.
92    scoping: ItemScoping,
93    /// The Redis key prefix mapped from the quota id.
94    prefix: &'a str,
95    /// The redis window in seconds mapped from the quota.
96    window: u64,
97    /// The ingestion timestamp determining the rate limiting bucket.
98    timestamp: UnixTimestamp,
99}
100
101impl<'a> RedisQuota<'a> {
102    /// Creates a new [`RedisQuota`] from a [`Quota`], item scoping, and timestamp.
103    ///
104    /// Returns `None` if the quota cannot be tracked in Redis because it's missing
105    /// required fields (ID or window). This allows forward compatibility with
106    /// future quota types.
107    pub fn new(quota: &'a Quota, scoping: ItemScoping, timestamp: UnixTimestamp) -> Option<Self> {
108        // These fields indicate that we *can* track this quota.
109        let prefix = quota.id.as_deref()?;
110        let window = quota.window?;
111
112        Some(Self {
113            quota,
114            scoping,
115            prefix,
116            window,
117            timestamp,
118        })
119    }
120
121    /// Converts this [`RedisQuota`] to an [`OwnedRedisQuota`] leaving the original
122    /// struct in place.
123    pub fn build_owned(&self) -> OwnedRedisQuota {
124        OwnedRedisQuota {
125            quota: self.quota.clone(),
126            scoping: self.scoping,
127            prefix: self.prefix.to_string(),
128            window: self.window,
129            timestamp: self.timestamp,
130        }
131    }
132
133    /// Returns the window size of the quota in seconds.
134    pub fn window(&self) -> u64 {
135        self.window
136    }
137
138    /// Returns the prefix of the quota used for Redis key generation.
139    pub fn prefix(&self) -> &'a str {
140        self.prefix
141    }
142
143    /// Returns the limit value formatted for Redis.
144    ///
145    /// Returns `-1` for unlimited quotas or when the limit doesn't fit into an `i64`.
146    /// Otherwise, returns the limit value as an `i64`.
147    pub fn limit(&self) -> i64 {
148        self.limit
149            // If it does not fit into i64, treat as unlimited:
150            .and_then(|limit| limit.try_into().ok())
151            .unwrap_or(-1)
152    }
153
154    fn shift(&self) -> u64 {
155        if self.quota.scope == QuotaScope::Global {
156            0
157        } else {
158            self.scoping.organization_id.value() % self.window
159        }
160    }
161
162    /// Returns the current time slot of the quota based on the timestamp.
163    ///
164    /// Slots are used to determine the time bucket for rate limiting.
165    pub fn slot(&self) -> u64 {
166        (self.timestamp.as_secs() - self.shift()) / self.window
167    }
168
169    /// Returns the timestamp when the current quota window will expire.
170    pub fn expiry(&self) -> UnixTimestamp {
171        let next_slot = self.slot() + 1;
172        let next_start = next_slot * self.window + self.shift();
173        UnixTimestamp::from_secs(next_start)
174    }
175
176    /// Returns when the Redis key should expire.
177    ///
178    /// This is the expiry time plus a grace period.
179    pub fn key_expiry(&self) -> u64 {
180        self.expiry().as_secs() + GRACE
181    }
182
183    /// Returns the Redis key for this quota.
184    ///
185    /// The key includes the quota ID, organization ID, and other scoping information
186    /// based on the quota's scope type. Keys are structured to ensure proper isolation
187    /// between different organizations and scopes.
188    pub fn key(&self) -> String {
189        // The subscope id is only formatted into the key if the quota is not organization-scoped.
190        // The organization id is always included.
191        let subscope = match self.quota.scope {
192            QuotaScope::Global => None,
193            QuotaScope::Organization => None,
194            scope => self.scoping.scope_id(scope),
195        };
196
197        let org = self.scoping.organization_id;
198
199        format!(
200            "quota:{id}{{{org}}}{subscope}{namespace}:{slot}",
201            id = self.prefix,
202            org = org,
203            subscope = OptionalDisplay(subscope),
204            namespace = OptionalDisplay(self.namespace),
205            slot = self.slot(),
206        )
207    }
208}
209
210impl std::ops::Deref for RedisQuota<'_> {
211    type Target = Quota;
212
213    fn deref(&self) -> &Self::Target {
214        self.quota
215    }
216}
217
218/// A service that executes quotas and checks for rate limits in a shared cache.
219///
220/// Quotas handle tracking a project's usage and respond whether a project has been
221/// configured to throttle incoming data if they go beyond the specified quota.
222///
223/// Quotas can specify a window to be tracked in, such as per minute or per hour. Additionally,
224/// quotas allow to specify the data categories they apply to, for example error events or
225/// attachments. For more information on quota parameters, see [`Quota`].
226///
227/// Requires the `redis` feature.
228pub struct RedisRateLimiter<T> {
229    client: AsyncRedisClient,
230    script: &'static Script,
231    max_limit: Option<u64>,
232    global_limiter: T,
233}
234
235impl<T: GlobalLimiter> RedisRateLimiter<T> {
236    /// Creates a new [`RedisRateLimiter`] instance.
237    pub fn new(client: AsyncRedisClient, global_limiter: T) -> Self {
238        RedisRateLimiter {
239            client,
240            script: RedisScripts::load_is_rate_limited(),
241            max_limit: None,
242            global_limiter,
243        }
244    }
245
246    /// Sets the maximum rate limit in seconds.
247    ///
248    /// By default, this rate limiter will return rate limits based on the quotas' `window` fields.
249    /// If a maximum rate limit is set, the returned rate limit will be bounded by this value.
250    pub fn max_limit(mut self, max_limit: Option<u64>) -> Self {
251        self.max_limit = max_limit;
252        self
253    }
254
255    /// Checks whether any of the quotas in effect have been exceeded and records consumption.
256    ///
257    /// By invoking this method, the caller signals that data is being ingested and needs to be
258    /// counted against the quota. This increment happens atomically if none of the quotas have been
259    /// exceeded. Otherwise, a rate limit is returned and data is not counted against the quotas.
260    ///
261    /// If no key is specified, then only organization-wide and project-wide quotas are checked. If
262    /// a key is specified, then key-quotas are also checked.
263    ///
264    /// When `over_accept_once` is set to `true` and the current quota would be exceeded by the
265    /// provided `quantity`, the data is accepted once and subsequent requests will be rejected
266    /// until the quota refreshes.
267    ///
268    /// A `quantity` of `0` can be used to check if the quota limit has been reached or exceeded
269    /// without incrementing it in the success case. This is useful for checking quotas in a different
270    /// data category.
271    pub async fn is_rate_limited<'a>(
272        &self,
273        quotas: impl IntoIterator<Item = &'a Quota>,
274        item_scoping: ItemScoping,
275        quantity: usize,
276        over_accept_once: bool,
277    ) -> Result<RateLimits, RateLimitingError> {
278        let timestamp = UnixTimestamp::now();
279        let mut invocation = self.script.prepare_invoke();
280        let mut tracked_quotas = Vec::new();
281        let mut rate_limits = RateLimits::new();
282
283        let mut global_quotas = vec![];
284
285        for quota in quotas {
286            if !quota.matches(item_scoping) {
287                // Silently skip all quotas that do not apply to this item.
288            } else if quota.limit == Some(0) {
289                // A zero-sized quota is strongest. Do not call into Redis at all, and do not
290                // increment any keys, as one quota has reached capacity (this is how regular quotas
291                // behave as well).
292                let retry_after = self.retry_after(REJECT_ALL_SECS);
293                rate_limits.add(RateLimit::from_quota(quota, *item_scoping, retry_after));
294            } else if let Some(quota) = RedisQuota::new(quota, item_scoping, timestamp) {
295                if quota.scope == QuotaScope::Global {
296                    global_quotas.push(quota);
297                } else {
298                    let key = quota.key();
299                    // Remaining quotas are expected to be trackable in Redis.
300                    let refund_key = get_refunded_quota_key(&key);
301
302                    invocation.key(key);
303                    invocation.key(refund_key);
304
305                    invocation.arg(quota.limit());
306                    invocation.arg(quota.key_expiry());
307                    invocation.arg(quantity);
308                    invocation.arg(over_accept_once);
309
310                    tracked_quotas.push(quota);
311                }
312            } else {
313                // This quota is neither a static reject-all, nor can it be tracked in Redis due to
314                // missing fields. We're skipping this for forward-compatibility.
315                relay_log::with_scope(
316                    |scope| scope.set_extra("quota", value::to_value(quota).unwrap()),
317                    || relay_log::warn!("skipping unsupported quota"),
318                )
319            }
320        }
321
322        // We check the global rate limits before the other limits. This step must be separate from
323        // checking the other rate limits, since those are checked with a Redis script that works
324        // under the invariant that all keys are within the same Redis instance (given their partitioning).
325        // Global keys on the other hand are always on the same instance, so if they were to be mixed
326        // with normal keys the script will end up referencing keys from multiple instances, making it
327        // impossible for the script to work.
328        let rate_limited_global_quotas = self
329            .global_limiter
330            .check_global_rate_limits(&global_quotas, quantity)
331            .await?;
332
333        for quota in rate_limited_global_quotas {
334            let retry_after = self.retry_after((quota.expiry() - timestamp).as_secs());
335            rate_limits.add(RateLimit::from_quota(quota, *item_scoping, retry_after));
336        }
337
338        // Either there are no quotas to run against Redis, or we already have a rate limit from a
339        // zero-sized quota. In either cases, skip invoking the script and return early.
340        if tracked_quotas.is_empty() || rate_limits.is_limited() {
341            return Ok(rate_limits);
342        }
343
344        // We get the redis client after the global rate limiting since we don't want to hold the
345        // client across await points, otherwise it might be held for too long, and we will run out
346        // of connections.
347        let mut connection = self.client.get_connection().await?;
348        let rejections: Vec<bool> = invocation
349            .invoke_async(&mut connection)
350            .await
351            .map_err(RedisError::Redis)?;
352
353        for (quota, is_rejected) in tracked_quotas.iter().zip(rejections) {
354            if is_rejected {
355                let retry_after = self.retry_after((quota.expiry() - timestamp).as_secs());
356                rate_limits.add(RateLimit::from_quota(quota, *item_scoping, retry_after));
357            }
358        }
359
360        Ok(rate_limits)
361    }
362
363    /// Creates a [`RetryAfter`] value that is bounded by the configured [`max_limit`](Self::max_limit).
364    ///
365    /// If a maximum rate limit has been set, the returned value will not exceed that limit.
366    fn retry_after(&self, mut seconds: u64) -> RetryAfter {
367        if let Some(max_limit) = self.max_limit {
368            seconds = std::cmp::min(seconds, max_limit);
369        }
370
371        RetryAfter::from_secs(seconds)
372    }
373}
374
375#[cfg(test)]
376mod tests {
377    use std::time::{SystemTime, UNIX_EPOCH};
378
379    use super::*;
380    use crate::quota::{DataCategories, DataCategory, ReasonCode, Scoping};
381    use crate::rate_limit::RateLimitScope;
382    use crate::{GlobalRateLimiter, MetricNamespaceScoping};
383    use relay_base_schema::metrics::MetricNamespace;
384    use relay_base_schema::organization::OrganizationId;
385    use relay_base_schema::project::{ProjectId, ProjectKey};
386    use relay_redis::RedisConfigOptions;
387    use relay_redis::redis::AsyncCommands;
388    use smallvec::smallvec;
389    use tokio::sync::Mutex;
390
391    struct MockGlobalLimiter {
392        client: AsyncRedisClient,
393        global_rate_limiter: Mutex<GlobalRateLimiter>,
394    }
395
396    impl GlobalLimiter for MockGlobalLimiter {
397        async fn check_global_rate_limits<'a>(
398            &self,
399            global_quotas: &'a [RedisQuota<'a>],
400            quantity: usize,
401        ) -> Result<Vec<&'a RedisQuota<'a>>, RateLimitingError> {
402            self.global_rate_limiter
403                .lock()
404                .await
405                .filter_rate_limited(&self.client, global_quotas, quantity)
406                .await
407        }
408    }
409
410    fn build_rate_limiter() -> RedisRateLimiter<MockGlobalLimiter> {
411        let url = std::env::var("RELAY_REDIS_URL")
412            .unwrap_or_else(|_| "redis://127.0.0.1:6379".to_owned());
413        let client = AsyncRedisClient::single(&url, &RedisConfigOptions::default()).unwrap();
414
415        let global_limiter = MockGlobalLimiter {
416            client: client.clone(),
417            global_rate_limiter: Mutex::new(GlobalRateLimiter::default()),
418        };
419
420        RedisRateLimiter {
421            client,
422            script: RedisScripts::load_is_rate_limited(),
423            max_limit: None,
424            global_limiter,
425        }
426    }
427
428    #[tokio::test]
429    async fn test_zero_size_quotas() {
430        let quotas = &[
431            Quota {
432                id: None,
433                categories: DataCategories::new(),
434                scope: QuotaScope::Organization,
435                scope_id: None,
436                limit: Some(0),
437                window: None,
438                reason_code: Some(ReasonCode::new("get_lost")),
439                namespace: None,
440            },
441            Quota {
442                id: Some("42".to_owned()),
443                categories: DataCategories::new(),
444                scope: QuotaScope::Organization,
445                scope_id: None,
446                limit: None,
447                window: Some(42),
448                reason_code: Some(ReasonCode::new("unlimited")),
449                namespace: None,
450            },
451        ];
452
453        let scoping = ItemScoping {
454            category: DataCategory::Error,
455            scoping: Scoping {
456                organization_id: OrganizationId::new(42),
457                project_id: ProjectId::new(43),
458                project_key: ProjectKey::parse("a94ae32be2584e0bbd7a4cbb95971fee").unwrap(),
459                key_id: Some(44),
460            },
461            namespace: MetricNamespaceScoping::None,
462        };
463
464        let rate_limits: Vec<RateLimit> = build_rate_limiter()
465            .is_rate_limited(quotas, scoping, 1, false)
466            .await
467            .expect("rate limiting failed")
468            .into_iter()
469            .collect();
470
471        assert_eq!(
472            rate_limits,
473            vec![RateLimit {
474                categories: DataCategories::new(),
475                scope: RateLimitScope::Organization(OrganizationId::new(42)),
476                reason_code: Some(ReasonCode::new("get_lost")),
477                retry_after: rate_limits[0].retry_after,
478                namespaces: smallvec![],
479            }]
480        );
481    }
482
483    /// Tests that a quota with and without namespace are counted separately.
484    #[tokio::test]
485    async fn test_non_global_namespace_quota() {
486        let quota_limit = 5;
487        let get_quota = |namespace: Option<MetricNamespace>| -> Quota {
488            Quota {
489                id: Some(format!("test_simple_quota_{}", uuid::Uuid::new_v4())),
490                categories: DataCategories::new(),
491                scope: QuotaScope::Organization,
492                scope_id: None,
493                limit: Some(quota_limit),
494                window: Some(600),
495                reason_code: Some(ReasonCode::new(format!("ns: {:?}", namespace))),
496                namespace,
497            }
498        };
499
500        let quotas = &[get_quota(None)];
501        let quota_with_namespace = &[get_quota(Some(MetricNamespace::Transactions))];
502
503        let scoping = ItemScoping {
504            category: DataCategory::Error,
505            scoping: Scoping {
506                organization_id: OrganizationId::new(42),
507                project_id: ProjectId::new(43),
508                project_key: ProjectKey::parse("a94ae32be2584e0bbd7a4cbb95971fee").unwrap(),
509                key_id: Some(44),
510            },
511            namespace: MetricNamespaceScoping::Some(MetricNamespace::Transactions),
512        };
513
514        let rate_limiter = build_rate_limiter();
515
516        // First confirm normal behaviour without namespace.
517        for i in 0..10 {
518            let rate_limits: Vec<RateLimit> = rate_limiter
519                .is_rate_limited(quotas, scoping, 1, false)
520                .await
521                .expect("rate limiting failed")
522                .into_iter()
523                .collect();
524
525            if i < quota_limit {
526                assert_eq!(rate_limits, vec![]);
527            } else {
528                assert_eq!(
529                    rate_limits[0].reason_code,
530                    Some(ReasonCode::new("ns: None"))
531                );
532            }
533        }
534
535        // Then, send identical quota with namespace and confirm it counts separately.
536        for i in 0..10 {
537            let rate_limits: Vec<RateLimit> = rate_limiter
538                .is_rate_limited(quota_with_namespace, scoping, 1, false)
539                .await
540                .expect("rate limiting failed")
541                .into_iter()
542                .collect();
543
544            if i < quota_limit {
545                assert_eq!(rate_limits, vec![]);
546            } else {
547                assert_eq!(
548                    rate_limits[0].reason_code,
549                    Some(ReasonCode::new("ns: Some(Transactions)"))
550                );
551            }
552        }
553    }
554
555    #[tokio::test]
556    async fn test_simple_quota() {
557        let quotas = &[Quota {
558            id: Some(format!("test_simple_quota_{}", uuid::Uuid::new_v4())),
559            categories: DataCategories::new(),
560            scope: QuotaScope::Organization,
561            scope_id: None,
562            limit: Some(5),
563            window: Some(60),
564            reason_code: Some(ReasonCode::new("get_lost")),
565            namespace: None,
566        }];
567
568        let scoping = ItemScoping {
569            category: DataCategory::Error,
570            scoping: Scoping {
571                organization_id: OrganizationId::new(42),
572                project_id: ProjectId::new(43),
573                project_key: ProjectKey::parse("a94ae32be2584e0bbd7a4cbb95971fee").unwrap(),
574                key_id: Some(44),
575            },
576            namespace: MetricNamespaceScoping::None,
577        };
578
579        let rate_limiter = build_rate_limiter();
580
581        for i in 0..10 {
582            let rate_limits: Vec<RateLimit> = rate_limiter
583                .is_rate_limited(quotas, scoping, 1, false)
584                .await
585                .expect("rate limiting failed")
586                .into_iter()
587                .collect();
588
589            if i >= 5 {
590                assert_eq!(
591                    rate_limits,
592                    vec![RateLimit {
593                        categories: DataCategories::new(),
594                        scope: RateLimitScope::Organization(OrganizationId::new(42)),
595                        reason_code: Some(ReasonCode::new("get_lost")),
596                        retry_after: rate_limits[0].retry_after,
597                        namespaces: smallvec![],
598                    }]
599                );
600            } else {
601                assert_eq!(rate_limits, vec![]);
602            }
603        }
604    }
605
606    #[tokio::test]
607    async fn test_simple_global_quota() {
608        let quotas = &[Quota {
609            id: Some(format!("test_simple_global_quota_{}", uuid::Uuid::new_v4())),
610            categories: DataCategories::new(),
611            scope: QuotaScope::Global,
612            scope_id: None,
613            limit: Some(5),
614            window: Some(60),
615            reason_code: Some(ReasonCode::new("get_lost")),
616            namespace: None,
617        }];
618
619        let scoping = ItemScoping {
620            category: DataCategory::Error,
621            scoping: Scoping {
622                organization_id: OrganizationId::new(42),
623                project_id: ProjectId::new(43),
624                project_key: ProjectKey::parse("a94ae32be2584e0bbd7a4cbb95971fee").unwrap(),
625                key_id: Some(44),
626            },
627            namespace: MetricNamespaceScoping::None,
628        };
629
630        let rate_limiter = build_rate_limiter();
631
632        for i in 0..10 {
633            let rate_limits: Vec<RateLimit> = rate_limiter
634                .is_rate_limited(quotas, scoping, 1, false)
635                .await
636                .expect("rate limiting failed")
637                .into_iter()
638                .collect();
639
640            if i >= 5 {
641                assert_eq!(
642                    rate_limits,
643                    vec![RateLimit {
644                        categories: DataCategories::new(),
645                        scope: RateLimitScope::Global,
646                        reason_code: Some(ReasonCode::new("get_lost")),
647                        retry_after: rate_limits[0].retry_after,
648                        namespaces: smallvec![],
649                    }]
650                );
651            } else {
652                assert_eq!(rate_limits, vec![]);
653            }
654        }
655    }
656
657    #[tokio::test]
658    async fn test_quantity_0() {
659        let quotas = &[Quota {
660            id: Some(format!("test_quantity_0_{}", uuid::Uuid::new_v4())),
661            categories: DataCategories::new(),
662            scope: QuotaScope::Organization,
663            scope_id: None,
664            limit: Some(1),
665            window: Some(60),
666            reason_code: Some(ReasonCode::new("get_lost")),
667            namespace: None,
668        }];
669
670        let scoping = ItemScoping {
671            category: DataCategory::Error,
672            scoping: Scoping {
673                organization_id: OrganizationId::new(42),
674                project_id: ProjectId::new(43),
675                project_key: ProjectKey::parse("a94ae32be2584e0bbd7a4cbb95971fee").unwrap(),
676                key_id: Some(44),
677            },
678            namespace: MetricNamespaceScoping::None,
679        };
680
681        let rate_limiter = build_rate_limiter();
682
683        // limit is 1, so first call not rate limited
684        assert!(
685            !rate_limiter
686                .is_rate_limited(quotas, scoping, 1, false)
687                .await
688                .unwrap()
689                .is_limited()
690        );
691
692        // quota is now exhausted
693        assert!(
694            rate_limiter
695                .is_rate_limited(quotas, scoping, 1, false)
696                .await
697                .unwrap()
698                .is_limited()
699        );
700
701        // quota is exhausted, regardless of the quantity
702        assert!(
703            rate_limiter
704                .is_rate_limited(quotas, scoping, 0, false)
705                .await
706                .unwrap()
707                .is_limited()
708        );
709
710        // quota is exhausted, regardless of the quantity
711        assert!(
712            rate_limiter
713                .is_rate_limited(quotas, scoping, 1, false)
714                .await
715                .unwrap()
716                .is_limited()
717        );
718    }
719
720    #[tokio::test]
721    async fn test_quota_go_over() {
722        let quotas = &[Quota {
723            id: Some(format!("test_quota_go_over{}", uuid::Uuid::new_v4())),
724            categories: DataCategories::new(),
725            scope: QuotaScope::Organization,
726            scope_id: None,
727            limit: Some(2),
728            window: Some(60),
729            reason_code: Some(ReasonCode::new("get_lost")),
730            namespace: None,
731        }];
732
733        let scoping = ItemScoping {
734            category: DataCategory::Error,
735            scoping: Scoping {
736                organization_id: OrganizationId::new(42),
737                project_id: ProjectId::new(43),
738                project_key: ProjectKey::parse("a94ae32be2584e0bbd7a4cbb95971fee").unwrap(),
739                key_id: Some(44),
740            },
741            namespace: MetricNamespaceScoping::None,
742        };
743
744        let rate_limiter = build_rate_limiter();
745
746        // limit is 2, so first call not rate limited
747        let is_limited = rate_limiter
748            .is_rate_limited(quotas, scoping, 1, true)
749            .await
750            .unwrap()
751            .is_limited();
752        assert!(!is_limited);
753
754        // go over limit, but first call is over-accepted
755        let is_limited = rate_limiter
756            .is_rate_limited(quotas, scoping, 2, true)
757            .await
758            .unwrap()
759            .is_limited();
760        assert!(!is_limited);
761
762        // quota is exhausted, regardless of the quantity
763        let is_limited = rate_limiter
764            .is_rate_limited(quotas, scoping, 0, true)
765            .await
766            .unwrap()
767            .is_limited();
768        assert!(is_limited);
769
770        // quota is exhausted, regardless of the quantity
771        let is_limited = rate_limiter
772            .is_rate_limited(quotas, scoping, 1, true)
773            .await
774            .unwrap()
775            .is_limited();
776        assert!(is_limited);
777    }
778
779    #[tokio::test]
780    async fn test_bails_immediately_without_any_quota() {
781        let scoping = ItemScoping {
782            category: DataCategory::Error,
783            scoping: Scoping {
784                organization_id: OrganizationId::new(42),
785                project_id: ProjectId::new(43),
786                project_key: ProjectKey::parse("a94ae32be2584e0bbd7a4cbb95971fee").unwrap(),
787                key_id: Some(44),
788            },
789            namespace: MetricNamespaceScoping::None,
790        };
791
792        let rate_limits: Vec<RateLimit> = build_rate_limiter()
793            .is_rate_limited(&[], scoping, 1, false)
794            .await
795            .expect("rate limiting failed")
796            .into_iter()
797            .collect();
798
799        assert_eq!(rate_limits, vec![]);
800    }
801
802    #[tokio::test]
803    async fn test_limited_with_unlimited_quota() {
804        let quotas = &[
805            Quota {
806                id: Some("q0".to_string()),
807                categories: DataCategories::new(),
808                scope: QuotaScope::Organization,
809                scope_id: None,
810                limit: None,
811                window: Some(1),
812                reason_code: Some(ReasonCode::new("project_quota0")),
813                namespace: None,
814            },
815            Quota {
816                id: Some("q1".to_string()),
817                categories: DataCategories::new(),
818                scope: QuotaScope::Organization,
819                scope_id: None,
820                limit: Some(1),
821                window: Some(1),
822                reason_code: Some(ReasonCode::new("project_quota1")),
823                namespace: None,
824            },
825        ];
826
827        let scoping = ItemScoping {
828            category: DataCategory::Error,
829            scoping: Scoping {
830                organization_id: OrganizationId::new(42),
831                project_id: ProjectId::new(43),
832                project_key: ProjectKey::parse("a94ae32be2584e0bbd7a4cbb95971fee").unwrap(),
833                key_id: Some(44),
834            },
835            namespace: MetricNamespaceScoping::None,
836        };
837
838        let rate_limiter = build_rate_limiter();
839
840        for i in 0..1 {
841            let rate_limits: Vec<RateLimit> = rate_limiter
842                .is_rate_limited(quotas, scoping, 1, false)
843                .await
844                .expect("rate limiting failed")
845                .into_iter()
846                .collect();
847
848            if i == 0 {
849                assert_eq!(rate_limits, &[]);
850            } else {
851                assert_eq!(
852                    rate_limits,
853                    vec![RateLimit {
854                        categories: DataCategories::new(),
855                        scope: RateLimitScope::Organization(OrganizationId::new(42)),
856                        reason_code: Some(ReasonCode::new("project_quota1")),
857                        retry_after: rate_limits[0].retry_after,
858                        namespaces: smallvec![],
859                    }]
860                );
861            }
862        }
863    }
864
865    #[tokio::test]
866    async fn test_quota_with_quantity() {
867        let quotas = &[Quota {
868            id: Some(format!("test_quantity_quota_{}", uuid::Uuid::new_v4())),
869            categories: DataCategories::new(),
870            scope: QuotaScope::Organization,
871            scope_id: None,
872            limit: Some(500),
873            window: Some(60),
874            reason_code: Some(ReasonCode::new("get_lost")),
875            namespace: None,
876        }];
877
878        let scoping = ItemScoping {
879            category: DataCategory::Error,
880            scoping: Scoping {
881                organization_id: OrganizationId::new(42),
882                project_id: ProjectId::new(43),
883                project_key: ProjectKey::parse("a94ae32be2584e0bbd7a4cbb95971fee").unwrap(),
884                key_id: Some(44),
885            },
886            namespace: MetricNamespaceScoping::None,
887        };
888
889        let rate_limiter = build_rate_limiter();
890
891        for i in 0..10 {
892            let rate_limits: Vec<RateLimit> = rate_limiter
893                .is_rate_limited(quotas, scoping, 100, false)
894                .await
895                .expect("rate limiting failed")
896                .into_iter()
897                .collect();
898
899            if i >= 5 {
900                assert_eq!(
901                    rate_limits,
902                    vec![RateLimit {
903                        categories: DataCategories::new(),
904                        scope: RateLimitScope::Organization(OrganizationId::new(42)),
905                        reason_code: Some(ReasonCode::new("get_lost")),
906                        retry_after: rate_limits[0].retry_after,
907                        namespaces: smallvec![],
908                    }]
909                );
910            } else {
911                assert_eq!(rate_limits, vec![]);
912            }
913        }
914    }
915
916    #[tokio::test]
917    async fn test_get_redis_key_scoped() {
918        let quota = Quota {
919            id: Some("foo".to_owned()),
920            categories: DataCategories::new(),
921            scope: QuotaScope::Project,
922            scope_id: Some("42".to_owned()),
923            window: Some(2),
924            limit: Some(0),
925            reason_code: None,
926            namespace: None,
927        };
928
929        let scoping = ItemScoping {
930            category: DataCategory::Error,
931            scoping: Scoping {
932                organization_id: OrganizationId::new(69420),
933                project_id: ProjectId::new(42),
934                project_key: ProjectKey::parse("a94ae32be2584e0bbd7a4cbb95971fee").unwrap(),
935                key_id: Some(4711),
936            },
937            namespace: MetricNamespaceScoping::None,
938        };
939
940        let timestamp = UnixTimestamp::from_secs(123_123_123);
941        let redis_quota = RedisQuota::new(&quota, scoping, timestamp).unwrap();
942        assert_eq!(redis_quota.key(), "quota:foo{69420}42:61561561");
943    }
944
945    #[tokio::test]
946    async fn test_get_redis_key_unscoped() {
947        let quota = Quota {
948            id: Some("foo".to_owned()),
949            categories: DataCategories::new(),
950            scope: QuotaScope::Organization,
951            scope_id: None,
952            window: Some(10),
953            limit: Some(0),
954            reason_code: None,
955            namespace: None,
956        };
957
958        let scoping = ItemScoping {
959            category: DataCategory::Error,
960            scoping: Scoping {
961                organization_id: OrganizationId::new(69420),
962                project_id: ProjectId::new(42),
963                project_key: ProjectKey::parse("a94ae32be2584e0bbd7a4cbb95971fee").unwrap(),
964                key_id: Some(4711),
965            },
966            namespace: MetricNamespaceScoping::None,
967        };
968
969        let timestamp = UnixTimestamp::from_secs(234_531);
970        let redis_quota = RedisQuota::new(&quota, scoping, timestamp).unwrap();
971        assert_eq!(redis_quota.key(), "quota:foo{69420}:23453");
972    }
973
974    #[tokio::test]
975    async fn test_large_redis_limit_large() {
976        let quota = Quota {
977            id: Some("foo".to_owned()),
978            categories: DataCategories::new(),
979            scope: QuotaScope::Organization,
980            scope_id: None,
981            window: Some(10),
982            limit: Some(9223372036854775808), // i64::MAX + 1
983            reason_code: None,
984            namespace: None,
985        };
986
987        let scoping = ItemScoping {
988            category: DataCategory::Error,
989            scoping: Scoping {
990                organization_id: OrganizationId::new(69420),
991                project_id: ProjectId::new(42),
992                project_key: ProjectKey::parse("a94ae32be2584e0bbd7a4cbb95971fee").unwrap(),
993                key_id: Some(4711),
994            },
995            namespace: MetricNamespaceScoping::None,
996        };
997
998        let timestamp = UnixTimestamp::from_secs(234_531);
999        let redis_quota = RedisQuota::new(&quota, scoping, timestamp).unwrap();
1000        assert_eq!(redis_quota.limit(), -1);
1001    }
1002
1003    #[tokio::test]
1004    #[allow(clippy::disallowed_names, clippy::let_unit_value)]
1005    async fn test_is_rate_limited_script() {
1006        let now = SystemTime::now()
1007            .duration_since(UNIX_EPOCH)
1008            .map(|duration| duration.as_secs())
1009            .unwrap();
1010
1011        let rate_limiter = build_rate_limiter();
1012        let mut conn = rate_limiter.client.get_connection().await.unwrap();
1013
1014        // define a few keys with random seed such that they do not collide with repeated test runs
1015        let foo = format!("foo___{now}");
1016        let r_foo = format!("r:foo___{now}");
1017        let bar = format!("bar___{now}");
1018        let r_bar = format!("r:bar___{now}");
1019        let apple = format!("apple___{now}");
1020        let orange = format!("orange___{now}");
1021        let baz = format!("baz___{now}");
1022
1023        let script = RedisScripts::load_is_rate_limited();
1024
1025        let mut invocation = script.prepare_invoke();
1026        invocation
1027            .key(&foo) // key
1028            .key(&r_foo) // refund key
1029            .key(&bar) // key
1030            .key(&r_bar) // refund key
1031            .arg(1) // limit
1032            .arg(now + 60) // expiry
1033            .arg(1) // quantity
1034            .arg(false) // over accept once
1035            .arg(2) // limit
1036            .arg(now + 120) // expiry
1037            .arg(1) // quantity
1038            .arg(false); // over accept once
1039
1040        // The item should not be rate limited by either key.
1041        assert_eq!(
1042            invocation
1043                .invoke_async::<Vec<bool>>(&mut conn)
1044                .await
1045                .unwrap(),
1046            vec![false, false]
1047        );
1048
1049        // The item should be rate limited by the first key (1).
1050        assert_eq!(
1051            invocation
1052                .invoke_async::<Vec<bool>>(&mut conn)
1053                .await
1054                .unwrap(),
1055            vec![true, false]
1056        );
1057
1058        // The item should still be rate limited by the first key (1), but *not*
1059        // rate limited by the second key (2) even though this is the third time
1060        // we've checked the quotas. This ensures items that are rejected by a lower
1061        // quota don't affect unrelated items that share a parent quota.
1062        assert_eq!(
1063            invocation
1064                .invoke_async::<Vec<bool>>(&mut conn)
1065                .await
1066                .unwrap(),
1067            vec![true, false]
1068        );
1069
1070        assert_eq!(conn.get::<_, String>(&foo).await.unwrap(), "1");
1071        let ttl: u64 = conn.ttl(&foo).await.unwrap();
1072        assert!(ttl >= 59);
1073        assert!(ttl <= 60);
1074
1075        assert_eq!(conn.get::<_, String>(&bar).await.unwrap(), "1");
1076        let ttl: u64 = conn.ttl(&bar).await.unwrap();
1077        assert!(ttl >= 119);
1078        assert!(ttl <= 120);
1079
1080        // make sure "refund/negative" keys haven't been incremented
1081        let () = conn.get(r_foo).await.unwrap();
1082        let () = conn.get(r_bar).await.unwrap();
1083
1084        // Test that refunded quotas work
1085        let () = conn.set(&apple, 5).await.unwrap();
1086
1087        let mut invocation = script.prepare_invoke();
1088        invocation
1089            .key(&orange) // key
1090            .key(&baz) // refund key
1091            .arg(1) // limit
1092            .arg(now + 60) // expiry
1093            .arg(1) // quantity
1094            .arg(false);
1095
1096        // increment
1097        assert_eq!(
1098            invocation
1099                .invoke_async::<Vec<bool>>(&mut conn)
1100                .await
1101                .unwrap(),
1102            vec![false]
1103        );
1104
1105        // test that it's rate limited without refund
1106        assert_eq!(
1107            invocation
1108                .invoke_async::<Vec<bool>>(&mut conn)
1109                .await
1110                .unwrap(),
1111            vec![true]
1112        );
1113
1114        let mut invocation = script.prepare_invoke();
1115        invocation
1116            .key(&orange) // key
1117            .key(&apple) // refund key
1118            .arg(1) // limit
1119            .arg(now + 60) // expiry
1120            .arg(1) // quantity
1121            .arg(false);
1122
1123        // test that refund key is used
1124        assert_eq!(
1125            invocation
1126                .invoke_async::<Vec<bool>>(&mut conn)
1127                .await
1128                .unwrap(),
1129            vec![false]
1130        );
1131    }
1132}