1use std::future::Future;
2
3use itertools::Itertools;
4use relay_base_schema::metrics::MetricNamespace;
5use relay_redis::{AsyncRedisClient, AsyncRedisConnection, RedisError, RedisScripts};
6
7use crate::RateLimitingError;
8use crate::redis::RedisQuota;
9
10const DEFAULT_BUDGET_RATIO: f32 = 0.001;
12
13pub trait GlobalLimiter {
15 fn check_global_rate_limits<'a>(
17 &self,
18 global_quotas: &'a [RedisQuota<'a>],
19 ) -> impl Future<Output = Result<Vec<&'a RedisQuota<'a>>, RateLimitingError>> + Send;
20}
21
22#[derive(Debug, Default)]
27pub struct GlobalRateLimiter {
28 limits: hashbrown::HashMap<Key, GlobalRateLimit>,
29}
30
31impl GlobalRateLimiter {
32 pub async fn filter_rate_limited<'a>(
40 &mut self,
41 client: &AsyncRedisClient,
42 quotas: &'a [RedisQuota<'a>],
43 ) -> Result<Vec<&'a RedisQuota<'a>>, RateLimitingError> {
44 let mut connection = client.get_connection().await?;
45
46 let mut rate_limited = vec![];
47 let mut not_rate_limited = vec![];
48
49 let min_by_keyref = quotas
50 .iter()
51 .into_grouping_map_by(|q| KeyRef::new(q))
52 .min_by_key(|_, q| q.limit());
53
54 for (key, quota) in min_by_keyref {
55 let global_rate_limit = self.limits.entry_ref(&key).or_default();
56
57 if global_rate_limit
58 .check_rate_limited(&mut connection, quota, key)
59 .await?
60 {
61 rate_limited.push(quota);
62 } else {
63 not_rate_limited.push(quota);
64 }
65 }
66
67 if rate_limited.is_empty() {
68 for quota in not_rate_limited {
69 if let Some(val) = self.limits.get_mut(&KeyRef::new(quota)) {
70 val.budget -= quota.quantity();
71 }
72 }
73 }
74
75 Ok(rate_limited)
76 }
77}
78
79#[derive(Clone, Copy, Hash, Debug, Eq, PartialEq, Ord, PartialOrd)]
84struct KeyRef<'a> {
85 prefix: &'a str,
86 window: u64,
87 namespace: Option<MetricNamespace>,
88}
89
90impl<'a> KeyRef<'a> {
91 fn new(quota: &'a RedisQuota<'a>) -> Self {
92 Self {
93 prefix: quota.prefix(),
94 window: quota.window(),
95 namespace: quota.namespace,
96 }
97 }
98
99 fn redis_key(&self, slot: u64) -> RedisKey {
100 RedisKey::new(self, slot)
101 }
102}
103
104impl hashbrown::Equivalent<Key> for KeyRef<'_> {
105 fn equivalent(&self, key: &Key) -> bool {
106 let Key {
107 prefix,
108 window,
109 namespace,
110 } = key;
111
112 self.prefix == prefix && self.window == *window && self.namespace == *namespace
113 }
114}
115
116#[derive(Debug, Clone, PartialEq, Eq, Hash)]
121struct Key {
122 prefix: String,
123 window: u64,
124 namespace: Option<MetricNamespace>,
125}
126
127impl From<&KeyRef<'_>> for Key {
128 fn from(value: &KeyRef<'_>) -> Self {
129 Key {
130 prefix: value.prefix.to_owned(),
131 window: value.window,
132 namespace: value.namespace,
133 }
134 }
135}
136
137#[derive(Debug)]
141struct RedisKey(String);
142
143impl RedisKey {
144 fn new(key: &KeyRef<'_>, slot: u64) -> Self {
145 Self(format!(
146 "global_quota:{id}{window}{namespace:?}:{slot}",
147 id = key.prefix,
148 window = key.window,
149 namespace = key.namespace,
150 slot = slot,
151 ))
152 }
153}
154
155#[derive(Debug)]
164struct GlobalRateLimit {
165 budget: u64,
166 last_seen_redis_value: u64,
167 slot: u64,
168}
169
170impl GlobalRateLimit {
171 fn new() -> Self {
173 Self {
174 budget: 0,
175 last_seen_redis_value: 0,
176 slot: 0,
177 }
178 }
179
180 pub async fn check_rate_limited(
186 &mut self,
187 connection: &mut AsyncRedisConnection,
188 quota: &RedisQuota<'_>,
189 key: KeyRef<'_>,
190 ) -> Result<bool, RateLimitingError> {
191 let quota_slot = quota.slot();
192 let quantity = quota.quantity();
193
194 if quota_slot > self.slot || quota_slot + 1 < self.slot {
203 self.budget = 0;
204 self.last_seen_redis_value = 0;
205 self.slot = quota_slot;
206 }
207
208 if self.budget >= quantity {
209 return Ok(false);
210 }
211
212 let redis_key = key.redis_key(quota_slot);
213 let reserved = self
214 .try_reserve(connection, quantity, quota, redis_key)
215 .await
216 .map_err(RateLimitingError::Redis)?;
217 self.budget += reserved;
218
219 Ok(self.budget < quantity)
220 }
221
222 async fn try_reserve(
227 &mut self,
228 connection: &mut AsyncRedisConnection,
229 quantity: u64,
230 quota: &RedisQuota<'_>,
231 redis_key: RedisKey,
232 ) -> Result<u64, RedisError> {
233 let min_required_budget = quantity.saturating_sub(self.budget);
234 let max_available_budget = quota
235 .limit
236 .unwrap_or(u64::MAX)
237 .saturating_sub(self.last_seen_redis_value);
238
239 if min_required_budget > max_available_budget {
240 return Ok(0);
241 }
242
243 let budget_to_reserve = min_required_budget.max(self.default_request_size(quantity, quota));
244
245 let (budget, value): (u64, u64) = RedisScripts::load_global_quota()
246 .prepare_invoke()
247 .key(redis_key.0)
248 .arg(budget_to_reserve)
249 .arg(quota.limit())
250 .arg(quota.key_expiry())
251 .invoke_async(connection)
252 .await
253 .map_err(RedisError::Redis)?;
254
255 self.last_seen_redis_value = value;
256
257 Ok(budget)
258 }
259
260 fn default_request_size(&self, quantity: u64, quota: &RedisQuota) -> u64 {
266 match quota.limit {
267 Some(limit) => (limit as f32 * DEFAULT_BUDGET_RATIO) as u64,
268 None => (quantity as f32 / DEFAULT_BUDGET_RATIO) as u64,
270 }
271 }
272}
273
274impl Default for GlobalRateLimit {
275 fn default() -> Self {
276 Self::new()
277 }
278}
279
280#[cfg(test)]
281mod tests {
282 use std::collections::BTreeSet;
283 use std::time::Duration;
284
285 use relay_base_schema::data_category::DataCategory;
286 use relay_base_schema::organization::OrganizationId;
287 use relay_base_schema::project::{ProjectId, ProjectKey};
288 use relay_common::time::UnixTimestamp;
289 use relay_redis::{AsyncRedisClient, RedisConfigOptions};
290
291 use super::*;
292 use crate::{DataCategories, Quota, QuotaScope, Scoping};
293
294 fn build_redis_client() -> AsyncRedisClient {
295 let url = std::env::var("RELAY_REDIS_URL")
296 .unwrap_or_else(|_| "redis://127.0.0.1:6379".to_owned());
297
298 AsyncRedisClient::single("test", &url, &RedisConfigOptions::default()).unwrap()
299 }
300
301 fn build_quota(window: u64, limit: impl Into<Option<u64>>) -> Quota {
302 Quota {
303 id: Some(uuid::Uuid::new_v4().to_string().into()),
304 categories: DataCategories::new(),
305 scope: QuotaScope::Global,
306 scope_id: None,
307 window: Some(window),
308 limit: limit.into(),
309 reason_code: None,
310 namespace: None,
311 }
312 }
313
314 fn build_scoping() -> Scoping {
315 Scoping {
316 organization_id: OrganizationId::new(69420),
317 project_id: ProjectId::new(42),
318 project_key: ProjectKey::parse("a94ae32be2584e0bbd7a4cbb95971fee").unwrap(),
319 key_id: Some(4711),
320 }
321 }
322
323 fn build_redis_quota<'a>(
324 quota: &'a Quota,
325 scoping: &'a Scoping,
326 quantity: u64,
327 ) -> RedisQuota<'a> {
328 let scoping = scoping.item(DataCategory::MetricBucket);
329 RedisQuota::new(quota, quantity, scoping, UnixTimestamp::now()).unwrap()
330 }
331
332 #[tokio::test]
333 async fn test_multiple_rate_limits() {
334 let scoping = build_scoping();
335
336 let quota1 = build_quota(10, 100);
337 let quota2 = build_quota(10, 150);
338 let quota3 = build_quota(10, 200);
339
340 let quantity = 175;
341 let redis_quotas = [
342 build_redis_quota("a1, &scoping, quantity),
343 build_redis_quota("a2, &scoping, quantity),
344 build_redis_quota("a3, &scoping, quantity),
345 ];
346
347 let client = build_redis_client();
348 let mut counter = GlobalRateLimiter::default();
349
350 let rate_limited_quotas = counter
351 .filter_rate_limited(&client, &redis_quotas)
352 .await
353 .unwrap();
354
355 assert_eq!(
357 BTreeSet::from([100, 150]),
358 rate_limited_quotas
359 .iter()
360 .map(|quota| quota.limit())
361 .collect()
362 );
363 }
364
365 #[tokio::test]
368 async fn test_use_smaller_limit() {
369 let smaller_limit = 100;
370 let bigger_limit = 200;
371
372 let scoping = build_scoping();
373
374 let mut smaller_quota = build_quota(10, smaller_limit);
375 let mut bigger_quota = build_quota(10, bigger_limit);
376
377 smaller_quota.id = Some("foobar".into());
378 bigger_quota.id = Some("foobar".into());
379
380 let quantity = bigger_limit * 2;
381 let redis_quotas = [
382 build_redis_quota(&smaller_quota, &scoping, quantity),
383 build_redis_quota(&bigger_quota, &scoping, quantity),
384 ];
385
386 let client = build_redis_client();
387 let mut counter = GlobalRateLimiter::default();
388
389 let rate_limited_quotas = counter
390 .filter_rate_limited(&client, &redis_quotas)
391 .await
392 .unwrap();
393
394 assert_eq!(rate_limited_quotas.len(), 1);
395
396 assert_eq!(
397 rate_limited_quotas.first().unwrap().limit(),
398 smaller_limit as i64
399 );
400 }
401
402 #[tokio::test]
403 async fn test_global_rate_limit() {
404 let limit = 200;
405
406 let quota = build_quota(10, limit);
407 let scoping = build_scoping();
408 let redis_quota = [build_redis_quota("a, &scoping, 90)];
409
410 let client = build_redis_client();
411 let mut counter = GlobalRateLimiter::default();
412
413 let expected_rate_limit_result = [false, false, true, true].to_vec();
414
415 for should_rate_limit in expected_rate_limit_result {
418 let is_rate_limited = counter
419 .filter_rate_limited(&client, &redis_quota)
420 .await
421 .unwrap();
422
423 assert_eq!(should_rate_limit, !is_rate_limited.is_empty());
424 }
425 }
426
427 #[tokio::test]
428 async fn test_global_rate_limit_over_under() {
429 let limit = 10;
430
431 let quota = build_quota(10, limit);
432 let scoping = build_scoping();
433
434 let client = build_redis_client();
435 let mut rl = GlobalRateLimiter::default();
436
437 let redis_quota = [build_redis_quota("a, &scoping, 11)];
438 assert!(
439 !rl.filter_rate_limited(&client, &redis_quota)
440 .await
441 .unwrap()
442 .is_empty()
443 );
444
445 let redis_quota = [build_redis_quota("a, &scoping, 10)];
446 assert!(
447 rl.filter_rate_limited(&client, &redis_quota)
448 .await
449 .unwrap()
450 .is_empty()
451 );
452 }
453
454 #[tokio::test]
455 async fn test_multiple_global_rate_limit() {
456 let limit = 91_337;
457
458 let quota = build_quota(10, limit);
459 let scoping = build_scoping();
460
461 let client = build_redis_client();
462
463 let mut counter1 = GlobalRateLimiter::default();
464 let mut counter2 = GlobalRateLimiter::default();
465
466 let mut total = 0;
467 let mut total_counter_1 = 0;
468 let mut total_counter_2 = 0;
469
470 let scoping = scoping.item(DataCategory::MetricBucket);
471 let ts = UnixTimestamp::now();
472
473 for i in 0.. {
474 let quantity = i % 17;
475
476 let quota = [RedisQuota::new("a, quantity, scoping, ts).unwrap()];
477 if counter1
478 .filter_rate_limited(&client, "a)
479 .await
480 .unwrap()
481 .is_empty()
482 {
483 total += quantity;
484 total_counter_1 += quantity;
485 }
486
487 if counter2
488 .filter_rate_limited(&client, "a)
489 .await
490 .unwrap()
491 .is_empty()
492 {
493 total += quantity;
494 total_counter_2 += quantity;
495 }
496
497 assert!(total <= limit);
498 if total == limit {
499 break;
500 }
501 }
502
503 assert_eq!(total, limit);
504
505 let diff = (total_counter_1 as f32 - total_counter_2 as f32).abs();
508 assert!(diff <= limit as f32 * DEFAULT_BUDGET_RATIO);
509 }
510
511 #[tokio::test]
512 async fn test_global_rate_limit_slots() {
513 let limit = 200;
514 let window = 10;
515
516 let ts = UnixTimestamp::now();
517 let quota = build_quota(window, limit);
518 let scoping = build_scoping();
519 let item_scoping = scoping.item(DataCategory::MetricBucket);
520
521 let client = build_redis_client();
522
523 let mut rl = GlobalRateLimiter::default();
524
525 let redis_quota = [RedisQuota::new("a, 200, item_scoping, ts).unwrap()];
526 assert!(
527 rl.filter_rate_limited(&client, &redis_quota)
528 .await
529 .unwrap()
530 .is_empty()
531 );
532
533 let redis_quota = [RedisQuota::new("a, 1, item_scoping, ts).unwrap()];
534 assert!(
535 !rl.filter_rate_limited(&client, &redis_quota)
536 .await
537 .unwrap()
538 .is_empty()
539 );
540
541 let redis_quota = RedisQuota::new(
543 "a,
544 200,
545 item_scoping,
546 ts + Duration::from_secs(window + 1),
547 )
548 .unwrap();
549 assert!(
550 rl.filter_rate_limited(&client, &[redis_quota])
551 .await
552 .unwrap()
553 .is_empty()
554 );
555
556 let redis_quota = RedisQuota::new(
557 "a,
558 1,
559 item_scoping,
560 ts + Duration::from_secs(window + 1),
561 )
562 .unwrap();
563 assert!(
564 !rl.filter_rate_limited(&client, &[redis_quota])
565 .await
566 .unwrap()
567 .is_empty()
568 );
569 }
570
571 #[tokio::test]
572 async fn test_global_rate_limit_infinite() {
573 let limit = None;
574
575 let timestamp = UnixTimestamp::now();
576
577 let mut quota = build_quota(100, limit);
578 let scoping = build_scoping();
579 let item_scoping = scoping.item(DataCategory::MetricBucket);
580
581 let client = build_redis_client();
582
583 let mut rl = GlobalRateLimiter::default();
584
585 let quantity = 2;
586 let redis_threshold = (quantity as f32 / DEFAULT_BUDGET_RATIO) as u64;
587 for _ in 0..redis_threshold + 10 {
588 let redis_quota = RedisQuota::new("a, quantity, item_scoping, timestamp).unwrap();
589 assert!(
590 rl.filter_rate_limited(&client, &[redis_quota])
591 .await
592 .unwrap()
593 .is_empty()
594 );
595 }
596
597 let mut rl = GlobalRateLimiter::default();
600
601 quota.limit = Some(redis_threshold);
602 let redis_quota = RedisQuota::new("a, quantity, item_scoping, timestamp).unwrap();
603
604 assert!(
605 !rl.filter_rate_limited(&client, &[redis_quota])
606 .await
607 .unwrap()
608 .is_empty()
609 );
610 }
611}