relay_quotas/
redis.rs

1use std::fmt::{self, Debug};
2use std::sync::Arc;
3
4use itertools::Itertools;
5use relay_base_schema::metrics::MetricNamespace;
6use relay_base_schema::organization::OrganizationId;
7use relay_common::time::UnixTimestamp;
8use relay_log::protocol::value;
9use relay_redis::redis::{self, FromRedisValue, ParsingError, Script};
10use relay_redis::{AsyncRedisClient, RedisError, RedisScripts};
11use thiserror::Error;
12
13use crate::cache::OpportunisticQuotaCache;
14use crate::quota::{ItemScoping, Quota, QuotaScope};
15use crate::rate_limit::{RateLimit, RateLimits, RetryAfter};
16use crate::statsd::{QuotaCounters, QuotaTimers};
17use crate::{REJECT_ALL_SECS, cache};
18
19/// The `grace` period allows accommodating for clock drift in TTL
20/// calculation since the clock on the Redis instance used to store quota
21/// metrics may not be in sync with the computer running this code.
22const GRACE: u64 = 60;
23
24/// An error returned by [`RedisRateLimiter`].
25#[derive(Debug, Error)]
26#[error("failed to communicate with redis")]
27pub struct RateLimitingError(
28    #[from]
29    #[source]
30    pub RedisError,
31);
32
33/// Creates a refund key for a given counter key.
34///
35/// Refund keys are used to track credits that should be applied to a quota,
36/// allowing for more flexible quota management.
37fn get_refunded_quota_key(counter_key: &str) -> String {
38    format!("r:{counter_key}")
39}
40
41/// A transparent wrapper around an Option that only displays `Some`.
42struct OptionalDisplay<T>(Option<T>);
43
44impl<T> fmt::Display for OptionalDisplay<T>
45where
46    T: fmt::Display,
47{
48    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
49        match self.0 {
50            Some(ref value) => write!(f, "{value}"),
51            None => Ok(()),
52        }
53    }
54}
55
56/// Owned version of [`RedisQuota`].
57#[derive(Debug, Clone)]
58pub struct OwnedRedisQuota {
59    /// The original quota.
60    quota: Quota,
61    /// Scopes of the item being tracked.
62    scoping: ItemScoping,
63    /// The Redis key prefix mapped from the quota id.
64    prefix: Arc<str>,
65    /// The Redis window in seconds mapped from the quota.
66    window: u64,
67    /// The quantity being checked.
68    quantity: u64,
69    /// The ingestion timestamp determining the rate limiting bucket.
70    timestamp: UnixTimestamp,
71}
72
73impl OwnedRedisQuota {
74    /// Returns an instance of [`RedisQuota`] which borrows from this [`OwnedRedisQuota`].
75    pub fn build_ref(&self) -> RedisQuota<'_> {
76        RedisQuota {
77            quota: &self.quota,
78            scoping: self.scoping,
79            prefix: Arc::clone(&self.prefix),
80            window: self.window,
81            quantity: self.quantity,
82            timestamp: self.timestamp,
83        }
84    }
85}
86
87/// Reference to information required for tracking quotas in Redis.
88#[derive(Debug, Clone, Eq, PartialEq)]
89pub struct RedisQuota<'a> {
90    /// The original quota.
91    quota: &'a Quota,
92    /// Scopes of the item being tracked.
93    scoping: ItemScoping,
94    /// The Redis key prefix mapped from the quota id.
95    prefix: Arc<str>,
96    /// The Redis window in seconds mapped from the quota.
97    window: u64,
98    /// The quantity being checked.
99    quantity: u64,
100    /// The ingestion timestamp determining the rate limiting bucket.
101    timestamp: UnixTimestamp,
102}
103
104impl<'a> RedisQuota<'a> {
105    /// Creates a new [`RedisQuota`] from a [`Quota`], item scoping, and timestamp.
106    ///
107    /// Returns `None` if the quota cannot be tracked in Redis because it's missing
108    /// required fields (ID or window). This allows forward compatibility with
109    /// future quota types.
110    pub fn new(
111        quota: &'a Quota,
112        quantity: u64,
113        scoping: ItemScoping,
114        timestamp: UnixTimestamp,
115    ) -> Option<Self> {
116        // These fields indicate that we *can* track this quota.
117        let prefix = quota.id.clone()?;
118        let window = quota.window?;
119
120        Some(Self {
121            quota,
122            scoping,
123            prefix,
124            quantity,
125            window,
126            timestamp,
127        })
128    }
129
130    /// Converts this [`RedisQuota`] to an [`OwnedRedisQuota`] leaving the original
131    /// struct in place.
132    pub fn build_owned(&self) -> OwnedRedisQuota {
133        OwnedRedisQuota {
134            quota: self.quota.clone(),
135            scoping: self.scoping,
136            prefix: Arc::clone(&self.prefix),
137            window: self.window,
138            quantity: self.quantity,
139            timestamp: self.timestamp,
140        }
141    }
142
143    /// Returns the window size of the quota in seconds.
144    pub fn window(&self) -> u64 {
145        self.window
146    }
147
148    /// Returns the prefix of the quota used for Redis key generation.
149    pub fn prefix(&self) -> &str {
150        &self.prefix
151    }
152
153    /// Returns the quantity to rate limit.
154    pub fn quantity(&self) -> u64 {
155        self.quantity
156    }
157
158    /// Returns the limit value formatted for Redis.
159    ///
160    /// Returns `-1` for unlimited quotas or when the limit doesn't fit into an `i64`.
161    /// Otherwise, returns the limit value as an `i64`.
162    pub fn limit(&self) -> i64 {
163        self.limit
164            // If it does not fit into i64, treat as unlimited:
165            .and_then(|limit| limit.try_into().ok())
166            .unwrap_or(-1)
167    }
168
169    fn shift(&self) -> u64 {
170        self.scoping.organization_id.value() % self.window
171    }
172
173    /// Returns the current time slot of the quota based on the timestamp.
174    ///
175    /// Slots are used to determine the time bucket for rate limiting.
176    pub fn slot(&self) -> u64 {
177        (self.timestamp.as_secs() - self.shift()) / self.window
178    }
179
180    /// Returns the timestamp when the current quota window will expire.
181    pub fn expiry(&self) -> UnixTimestamp {
182        let next_slot = self.slot() + 1;
183        let next_start = next_slot * self.window + self.shift();
184        UnixTimestamp::from_secs(next_start)
185    }
186
187    /// Returns when the Redis key should expire.
188    ///
189    /// This is the expiry time plus a grace period.
190    pub fn key_expiry(&self) -> u64 {
191        self.expiry().as_secs() + GRACE
192    }
193
194    /// Returns the Redis key for this quota.
195    ///
196    /// The key includes the quota ID, organization ID, and other scoping information
197    /// based on the quota's scope type. Keys are structured to ensure proper isolation
198    /// between different organizations and scopes.
199    pub fn key(&self) -> QuotaCacheKey {
200        // The subscope id is only formatted into the key if the quota is not organization-scoped.
201        // The organization id is always included.
202        let subscope = match self.quota.scope {
203            QuotaScope::Organization => None,
204            scope => self.scoping.scope_id(scope),
205        };
206
207        QuotaCacheKey {
208            id: Arc::clone(&self.prefix),
209            org: self.scoping.organization_id,
210            subscope,
211            namespace: self.namespace,
212            slot: self.slot(),
213        }
214    }
215
216    /// Returns a [`cache::Quota`] built from this [`RedisQuota`].
217    fn for_cache(&self) -> cache::Quota<QuotaCacheKey> {
218        cache::Quota {
219            limit: self.limit(),
220            window: self.window,
221            key: self.key(),
222            expiry: UnixTimestamp::from_secs(self.key_expiry()),
223        }
224    }
225}
226
227impl std::ops::Deref for RedisQuota<'_> {
228    type Target = Quota;
229
230    fn deref(&self) -> &Self::Target {
231        self.quota
232    }
233}
234
235/// A key which uniquely identifies a quota.
236///
237/// Can be used as a Redis cache key by using the [`fmt::Display`] trait.
238///
239/// See also: [`RedisQuota::key`].
240#[derive(Debug, Clone, Hash, PartialEq, Eq, PartialOrd, Ord)]
241pub struct QuotaCacheKey {
242    id: Arc<str>,
243    org: OrganizationId,
244    subscope: Option<u64>,
245    namespace: Option<MetricNamespace>,
246    slot: u64,
247}
248
249impl fmt::Display for QuotaCacheKey {
250    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
251        write!(
252            f,
253            "quota:{id}{{{org}}}{subscope}{namespace}:{slot}",
254            id = self.id,
255            org = self.org,
256            subscope = OptionalDisplay(self.subscope),
257            namespace = OptionalDisplay(self.namespace),
258            slot = self.slot,
259        )
260    }
261}
262
263/// A service that executes quotas and checks for rate limits in a shared cache.
264///
265/// Quotas handle tracking a project's usage and respond whether a project has been
266/// configured to throttle incoming data if they go beyond the specified quota.
267///
268/// Quotas can specify a window to be tracked in, such as per minute or per hour. Additionally,
269/// quotas allow to specify the data categories they apply to, for example error events or
270/// attachments. For more information on quota parameters, see [`Quota`].
271///
272/// Requires the `redis` feature.
273#[derive(Clone)]
274pub struct RedisRateLimiter {
275    client: AsyncRedisClient,
276    cache: Option<Arc<OpportunisticQuotaCache<QuotaCacheKey>>>,
277    script: &'static Script,
278    max_limit: Option<u64>,
279}
280
281impl RedisRateLimiter {
282    /// Creates a new [`RedisRateLimiter`] instance.
283    pub fn new(client: AsyncRedisClient) -> Self {
284        RedisRateLimiter {
285            client,
286            cache: None,
287            script: RedisScripts::load_is_rate_limited(),
288            max_limit: None,
289        }
290    }
291
292    /// Sets the maximum rate limit in seconds.
293    ///
294    /// By default, this rate limiter will return rate limits based on the quotas' `window` fields.
295    /// If a maximum rate limit is set, the returned rate limit will be bounded by this value.
296    pub fn max_limit(mut self, max_limit: Option<u64>) -> Self {
297        self.max_limit = max_limit;
298        self
299    }
300
301    /// Enables an opportunistic cache for quotas.
302    ///
303    /// The opportunistic cache, opportunistically serves quotas from a local cache, reducing the
304    /// load on Redis heavily.
305    ///
306    /// Caching considers a ratio of the remaining quota to be available and periodically
307    /// synchronizes with Redis.
308    pub fn cache(mut self, cache_ratio: Option<f32>, max: Option<f32>) -> Self {
309        self.cache = cache_ratio
310            .map(OpportunisticQuotaCache::new)
311            .map(|c| c.with_max(max))
312            .map(Arc::new);
313
314        self
315    }
316
317    /// Checks whether any of the quotas in effect have been exceeded and records consumption.
318    ///
319    /// By invoking this method, the caller signals that data is being ingested and needs to be
320    /// counted against the quota. This increment happens atomically if none of the quotas have been
321    /// exceeded. Otherwise, a rate limit is returned and data is not counted against the quotas.
322    ///
323    /// If no key is specified, then only organization-wide and project-wide quotas are checked. If
324    /// a key is specified, then key-quotas are also checked.
325    ///
326    /// When `over_accept_once` is set to `true` and the current quota would be exceeded by the
327    /// provided `quantity`, the data is accepted once and subsequent requests will be rejected
328    /// until the quota refreshes.
329    ///
330    /// A `quantity` of `0` can be used to check if the quota limit has been reached or exceeded
331    /// without incrementing it in the success case. This is useful for checking quotas in a different
332    /// data category.
333    pub async fn is_rate_limited<'a>(
334        &self,
335        quotas: impl IntoIterator<Item = &'a Quota>,
336        item_scoping: ItemScoping,
337        quantity: usize,
338        over_accept_once: bool,
339    ) -> Result<RateLimits, RateLimitingError> {
340        let timestamp = UnixTimestamp::now();
341        let mut invocation = self.script.prepare_invoke();
342        let mut tracked_quotas = Vec::new();
343        let mut rate_limits = RateLimits::new();
344
345        let quantity = u64::try_from(quantity).unwrap_or(u64::MAX);
346
347        for quota in quotas {
348            if !quota.matches(item_scoping) {
349                // Silently skip all quotas that do not apply to this item.
350            } else if quota.limit == Some(0) {
351                // A zero-sized quota is strongest. Do not call into Redis at all, and do not
352                // increment any keys, as one quota has reached capacity (this is how regular quotas
353                // behave as well).
354                let retry_after = self.retry_after(REJECT_ALL_SECS);
355                rate_limits.add(RateLimit::from_quota(quota, *item_scoping, retry_after));
356            } else if let Some(mut quota) =
357                RedisQuota::new(quota, quantity, item_scoping, timestamp)
358            {
359                if let Some(cache) = &self.cache {
360                    quota.quantity = match cache.check_quota(quota.for_cache(), quantity) {
361                        cache::Action::Accept => continue,
362                        cache::Action::Check(quantity) => quantity,
363                    };
364                }
365
366                let redis_key = quota.key().to_string();
367                // Remaining quotas are expected to be track-able in Redis.
368                let refund_key = get_refunded_quota_key(&redis_key);
369
370                invocation.key(redis_key);
371                invocation.key(refund_key);
372
373                invocation.arg(quota.limit());
374                invocation.arg(quota.key_expiry());
375                invocation.arg(quota.quantity);
376                invocation.arg(over_accept_once);
377
378                tracked_quotas.push(quota);
379            } else {
380                // This quota is neither a static reject-all, nor can it be tracked in Redis due to
381                // missing fields. We're skipping this for forward-compatibility.
382                relay_log::with_scope(
383                    |scope| scope.set_extra("quota", value::to_value(quota).unwrap()),
384                    || relay_log::warn!("skipping unsupported quota"),
385                )
386            }
387        }
388
389        // Either there are no quotas to run against Redis, or we already have a rate limit from a
390        // zero-sized quota. In either cases, skip invoking the script and return early.
391        if tracked_quotas.is_empty() || rate_limits.is_limited() {
392            return Ok(rate_limits);
393        }
394
395        let mut connection = self.client.get_connection().await?;
396        let result: ScriptResult = invocation
397            .invoke_async(&mut connection)
398            .await
399            .map_err(RedisError::Redis)?;
400
401        for (quota, state) in tracked_quotas.iter().zip(result.0) {
402            if state.is_rejected {
403                // We can calculate the error by comparing how much the cache added to the
404                // quantity with remaining difference of the consumption and limit.
405                let cache_error = {
406                    let remaining = quota.limit().saturating_sub(state.consumed).max(0) as u64;
407                    let cache_quantity = quota.quantity.saturating_sub(quantity);
408
409                    cache_quantity.saturating_sub(remaining)
410                };
411                relay_statsd::metric!(
412                    counter(QuotaCounters::CacheError) += cache_error,
413                    category = item_scoping.category.name(),
414                );
415
416                let retry_after = self.retry_after((quota.expiry() - timestamp).as_secs());
417                rate_limits.add(RateLimit::from_quota(quota, *item_scoping, retry_after));
418            } else if let Some(cache) = &self.cache {
419                // Only update the cache if it's really necessary. Quotas which are being rejected,
420                // will not be able to be handled from the cache anyways.
421                cache.set_quota(quota.for_cache(), state.consumed);
422            }
423        }
424        drop(connection);
425
426        if let Some(cache) = &self.cache {
427            let vacuum_start = std::time::Instant::now();
428            if cache.try_vacuum(timestamp) {
429                relay_statsd::metric!(
430                    timer(QuotaTimers::CacheVacuumDuration) = vacuum_start.elapsed()
431                );
432            }
433        }
434
435        Ok(rate_limits)
436    }
437
438    /// Creates a [`RetryAfter`] value that is bounded by the configured [`max_limit`](Self::max_limit).
439    ///
440    /// If a maximum rate limit has been set, the returned value will not exceed that limit.
441    fn retry_after(&self, mut seconds: u64) -> RetryAfter {
442        if let Some(max_limit) = self.max_limit {
443            seconds = std::cmp::min(seconds, max_limit);
444        }
445
446        RetryAfter::from_secs(seconds)
447    }
448}
449
450/// The result returned from the rate limiting Redis script.
451#[derive(Debug)]
452struct ScriptResult(Vec<QuotaState>);
453
454impl FromRedisValue for ScriptResult {
455    fn from_redis_value(v: redis::Value) -> Result<Self, ParsingError> {
456        let seq = v.into_sequence().map_err(|v| {
457            format!("Expected a sequence from the rate limiting script (value was: {v:?})")
458        })?;
459
460        if !seq.len().is_multiple_of(2) {
461            return Err(format!(
462                "Expected an even number of values from the rate limiting script (value was: {seq:?})"
463            ).into());
464        }
465
466        let mut quotas = Vec::with_capacity(seq.len() / 2);
467        for (is_rejected, consumed) in seq.into_iter().tuples() {
468            quotas.push(QuotaState {
469                is_rejected: bool::from_redis_value(is_rejected)?,
470                consumed: i64::from_redis_value(consumed)?,
471            });
472        }
473
474        Ok(Self(quotas))
475    }
476}
477
478/// The state returned from the rate limiting script for a single quota.
479#[derive(Debug)]
480struct QuotaState {
481    /// Whether the quota rejects the request.
482    is_rejected: bool,
483    /// How much of the quota has already been consumed.
484    consumed: i64,
485}
486
487#[cfg(test)]
488mod tests {
489    use std::time::{SystemTime, UNIX_EPOCH};
490
491    use super::*;
492    use crate::MetricNamespaceScoping;
493    use crate::quota::{DataCategories, DataCategory, ReasonCode, Scoping};
494    use crate::rate_limit::RateLimitScope;
495    use relay_base_schema::metrics::MetricNamespace;
496    use relay_base_schema::organization::OrganizationId;
497    use relay_base_schema::project::{ProjectId, ProjectKey};
498    use relay_redis::RedisConfigOptions;
499    use relay_redis::redis::AsyncCommands;
500    use smallvec::smallvec;
501
502    fn build_rate_limiter() -> RedisRateLimiter {
503        let url = std::env::var("RELAY_REDIS_URL")
504            .unwrap_or_else(|_| "redis://127.0.0.1:6379".to_owned());
505        let client =
506            AsyncRedisClient::single("test", &url, &RedisConfigOptions::default()).unwrap();
507
508        RedisRateLimiter {
509            client,
510            cache: None,
511            script: RedisScripts::load_is_rate_limited(),
512            max_limit: None,
513        }
514    }
515
516    #[tokio::test]
517    async fn test_zero_size_quotas() {
518        let quotas = &[
519            Quota {
520                id: None,
521                categories: DataCategories::new(),
522                scope: QuotaScope::Organization,
523                scope_id: None,
524                limit: Some(0),
525                window: None,
526                reason_code: Some(ReasonCode::new("get_lost")),
527                namespace: None,
528            },
529            Quota {
530                id: Some("42".into()),
531                categories: DataCategories::new(),
532                scope: QuotaScope::Organization,
533                scope_id: None,
534                limit: None,
535                window: Some(42),
536                reason_code: Some(ReasonCode::new("unlimited")),
537                namespace: None,
538            },
539        ];
540
541        let scoping = ItemScoping {
542            category: DataCategory::Error,
543            scoping: Scoping {
544                organization_id: OrganizationId::new(42),
545                project_id: ProjectId::new(43),
546                project_key: ProjectKey::parse("a94ae32be2584e0bbd7a4cbb95971fee").unwrap(),
547                key_id: Some(44),
548            },
549            namespace: MetricNamespaceScoping::None,
550        };
551
552        let rate_limits: Vec<RateLimit> = build_rate_limiter()
553            .is_rate_limited(quotas, scoping, 1, false)
554            .await
555            .expect("rate limiting failed")
556            .into_iter()
557            .collect();
558
559        assert_eq!(
560            rate_limits,
561            vec![RateLimit {
562                categories: DataCategories::new(),
563                scope: RateLimitScope::Organization(OrganizationId::new(42)),
564                reason_code: Some(ReasonCode::new("get_lost")),
565                retry_after: rate_limits[0].retry_after,
566                namespaces: smallvec![],
567            }]
568        );
569    }
570
571    /// Tests that a quota with and without namespace are counted separately.
572    #[tokio::test]
573    async fn test_namespace_quota() {
574        let quota_limit = 5;
575        let get_quota = |namespace: Option<MetricNamespace>| -> Quota {
576            Quota {
577                id: Some(format!("test_simple_quota_{}", uuid::Uuid::new_v4()).into()),
578                categories: DataCategories::new(),
579                scope: QuotaScope::Organization,
580                scope_id: None,
581                limit: Some(quota_limit),
582                window: Some(600),
583                reason_code: Some(ReasonCode::new(format!("ns: {namespace:?}"))),
584                namespace,
585            }
586        };
587
588        let quotas = &[get_quota(None)];
589        let quota_with_namespace = &[get_quota(Some(MetricNamespace::Transactions))];
590
591        let scoping = ItemScoping {
592            category: DataCategory::Error,
593            scoping: Scoping {
594                organization_id: OrganizationId::new(42),
595                project_id: ProjectId::new(43),
596                project_key: ProjectKey::parse("a94ae32be2584e0bbd7a4cbb95971fee").unwrap(),
597                key_id: Some(44),
598            },
599            namespace: MetricNamespaceScoping::Some(MetricNamespace::Transactions),
600        };
601
602        let rate_limiter = build_rate_limiter();
603
604        // First confirm normal behaviour without namespace.
605        for i in 0..10 {
606            let rate_limits: Vec<RateLimit> = rate_limiter
607                .is_rate_limited(quotas, scoping, 1, false)
608                .await
609                .expect("rate limiting failed")
610                .into_iter()
611                .collect();
612
613            if i < quota_limit {
614                assert_eq!(rate_limits, vec![]);
615            } else {
616                assert_eq!(
617                    rate_limits[0].reason_code,
618                    Some(ReasonCode::new("ns: None"))
619                );
620            }
621        }
622
623        // Then, send identical quota with namespace and confirm it counts separately.
624        for i in 0..10 {
625            let rate_limits: Vec<RateLimit> = rate_limiter
626                .is_rate_limited(quota_with_namespace, scoping, 1, false)
627                .await
628                .expect("rate limiting failed")
629                .into_iter()
630                .collect();
631
632            if i < quota_limit {
633                assert_eq!(rate_limits, vec![]);
634            } else {
635                assert_eq!(
636                    rate_limits[0].reason_code,
637                    Some(ReasonCode::new("ns: Some(Transactions)"))
638                );
639            }
640        }
641    }
642
643    #[tokio::test]
644    async fn test_simple_quota() {
645        let quotas = &[Quota {
646            id: Some(format!("test_simple_quota_{}", uuid::Uuid::new_v4()).into()),
647            categories: DataCategories::new(),
648            scope: QuotaScope::Organization,
649            scope_id: None,
650            limit: Some(5),
651            window: Some(60),
652            reason_code: Some(ReasonCode::new("get_lost")),
653            namespace: None,
654        }];
655
656        let scoping = ItemScoping {
657            category: DataCategory::Error,
658            scoping: Scoping {
659                organization_id: OrganizationId::new(42),
660                project_id: ProjectId::new(43),
661                project_key: ProjectKey::parse("a94ae32be2584e0bbd7a4cbb95971fee").unwrap(),
662                key_id: Some(44),
663            },
664            namespace: MetricNamespaceScoping::None,
665        };
666
667        let rate_limiter = build_rate_limiter();
668
669        for i in 0..10 {
670            let rate_limits: Vec<RateLimit> = rate_limiter
671                .is_rate_limited(quotas, scoping, 1, false)
672                .await
673                .expect("rate limiting failed")
674                .into_iter()
675                .collect();
676
677            if i >= 5 {
678                assert_eq!(
679                    rate_limits,
680                    vec![RateLimit {
681                        categories: DataCategories::new(),
682                        scope: RateLimitScope::Organization(OrganizationId::new(42)),
683                        reason_code: Some(ReasonCode::new("get_lost")),
684                        retry_after: rate_limits[0].retry_after,
685                        namespaces: smallvec![],
686                    }]
687                );
688            } else {
689                assert_eq!(rate_limits, vec![]);
690            }
691        }
692    }
693
694    #[tokio::test]
695    async fn test_quantity_0() {
696        let quotas = &[Quota {
697            id: Some(format!("test_quantity_0_{}", uuid::Uuid::new_v4()).into()),
698            categories: DataCategories::new(),
699            scope: QuotaScope::Organization,
700            scope_id: None,
701            limit: Some(1),
702            window: Some(60),
703            reason_code: Some(ReasonCode::new("get_lost")),
704            namespace: None,
705        }];
706
707        let scoping = ItemScoping {
708            category: DataCategory::Error,
709            scoping: Scoping {
710                organization_id: OrganizationId::new(42),
711                project_id: ProjectId::new(43),
712                project_key: ProjectKey::parse("a94ae32be2584e0bbd7a4cbb95971fee").unwrap(),
713                key_id: Some(44),
714            },
715            namespace: MetricNamespaceScoping::None,
716        };
717
718        let rate_limiter = build_rate_limiter();
719
720        // limit is 1, so first call not rate limited
721        assert!(
722            !rate_limiter
723                .is_rate_limited(quotas, scoping, 1, false)
724                .await
725                .unwrap()
726                .is_limited()
727        );
728
729        // quota is now exhausted
730        assert!(
731            rate_limiter
732                .is_rate_limited(quotas, scoping, 1, false)
733                .await
734                .unwrap()
735                .is_limited()
736        );
737
738        // quota is exhausted, regardless of the quantity
739        assert!(
740            rate_limiter
741                .is_rate_limited(quotas, scoping, 0, false)
742                .await
743                .unwrap()
744                .is_limited()
745        );
746
747        // quota is exhausted, regardless of the quantity
748        assert!(
749            rate_limiter
750                .is_rate_limited(quotas, scoping, 1, false)
751                .await
752                .unwrap()
753                .is_limited()
754        );
755    }
756
757    #[tokio::test]
758    async fn test_quota_go_over() {
759        let quotas = &[Quota {
760            id: Some(format!("test_quota_go_over{}", uuid::Uuid::new_v4()).into()),
761            categories: DataCategories::new(),
762            scope: QuotaScope::Organization,
763            scope_id: None,
764            limit: Some(2),
765            window: Some(60),
766            reason_code: Some(ReasonCode::new("get_lost")),
767            namespace: None,
768        }];
769
770        let scoping = ItemScoping {
771            category: DataCategory::Error,
772            scoping: Scoping {
773                organization_id: OrganizationId::new(42),
774                project_id: ProjectId::new(43),
775                project_key: ProjectKey::parse("a94ae32be2584e0bbd7a4cbb95971fee").unwrap(),
776                key_id: Some(44),
777            },
778            namespace: MetricNamespaceScoping::None,
779        };
780
781        let rate_limiter = build_rate_limiter();
782
783        // limit is 2, so first call not rate limited
784        let is_limited = rate_limiter
785            .is_rate_limited(quotas, scoping, 1, true)
786            .await
787            .unwrap()
788            .is_limited();
789        assert!(!is_limited);
790
791        // go over limit, but first call is over-accepted
792        let is_limited = rate_limiter
793            .is_rate_limited(quotas, scoping, 2, true)
794            .await
795            .unwrap()
796            .is_limited();
797        assert!(!is_limited);
798
799        // quota is exhausted, regardless of the quantity
800        let is_limited = rate_limiter
801            .is_rate_limited(quotas, scoping, 0, true)
802            .await
803            .unwrap()
804            .is_limited();
805        assert!(is_limited);
806
807        // quota is exhausted, regardless of the quantity
808        let is_limited = rate_limiter
809            .is_rate_limited(quotas, scoping, 1, true)
810            .await
811            .unwrap()
812            .is_limited();
813        assert!(is_limited);
814    }
815
816    #[tokio::test]
817    async fn test_bails_immediately_without_any_quota() {
818        let scoping = ItemScoping {
819            category: DataCategory::Error,
820            scoping: Scoping {
821                organization_id: OrganizationId::new(42),
822                project_id: ProjectId::new(43),
823                project_key: ProjectKey::parse("a94ae32be2584e0bbd7a4cbb95971fee").unwrap(),
824                key_id: Some(44),
825            },
826            namespace: MetricNamespaceScoping::None,
827        };
828
829        let rate_limits: Vec<RateLimit> = build_rate_limiter()
830            .is_rate_limited(&[], scoping, 1, false)
831            .await
832            .expect("rate limiting failed")
833            .into_iter()
834            .collect();
835
836        assert_eq!(rate_limits, vec![]);
837    }
838
839    #[tokio::test]
840    async fn test_limited_with_unlimited_quota() {
841        let quotas = &[
842            Quota {
843                id: Some("q0".into()),
844                categories: DataCategories::new(),
845                scope: QuotaScope::Organization,
846                scope_id: None,
847                limit: None,
848                window: Some(1),
849                reason_code: Some(ReasonCode::new("project_quota0")),
850                namespace: None,
851            },
852            Quota {
853                id: Some("q1".into()),
854                categories: DataCategories::new(),
855                scope: QuotaScope::Organization,
856                scope_id: None,
857                limit: Some(1),
858                window: Some(1),
859                reason_code: Some(ReasonCode::new("project_quota1")),
860                namespace: None,
861            },
862        ];
863
864        let scoping = ItemScoping {
865            category: DataCategory::Error,
866            scoping: Scoping {
867                organization_id: OrganizationId::new(42),
868                project_id: ProjectId::new(43),
869                project_key: ProjectKey::parse("a94ae32be2584e0bbd7a4cbb95971fee").unwrap(),
870                key_id: Some(44),
871            },
872            namespace: MetricNamespaceScoping::None,
873        };
874
875        let rate_limiter = build_rate_limiter();
876
877        for i in 0..1 {
878            let rate_limits: Vec<RateLimit> = rate_limiter
879                .is_rate_limited(quotas, scoping, 1, false)
880                .await
881                .expect("rate limiting failed")
882                .into_iter()
883                .collect();
884
885            if i == 0 {
886                assert_eq!(rate_limits, &[]);
887            } else {
888                assert_eq!(
889                    rate_limits,
890                    vec![RateLimit {
891                        categories: DataCategories::new(),
892                        scope: RateLimitScope::Organization(OrganizationId::new(42)),
893                        reason_code: Some(ReasonCode::new("project_quota1")),
894                        retry_after: rate_limits[0].retry_after,
895                        namespaces: smallvec![],
896                    }]
897                );
898            }
899        }
900    }
901
902    #[tokio::test]
903    async fn test_quota_with_quantity() {
904        let quotas = &[Quota {
905            id: Some(format!("test_quantity_quota_{}", uuid::Uuid::new_v4()).into()),
906            categories: DataCategories::new(),
907            scope: QuotaScope::Organization,
908            scope_id: None,
909            limit: Some(500),
910            window: Some(60),
911            reason_code: Some(ReasonCode::new("get_lost")),
912            namespace: None,
913        }];
914
915        let scoping = ItemScoping {
916            category: DataCategory::Error,
917            scoping: Scoping {
918                organization_id: OrganizationId::new(42),
919                project_id: ProjectId::new(43),
920                project_key: ProjectKey::parse("a94ae32be2584e0bbd7a4cbb95971fee").unwrap(),
921                key_id: Some(44),
922            },
923            namespace: MetricNamespaceScoping::None,
924        };
925
926        let rate_limiter = build_rate_limiter();
927
928        for i in 0..10 {
929            let rate_limits: Vec<RateLimit> = rate_limiter
930                .is_rate_limited(quotas, scoping, 100, false)
931                .await
932                .expect("rate limiting failed")
933                .into_iter()
934                .collect();
935
936            if i >= 5 {
937                assert_eq!(
938                    rate_limits,
939                    vec![RateLimit {
940                        categories: DataCategories::new(),
941                        scope: RateLimitScope::Organization(OrganizationId::new(42)),
942                        reason_code: Some(ReasonCode::new("get_lost")),
943                        retry_after: rate_limits[0].retry_after,
944                        namespaces: smallvec![],
945                    }]
946                );
947            } else {
948                assert_eq!(rate_limits, vec![]);
949            }
950        }
951    }
952
953    #[tokio::test]
954    async fn test_get_redis_key_scoped() {
955        let quota = Quota {
956            id: Some("foo".into()),
957            categories: DataCategories::new(),
958            scope: QuotaScope::Project,
959            scope_id: Some("42".into()),
960            window: Some(2),
961            limit: Some(0),
962            reason_code: None,
963            namespace: None,
964        };
965
966        let scoping = ItemScoping {
967            category: DataCategory::Error,
968            scoping: Scoping {
969                organization_id: OrganizationId::new(69420),
970                project_id: ProjectId::new(42),
971                project_key: ProjectKey::parse("a94ae32be2584e0bbd7a4cbb95971fee").unwrap(),
972                key_id: Some(4711),
973            },
974            namespace: MetricNamespaceScoping::None,
975        };
976
977        let timestamp = UnixTimestamp::from_secs(123_123_123);
978        let redis_quota = RedisQuota::new(&quota, 0, scoping, timestamp).unwrap();
979        assert_eq!(redis_quota.key().to_string(), "quota:foo{69420}42:61561561");
980    }
981
982    #[tokio::test]
983    async fn test_get_redis_key_unscoped() {
984        let quota = Quota {
985            id: Some("foo".into()),
986            categories: DataCategories::new(),
987            scope: QuotaScope::Organization,
988            scope_id: None,
989            window: Some(10),
990            limit: Some(0),
991            reason_code: None,
992            namespace: None,
993        };
994
995        let scoping = ItemScoping {
996            category: DataCategory::Error,
997            scoping: Scoping {
998                organization_id: OrganizationId::new(69420),
999                project_id: ProjectId::new(42),
1000                project_key: ProjectKey::parse("a94ae32be2584e0bbd7a4cbb95971fee").unwrap(),
1001                key_id: Some(4711),
1002            },
1003            namespace: MetricNamespaceScoping::None,
1004        };
1005
1006        let timestamp = UnixTimestamp::from_secs(234_531);
1007        let redis_quota = RedisQuota::new(&quota, 0, scoping, timestamp).unwrap();
1008        assert_eq!(redis_quota.key().to_string(), "quota:foo{69420}:23453");
1009    }
1010
1011    #[tokio::test]
1012    async fn test_large_redis_limit_large() {
1013        let quota = Quota {
1014            id: Some("foo".into()),
1015            categories: DataCategories::new(),
1016            scope: QuotaScope::Organization,
1017            scope_id: None,
1018            window: Some(10),
1019            limit: Some(9223372036854775808), // i64::MAX + 1
1020            reason_code: None,
1021            namespace: None,
1022        };
1023
1024        let scoping = ItemScoping {
1025            category: DataCategory::Error,
1026            scoping: Scoping {
1027                organization_id: OrganizationId::new(69420),
1028                project_id: ProjectId::new(42),
1029                project_key: ProjectKey::parse("a94ae32be2584e0bbd7a4cbb95971fee").unwrap(),
1030                key_id: Some(4711),
1031            },
1032            namespace: MetricNamespaceScoping::None,
1033        };
1034
1035        let timestamp = UnixTimestamp::from_secs(234_531);
1036        let redis_quota = RedisQuota::new(&quota, 0, scoping, timestamp).unwrap();
1037        assert_eq!(redis_quota.limit(), -1);
1038    }
1039
1040    #[tokio::test]
1041    async fn test_is_rate_limited_script() {
1042        let now = SystemTime::now()
1043            .duration_since(UNIX_EPOCH)
1044            .map(|duration| duration.as_secs())
1045            .unwrap();
1046
1047        let rate_limiter = build_rate_limiter();
1048        let mut conn = rate_limiter.client.get_connection().await.unwrap();
1049
1050        // define a few keys with random seed such that they do not collide with repeated test runs
1051        let foo = format!("foo___{now}");
1052        let r_foo = format!("r:foo___{now}");
1053        let bar = format!("bar___{now}");
1054        let r_bar = format!("r:bar___{now}");
1055        let apple = format!("apple___{now}");
1056        let orange = format!("orange___{now}");
1057        let baz = format!("baz___{now}");
1058
1059        let script = RedisScripts::load_is_rate_limited();
1060
1061        macro_rules! assert_invocation {
1062            ($invocation:expr, $($tt:tt)*) => {{
1063                let result = $invocation
1064                    .invoke_async::<ScriptResult>(&mut conn)
1065                    .await
1066                    .unwrap();
1067
1068                insta::assert_debug_snapshot!(result, $($tt)*);
1069            }};
1070        }
1071
1072        let mut invocation = script.prepare_invoke();
1073        invocation
1074            .key(&foo) // key
1075            .key(&r_foo) // refund key
1076            .key(&bar) // key
1077            .key(&r_bar) // refund key
1078            .arg(1) // limit
1079            .arg(now + 60) // expiry
1080            .arg(1) // quantity
1081            .arg(false) // over accept once
1082            .arg(2) // limit
1083            .arg(now + 120) // expiry
1084            .arg(1) // quantity
1085            .arg(false); // over accept once
1086
1087        // Craft a new invocation similar to the previous one, but it only applies to the quota
1088        // with a higher limit (2).
1089        let mut invocation2 = script.prepare_invoke();
1090        invocation2
1091            .key(&bar) // key
1092            .key(&r_bar) // refund key
1093            .arg(2) // limit
1094            .arg(now + 120) // expiry
1095            .arg(1) // quantity
1096            .arg(false); // over accept once
1097
1098        // 1 quantity used from both quotas.
1099        assert_invocation!(invocation, @r"
1100        ScriptResult(
1101            [
1102                QuotaState {
1103                    is_rejected: false,
1104                    consumed: 1,
1105                },
1106                QuotaState {
1107                    is_rejected: false,
1108                    consumed: 1,
1109                },
1110            ],
1111        )
1112        "
1113        );
1114
1115        // This invocation fails the rate limit on the first quota.
1116        // -> No changes are made to the counters.
1117        assert_invocation!(invocation, @r"
1118        ScriptResult(
1119            [
1120                QuotaState {
1121                    is_rejected: true,
1122                    consumed: 1,
1123                },
1124                QuotaState {
1125                    is_rejected: false,
1126                    consumed: 1,
1127                },
1128            ],
1129        )
1130        "
1131        );
1132
1133        // Another call, same result as before, just making sure there were no changes applied.
1134        assert_invocation!(invocation, @r"
1135        ScriptResult(
1136            [
1137                QuotaState {
1138                    is_rejected: true,
1139                    consumed: 1,
1140                },
1141                QuotaState {
1142                    is_rejected: false,
1143                    consumed: 1,
1144                },
1145            ],
1146        )
1147        "
1148        );
1149
1150        // Using the second invocation which only considers a quota with a higher limit, usage for
1151        // that quota is now 2.
1152        assert_invocation!(invocation2, @r"
1153        ScriptResult(
1154            [
1155                QuotaState {
1156                    is_rejected: false,
1157                    consumed: 2,
1158                },
1159            ],
1160        )
1161        "
1162        );
1163
1164        // Same invocation, but this time the limit is reached, quota should not increase.
1165        assert_invocation!(invocation2, @r"
1166        ScriptResult(
1167            [
1168                QuotaState {
1169                    is_rejected: true,
1170                    consumed: 2,
1171                },
1172            ],
1173        )
1174        "
1175        );
1176
1177        // Check again with the original invocation, this now yields `[1, 2]`.
1178        assert_invocation!(invocation, @r"
1179        ScriptResult(
1180            [
1181                QuotaState {
1182                    is_rejected: true,
1183                    consumed: 1,
1184                },
1185                QuotaState {
1186                    is_rejected: true,
1187                    consumed: 2,
1188                },
1189            ],
1190        )
1191        "
1192        );
1193
1194        assert_eq!(conn.get::<_, String>(&foo).await.unwrap(), "1");
1195        let ttl: u64 = conn.ttl(&foo).await.unwrap();
1196        assert!(ttl >= 59);
1197        assert!(ttl <= 60);
1198
1199        assert_eq!(conn.get::<_, String>(&bar).await.unwrap(), "2");
1200        let ttl: u64 = conn.ttl(&bar).await.unwrap();
1201        assert!(ttl >= 119);
1202        assert!(ttl <= 120);
1203
1204        // make sure "refund/negative" keys haven't been incremented
1205        let () = conn.get(r_foo).await.unwrap();
1206        let () = conn.get(r_bar).await.unwrap();
1207
1208        // Test that refunded quotas work
1209        let () = conn.set(&apple, 5).await.unwrap();
1210
1211        let mut invocation = script.prepare_invoke();
1212        invocation
1213            .key(&orange) // key
1214            .key(&baz) // refund key
1215            .arg(1) // limit
1216            .arg(now + 60) // expiry
1217            .arg(1) // quantity
1218            .arg(false);
1219
1220        // increment, current quota usage is 1.
1221        assert_invocation!(invocation, @r"
1222        ScriptResult(
1223            [
1224                QuotaState {
1225                    is_rejected: false,
1226                    consumed: 1,
1227                },
1228            ],
1229        )
1230        "
1231        );
1232
1233        // test that it's rate limited without refund.
1234        assert_invocation!(invocation, @r"
1235        ScriptResult(
1236            [
1237                QuotaState {
1238                    is_rejected: true,
1239                    consumed: 1,
1240                },
1241            ],
1242        )
1243        "
1244        );
1245
1246        // Make sure, the counter wasn't incremented.
1247        assert_invocation!(invocation, @r"
1248        ScriptResult(
1249            [
1250                QuotaState {
1251                    is_rejected: true,
1252                    consumed: 1,
1253                },
1254            ],
1255        )
1256        "
1257        );
1258
1259        let mut invocation = script.prepare_invoke();
1260        invocation
1261            .key(&orange) // key
1262            .key(&apple) // refund key
1263            .arg(1) // limit
1264            .arg(now + 60) // expiry
1265            .arg(1) // quantity
1266            .arg(false);
1267
1268        // test that refund key is used
1269        assert_invocation!(invocation, @r"
1270        ScriptResult(
1271            [
1272                QuotaState {
1273                    is_rejected: false,
1274                    consumed: -3,
1275                },
1276            ],
1277        )
1278        "
1279        );
1280    }
1281
1282    /// Usual rate limiting with a cache should just work as expected.
1283    #[tokio::test]
1284    async fn test_quota_with_cache() {
1285        let quotas = &[Quota {
1286            id: Some(format!("test_simple_quota_{}", uuid::Uuid::new_v4()).into()),
1287            categories: DataCategories::new(),
1288            scope: QuotaScope::Organization,
1289            scope_id: None,
1290            limit: Some(50),
1291            window: Some(60),
1292            reason_code: Some(ReasonCode::new("get_lost")),
1293            namespace: None,
1294        }];
1295
1296        let scoping = ItemScoping {
1297            category: DataCategory::Error,
1298            scoping: Scoping {
1299                organization_id: OrganizationId::new(42),
1300                project_id: ProjectId::new(43),
1301                project_key: ProjectKey::parse("a94ae32be2584e0bbd7a4cbb95971fee").unwrap(),
1302                key_id: Some(44),
1303            },
1304            namespace: MetricNamespaceScoping::None,
1305        };
1306
1307        // For this test, with only a single rate limiter accessing Redis and always a quantity of
1308        // `1`, the parameters for the cache should not make any difference.
1309        let rate_limiter = build_rate_limiter().cache(Some(0.1), Some(0.9));
1310
1311        for _ in 0..50 {
1312            let rate_limits = rate_limiter
1313                .is_rate_limited(quotas, scoping, 1, false)
1314                .await
1315                .unwrap();
1316
1317            assert!(rate_limits.is_empty());
1318        }
1319
1320        let rate_limits: Vec<RateLimit> = rate_limiter
1321            .is_rate_limited(quotas, scoping, 1, false)
1322            .await
1323            .expect("rate limiting failed")
1324            .into_iter()
1325            .collect();
1326
1327        assert_eq!(
1328            rate_limits,
1329            vec![RateLimit {
1330                categories: DataCategories::new(),
1331                scope: RateLimitScope::Organization(OrganizationId::new(42)),
1332                reason_code: Some(ReasonCode::new("get_lost")),
1333                retry_after: rate_limits[0].retry_after,
1334                namespaces: smallvec![],
1335            }]
1336        );
1337    }
1338
1339    #[tokio::test]
1340    async fn test_quota_with_cache_slightly_over_account() {
1341        let window = 60;
1342        let limit = 50 * window;
1343
1344        let quotas = &[Quota {
1345            id: Some(format!("test_simple_quota_{}", uuid::Uuid::new_v4()).into()),
1346            categories: DataCategories::new(),
1347            scope: QuotaScope::Organization,
1348            scope_id: None,
1349            limit: Some(limit),
1350            window: Some(window),
1351            reason_code: Some(ReasonCode::new("get_lost")),
1352            namespace: None,
1353        }];
1354
1355        let scoping = ItemScoping {
1356            category: DataCategory::Error,
1357            scoping: Scoping {
1358                organization_id: OrganizationId::new(42),
1359                project_id: ProjectId::new(43),
1360                project_key: ProjectKey::parse("a94ae32be2584e0bbd7a4cbb95971fee").unwrap(),
1361                key_id: Some(44),
1362            },
1363            namespace: MetricNamespaceScoping::None,
1364        };
1365
1366        // 10% Quota cache.
1367        let rate_limiter1 = build_rate_limiter().cache(Some(0.1), None);
1368        let rate_limiter2 = build_rate_limiter().cache(Some(0.1), None);
1369
1370        // Prime the cache.
1371        let rate_limits = rate_limiter1
1372            .is_rate_limited(quotas, scoping, 1, false)
1373            .await
1374            .unwrap();
1375        assert!(rate_limits.is_empty());
1376        // Reserve 3 out 5 in the cache.
1377        let rate_limits = rate_limiter1
1378            .is_rate_limited(quotas, scoping, 3, false)
1379            .await
1380            .unwrap();
1381        assert!(rate_limits.is_empty());
1382
1383        // Consume right up to the limit on the other limiter
1384        let rate_limits = rate_limiter2
1385            .is_rate_limited(quotas, scoping, limit as usize - 1, false)
1386            .await
1387            .unwrap();
1388        assert!(rate_limits.is_empty());
1389
1390        // There is still one more slot in the cache.
1391        let rate_limits = rate_limiter1
1392            .is_rate_limited(quotas, scoping, 1, false)
1393            .await
1394            .unwrap();
1395        assert!(rate_limits.is_empty());
1396
1397        // This should now rate limit, as the cache is exhausted and Redis is checked.
1398        let rate_limits: Vec<RateLimit> = rate_limiter1
1399            .is_rate_limited(quotas, scoping, 1, false)
1400            .await
1401            .unwrap()
1402            .into_iter()
1403            .collect();
1404
1405        assert_eq!(
1406            rate_limits,
1407            vec![RateLimit {
1408                categories: DataCategories::new(),
1409                scope: RateLimitScope::Organization(OrganizationId::new(42)),
1410                reason_code: Some(ReasonCode::new("get_lost")),
1411                retry_after: rate_limits[0].retry_after,
1412                namespaces: smallvec![],
1413            }]
1414        );
1415    }
1416}