1use std::future::Future;
2
3use itertools::Itertools;
4use relay_base_schema::metrics::MetricNamespace;
5use relay_redis::{AsyncRedisClient, AsyncRedisConnection, RedisError, RedisScripts};
6
7use crate::RateLimitingError;
8use crate::redis::RedisQuota;
9
10const DEFAULT_BUDGET_RATIO: f32 = 0.001;
12
13pub trait GlobalLimiter {
15 fn check_global_rate_limits<'a>(
17 &self,
18 global_quotas: &'a [RedisQuota<'a>],
19 quantity: usize,
20 ) -> impl Future<Output = Result<Vec<&'a RedisQuota<'a>>, RateLimitingError>> + Send;
21}
22
23#[derive(Debug, Default)]
28pub struct GlobalRateLimiter {
29 limits: hashbrown::HashMap<Key, GlobalRateLimit>,
30}
31
32impl GlobalRateLimiter {
33 pub async fn filter_rate_limited<'a>(
41 &mut self,
42 client: &AsyncRedisClient,
43 quotas: &'a [RedisQuota<'a>],
44 quantity: usize,
45 ) -> Result<Vec<&'a RedisQuota<'a>>, RateLimitingError> {
46 let mut connection = client.get_connection().await?;
47
48 let mut rate_limited = vec![];
49 let mut not_rate_limited = vec![];
50
51 let min_by_keyref = quotas
52 .iter()
53 .into_grouping_map_by(|q| KeyRef::new(q))
54 .min_by_key(|_, q| q.limit());
55
56 for (key, quota) in min_by_keyref {
57 let global_rate_limit = self.limits.entry_ref(&key).or_default();
58
59 if global_rate_limit
60 .check_rate_limited(&mut connection, quota, key, quantity as u64)
61 .await?
62 {
63 rate_limited.push(quota);
64 } else {
65 not_rate_limited.push(quota);
66 }
67 }
68
69 if rate_limited.is_empty() {
70 for quota in not_rate_limited {
71 if let Some(val) = self.limits.get_mut(&KeyRef::new(quota)) {
72 val.budget -= quantity as u64;
73 }
74 }
75 }
76
77 Ok(rate_limited)
78 }
79}
80
81#[derive(Clone, Copy, Hash, Debug, Eq, PartialEq, Ord, PartialOrd)]
86struct KeyRef<'a> {
87 prefix: &'a str,
88 window: u64,
89 namespace: Option<MetricNamespace>,
90}
91
92impl<'a> KeyRef<'a> {
93 fn new(quota: &'a RedisQuota<'a>) -> Self {
94 Self {
95 prefix: quota.prefix(),
96 window: quota.window(),
97 namespace: quota.namespace,
98 }
99 }
100
101 fn redis_key(&self, slot: u64) -> RedisKey {
102 RedisKey::new(self, slot)
103 }
104}
105
106impl hashbrown::Equivalent<Key> for KeyRef<'_> {
107 fn equivalent(&self, key: &Key) -> bool {
108 let Key {
109 prefix,
110 window,
111 namespace,
112 } = key;
113
114 self.prefix == prefix && self.window == *window && self.namespace == *namespace
115 }
116}
117
118#[derive(Debug, Clone, PartialEq, Eq, Hash)]
123struct Key {
124 prefix: String,
125 window: u64,
126 namespace: Option<MetricNamespace>,
127}
128
129impl From<&KeyRef<'_>> for Key {
130 fn from(value: &KeyRef<'_>) -> Self {
131 Key {
132 prefix: value.prefix.to_owned(),
133 window: value.window,
134 namespace: value.namespace,
135 }
136 }
137}
138
139#[derive(Debug)]
143struct RedisKey(String);
144
145impl RedisKey {
146 fn new(key: &KeyRef<'_>, slot: u64) -> Self {
147 Self(format!(
148 "global_quota:{id}{window}{namespace:?}:{slot}",
149 id = key.prefix,
150 window = key.window,
151 namespace = key.namespace,
152 slot = slot,
153 ))
154 }
155}
156
157#[derive(Debug)]
166struct GlobalRateLimit {
167 budget: u64,
168 last_seen_redis_value: u64,
169 slot: u64,
170}
171
172impl GlobalRateLimit {
173 fn new() -> Self {
175 Self {
176 budget: 0,
177 last_seen_redis_value: 0,
178 slot: 0,
179 }
180 }
181
182 pub async fn check_rate_limited(
188 &mut self,
189 connection: &mut AsyncRedisConnection,
190 quota: &RedisQuota<'_>,
191 key: KeyRef<'_>,
192 quantity: u64,
193 ) -> Result<bool, RateLimitingError> {
194 let quota_slot = quota.slot();
195
196 if quota_slot > self.slot || quota_slot + 1 < self.slot {
205 self.budget = 0;
206 self.last_seen_redis_value = 0;
207 self.slot = quota_slot;
208 }
209
210 if self.budget >= quantity {
211 return Ok(false);
212 }
213
214 let redis_key = key.redis_key(quota_slot);
215 let reserved = self
216 .try_reserve(connection, quantity, quota, redis_key)
217 .await
218 .map_err(RateLimitingError::Redis)?;
219 self.budget += reserved;
220
221 Ok(self.budget < quantity)
222 }
223
224 async fn try_reserve(
229 &mut self,
230 connection: &mut AsyncRedisConnection,
231 quantity: u64,
232 quota: &RedisQuota<'_>,
233 redis_key: RedisKey,
234 ) -> Result<u64, RedisError> {
235 let min_required_budget = quantity.saturating_sub(self.budget);
236 let max_available_budget = quota
237 .limit
238 .unwrap_or(u64::MAX)
239 .saturating_sub(self.last_seen_redis_value);
240
241 if min_required_budget > max_available_budget {
242 return Ok(0);
243 }
244
245 let budget_to_reserve = min_required_budget.max(self.default_request_size(quantity, quota));
246
247 let (budget, value): (u64, u64) = RedisScripts::load_global_quota()
248 .prepare_invoke()
249 .key(redis_key.0)
250 .arg(budget_to_reserve)
251 .arg(quota.limit())
252 .arg(quota.key_expiry())
253 .invoke_async(connection)
254 .await
255 .map_err(RedisError::Redis)?;
256
257 self.last_seen_redis_value = value;
258
259 Ok(budget)
260 }
261
262 fn default_request_size(&self, quantity: u64, quota: &RedisQuota) -> u64 {
268 match quota.limit {
269 Some(limit) => (limit as f32 * DEFAULT_BUDGET_RATIO) as u64,
270 None => (quantity as f32 / DEFAULT_BUDGET_RATIO) as u64,
272 }
273 }
274}
275
276impl Default for GlobalRateLimit {
277 fn default() -> Self {
278 Self::new()
279 }
280}
281
282#[cfg(test)]
283mod tests {
284 use std::collections::BTreeSet;
285 use std::time::Duration;
286
287 use relay_base_schema::data_category::DataCategory;
288 use relay_base_schema::organization::OrganizationId;
289 use relay_base_schema::project::{ProjectId, ProjectKey};
290 use relay_common::time::UnixTimestamp;
291 use relay_redis::{AsyncRedisClient, RedisConfigOptions};
292
293 use super::*;
294 use crate::{DataCategories, Quota, QuotaScope, Scoping};
295
296 fn build_redis_client() -> AsyncRedisClient {
297 let url = std::env::var("RELAY_REDIS_URL")
298 .unwrap_or_else(|_| "redis://127.0.0.1:6379".to_owned());
299
300 AsyncRedisClient::single(&url, &RedisConfigOptions::default()).unwrap()
301 }
302
303 fn build_quota(window: u64, limit: impl Into<Option<u64>>) -> Quota {
304 Quota {
305 id: Some(uuid::Uuid::new_v4().to_string()),
306 categories: DataCategories::new(),
307 scope: QuotaScope::Global,
308 scope_id: None,
309 window: Some(window),
310 limit: limit.into(),
311 reason_code: None,
312 namespace: None,
313 }
314 }
315
316 fn build_scoping() -> Scoping {
317 Scoping {
318 organization_id: OrganizationId::new(69420),
319 project_id: ProjectId::new(42),
320 project_key: ProjectKey::parse("a94ae32be2584e0bbd7a4cbb95971fee").unwrap(),
321 key_id: Some(4711),
322 }
323 }
324
325 fn build_redis_quota<'a>(quota: &'a Quota, scoping: &'a Scoping) -> RedisQuota<'a> {
326 let scoping = scoping.item(DataCategory::MetricBucket);
327 RedisQuota::new(quota, scoping, UnixTimestamp::now()).unwrap()
328 }
329
330 #[tokio::test]
331 async fn test_multiple_rate_limits() {
332 let scoping = build_scoping();
333
334 let quota1 = build_quota(10, 100);
335 let quota2 = build_quota(10, 150);
336 let quota3 = build_quota(10, 200);
337 let quantity = 175;
338
339 let redis_quotas = [
340 build_redis_quota("a1, &scoping),
341 build_redis_quota("a2, &scoping),
342 build_redis_quota("a3, &scoping),
343 ];
344
345 let client = build_redis_client();
346 let mut counter = GlobalRateLimiter::default();
347
348 let rate_limited_quotas = counter
349 .filter_rate_limited(&client, &redis_quotas, quantity)
350 .await
351 .unwrap();
352
353 assert_eq!(
355 BTreeSet::from([100, 150]),
356 rate_limited_quotas
357 .iter()
358 .map(|quota| quota.limit())
359 .collect()
360 );
361 }
362
363 #[tokio::test]
366 async fn test_use_smaller_limit() {
367 let smaller_limit = 100;
368 let bigger_limit = 200;
369
370 let scoping = build_scoping();
371
372 let mut smaller_quota = build_quota(10, smaller_limit);
373 let mut bigger_quota = build_quota(10, bigger_limit);
374
375 smaller_quota.id = Some("foobar".into());
376 bigger_quota.id = Some("foobar".into());
377
378 let redis_quotas = [
379 build_redis_quota(&smaller_quota, &scoping),
380 build_redis_quota(&bigger_quota, &scoping),
381 ];
382
383 let client = build_redis_client();
384 let mut counter = GlobalRateLimiter::default();
385
386 let rate_limited_quotas = counter
387 .filter_rate_limited(&client, &redis_quotas, (bigger_limit * 2) as usize)
388 .await
389 .unwrap();
390
391 assert_eq!(rate_limited_quotas.len(), 1);
392
393 assert_eq!(
394 rate_limited_quotas.first().unwrap().limit(),
395 smaller_limit as i64
396 );
397 }
398
399 #[tokio::test]
400 async fn test_global_rate_limit() {
401 let limit = 200;
402
403 let quota = build_quota(10, limit);
404 let scoping = build_scoping();
405 let redis_quota = [build_redis_quota("a, &scoping)];
406
407 let client = build_redis_client();
408 let mut counter = GlobalRateLimiter::default();
409
410 let expected_rate_limit_result = [false, false, true, true].to_vec();
411
412 for should_rate_limit in expected_rate_limit_result {
415 let is_rate_limited = counter
416 .filter_rate_limited(&client, &redis_quota, 90)
417 .await
418 .unwrap();
419
420 assert_eq!(should_rate_limit, !is_rate_limited.is_empty());
421 }
422 }
423
424 #[tokio::test]
425 async fn test_global_rate_limit_over_under() {
426 let limit = 10;
427
428 let quota = build_quota(10, limit);
429 let scoping = build_scoping();
430
431 let client = build_redis_client();
432 let mut rl = GlobalRateLimiter::default();
433
434 let redis_quota = [build_redis_quota("a, &scoping)];
435 assert!(
436 !rl.filter_rate_limited(&client, &redis_quota, 11)
437 .await
438 .unwrap()
439 .is_empty()
440 );
441
442 assert!(
443 rl.filter_rate_limited(&client, &redis_quota, 10)
444 .await
445 .unwrap()
446 .is_empty()
447 );
448 }
449
450 #[tokio::test]
451 async fn test_multiple_global_rate_limit() {
452 let limit = 91_337;
453
454 let quota = build_quota(10, limit as u64);
455 let scoping = build_scoping();
456 let quota = [build_redis_quota("a, &scoping)];
457
458 let client = build_redis_client();
459
460 let mut counter1 = GlobalRateLimiter::default();
461 let mut counter2 = GlobalRateLimiter::default();
462
463 let mut total = 0;
464 let mut total_counter_1 = 0;
465 let mut total_counter_2 = 0;
466 for i in 0.. {
467 let quantity = i % 17;
468
469 if counter1
470 .filter_rate_limited(&client, "a, quantity)
471 .await
472 .unwrap()
473 .is_empty()
474 {
475 total += quantity;
476 total_counter_1 += quantity;
477 }
478
479 if counter2
480 .filter_rate_limited(&client, "a, quantity)
481 .await
482 .unwrap()
483 .is_empty()
484 {
485 total += quantity;
486 total_counter_2 += quantity;
487 }
488
489 assert!(total <= limit);
490 if total == limit {
491 break;
492 }
493 }
494
495 assert_eq!(total, limit);
496
497 let diff = (total_counter_1 as f32 - total_counter_2 as f32).abs();
500 assert!(diff <= limit as f32 * DEFAULT_BUDGET_RATIO);
501 }
502
503 #[tokio::test]
504 async fn test_global_rate_limit_slots() {
505 let limit = 200;
506 let window = 10;
507
508 let ts = UnixTimestamp::now();
509 let quota = build_quota(window, limit);
510 let scoping = build_scoping();
511 let item_scoping = scoping.item(DataCategory::MetricBucket);
512
513 let client = build_redis_client();
514
515 let mut rl = GlobalRateLimiter::default();
516
517 let redis_quota = [RedisQuota::new("a, item_scoping, ts).unwrap()];
518 assert!(
519 rl.filter_rate_limited(&client, &redis_quota, 200)
520 .await
521 .unwrap()
522 .is_empty()
523 );
524
525 assert!(
526 !rl.filter_rate_limited(&client, &redis_quota, 1)
527 .await
528 .unwrap()
529 .is_empty()
530 );
531
532 let redis_quota =
534 [
535 RedisQuota::new("a, item_scoping, ts + Duration::from_secs(window + 1))
536 .unwrap(),
537 ];
538 assert!(
539 rl.filter_rate_limited(&client, &redis_quota, 200)
540 .await
541 .unwrap()
542 .is_empty()
543 );
544
545 assert!(
546 !rl.filter_rate_limited(&client, &redis_quota, 1)
547 .await
548 .unwrap()
549 .is_empty()
550 );
551 }
552
553 #[tokio::test]
554 async fn test_global_rate_limit_infinite() {
555 let limit = None;
556
557 let timestamp = UnixTimestamp::now();
558
559 let mut quota = build_quota(100, limit);
560 let scoping = build_scoping();
561 let item_scoping = scoping.item(DataCategory::MetricBucket);
562
563 let client = build_redis_client();
564
565 let mut rl = GlobalRateLimiter::default();
566
567 let quantity = 2;
568 let redis_threshold = (quantity as f32 / DEFAULT_BUDGET_RATIO) as u64;
569 for _ in 0..redis_threshold + 10 {
570 let redis_quota = RedisQuota::new("a, item_scoping, timestamp).unwrap();
571 assert!(
572 rl.filter_rate_limited(&client, &[redis_quota], quantity)
573 .await
574 .unwrap()
575 .is_empty()
576 );
577 }
578
579 let mut rl = GlobalRateLimiter::default();
582
583 quota.limit = Some(redis_threshold);
584 let redis_quota = RedisQuota::new("a, item_scoping, timestamp).unwrap();
585
586 assert!(
587 !rl.filter_rate_limited(&client, &[redis_quota], quantity)
588 .await
589 .unwrap()
590 .is_empty()
591 );
592 }
593}