relay_quotas/
redis.rs

1use std::fmt::{self, Debug};
2
3use relay_common::time::UnixTimestamp;
4use relay_log::protocol::value;
5use relay_redis::redis::Script;
6use relay_redis::{AsyncRedisClient, RedisError, RedisScripts};
7use thiserror::Error;
8
9use crate::REJECT_ALL_SECS;
10use crate::global::GlobalLimiter;
11use crate::quota::{ItemScoping, Quota, QuotaScope};
12use crate::rate_limit::{RateLimit, RateLimits, RetryAfter};
13
14/// The `grace` period allows accommodating for clock drift in TTL
15/// calculation since the clock on the Redis instance used to store quota
16/// metrics may not be in sync with the computer running this code.
17const GRACE: u64 = 60;
18
19/// An error returned by [`RedisRateLimiter`].
20#[derive(Debug, Error)]
21pub enum RateLimitingError {
22    /// Failed to communicate with Redis.
23    #[error("failed to communicate with redis")]
24    Redis(
25        #[from]
26        #[source]
27        RedisError,
28    ),
29
30    /// Failed to check global rate limits via the service.
31    #[error("failed to check global rate limits")]
32    UnreachableGlobalRateLimits,
33}
34
35/// Creates a refund key for a given counter key.
36///
37/// Refund keys are used to track credits that should be applied to a quota,
38/// allowing for more flexible quota management.
39fn get_refunded_quota_key(counter_key: &str) -> String {
40    format!("r:{counter_key}")
41}
42
43/// A transparent wrapper around an Option that only displays `Some`.
44struct OptionalDisplay<T>(Option<T>);
45
46impl<T> fmt::Display for OptionalDisplay<T>
47where
48    T: fmt::Display,
49{
50    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
51        match self.0 {
52            Some(ref value) => write!(f, "{value}"),
53            None => Ok(()),
54        }
55    }
56}
57
58/// Owned version of [`RedisQuota`].
59#[derive(Debug, Clone)]
60pub struct OwnedRedisQuota {
61    /// The original quota.
62    quota: Quota,
63    /// Scopes of the item being tracked.
64    scoping: ItemScoping,
65    /// The Redis key prefix mapped from the quota id.
66    prefix: String,
67    /// The redis window in seconds mapped from the quota.
68    window: u64,
69    /// The ingestion timestamp determining the rate limiting bucket.
70    timestamp: UnixTimestamp,
71}
72
73impl OwnedRedisQuota {
74    /// Returns an instance of [`RedisQuota`] which borrows from this [`OwnedRedisQuota`].
75    pub fn build_ref(&self) -> RedisQuota {
76        RedisQuota {
77            quota: &self.quota,
78            scoping: self.scoping,
79            prefix: &self.prefix,
80            window: self.window,
81            timestamp: self.timestamp,
82        }
83    }
84}
85
86/// Reference to information required for tracking quotas in Redis.
87#[derive(Debug, Clone, Eq, PartialEq)]
88pub struct RedisQuota<'a> {
89    /// The original quota.
90    quota: &'a Quota,
91    /// Scopes of the item being tracked.
92    scoping: ItemScoping,
93    /// The Redis key prefix mapped from the quota id.
94    prefix: &'a str,
95    /// The redis window in seconds mapped from the quota.
96    window: u64,
97    /// The ingestion timestamp determining the rate limiting bucket.
98    timestamp: UnixTimestamp,
99}
100
101impl<'a> RedisQuota<'a> {
102    /// Creates a new [`RedisQuota`] from a [`Quota`], item scoping, and timestamp.
103    ///
104    /// Returns `None` if the quota cannot be tracked in Redis because it's missing
105    /// required fields (ID or window). This allows forward compatibility with
106    /// future quota types.
107    pub fn new(quota: &'a Quota, scoping: ItemScoping, timestamp: UnixTimestamp) -> Option<Self> {
108        // These fields indicate that we *can* track this quota.
109        let prefix = quota.id.as_deref()?;
110        let window = quota.window?;
111
112        Some(Self {
113            quota,
114            scoping,
115            prefix,
116            window,
117            timestamp,
118        })
119    }
120
121    /// Converts this [`RedisQuota`] to an [`OwnedRedisQuota`] leaving the original
122    /// struct in place.
123    pub fn build_owned(&self) -> OwnedRedisQuota {
124        OwnedRedisQuota {
125            quota: self.quota.clone(),
126            scoping: self.scoping,
127            prefix: self.prefix.to_owned(),
128            window: self.window,
129            timestamp: self.timestamp,
130        }
131    }
132
133    /// Returns the window size of the quota in seconds.
134    pub fn window(&self) -> u64 {
135        self.window
136    }
137
138    /// Returns the prefix of the quota used for Redis key generation.
139    pub fn prefix(&self) -> &'a str {
140        self.prefix
141    }
142
143    /// Returns the limit value formatted for Redis.
144    ///
145    /// Returns `-1` for unlimited quotas or when the limit doesn't fit into an `i64`.
146    /// Otherwise, returns the limit value as an `i64`.
147    pub fn limit(&self) -> i64 {
148        self.limit
149            // If it does not fit into i64, treat as unlimited:
150            .and_then(|limit| limit.try_into().ok())
151            .unwrap_or(-1)
152    }
153
154    fn shift(&self) -> u64 {
155        if self.quota.scope == QuotaScope::Global {
156            0
157        } else {
158            self.scoping.organization_id.value() % self.window
159        }
160    }
161
162    /// Returns the current time slot of the quota based on the timestamp.
163    ///
164    /// Slots are used to determine the time bucket for rate limiting.
165    pub fn slot(&self) -> u64 {
166        (self.timestamp.as_secs() - self.shift()) / self.window
167    }
168
169    /// Returns the timestamp when the current quota window will expire.
170    pub fn expiry(&self) -> UnixTimestamp {
171        let next_slot = self.slot() + 1;
172        let next_start = next_slot * self.window + self.shift();
173        UnixTimestamp::from_secs(next_start)
174    }
175
176    /// Returns when the Redis key should expire.
177    ///
178    /// This is the expiry time plus a grace period.
179    pub fn key_expiry(&self) -> u64 {
180        self.expiry().as_secs() + GRACE
181    }
182
183    /// Returns the Redis key for this quota.
184    ///
185    /// The key includes the quota ID, organization ID, and other scoping information
186    /// based on the quota's scope type. Keys are structured to ensure proper isolation
187    /// between different organizations and scopes.
188    pub fn key(&self) -> String {
189        // The subscope id is only formatted into the key if the quota is not organization-scoped.
190        // The organization id is always included.
191        let subscope = match self.quota.scope {
192            QuotaScope::Global => None,
193            QuotaScope::Organization => None,
194            scope => self.scoping.scope_id(scope),
195        };
196
197        let org = self.scoping.organization_id;
198
199        format!(
200            "quota:{id}{{{org}}}{subscope}{namespace}:{slot}",
201            id = self.prefix,
202            org = org,
203            subscope = OptionalDisplay(subscope),
204            namespace = OptionalDisplay(self.namespace),
205            slot = self.slot(),
206        )
207    }
208}
209
210impl std::ops::Deref for RedisQuota<'_> {
211    type Target = Quota;
212
213    fn deref(&self) -> &Self::Target {
214        self.quota
215    }
216}
217
218/// A service that executes quotas and checks for rate limits in a shared cache.
219///
220/// Quotas handle tracking a project's usage and respond whether a project has been
221/// configured to throttle incoming data if they go beyond the specified quota.
222///
223/// Quotas can specify a window to be tracked in, such as per minute or per hour. Additionally,
224/// quotas allow to specify the data categories they apply to, for example error events or
225/// attachments. For more information on quota parameters, see [`Quota`].
226///
227/// Requires the `redis` feature.
228#[derive(Clone)]
229pub struct RedisRateLimiter<T> {
230    client: AsyncRedisClient,
231    script: &'static Script,
232    max_limit: Option<u64>,
233    global_limiter: T,
234}
235
236impl<T: GlobalLimiter> RedisRateLimiter<T> {
237    /// Creates a new [`RedisRateLimiter`] instance.
238    pub fn new(client: AsyncRedisClient, global_limiter: T) -> Self {
239        RedisRateLimiter {
240            client,
241            script: RedisScripts::load_is_rate_limited(),
242            max_limit: None,
243            global_limiter,
244        }
245    }
246
247    /// Sets the maximum rate limit in seconds.
248    ///
249    /// By default, this rate limiter will return rate limits based on the quotas' `window` fields.
250    /// If a maximum rate limit is set, the returned rate limit will be bounded by this value.
251    pub fn max_limit(mut self, max_limit: Option<u64>) -> Self {
252        self.max_limit = max_limit;
253        self
254    }
255
256    /// Checks whether any of the quotas in effect have been exceeded and records consumption.
257    ///
258    /// By invoking this method, the caller signals that data is being ingested and needs to be
259    /// counted against the quota. This increment happens atomically if none of the quotas have been
260    /// exceeded. Otherwise, a rate limit is returned and data is not counted against the quotas.
261    ///
262    /// If no key is specified, then only organization-wide and project-wide quotas are checked. If
263    /// a key is specified, then key-quotas are also checked.
264    ///
265    /// When `over_accept_once` is set to `true` and the current quota would be exceeded by the
266    /// provided `quantity`, the data is accepted once and subsequent requests will be rejected
267    /// until the quota refreshes.
268    ///
269    /// A `quantity` of `0` can be used to check if the quota limit has been reached or exceeded
270    /// without incrementing it in the success case. This is useful for checking quotas in a different
271    /// data category.
272    pub async fn is_rate_limited<'a>(
273        &self,
274        quotas: impl IntoIterator<Item = &'a Quota>,
275        item_scoping: ItemScoping,
276        quantity: usize,
277        over_accept_once: bool,
278    ) -> Result<RateLimits, RateLimitingError> {
279        let timestamp = UnixTimestamp::now();
280        let mut invocation = self.script.prepare_invoke();
281        let mut tracked_quotas = Vec::new();
282        let mut rate_limits = RateLimits::new();
283
284        let mut global_quotas = vec![];
285
286        for quota in quotas {
287            if !quota.matches(item_scoping) {
288                // Silently skip all quotas that do not apply to this item.
289            } else if quota.limit == Some(0) {
290                // A zero-sized quota is strongest. Do not call into Redis at all, and do not
291                // increment any keys, as one quota has reached capacity (this is how regular quotas
292                // behave as well).
293                let retry_after = self.retry_after(REJECT_ALL_SECS);
294                rate_limits.add(RateLimit::from_quota(quota, *item_scoping, retry_after));
295            } else if let Some(quota) = RedisQuota::new(quota, item_scoping, timestamp) {
296                if quota.scope == QuotaScope::Global {
297                    global_quotas.push(quota);
298                } else {
299                    let key = quota.key();
300                    // Remaining quotas are expected to be trackable in Redis.
301                    let refund_key = get_refunded_quota_key(&key);
302
303                    invocation.key(key);
304                    invocation.key(refund_key);
305
306                    invocation.arg(quota.limit());
307                    invocation.arg(quota.key_expiry());
308                    invocation.arg(quantity);
309                    invocation.arg(over_accept_once);
310
311                    tracked_quotas.push(quota);
312                }
313            } else {
314                // This quota is neither a static reject-all, nor can it be tracked in Redis due to
315                // missing fields. We're skipping this for forward-compatibility.
316                relay_log::with_scope(
317                    |scope| scope.set_extra("quota", value::to_value(quota).unwrap()),
318                    || relay_log::warn!("skipping unsupported quota"),
319                )
320            }
321        }
322
323        // We check the global rate limits before the other limits. This step must be separate from
324        // checking the other rate limits, since those are checked with a Redis script that works
325        // under the invariant that all keys are within the same Redis instance (given their partitioning).
326        // Global keys on the other hand are always on the same instance, so if they were to be mixed
327        // with normal keys the script will end up referencing keys from multiple instances, making it
328        // impossible for the script to work.
329        let rate_limited_global_quotas = self
330            .global_limiter
331            .check_global_rate_limits(&global_quotas, quantity)
332            .await?;
333
334        for quota in rate_limited_global_quotas {
335            let retry_after = self.retry_after((quota.expiry() - timestamp).as_secs());
336            rate_limits.add(RateLimit::from_quota(quota, *item_scoping, retry_after));
337        }
338
339        // Either there are no quotas to run against Redis, or we already have a rate limit from a
340        // zero-sized quota. In either cases, skip invoking the script and return early.
341        if tracked_quotas.is_empty() || rate_limits.is_limited() {
342            return Ok(rate_limits);
343        }
344
345        // We get the redis client after the global rate limiting since we don't want to hold the
346        // client across await points, otherwise it might be held for too long, and we will run out
347        // of connections.
348        let mut connection = self.client.get_connection().await?;
349        let rejections: Vec<bool> = invocation
350            .invoke_async(&mut connection)
351            .await
352            .map_err(RedisError::Redis)?;
353
354        for (quota, is_rejected) in tracked_quotas.iter().zip(rejections) {
355            if is_rejected {
356                let retry_after = self.retry_after((quota.expiry() - timestamp).as_secs());
357                rate_limits.add(RateLimit::from_quota(quota, *item_scoping, retry_after));
358            }
359        }
360
361        Ok(rate_limits)
362    }
363
364    /// Creates a [`RetryAfter`] value that is bounded by the configured [`max_limit`](Self::max_limit).
365    ///
366    /// If a maximum rate limit has been set, the returned value will not exceed that limit.
367    fn retry_after(&self, mut seconds: u64) -> RetryAfter {
368        if let Some(max_limit) = self.max_limit {
369            seconds = std::cmp::min(seconds, max_limit);
370        }
371
372        RetryAfter::from_secs(seconds)
373    }
374}
375
376#[cfg(test)]
377mod tests {
378    use std::time::{SystemTime, UNIX_EPOCH};
379
380    use super::*;
381    use crate::quota::{DataCategories, DataCategory, ReasonCode, Scoping};
382    use crate::rate_limit::RateLimitScope;
383    use crate::{GlobalRateLimiter, MetricNamespaceScoping};
384    use relay_base_schema::metrics::MetricNamespace;
385    use relay_base_schema::organization::OrganizationId;
386    use relay_base_schema::project::{ProjectId, ProjectKey};
387    use relay_redis::RedisConfigOptions;
388    use relay_redis::redis::AsyncCommands;
389    use smallvec::smallvec;
390    use tokio::sync::Mutex;
391
392    struct MockGlobalLimiter {
393        client: AsyncRedisClient,
394        global_rate_limiter: Mutex<GlobalRateLimiter>,
395    }
396
397    impl GlobalLimiter for MockGlobalLimiter {
398        async fn check_global_rate_limits<'a>(
399            &self,
400            global_quotas: &'a [RedisQuota<'a>],
401            quantity: usize,
402        ) -> Result<Vec<&'a RedisQuota<'a>>, RateLimitingError> {
403            self.global_rate_limiter
404                .lock()
405                .await
406                .filter_rate_limited(&self.client, global_quotas, quantity)
407                .await
408        }
409    }
410
411    fn build_rate_limiter() -> RedisRateLimiter<MockGlobalLimiter> {
412        let url = std::env::var("RELAY_REDIS_URL")
413            .unwrap_or_else(|_| "redis://127.0.0.1:6379".to_owned());
414        let client = AsyncRedisClient::single(&url, &RedisConfigOptions::default()).unwrap();
415
416        let global_limiter = MockGlobalLimiter {
417            client: client.clone(),
418            global_rate_limiter: Mutex::new(GlobalRateLimiter::default()),
419        };
420
421        RedisRateLimiter {
422            client,
423            script: RedisScripts::load_is_rate_limited(),
424            max_limit: None,
425            global_limiter,
426        }
427    }
428
429    #[tokio::test]
430    async fn test_zero_size_quotas() {
431        let quotas = &[
432            Quota {
433                id: None,
434                categories: DataCategories::new(),
435                scope: QuotaScope::Organization,
436                scope_id: None,
437                limit: Some(0),
438                window: None,
439                reason_code: Some(ReasonCode::new("get_lost")),
440                namespace: None,
441            },
442            Quota {
443                id: Some("42".to_owned()),
444                categories: DataCategories::new(),
445                scope: QuotaScope::Organization,
446                scope_id: None,
447                limit: None,
448                window: Some(42),
449                reason_code: Some(ReasonCode::new("unlimited")),
450                namespace: None,
451            },
452        ];
453
454        let scoping = ItemScoping {
455            category: DataCategory::Error,
456            scoping: Scoping {
457                organization_id: OrganizationId::new(42),
458                project_id: ProjectId::new(43),
459                project_key: ProjectKey::parse("a94ae32be2584e0bbd7a4cbb95971fee").unwrap(),
460                key_id: Some(44),
461            },
462            namespace: MetricNamespaceScoping::None,
463        };
464
465        let rate_limits: Vec<RateLimit> = build_rate_limiter()
466            .is_rate_limited(quotas, scoping, 1, false)
467            .await
468            .expect("rate limiting failed")
469            .into_iter()
470            .collect();
471
472        assert_eq!(
473            rate_limits,
474            vec![RateLimit {
475                categories: DataCategories::new(),
476                scope: RateLimitScope::Organization(OrganizationId::new(42)),
477                reason_code: Some(ReasonCode::new("get_lost")),
478                retry_after: rate_limits[0].retry_after,
479                namespaces: smallvec![],
480            }]
481        );
482    }
483
484    /// Tests that a quota with and without namespace are counted separately.
485    #[tokio::test]
486    async fn test_non_global_namespace_quota() {
487        let quota_limit = 5;
488        let get_quota = |namespace: Option<MetricNamespace>| -> Quota {
489            Quota {
490                id: Some(format!("test_simple_quota_{}", uuid::Uuid::new_v4())),
491                categories: DataCategories::new(),
492                scope: QuotaScope::Organization,
493                scope_id: None,
494                limit: Some(quota_limit),
495                window: Some(600),
496                reason_code: Some(ReasonCode::new(format!("ns: {namespace:?}"))),
497                namespace,
498            }
499        };
500
501        let quotas = &[get_quota(None)];
502        let quota_with_namespace = &[get_quota(Some(MetricNamespace::Transactions))];
503
504        let scoping = ItemScoping {
505            category: DataCategory::Error,
506            scoping: Scoping {
507                organization_id: OrganizationId::new(42),
508                project_id: ProjectId::new(43),
509                project_key: ProjectKey::parse("a94ae32be2584e0bbd7a4cbb95971fee").unwrap(),
510                key_id: Some(44),
511            },
512            namespace: MetricNamespaceScoping::Some(MetricNamespace::Transactions),
513        };
514
515        let rate_limiter = build_rate_limiter();
516
517        // First confirm normal behaviour without namespace.
518        for i in 0..10 {
519            let rate_limits: Vec<RateLimit> = rate_limiter
520                .is_rate_limited(quotas, scoping, 1, false)
521                .await
522                .expect("rate limiting failed")
523                .into_iter()
524                .collect();
525
526            if i < quota_limit {
527                assert_eq!(rate_limits, vec![]);
528            } else {
529                assert_eq!(
530                    rate_limits[0].reason_code,
531                    Some(ReasonCode::new("ns: None"))
532                );
533            }
534        }
535
536        // Then, send identical quota with namespace and confirm it counts separately.
537        for i in 0..10 {
538            let rate_limits: Vec<RateLimit> = rate_limiter
539                .is_rate_limited(quota_with_namespace, scoping, 1, false)
540                .await
541                .expect("rate limiting failed")
542                .into_iter()
543                .collect();
544
545            if i < quota_limit {
546                assert_eq!(rate_limits, vec![]);
547            } else {
548                assert_eq!(
549                    rate_limits[0].reason_code,
550                    Some(ReasonCode::new("ns: Some(Transactions)"))
551                );
552            }
553        }
554    }
555
556    #[tokio::test]
557    async fn test_simple_quota() {
558        let quotas = &[Quota {
559            id: Some(format!("test_simple_quota_{}", uuid::Uuid::new_v4())),
560            categories: DataCategories::new(),
561            scope: QuotaScope::Organization,
562            scope_id: None,
563            limit: Some(5),
564            window: Some(60),
565            reason_code: Some(ReasonCode::new("get_lost")),
566            namespace: None,
567        }];
568
569        let scoping = ItemScoping {
570            category: DataCategory::Error,
571            scoping: Scoping {
572                organization_id: OrganizationId::new(42),
573                project_id: ProjectId::new(43),
574                project_key: ProjectKey::parse("a94ae32be2584e0bbd7a4cbb95971fee").unwrap(),
575                key_id: Some(44),
576            },
577            namespace: MetricNamespaceScoping::None,
578        };
579
580        let rate_limiter = build_rate_limiter();
581
582        for i in 0..10 {
583            let rate_limits: Vec<RateLimit> = rate_limiter
584                .is_rate_limited(quotas, scoping, 1, false)
585                .await
586                .expect("rate limiting failed")
587                .into_iter()
588                .collect();
589
590            if i >= 5 {
591                assert_eq!(
592                    rate_limits,
593                    vec![RateLimit {
594                        categories: DataCategories::new(),
595                        scope: RateLimitScope::Organization(OrganizationId::new(42)),
596                        reason_code: Some(ReasonCode::new("get_lost")),
597                        retry_after: rate_limits[0].retry_after,
598                        namespaces: smallvec![],
599                    }]
600                );
601            } else {
602                assert_eq!(rate_limits, vec![]);
603            }
604        }
605    }
606
607    #[tokio::test]
608    async fn test_simple_global_quota() {
609        let quotas = &[Quota {
610            id: Some(format!("test_simple_global_quota_{}", uuid::Uuid::new_v4())),
611            categories: DataCategories::new(),
612            scope: QuotaScope::Global,
613            scope_id: None,
614            limit: Some(5),
615            window: Some(60),
616            reason_code: Some(ReasonCode::new("get_lost")),
617            namespace: None,
618        }];
619
620        let scoping = ItemScoping {
621            category: DataCategory::Error,
622            scoping: Scoping {
623                organization_id: OrganizationId::new(42),
624                project_id: ProjectId::new(43),
625                project_key: ProjectKey::parse("a94ae32be2584e0bbd7a4cbb95971fee").unwrap(),
626                key_id: Some(44),
627            },
628            namespace: MetricNamespaceScoping::None,
629        };
630
631        let rate_limiter = build_rate_limiter();
632
633        for i in 0..10 {
634            let rate_limits: Vec<RateLimit> = rate_limiter
635                .is_rate_limited(quotas, scoping, 1, false)
636                .await
637                .expect("rate limiting failed")
638                .into_iter()
639                .collect();
640
641            if i >= 5 {
642                assert_eq!(
643                    rate_limits,
644                    vec![RateLimit {
645                        categories: DataCategories::new(),
646                        scope: RateLimitScope::Global,
647                        reason_code: Some(ReasonCode::new("get_lost")),
648                        retry_after: rate_limits[0].retry_after,
649                        namespaces: smallvec![],
650                    }]
651                );
652            } else {
653                assert_eq!(rate_limits, vec![]);
654            }
655        }
656    }
657
658    #[tokio::test]
659    async fn test_quantity_0() {
660        let quotas = &[Quota {
661            id: Some(format!("test_quantity_0_{}", uuid::Uuid::new_v4())),
662            categories: DataCategories::new(),
663            scope: QuotaScope::Organization,
664            scope_id: None,
665            limit: Some(1),
666            window: Some(60),
667            reason_code: Some(ReasonCode::new("get_lost")),
668            namespace: None,
669        }];
670
671        let scoping = ItemScoping {
672            category: DataCategory::Error,
673            scoping: Scoping {
674                organization_id: OrganizationId::new(42),
675                project_id: ProjectId::new(43),
676                project_key: ProjectKey::parse("a94ae32be2584e0bbd7a4cbb95971fee").unwrap(),
677                key_id: Some(44),
678            },
679            namespace: MetricNamespaceScoping::None,
680        };
681
682        let rate_limiter = build_rate_limiter();
683
684        // limit is 1, so first call not rate limited
685        assert!(
686            !rate_limiter
687                .is_rate_limited(quotas, scoping, 1, false)
688                .await
689                .unwrap()
690                .is_limited()
691        );
692
693        // quota is now exhausted
694        assert!(
695            rate_limiter
696                .is_rate_limited(quotas, scoping, 1, false)
697                .await
698                .unwrap()
699                .is_limited()
700        );
701
702        // quota is exhausted, regardless of the quantity
703        assert!(
704            rate_limiter
705                .is_rate_limited(quotas, scoping, 0, false)
706                .await
707                .unwrap()
708                .is_limited()
709        );
710
711        // quota is exhausted, regardless of the quantity
712        assert!(
713            rate_limiter
714                .is_rate_limited(quotas, scoping, 1, false)
715                .await
716                .unwrap()
717                .is_limited()
718        );
719    }
720
721    #[tokio::test]
722    async fn test_quota_go_over() {
723        let quotas = &[Quota {
724            id: Some(format!("test_quota_go_over{}", uuid::Uuid::new_v4())),
725            categories: DataCategories::new(),
726            scope: QuotaScope::Organization,
727            scope_id: None,
728            limit: Some(2),
729            window: Some(60),
730            reason_code: Some(ReasonCode::new("get_lost")),
731            namespace: None,
732        }];
733
734        let scoping = ItemScoping {
735            category: DataCategory::Error,
736            scoping: Scoping {
737                organization_id: OrganizationId::new(42),
738                project_id: ProjectId::new(43),
739                project_key: ProjectKey::parse("a94ae32be2584e0bbd7a4cbb95971fee").unwrap(),
740                key_id: Some(44),
741            },
742            namespace: MetricNamespaceScoping::None,
743        };
744
745        let rate_limiter = build_rate_limiter();
746
747        // limit is 2, so first call not rate limited
748        let is_limited = rate_limiter
749            .is_rate_limited(quotas, scoping, 1, true)
750            .await
751            .unwrap()
752            .is_limited();
753        assert!(!is_limited);
754
755        // go over limit, but first call is over-accepted
756        let is_limited = rate_limiter
757            .is_rate_limited(quotas, scoping, 2, true)
758            .await
759            .unwrap()
760            .is_limited();
761        assert!(!is_limited);
762
763        // quota is exhausted, regardless of the quantity
764        let is_limited = rate_limiter
765            .is_rate_limited(quotas, scoping, 0, true)
766            .await
767            .unwrap()
768            .is_limited();
769        assert!(is_limited);
770
771        // quota is exhausted, regardless of the quantity
772        let is_limited = rate_limiter
773            .is_rate_limited(quotas, scoping, 1, true)
774            .await
775            .unwrap()
776            .is_limited();
777        assert!(is_limited);
778    }
779
780    #[tokio::test]
781    async fn test_bails_immediately_without_any_quota() {
782        let scoping = ItemScoping {
783            category: DataCategory::Error,
784            scoping: Scoping {
785                organization_id: OrganizationId::new(42),
786                project_id: ProjectId::new(43),
787                project_key: ProjectKey::parse("a94ae32be2584e0bbd7a4cbb95971fee").unwrap(),
788                key_id: Some(44),
789            },
790            namespace: MetricNamespaceScoping::None,
791        };
792
793        let rate_limits: Vec<RateLimit> = build_rate_limiter()
794            .is_rate_limited(&[], scoping, 1, false)
795            .await
796            .expect("rate limiting failed")
797            .into_iter()
798            .collect();
799
800        assert_eq!(rate_limits, vec![]);
801    }
802
803    #[tokio::test]
804    async fn test_limited_with_unlimited_quota() {
805        let quotas = &[
806            Quota {
807                id: Some("q0".to_owned()),
808                categories: DataCategories::new(),
809                scope: QuotaScope::Organization,
810                scope_id: None,
811                limit: None,
812                window: Some(1),
813                reason_code: Some(ReasonCode::new("project_quota0")),
814                namespace: None,
815            },
816            Quota {
817                id: Some("q1".to_owned()),
818                categories: DataCategories::new(),
819                scope: QuotaScope::Organization,
820                scope_id: None,
821                limit: Some(1),
822                window: Some(1),
823                reason_code: Some(ReasonCode::new("project_quota1")),
824                namespace: None,
825            },
826        ];
827
828        let scoping = ItemScoping {
829            category: DataCategory::Error,
830            scoping: Scoping {
831                organization_id: OrganizationId::new(42),
832                project_id: ProjectId::new(43),
833                project_key: ProjectKey::parse("a94ae32be2584e0bbd7a4cbb95971fee").unwrap(),
834                key_id: Some(44),
835            },
836            namespace: MetricNamespaceScoping::None,
837        };
838
839        let rate_limiter = build_rate_limiter();
840
841        for i in 0..1 {
842            let rate_limits: Vec<RateLimit> = rate_limiter
843                .is_rate_limited(quotas, scoping, 1, false)
844                .await
845                .expect("rate limiting failed")
846                .into_iter()
847                .collect();
848
849            if i == 0 {
850                assert_eq!(rate_limits, &[]);
851            } else {
852                assert_eq!(
853                    rate_limits,
854                    vec![RateLimit {
855                        categories: DataCategories::new(),
856                        scope: RateLimitScope::Organization(OrganizationId::new(42)),
857                        reason_code: Some(ReasonCode::new("project_quota1")),
858                        retry_after: rate_limits[0].retry_after,
859                        namespaces: smallvec![],
860                    }]
861                );
862            }
863        }
864    }
865
866    #[tokio::test]
867    async fn test_quota_with_quantity() {
868        let quotas = &[Quota {
869            id: Some(format!("test_quantity_quota_{}", uuid::Uuid::new_v4())),
870            categories: DataCategories::new(),
871            scope: QuotaScope::Organization,
872            scope_id: None,
873            limit: Some(500),
874            window: Some(60),
875            reason_code: Some(ReasonCode::new("get_lost")),
876            namespace: None,
877        }];
878
879        let scoping = ItemScoping {
880            category: DataCategory::Error,
881            scoping: Scoping {
882                organization_id: OrganizationId::new(42),
883                project_id: ProjectId::new(43),
884                project_key: ProjectKey::parse("a94ae32be2584e0bbd7a4cbb95971fee").unwrap(),
885                key_id: Some(44),
886            },
887            namespace: MetricNamespaceScoping::None,
888        };
889
890        let rate_limiter = build_rate_limiter();
891
892        for i in 0..10 {
893            let rate_limits: Vec<RateLimit> = rate_limiter
894                .is_rate_limited(quotas, scoping, 100, false)
895                .await
896                .expect("rate limiting failed")
897                .into_iter()
898                .collect();
899
900            if i >= 5 {
901                assert_eq!(
902                    rate_limits,
903                    vec![RateLimit {
904                        categories: DataCategories::new(),
905                        scope: RateLimitScope::Organization(OrganizationId::new(42)),
906                        reason_code: Some(ReasonCode::new("get_lost")),
907                        retry_after: rate_limits[0].retry_after,
908                        namespaces: smallvec![],
909                    }]
910                );
911            } else {
912                assert_eq!(rate_limits, vec![]);
913            }
914        }
915    }
916
917    #[tokio::test]
918    async fn test_get_redis_key_scoped() {
919        let quota = Quota {
920            id: Some("foo".to_owned()),
921            categories: DataCategories::new(),
922            scope: QuotaScope::Project,
923            scope_id: Some("42".to_owned()),
924            window: Some(2),
925            limit: Some(0),
926            reason_code: None,
927            namespace: None,
928        };
929
930        let scoping = ItemScoping {
931            category: DataCategory::Error,
932            scoping: Scoping {
933                organization_id: OrganizationId::new(69420),
934                project_id: ProjectId::new(42),
935                project_key: ProjectKey::parse("a94ae32be2584e0bbd7a4cbb95971fee").unwrap(),
936                key_id: Some(4711),
937            },
938            namespace: MetricNamespaceScoping::None,
939        };
940
941        let timestamp = UnixTimestamp::from_secs(123_123_123);
942        let redis_quota = RedisQuota::new(&quota, scoping, timestamp).unwrap();
943        assert_eq!(redis_quota.key(), "quota:foo{69420}42:61561561");
944    }
945
946    #[tokio::test]
947    async fn test_get_redis_key_unscoped() {
948        let quota = Quota {
949            id: Some("foo".to_owned()),
950            categories: DataCategories::new(),
951            scope: QuotaScope::Organization,
952            scope_id: None,
953            window: Some(10),
954            limit: Some(0),
955            reason_code: None,
956            namespace: None,
957        };
958
959        let scoping = ItemScoping {
960            category: DataCategory::Error,
961            scoping: Scoping {
962                organization_id: OrganizationId::new(69420),
963                project_id: ProjectId::new(42),
964                project_key: ProjectKey::parse("a94ae32be2584e0bbd7a4cbb95971fee").unwrap(),
965                key_id: Some(4711),
966            },
967            namespace: MetricNamespaceScoping::None,
968        };
969
970        let timestamp = UnixTimestamp::from_secs(234_531);
971        let redis_quota = RedisQuota::new(&quota, scoping, timestamp).unwrap();
972        assert_eq!(redis_quota.key(), "quota:foo{69420}:23453");
973    }
974
975    #[tokio::test]
976    async fn test_large_redis_limit_large() {
977        let quota = Quota {
978            id: Some("foo".to_owned()),
979            categories: DataCategories::new(),
980            scope: QuotaScope::Organization,
981            scope_id: None,
982            window: Some(10),
983            limit: Some(9223372036854775808), // i64::MAX + 1
984            reason_code: None,
985            namespace: None,
986        };
987
988        let scoping = ItemScoping {
989            category: DataCategory::Error,
990            scoping: Scoping {
991                organization_id: OrganizationId::new(69420),
992                project_id: ProjectId::new(42),
993                project_key: ProjectKey::parse("a94ae32be2584e0bbd7a4cbb95971fee").unwrap(),
994                key_id: Some(4711),
995            },
996            namespace: MetricNamespaceScoping::None,
997        };
998
999        let timestamp = UnixTimestamp::from_secs(234_531);
1000        let redis_quota = RedisQuota::new(&quota, scoping, timestamp).unwrap();
1001        assert_eq!(redis_quota.limit(), -1);
1002    }
1003
1004    #[tokio::test]
1005    #[allow(clippy::disallowed_names, clippy::let_unit_value)]
1006    async fn test_is_rate_limited_script() {
1007        let now = SystemTime::now()
1008            .duration_since(UNIX_EPOCH)
1009            .map(|duration| duration.as_secs())
1010            .unwrap();
1011
1012        let rate_limiter = build_rate_limiter();
1013        let mut conn = rate_limiter.client.get_connection().await.unwrap();
1014
1015        // define a few keys with random seed such that they do not collide with repeated test runs
1016        let foo = format!("foo___{now}");
1017        let r_foo = format!("r:foo___{now}");
1018        let bar = format!("bar___{now}");
1019        let r_bar = format!("r:bar___{now}");
1020        let apple = format!("apple___{now}");
1021        let orange = format!("orange___{now}");
1022        let baz = format!("baz___{now}");
1023
1024        let script = RedisScripts::load_is_rate_limited();
1025
1026        let mut invocation = script.prepare_invoke();
1027        invocation
1028            .key(&foo) // key
1029            .key(&r_foo) // refund key
1030            .key(&bar) // key
1031            .key(&r_bar) // refund key
1032            .arg(1) // limit
1033            .arg(now + 60) // expiry
1034            .arg(1) // quantity
1035            .arg(false) // over accept once
1036            .arg(2) // limit
1037            .arg(now + 120) // expiry
1038            .arg(1) // quantity
1039            .arg(false); // over accept once
1040
1041        // The item should not be rate limited by either key.
1042        assert_eq!(
1043            invocation
1044                .invoke_async::<Vec<bool>>(&mut conn)
1045                .await
1046                .unwrap(),
1047            vec![false, false]
1048        );
1049
1050        // The item should be rate limited by the first key (1).
1051        assert_eq!(
1052            invocation
1053                .invoke_async::<Vec<bool>>(&mut conn)
1054                .await
1055                .unwrap(),
1056            vec![true, false]
1057        );
1058
1059        // The item should still be rate limited by the first key (1), but *not*
1060        // rate limited by the second key (2) even though this is the third time
1061        // we've checked the quotas. This ensures items that are rejected by a lower
1062        // quota don't affect unrelated items that share a parent quota.
1063        assert_eq!(
1064            invocation
1065                .invoke_async::<Vec<bool>>(&mut conn)
1066                .await
1067                .unwrap(),
1068            vec![true, false]
1069        );
1070
1071        assert_eq!(conn.get::<_, String>(&foo).await.unwrap(), "1");
1072        let ttl: u64 = conn.ttl(&foo).await.unwrap();
1073        assert!(ttl >= 59);
1074        assert!(ttl <= 60);
1075
1076        assert_eq!(conn.get::<_, String>(&bar).await.unwrap(), "1");
1077        let ttl: u64 = conn.ttl(&bar).await.unwrap();
1078        assert!(ttl >= 119);
1079        assert!(ttl <= 120);
1080
1081        // make sure "refund/negative" keys haven't been incremented
1082        let () = conn.get(r_foo).await.unwrap();
1083        let () = conn.get(r_bar).await.unwrap();
1084
1085        // Test that refunded quotas work
1086        let () = conn.set(&apple, 5).await.unwrap();
1087
1088        let mut invocation = script.prepare_invoke();
1089        invocation
1090            .key(&orange) // key
1091            .key(&baz) // refund key
1092            .arg(1) // limit
1093            .arg(now + 60) // expiry
1094            .arg(1) // quantity
1095            .arg(false);
1096
1097        // increment
1098        assert_eq!(
1099            invocation
1100                .invoke_async::<Vec<bool>>(&mut conn)
1101                .await
1102                .unwrap(),
1103            vec![false]
1104        );
1105
1106        // test that it's rate limited without refund
1107        assert_eq!(
1108            invocation
1109                .invoke_async::<Vec<bool>>(&mut conn)
1110                .await
1111                .unwrap(),
1112            vec![true]
1113        );
1114
1115        let mut invocation = script.prepare_invoke();
1116        invocation
1117            .key(&orange) // key
1118            .key(&apple) // refund key
1119            .arg(1) // limit
1120            .arg(now + 60) // expiry
1121            .arg(1) // quantity
1122            .arg(false);
1123
1124        // test that refund key is used
1125        assert_eq!(
1126            invocation
1127                .invoke_async::<Vec<bool>>(&mut conn)
1128                .await
1129                .unwrap(),
1130            vec![false]
1131        );
1132    }
1133}