relay_quotas/
redis.rs

1use std::fmt::{self, Debug};
2
3use relay_common::time::UnixTimestamp;
4use relay_log::protocol::value;
5use relay_redis::redis::{self, FromRedisValue, Script};
6use relay_redis::{AsyncRedisClient, RedisError, RedisScripts};
7use thiserror::Error;
8
9use crate::REJECT_ALL_SECS;
10use crate::global::GlobalLimiter;
11use crate::quota::{ItemScoping, Quota, QuotaScope};
12use crate::rate_limit::{RateLimit, RateLimits, RetryAfter};
13
14/// The `grace` period allows accommodating for clock drift in TTL
15/// calculation since the clock on the Redis instance used to store quota
16/// metrics may not be in sync with the computer running this code.
17const GRACE: u64 = 60;
18
19/// An error returned by [`RedisRateLimiter`].
20#[derive(Debug, Error)]
21pub enum RateLimitingError {
22    /// Failed to communicate with Redis.
23    #[error("failed to communicate with redis")]
24    Redis(
25        #[from]
26        #[source]
27        RedisError,
28    ),
29
30    /// Failed to check global rate limits via the service.
31    #[error("failed to check global rate limits")]
32    UnreachableGlobalRateLimits,
33}
34
35/// Creates a refund key for a given counter key.
36///
37/// Refund keys are used to track credits that should be applied to a quota,
38/// allowing for more flexible quota management.
39fn get_refunded_quota_key(counter_key: &str) -> String {
40    format!("r:{counter_key}")
41}
42
43/// A transparent wrapper around an Option that only displays `Some`.
44struct OptionalDisplay<T>(Option<T>);
45
46impl<T> fmt::Display for OptionalDisplay<T>
47where
48    T: fmt::Display,
49{
50    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
51        match self.0 {
52            Some(ref value) => write!(f, "{value}"),
53            None => Ok(()),
54        }
55    }
56}
57
58/// Owned version of [`RedisQuota`].
59#[derive(Debug, Clone)]
60pub struct OwnedRedisQuota {
61    /// The original quota.
62    quota: Quota,
63    /// Scopes of the item being tracked.
64    scoping: ItemScoping,
65    /// The Redis key prefix mapped from the quota id.
66    prefix: String,
67    /// The redis window in seconds mapped from the quota.
68    window: u64,
69    /// The ingestion timestamp determining the rate limiting bucket.
70    timestamp: UnixTimestamp,
71}
72
73impl OwnedRedisQuota {
74    /// Returns an instance of [`RedisQuota`] which borrows from this [`OwnedRedisQuota`].
75    pub fn build_ref(&self) -> RedisQuota<'_> {
76        RedisQuota {
77            quota: &self.quota,
78            scoping: self.scoping,
79            prefix: &self.prefix,
80            window: self.window,
81            timestamp: self.timestamp,
82        }
83    }
84}
85
86/// Reference to information required for tracking quotas in Redis.
87#[derive(Debug, Clone, Eq, PartialEq)]
88pub struct RedisQuota<'a> {
89    /// The original quota.
90    quota: &'a Quota,
91    /// Scopes of the item being tracked.
92    scoping: ItemScoping,
93    /// The Redis key prefix mapped from the quota id.
94    prefix: &'a str,
95    /// The redis window in seconds mapped from the quota.
96    window: u64,
97    /// The ingestion timestamp determining the rate limiting bucket.
98    timestamp: UnixTimestamp,
99}
100
101impl<'a> RedisQuota<'a> {
102    /// Creates a new [`RedisQuota`] from a [`Quota`], item scoping, and timestamp.
103    ///
104    /// Returns `None` if the quota cannot be tracked in Redis because it's missing
105    /// required fields (ID or window). This allows forward compatibility with
106    /// future quota types.
107    pub fn new(quota: &'a Quota, scoping: ItemScoping, timestamp: UnixTimestamp) -> Option<Self> {
108        // These fields indicate that we *can* track this quota.
109        let prefix = quota.id.as_deref()?;
110        let window = quota.window?;
111
112        Some(Self {
113            quota,
114            scoping,
115            prefix,
116            window,
117            timestamp,
118        })
119    }
120
121    /// Converts this [`RedisQuota`] to an [`OwnedRedisQuota`] leaving the original
122    /// struct in place.
123    pub fn build_owned(&self) -> OwnedRedisQuota {
124        OwnedRedisQuota {
125            quota: self.quota.clone(),
126            scoping: self.scoping,
127            prefix: self.prefix.to_owned(),
128            window: self.window,
129            timestamp: self.timestamp,
130        }
131    }
132
133    /// Returns the window size of the quota in seconds.
134    pub fn window(&self) -> u64 {
135        self.window
136    }
137
138    /// Returns the prefix of the quota used for Redis key generation.
139    pub fn prefix(&self) -> &'a str {
140        self.prefix
141    }
142
143    /// Returns the limit value formatted for Redis.
144    ///
145    /// Returns `-1` for unlimited quotas or when the limit doesn't fit into an `i64`.
146    /// Otherwise, returns the limit value as an `i64`.
147    pub fn limit(&self) -> i64 {
148        self.limit
149            // If it does not fit into i64, treat as unlimited:
150            .and_then(|limit| limit.try_into().ok())
151            .unwrap_or(-1)
152    }
153
154    fn shift(&self) -> u64 {
155        if self.quota.scope == QuotaScope::Global {
156            0
157        } else {
158            self.scoping.organization_id.value() % self.window
159        }
160    }
161
162    /// Returns the current time slot of the quota based on the timestamp.
163    ///
164    /// Slots are used to determine the time bucket for rate limiting.
165    pub fn slot(&self) -> u64 {
166        (self.timestamp.as_secs() - self.shift()) / self.window
167    }
168
169    /// Returns the timestamp when the current quota window will expire.
170    pub fn expiry(&self) -> UnixTimestamp {
171        let next_slot = self.slot() + 1;
172        let next_start = next_slot * self.window + self.shift();
173        UnixTimestamp::from_secs(next_start)
174    }
175
176    /// Returns when the Redis key should expire.
177    ///
178    /// This is the expiry time plus a grace period.
179    pub fn key_expiry(&self) -> u64 {
180        self.expiry().as_secs() + GRACE
181    }
182
183    /// Returns the Redis key for this quota.
184    ///
185    /// The key includes the quota ID, organization ID, and other scoping information
186    /// based on the quota's scope type. Keys are structured to ensure proper isolation
187    /// between different organizations and scopes.
188    pub fn key(&self) -> String {
189        // The subscope id is only formatted into the key if the quota is not organization-scoped.
190        // The organization id is always included.
191        let subscope = match self.quota.scope {
192            QuotaScope::Global => None,
193            QuotaScope::Organization => None,
194            scope => self.scoping.scope_id(scope),
195        };
196
197        let org = self.scoping.organization_id;
198
199        format!(
200            "quota:{id}{{{org}}}{subscope}{namespace}:{slot}",
201            id = self.prefix,
202            org = org,
203            subscope = OptionalDisplay(subscope),
204            namespace = OptionalDisplay(self.namespace),
205            slot = self.slot(),
206        )
207    }
208}
209
210impl std::ops::Deref for RedisQuota<'_> {
211    type Target = Quota;
212
213    fn deref(&self) -> &Self::Target {
214        self.quota
215    }
216}
217
218/// A service that executes quotas and checks for rate limits in a shared cache.
219///
220/// Quotas handle tracking a project's usage and respond whether a project has been
221/// configured to throttle incoming data if they go beyond the specified quota.
222///
223/// Quotas can specify a window to be tracked in, such as per minute or per hour. Additionally,
224/// quotas allow to specify the data categories they apply to, for example error events or
225/// attachments. For more information on quota parameters, see [`Quota`].
226///
227/// Requires the `redis` feature.
228#[derive(Clone)]
229pub struct RedisRateLimiter<T> {
230    client: AsyncRedisClient,
231    script: &'static Script,
232    max_limit: Option<u64>,
233    global_limiter: T,
234}
235
236impl<T: GlobalLimiter> RedisRateLimiter<T> {
237    /// Creates a new [`RedisRateLimiter`] instance.
238    pub fn new(client: AsyncRedisClient, global_limiter: T) -> Self {
239        RedisRateLimiter {
240            client,
241            script: RedisScripts::load_is_rate_limited(),
242            max_limit: None,
243            global_limiter,
244        }
245    }
246
247    /// Sets the maximum rate limit in seconds.
248    ///
249    /// By default, this rate limiter will return rate limits based on the quotas' `window` fields.
250    /// If a maximum rate limit is set, the returned rate limit will be bounded by this value.
251    pub fn max_limit(mut self, max_limit: Option<u64>) -> Self {
252        self.max_limit = max_limit;
253        self
254    }
255
256    /// Checks whether any of the quotas in effect have been exceeded and records consumption.
257    ///
258    /// By invoking this method, the caller signals that data is being ingested and needs to be
259    /// counted against the quota. This increment happens atomically if none of the quotas have been
260    /// exceeded. Otherwise, a rate limit is returned and data is not counted against the quotas.
261    ///
262    /// If no key is specified, then only organization-wide and project-wide quotas are checked. If
263    /// a key is specified, then key-quotas are also checked.
264    ///
265    /// When `over_accept_once` is set to `true` and the current quota would be exceeded by the
266    /// provided `quantity`, the data is accepted once and subsequent requests will be rejected
267    /// until the quota refreshes.
268    ///
269    /// A `quantity` of `0` can be used to check if the quota limit has been reached or exceeded
270    /// without incrementing it in the success case. This is useful for checking quotas in a different
271    /// data category.
272    pub async fn is_rate_limited<'a>(
273        &self,
274        quotas: impl IntoIterator<Item = &'a Quota>,
275        item_scoping: ItemScoping,
276        quantity: usize,
277        over_accept_once: bool,
278    ) -> Result<RateLimits, RateLimitingError> {
279        let timestamp = UnixTimestamp::now();
280        let mut invocation = self.script.prepare_invoke();
281        let mut tracked_quotas = Vec::new();
282        let mut rate_limits = RateLimits::new();
283
284        let mut global_quotas = vec![];
285
286        for quota in quotas {
287            if !quota.matches(item_scoping) {
288                // Silently skip all quotas that do not apply to this item.
289            } else if quota.limit == Some(0) {
290                // A zero-sized quota is strongest. Do not call into Redis at all, and do not
291                // increment any keys, as one quota has reached capacity (this is how regular quotas
292                // behave as well).
293                let retry_after = self.retry_after(REJECT_ALL_SECS);
294                rate_limits.add(RateLimit::from_quota(quota, *item_scoping, retry_after));
295            } else if let Some(quota) = RedisQuota::new(quota, item_scoping, timestamp) {
296                if quota.scope == QuotaScope::Global {
297                    global_quotas.push(quota);
298                } else {
299                    let key = quota.key();
300                    // Remaining quotas are expected to be trackable in Redis.
301                    let refund_key = get_refunded_quota_key(&key);
302
303                    invocation.key(key);
304                    invocation.key(refund_key);
305
306                    invocation.arg(quota.limit());
307                    invocation.arg(quota.key_expiry());
308                    invocation.arg(quantity);
309                    invocation.arg(over_accept_once);
310
311                    tracked_quotas.push(quota);
312                }
313            } else {
314                // This quota is neither a static reject-all, nor can it be tracked in Redis due to
315                // missing fields. We're skipping this for forward-compatibility.
316                relay_log::with_scope(
317                    |scope| scope.set_extra("quota", value::to_value(quota).unwrap()),
318                    || relay_log::warn!("skipping unsupported quota"),
319                )
320            }
321        }
322
323        // We check the global rate limits before the other limits. This step must be separate from
324        // checking the other rate limits, since those are checked with a Redis script that works
325        // under the invariant that all keys are within the same Redis instance (given their partitioning).
326        // Global keys on the other hand are always on the same instance, so if they were to be mixed
327        // with normal keys the script will end up referencing keys from multiple instances, making it
328        // impossible for the script to work.
329        let rate_limited_global_quotas = self
330            .global_limiter
331            .check_global_rate_limits(&global_quotas, quantity)
332            .await?;
333
334        for quota in rate_limited_global_quotas {
335            let retry_after = self.retry_after((quota.expiry() - timestamp).as_secs());
336            rate_limits.add(RateLimit::from_quota(quota, *item_scoping, retry_after));
337        }
338
339        // Either there are no quotas to run against Redis, or we already have a rate limit from a
340        // zero-sized quota. In either cases, skip invoking the script and return early.
341        if tracked_quotas.is_empty() || rate_limits.is_limited() {
342            return Ok(rate_limits);
343        }
344
345        // We get the redis client after the global rate limiting since we don't want to hold the
346        // client across await points, otherwise it might be held for too long, and we will run out
347        // of connections.
348        let mut connection = self.client.get_connection().await?;
349        let results: ScriptResult = invocation
350            .invoke_async(&mut connection)
351            .await
352            .map_err(RedisError::Redis)?;
353
354        for (quota, state) in tracked_quotas.iter().zip(results.0) {
355            if state.is_rejected {
356                let retry_after = self.retry_after((quota.expiry() - timestamp).as_secs());
357                rate_limits.add(RateLimit::from_quota(quota, *item_scoping, retry_after));
358            }
359        }
360
361        Ok(rate_limits)
362    }
363
364    /// Creates a [`RetryAfter`] value that is bounded by the configured [`max_limit`](Self::max_limit).
365    ///
366    /// If a maximum rate limit has been set, the returned value will not exceed that limit.
367    fn retry_after(&self, mut seconds: u64) -> RetryAfter {
368        if let Some(max_limit) = self.max_limit {
369            seconds = std::cmp::min(seconds, max_limit);
370        }
371
372        RetryAfter::from_secs(seconds)
373    }
374}
375
376/// The result returned from the rate limiting Redis script.
377#[derive(Debug)]
378struct ScriptResult(Vec<QuotaState>);
379
380impl FromRedisValue for ScriptResult {
381    fn from_redis_value(v: &redis::Value) -> redis::RedisResult<Self> {
382        let Some(seq) = v.as_sequence() else {
383            return Err(redis::RedisError::from((
384                redis::ErrorKind::TypeError,
385                "Expected a sequence from the rate limiting script",
386                format!("{v:?}"),
387            )));
388        };
389
390        let (chunks, rem) = seq.as_chunks();
391        if !rem.is_empty() {
392            return Err(redis::RedisError::from((
393                redis::ErrorKind::TypeError,
394                "Expected an even number of values from the rate limiting script",
395                format!("{v:?}"),
396            )));
397        }
398
399        let mut result = Vec::with_capacity(chunks.len());
400        for [is_rejected, consumed] in chunks {
401            result.push(QuotaState {
402                is_rejected: bool::from_redis_value(is_rejected)?,
403                consumed: i64::from_redis_value(consumed)?,
404            });
405        }
406
407        Ok(Self(result))
408    }
409}
410
411/// The state returned from the rate limiting script for a single quota.
412#[derive(Debug)]
413struct QuotaState {
414    /// Whether the quota rejects the request.
415    is_rejected: bool,
416    /// How much of the quota has already been consumed, before adding the requested quantity.
417    #[expect(unused, reason = "not yet used")]
418    consumed: i64,
419}
420
421#[cfg(test)]
422mod tests {
423    use std::time::{SystemTime, UNIX_EPOCH};
424
425    use super::*;
426    use crate::quota::{DataCategories, DataCategory, ReasonCode, Scoping};
427    use crate::rate_limit::RateLimitScope;
428    use crate::{GlobalRateLimiter, MetricNamespaceScoping};
429    use relay_base_schema::metrics::MetricNamespace;
430    use relay_base_schema::organization::OrganizationId;
431    use relay_base_schema::project::{ProjectId, ProjectKey};
432    use relay_redis::RedisConfigOptions;
433    use relay_redis::redis::AsyncCommands;
434    use smallvec::smallvec;
435    use tokio::sync::Mutex;
436
437    struct MockGlobalLimiter {
438        client: AsyncRedisClient,
439        global_rate_limiter: Mutex<GlobalRateLimiter>,
440    }
441
442    impl GlobalLimiter for MockGlobalLimiter {
443        async fn check_global_rate_limits<'a>(
444            &self,
445            global_quotas: &'a [RedisQuota<'a>],
446            quantity: usize,
447        ) -> Result<Vec<&'a RedisQuota<'a>>, RateLimitingError> {
448            self.global_rate_limiter
449                .lock()
450                .await
451                .filter_rate_limited(&self.client, global_quotas, quantity)
452                .await
453        }
454    }
455
456    fn build_rate_limiter() -> RedisRateLimiter<MockGlobalLimiter> {
457        let url = std::env::var("RELAY_REDIS_URL")
458            .unwrap_or_else(|_| "redis://127.0.0.1:6379".to_owned());
459        let client =
460            AsyncRedisClient::single("test", &url, &RedisConfigOptions::default()).unwrap();
461
462        let global_limiter = MockGlobalLimiter {
463            client: client.clone(),
464            global_rate_limiter: Mutex::new(GlobalRateLimiter::default()),
465        };
466
467        RedisRateLimiter {
468            client,
469            script: RedisScripts::load_is_rate_limited(),
470            max_limit: None,
471            global_limiter,
472        }
473    }
474
475    #[tokio::test]
476    async fn test_zero_size_quotas() {
477        let quotas = &[
478            Quota {
479                id: None,
480                categories: DataCategories::new(),
481                scope: QuotaScope::Organization,
482                scope_id: None,
483                limit: Some(0),
484                window: None,
485                reason_code: Some(ReasonCode::new("get_lost")),
486                namespace: None,
487            },
488            Quota {
489                id: Some("42".into()),
490                categories: DataCategories::new(),
491                scope: QuotaScope::Organization,
492                scope_id: None,
493                limit: None,
494                window: Some(42),
495                reason_code: Some(ReasonCode::new("unlimited")),
496                namespace: None,
497            },
498        ];
499
500        let scoping = ItemScoping {
501            category: DataCategory::Error,
502            scoping: Scoping {
503                organization_id: OrganizationId::new(42),
504                project_id: ProjectId::new(43),
505                project_key: ProjectKey::parse("a94ae32be2584e0bbd7a4cbb95971fee").unwrap(),
506                key_id: Some(44),
507            },
508            namespace: MetricNamespaceScoping::None,
509        };
510
511        let rate_limits: Vec<RateLimit> = build_rate_limiter()
512            .is_rate_limited(quotas, scoping, 1, false)
513            .await
514            .expect("rate limiting failed")
515            .into_iter()
516            .collect();
517
518        assert_eq!(
519            rate_limits,
520            vec![RateLimit {
521                categories: DataCategories::new(),
522                scope: RateLimitScope::Organization(OrganizationId::new(42)),
523                reason_code: Some(ReasonCode::new("get_lost")),
524                retry_after: rate_limits[0].retry_after,
525                namespaces: smallvec![],
526            }]
527        );
528    }
529
530    /// Tests that a quota with and without namespace are counted separately.
531    #[tokio::test]
532    async fn test_non_global_namespace_quota() {
533        let quota_limit = 5;
534        let get_quota = |namespace: Option<MetricNamespace>| -> Quota {
535            Quota {
536                id: Some(format!("test_simple_quota_{}", uuid::Uuid::new_v4()).into()),
537                categories: DataCategories::new(),
538                scope: QuotaScope::Organization,
539                scope_id: None,
540                limit: Some(quota_limit),
541                window: Some(600),
542                reason_code: Some(ReasonCode::new(format!("ns: {namespace:?}"))),
543                namespace,
544            }
545        };
546
547        let quotas = &[get_quota(None)];
548        let quota_with_namespace = &[get_quota(Some(MetricNamespace::Transactions))];
549
550        let scoping = ItemScoping {
551            category: DataCategory::Error,
552            scoping: Scoping {
553                organization_id: OrganizationId::new(42),
554                project_id: ProjectId::new(43),
555                project_key: ProjectKey::parse("a94ae32be2584e0bbd7a4cbb95971fee").unwrap(),
556                key_id: Some(44),
557            },
558            namespace: MetricNamespaceScoping::Some(MetricNamespace::Transactions),
559        };
560
561        let rate_limiter = build_rate_limiter();
562
563        // First confirm normal behaviour without namespace.
564        for i in 0..10 {
565            let rate_limits: Vec<RateLimit> = rate_limiter
566                .is_rate_limited(quotas, scoping, 1, false)
567                .await
568                .expect("rate limiting failed")
569                .into_iter()
570                .collect();
571
572            if i < quota_limit {
573                assert_eq!(rate_limits, vec![]);
574            } else {
575                assert_eq!(
576                    rate_limits[0].reason_code,
577                    Some(ReasonCode::new("ns: None"))
578                );
579            }
580        }
581
582        // Then, send identical quota with namespace and confirm it counts separately.
583        for i in 0..10 {
584            let rate_limits: Vec<RateLimit> = rate_limiter
585                .is_rate_limited(quota_with_namespace, scoping, 1, false)
586                .await
587                .expect("rate limiting failed")
588                .into_iter()
589                .collect();
590
591            if i < quota_limit {
592                assert_eq!(rate_limits, vec![]);
593            } else {
594                assert_eq!(
595                    rate_limits[0].reason_code,
596                    Some(ReasonCode::new("ns: Some(Transactions)"))
597                );
598            }
599        }
600    }
601
602    #[tokio::test]
603    async fn test_simple_quota() {
604        let quotas = &[Quota {
605            id: Some(format!("test_simple_quota_{}", uuid::Uuid::new_v4()).into()),
606            categories: DataCategories::new(),
607            scope: QuotaScope::Organization,
608            scope_id: None,
609            limit: Some(5),
610            window: Some(60),
611            reason_code: Some(ReasonCode::new("get_lost")),
612            namespace: None,
613        }];
614
615        let scoping = ItemScoping {
616            category: DataCategory::Error,
617            scoping: Scoping {
618                organization_id: OrganizationId::new(42),
619                project_id: ProjectId::new(43),
620                project_key: ProjectKey::parse("a94ae32be2584e0bbd7a4cbb95971fee").unwrap(),
621                key_id: Some(44),
622            },
623            namespace: MetricNamespaceScoping::None,
624        };
625
626        let rate_limiter = build_rate_limiter();
627
628        for i in 0..10 {
629            let rate_limits: Vec<RateLimit> = rate_limiter
630                .is_rate_limited(quotas, scoping, 1, false)
631                .await
632                .expect("rate limiting failed")
633                .into_iter()
634                .collect();
635
636            if i >= 5 {
637                assert_eq!(
638                    rate_limits,
639                    vec![RateLimit {
640                        categories: DataCategories::new(),
641                        scope: RateLimitScope::Organization(OrganizationId::new(42)),
642                        reason_code: Some(ReasonCode::new("get_lost")),
643                        retry_after: rate_limits[0].retry_after,
644                        namespaces: smallvec![],
645                    }]
646                );
647            } else {
648                assert_eq!(rate_limits, vec![]);
649            }
650        }
651    }
652
653    #[tokio::test]
654    async fn test_simple_global_quota() {
655        let quotas = &[Quota {
656            id: Some(format!("test_simple_global_quota_{}", uuid::Uuid::new_v4()).into()),
657            categories: DataCategories::new(),
658            scope: QuotaScope::Global,
659            scope_id: None,
660            limit: Some(5),
661            window: Some(60),
662            reason_code: Some(ReasonCode::new("get_lost")),
663            namespace: None,
664        }];
665
666        let scoping = ItemScoping {
667            category: DataCategory::Error,
668            scoping: Scoping {
669                organization_id: OrganizationId::new(42),
670                project_id: ProjectId::new(43),
671                project_key: ProjectKey::parse("a94ae32be2584e0bbd7a4cbb95971fee").unwrap(),
672                key_id: Some(44),
673            },
674            namespace: MetricNamespaceScoping::None,
675        };
676
677        let rate_limiter = build_rate_limiter();
678
679        for i in 0..10 {
680            let rate_limits: Vec<RateLimit> = rate_limiter
681                .is_rate_limited(quotas, scoping, 1, false)
682                .await
683                .expect("rate limiting failed")
684                .into_iter()
685                .collect();
686
687            if i >= 5 {
688                assert_eq!(
689                    rate_limits,
690                    vec![RateLimit {
691                        categories: DataCategories::new(),
692                        scope: RateLimitScope::Global,
693                        reason_code: Some(ReasonCode::new("get_lost")),
694                        retry_after: rate_limits[0].retry_after,
695                        namespaces: smallvec![],
696                    }]
697                );
698            } else {
699                assert_eq!(rate_limits, vec![]);
700            }
701        }
702    }
703
704    #[tokio::test]
705    async fn test_quantity_0() {
706        let quotas = &[Quota {
707            id: Some(format!("test_quantity_0_{}", uuid::Uuid::new_v4()).into()),
708            categories: DataCategories::new(),
709            scope: QuotaScope::Organization,
710            scope_id: None,
711            limit: Some(1),
712            window: Some(60),
713            reason_code: Some(ReasonCode::new("get_lost")),
714            namespace: None,
715        }];
716
717        let scoping = ItemScoping {
718            category: DataCategory::Error,
719            scoping: Scoping {
720                organization_id: OrganizationId::new(42),
721                project_id: ProjectId::new(43),
722                project_key: ProjectKey::parse("a94ae32be2584e0bbd7a4cbb95971fee").unwrap(),
723                key_id: Some(44),
724            },
725            namespace: MetricNamespaceScoping::None,
726        };
727
728        let rate_limiter = build_rate_limiter();
729
730        // limit is 1, so first call not rate limited
731        assert!(
732            !rate_limiter
733                .is_rate_limited(quotas, scoping, 1, false)
734                .await
735                .unwrap()
736                .is_limited()
737        );
738
739        // quota is now exhausted
740        assert!(
741            rate_limiter
742                .is_rate_limited(quotas, scoping, 1, false)
743                .await
744                .unwrap()
745                .is_limited()
746        );
747
748        // quota is exhausted, regardless of the quantity
749        assert!(
750            rate_limiter
751                .is_rate_limited(quotas, scoping, 0, false)
752                .await
753                .unwrap()
754                .is_limited()
755        );
756
757        // quota is exhausted, regardless of the quantity
758        assert!(
759            rate_limiter
760                .is_rate_limited(quotas, scoping, 1, false)
761                .await
762                .unwrap()
763                .is_limited()
764        );
765    }
766
767    #[tokio::test]
768    async fn test_quota_go_over() {
769        let quotas = &[Quota {
770            id: Some(format!("test_quota_go_over{}", uuid::Uuid::new_v4()).into()),
771            categories: DataCategories::new(),
772            scope: QuotaScope::Organization,
773            scope_id: None,
774            limit: Some(2),
775            window: Some(60),
776            reason_code: Some(ReasonCode::new("get_lost")),
777            namespace: None,
778        }];
779
780        let scoping = ItemScoping {
781            category: DataCategory::Error,
782            scoping: Scoping {
783                organization_id: OrganizationId::new(42),
784                project_id: ProjectId::new(43),
785                project_key: ProjectKey::parse("a94ae32be2584e0bbd7a4cbb95971fee").unwrap(),
786                key_id: Some(44),
787            },
788            namespace: MetricNamespaceScoping::None,
789        };
790
791        let rate_limiter = build_rate_limiter();
792
793        // limit is 2, so first call not rate limited
794        let is_limited = rate_limiter
795            .is_rate_limited(quotas, scoping, 1, true)
796            .await
797            .unwrap()
798            .is_limited();
799        assert!(!is_limited);
800
801        // go over limit, but first call is over-accepted
802        let is_limited = rate_limiter
803            .is_rate_limited(quotas, scoping, 2, true)
804            .await
805            .unwrap()
806            .is_limited();
807        assert!(!is_limited);
808
809        // quota is exhausted, regardless of the quantity
810        let is_limited = rate_limiter
811            .is_rate_limited(quotas, scoping, 0, true)
812            .await
813            .unwrap()
814            .is_limited();
815        assert!(is_limited);
816
817        // quota is exhausted, regardless of the quantity
818        let is_limited = rate_limiter
819            .is_rate_limited(quotas, scoping, 1, true)
820            .await
821            .unwrap()
822            .is_limited();
823        assert!(is_limited);
824    }
825
826    #[tokio::test]
827    async fn test_bails_immediately_without_any_quota() {
828        let scoping = ItemScoping {
829            category: DataCategory::Error,
830            scoping: Scoping {
831                organization_id: OrganizationId::new(42),
832                project_id: ProjectId::new(43),
833                project_key: ProjectKey::parse("a94ae32be2584e0bbd7a4cbb95971fee").unwrap(),
834                key_id: Some(44),
835            },
836            namespace: MetricNamespaceScoping::None,
837        };
838
839        let rate_limits: Vec<RateLimit> = build_rate_limiter()
840            .is_rate_limited(&[], scoping, 1, false)
841            .await
842            .expect("rate limiting failed")
843            .into_iter()
844            .collect();
845
846        assert_eq!(rate_limits, vec![]);
847    }
848
849    #[tokio::test]
850    async fn test_limited_with_unlimited_quota() {
851        let quotas = &[
852            Quota {
853                id: Some("q0".into()),
854                categories: DataCategories::new(),
855                scope: QuotaScope::Organization,
856                scope_id: None,
857                limit: None,
858                window: Some(1),
859                reason_code: Some(ReasonCode::new("project_quota0")),
860                namespace: None,
861            },
862            Quota {
863                id: Some("q1".into()),
864                categories: DataCategories::new(),
865                scope: QuotaScope::Organization,
866                scope_id: None,
867                limit: Some(1),
868                window: Some(1),
869                reason_code: Some(ReasonCode::new("project_quota1")),
870                namespace: None,
871            },
872        ];
873
874        let scoping = ItemScoping {
875            category: DataCategory::Error,
876            scoping: Scoping {
877                organization_id: OrganizationId::new(42),
878                project_id: ProjectId::new(43),
879                project_key: ProjectKey::parse("a94ae32be2584e0bbd7a4cbb95971fee").unwrap(),
880                key_id: Some(44),
881            },
882            namespace: MetricNamespaceScoping::None,
883        };
884
885        let rate_limiter = build_rate_limiter();
886
887        for i in 0..1 {
888            let rate_limits: Vec<RateLimit> = rate_limiter
889                .is_rate_limited(quotas, scoping, 1, false)
890                .await
891                .expect("rate limiting failed")
892                .into_iter()
893                .collect();
894
895            if i == 0 {
896                assert_eq!(rate_limits, &[]);
897            } else {
898                assert_eq!(
899                    rate_limits,
900                    vec![RateLimit {
901                        categories: DataCategories::new(),
902                        scope: RateLimitScope::Organization(OrganizationId::new(42)),
903                        reason_code: Some(ReasonCode::new("project_quota1")),
904                        retry_after: rate_limits[0].retry_after,
905                        namespaces: smallvec![],
906                    }]
907                );
908            }
909        }
910    }
911
912    #[tokio::test]
913    async fn test_quota_with_quantity() {
914        let quotas = &[Quota {
915            id: Some(format!("test_quantity_quota_{}", uuid::Uuid::new_v4()).into()),
916            categories: DataCategories::new(),
917            scope: QuotaScope::Organization,
918            scope_id: None,
919            limit: Some(500),
920            window: Some(60),
921            reason_code: Some(ReasonCode::new("get_lost")),
922            namespace: None,
923        }];
924
925        let scoping = ItemScoping {
926            category: DataCategory::Error,
927            scoping: Scoping {
928                organization_id: OrganizationId::new(42),
929                project_id: ProjectId::new(43),
930                project_key: ProjectKey::parse("a94ae32be2584e0bbd7a4cbb95971fee").unwrap(),
931                key_id: Some(44),
932            },
933            namespace: MetricNamespaceScoping::None,
934        };
935
936        let rate_limiter = build_rate_limiter();
937
938        for i in 0..10 {
939            let rate_limits: Vec<RateLimit> = rate_limiter
940                .is_rate_limited(quotas, scoping, 100, false)
941                .await
942                .expect("rate limiting failed")
943                .into_iter()
944                .collect();
945
946            if i >= 5 {
947                assert_eq!(
948                    rate_limits,
949                    vec![RateLimit {
950                        categories: DataCategories::new(),
951                        scope: RateLimitScope::Organization(OrganizationId::new(42)),
952                        reason_code: Some(ReasonCode::new("get_lost")),
953                        retry_after: rate_limits[0].retry_after,
954                        namespaces: smallvec![],
955                    }]
956                );
957            } else {
958                assert_eq!(rate_limits, vec![]);
959            }
960        }
961    }
962
963    #[tokio::test]
964    async fn test_get_redis_key_scoped() {
965        let quota = Quota {
966            id: Some("foo".into()),
967            categories: DataCategories::new(),
968            scope: QuotaScope::Project,
969            scope_id: Some("42".into()),
970            window: Some(2),
971            limit: Some(0),
972            reason_code: None,
973            namespace: None,
974        };
975
976        let scoping = ItemScoping {
977            category: DataCategory::Error,
978            scoping: Scoping {
979                organization_id: OrganizationId::new(69420),
980                project_id: ProjectId::new(42),
981                project_key: ProjectKey::parse("a94ae32be2584e0bbd7a4cbb95971fee").unwrap(),
982                key_id: Some(4711),
983            },
984            namespace: MetricNamespaceScoping::None,
985        };
986
987        let timestamp = UnixTimestamp::from_secs(123_123_123);
988        let redis_quota = RedisQuota::new(&quota, scoping, timestamp).unwrap();
989        assert_eq!(redis_quota.key(), "quota:foo{69420}42:61561561");
990    }
991
992    #[tokio::test]
993    async fn test_get_redis_key_unscoped() {
994        let quota = Quota {
995            id: Some("foo".into()),
996            categories: DataCategories::new(),
997            scope: QuotaScope::Organization,
998            scope_id: None,
999            window: Some(10),
1000            limit: Some(0),
1001            reason_code: None,
1002            namespace: None,
1003        };
1004
1005        let scoping = ItemScoping {
1006            category: DataCategory::Error,
1007            scoping: Scoping {
1008                organization_id: OrganizationId::new(69420),
1009                project_id: ProjectId::new(42),
1010                project_key: ProjectKey::parse("a94ae32be2584e0bbd7a4cbb95971fee").unwrap(),
1011                key_id: Some(4711),
1012            },
1013            namespace: MetricNamespaceScoping::None,
1014        };
1015
1016        let timestamp = UnixTimestamp::from_secs(234_531);
1017        let redis_quota = RedisQuota::new(&quota, scoping, timestamp).unwrap();
1018        assert_eq!(redis_quota.key(), "quota:foo{69420}:23453");
1019    }
1020
1021    #[tokio::test]
1022    async fn test_large_redis_limit_large() {
1023        let quota = Quota {
1024            id: Some("foo".into()),
1025            categories: DataCategories::new(),
1026            scope: QuotaScope::Organization,
1027            scope_id: None,
1028            window: Some(10),
1029            limit: Some(9223372036854775808), // i64::MAX + 1
1030            reason_code: None,
1031            namespace: None,
1032        };
1033
1034        let scoping = ItemScoping {
1035            category: DataCategory::Error,
1036            scoping: Scoping {
1037                organization_id: OrganizationId::new(69420),
1038                project_id: ProjectId::new(42),
1039                project_key: ProjectKey::parse("a94ae32be2584e0bbd7a4cbb95971fee").unwrap(),
1040                key_id: Some(4711),
1041            },
1042            namespace: MetricNamespaceScoping::None,
1043        };
1044
1045        let timestamp = UnixTimestamp::from_secs(234_531);
1046        let redis_quota = RedisQuota::new(&quota, scoping, timestamp).unwrap();
1047        assert_eq!(redis_quota.limit(), -1);
1048    }
1049
1050    #[tokio::test]
1051    async fn test_is_rate_limited_script() {
1052        let now = SystemTime::now()
1053            .duration_since(UNIX_EPOCH)
1054            .map(|duration| duration.as_secs())
1055            .unwrap();
1056
1057        let rate_limiter = build_rate_limiter();
1058        let mut conn = rate_limiter.client.get_connection().await.unwrap();
1059
1060        // define a few keys with random seed such that they do not collide with repeated test runs
1061        let foo = format!("foo___{now}");
1062        let r_foo = format!("r:foo___{now}");
1063        let bar = format!("bar___{now}");
1064        let r_bar = format!("r:bar___{now}");
1065        let apple = format!("apple___{now}");
1066        let orange = format!("orange___{now}");
1067        let baz = format!("baz___{now}");
1068
1069        let script = RedisScripts::load_is_rate_limited();
1070
1071        macro_rules! assert_invocation {
1072            ($invocation:expr, $($tt:tt)*) => {{
1073                let result = $invocation
1074                    .invoke_async::<ScriptResult>(&mut conn)
1075                    .await
1076                    .unwrap();
1077
1078                insta::assert_debug_snapshot!(result, $($tt)*);
1079            }};
1080        }
1081
1082        let mut invocation = script.prepare_invoke();
1083        invocation
1084            .key(&foo) // key
1085            .key(&r_foo) // refund key
1086            .key(&bar) // key
1087            .key(&r_bar) // refund key
1088            .arg(1) // limit
1089            .arg(now + 60) // expiry
1090            .arg(1) // quantity
1091            .arg(false) // over accept once
1092            .arg(2) // limit
1093            .arg(now + 120) // expiry
1094            .arg(1) // quantity
1095            .arg(false); // over accept once
1096
1097        // Craft a new invocation similar to the previous one, but it only applies to the quota
1098        // with a higher limit (2).
1099        let mut invocation2 = script.prepare_invoke();
1100        invocation2
1101            .key(&bar) // key
1102            .key(&r_bar) // refund key
1103            .arg(2) // limit
1104            .arg(now + 120) // expiry
1105            .arg(1) // quantity
1106            .arg(false); // over accept once
1107
1108        // Current usage is 0. But current values are now incremented by 1 (quantity).
1109        assert_invocation!(invocation, @r"
1110        ScriptResult(
1111            [
1112                QuotaState {
1113                    is_rejected: false,
1114                    consumed: 0,
1115                },
1116                QuotaState {
1117                    is_rejected: false,
1118                    consumed: 0,
1119                },
1120            ],
1121        )
1122        "
1123        );
1124
1125        // The usage was incremented in the last invocation, this invocation fails the rate limit
1126        // on the first quota. -> No changes are made to the counters, the next invocation still
1127        // needs to be `[1, 1]`.
1128        assert_invocation!(invocation, @r"
1129        ScriptResult(
1130            [
1131                QuotaState {
1132                    is_rejected: true,
1133                    consumed: 1,
1134                },
1135                QuotaState {
1136                    is_rejected: false,
1137                    consumed: 1,
1138                },
1139            ],
1140        )
1141        "
1142        );
1143
1144        // The item should still be rate limited by the first key (1), but *not*
1145        // rate limited by the second key (2) even though this is the third time
1146        // we've checked the quotas. This ensures items that are rejected by a lower
1147        // quota don't affect unrelated items that share a parent quota.
1148        assert_invocation!(invocation, @r"
1149        ScriptResult(
1150            [
1151                QuotaState {
1152                    is_rejected: true,
1153                    consumed: 1,
1154                },
1155                QuotaState {
1156                    is_rejected: false,
1157                    consumed: 1,
1158                },
1159            ],
1160        )
1161        "
1162        );
1163
1164        // Using the second invocation which only considers a quota with a higher limit, this
1165        // should still yield the current value of `1` and the next invocation should yield `2`.
1166        assert_invocation!(invocation2, @r"
1167        ScriptResult(
1168            [
1169                QuotaState {
1170                    is_rejected: false,
1171                    consumed: 1,
1172                },
1173            ],
1174        )
1175        "
1176        );
1177
1178        // This now yields `2`. This is also the invocation at the limit, which means it should no
1179        // longer increment the counter.
1180        assert_invocation!(invocation2, @r"
1181        ScriptResult(
1182            [
1183                QuotaState {
1184                    is_rejected: true,
1185                    consumed: 2,
1186                },
1187            ],
1188        )
1189        "
1190        );
1191
1192        // Check again with the original invocation, this now yields `[1, 2]`.
1193        assert_invocation!(invocation, @r"
1194        ScriptResult(
1195            [
1196                QuotaState {
1197                    is_rejected: true,
1198                    consumed: 1,
1199                },
1200                QuotaState {
1201                    is_rejected: true,
1202                    consumed: 2,
1203                },
1204            ],
1205        )
1206        "
1207        );
1208
1209        assert_eq!(conn.get::<_, String>(&foo).await.unwrap(), "1");
1210        let ttl: u64 = conn.ttl(&foo).await.unwrap();
1211        assert!(ttl >= 59);
1212        assert!(ttl <= 60);
1213
1214        assert_eq!(conn.get::<_, String>(&bar).await.unwrap(), "2");
1215        let ttl: u64 = conn.ttl(&bar).await.unwrap();
1216        assert!(ttl >= 119);
1217        assert!(ttl <= 120);
1218
1219        // make sure "refund/negative" keys haven't been incremented
1220        let () = conn.get(r_foo).await.unwrap();
1221        let () = conn.get(r_bar).await.unwrap();
1222
1223        // Test that refunded quotas work
1224        let () = conn.set(&apple, 5).await.unwrap();
1225
1226        let mut invocation = script.prepare_invoke();
1227        invocation
1228            .key(&orange) // key
1229            .key(&baz) // refund key
1230            .arg(1) // limit
1231            .arg(now + 60) // expiry
1232            .arg(1) // quantity
1233            .arg(false);
1234
1235        // increment, current quota is 0.
1236        assert_invocation!(invocation, @r"
1237        ScriptResult(
1238            [
1239                QuotaState {
1240                    is_rejected: false,
1241                    consumed: 0,
1242                },
1243            ],
1244        )
1245        "
1246        );
1247
1248        // test that it's rate limited without refund.
1249        assert_invocation!(invocation, @r"
1250        ScriptResult(
1251            [
1252                QuotaState {
1253                    is_rejected: true,
1254                    consumed: 1,
1255                },
1256            ],
1257        )
1258        "
1259        );
1260
1261        // Make sure, the counter wasn't incremented.
1262        assert_invocation!(invocation, @r"
1263        ScriptResult(
1264            [
1265                QuotaState {
1266                    is_rejected: true,
1267                    consumed: 1,
1268                },
1269            ],
1270        )
1271        "
1272        );
1273
1274        let mut invocation = script.prepare_invoke();
1275        invocation
1276            .key(&orange) // key
1277            .key(&apple) // refund key
1278            .arg(1) // limit
1279            .arg(now + 60) // expiry
1280            .arg(1) // quantity
1281            .arg(false);
1282
1283        // test that refund key is used
1284        assert_invocation!(invocation, @r"
1285        ScriptResult(
1286            [
1287                QuotaState {
1288                    is_rejected: false,
1289                    consumed: -4,
1290                },
1291            ],
1292        )
1293        "
1294        );
1295    }
1296}