1use relay_quotas::{
2 GlobalLimiter, GlobalRateLimiter, OwnedRedisQuota, RateLimitingError, RedisQuota,
3};
4use relay_redis::AsyncRedisClient;
5use relay_system::{
6 Addr, AsyncResponse, FromMessage, Interface, MessageResponse, Receiver, Sender, Service,
7};
8
9pub struct CheckRateLimited {
14 pub global_quotas: Vec<OwnedRedisQuota>,
15 pub quantity: usize,
16}
17
18pub enum GlobalRateLimits {
23 CheckRateLimited(
25 CheckRateLimited,
26 Sender<Result<Vec<OwnedRedisQuota>, RateLimitingError>>,
27 ),
28}
29
30impl Interface for GlobalRateLimits {}
31
32impl FromMessage<CheckRateLimited> for GlobalRateLimits {
33 type Response = AsyncResponse<Result<Vec<OwnedRedisQuota>, RateLimitingError>>;
34
35 fn from_message(
36 message: CheckRateLimited,
37 sender: <Self::Response as MessageResponse>::Sender,
38 ) -> Self {
39 Self::CheckRateLimited(message, sender)
40 }
41}
42
43#[derive(Clone)]
48pub struct GlobalRateLimitsServiceHandle {
49 tx: Addr<GlobalRateLimits>,
50}
51
52impl GlobalLimiter for GlobalRateLimitsServiceHandle {
53 async fn check_global_rate_limits<'a>(
54 &self,
55 global_quotas: &'a [RedisQuota<'a>],
56 quantity: usize,
57 ) -> Result<Vec<&'a RedisQuota<'a>>, RateLimitingError> {
58 let owned_global_quotas = global_quotas
60 .iter()
61 .map(|q| q.build_owned())
62 .collect::<Vec<_>>();
63
64 let rate_limited_owned_global_quotas = self
65 .tx
66 .send(CheckRateLimited {
67 global_quotas: owned_global_quotas,
68 quantity,
69 })
70 .await
71 .map_err(|_| RateLimitingError::UnreachableGlobalRateLimits)??;
72
73 let res = rate_limited_owned_global_quotas
86 .iter()
87 .filter_map(|owned_global_quota| {
88 let global_quota = owned_global_quota.build_ref();
89 global_quotas.iter().find(|x| **x == global_quota)
90 })
91 .collect::<Vec<_>>();
92 Ok(res)
93 }
94}
95
96impl From<Addr<GlobalRateLimits>> for GlobalRateLimitsServiceHandle {
97 fn from(tx: Addr<GlobalRateLimits>) -> Self {
98 Self { tx }
99 }
100}
101
102#[derive(Debug)]
107pub struct GlobalRateLimitsService {
108 client: AsyncRedisClient,
109 limiter: GlobalRateLimiter,
110}
111
112impl GlobalRateLimitsService {
113 pub fn new(client: AsyncRedisClient) -> Self {
118 Self {
119 client,
120 limiter: GlobalRateLimiter::default(),
121 }
122 }
123
124 async fn handle_message(
126 client: &AsyncRedisClient,
127 limiter: &mut GlobalRateLimiter,
128 message: GlobalRateLimits,
129 ) {
130 match message {
131 GlobalRateLimits::CheckRateLimited(check_rate_limited, sender) => {
132 let result =
133 Self::handle_check_rate_limited(client, limiter, check_rate_limited).await;
134 sender.send(result);
135 }
136 }
137 }
138
139 async fn handle_check_rate_limited(
144 client: &AsyncRedisClient,
145 limiter: &mut GlobalRateLimiter,
146 check_rate_limited: CheckRateLimited,
147 ) -> Result<Vec<OwnedRedisQuota>, RateLimitingError> {
148 let quotas = check_rate_limited
149 .global_quotas
150 .iter()
151 .map(|q| q.build_ref())
152 .collect::<Vec<_>>();
153
154 limiter
155 .filter_rate_limited(client, "as, check_rate_limited.quantity)
156 .await
157 .map(|q| q.into_iter().map(|q| q.build_owned()).collect::<Vec<_>>())
158 }
159}
160
161impl Service for GlobalRateLimitsService {
162 type Interface = GlobalRateLimits;
163
164 async fn run(mut self, mut rx: Receiver<Self::Interface>) {
165 loop {
166 let Some(message) = rx.recv().await else {
167 break;
168 };
169
170 Self::handle_message(&self.client, &mut self.limiter, message).await;
171 }
172 }
173}
174
175#[cfg(test)]
176mod tests {
177 use std::collections::BTreeSet;
178
179 use relay_base_schema::data_category::DataCategory;
180 use relay_base_schema::organization::OrganizationId;
181 use relay_base_schema::project::{ProjectId, ProjectKey};
182 use relay_common::time::UnixTimestamp;
183 use relay_quotas::{DataCategories, Quota, QuotaScope, RedisQuota, Scoping};
184 use relay_redis::{AsyncRedisClient, RedisConfigOptions};
185 use relay_system::Service;
186
187 use crate::services::global_rate_limits::{CheckRateLimited, GlobalRateLimitsService};
188
189 fn build_redis_client() -> AsyncRedisClient {
190 let url = std::env::var("RELAY_REDIS_URL")
191 .unwrap_or_else(|_| "redis://127.0.0.1:6379".to_owned());
192
193 AsyncRedisClient::single(&url, &RedisConfigOptions::default()).unwrap()
194 }
195
196 fn build_quota(window: u64, limit: impl Into<Option<u64>>) -> Quota {
197 Quota {
198 id: Some(uuid::Uuid::new_v4().to_string()),
199 categories: DataCategories::new(),
200 scope: QuotaScope::Global,
201 scope_id: None,
202 window: Some(window),
203 limit: limit.into(),
204 reason_code: None,
205 namespace: None,
206 }
207 }
208
209 fn build_redis_quota<'a>(quota: &'a Quota, scoping: &'a Scoping) -> RedisQuota<'a> {
210 let scoping = scoping.item(DataCategory::MetricBucket);
211 RedisQuota::new(quota, scoping, UnixTimestamp::now()).unwrap()
212 }
213
214 #[tokio::test]
215 async fn test_global_rate_limits_service() {
216 let client = build_redis_client();
217 let service = GlobalRateLimitsService::new(client);
218 let tx = service.start_detached();
219
220 let scoping = Scoping {
221 organization_id: OrganizationId::new(69420),
222 project_id: ProjectId::new(42),
223 project_key: ProjectKey::parse("a94ae32be2584e0bbd7a4cbb95971fee").unwrap(),
224 key_id: Some(4711),
225 };
226
227 let quota1 = build_quota(10, 100);
228 let quota2 = build_quota(10, 150);
229 let quota3 = build_quota(10, 200);
230 let quantity = 175;
231
232 let redis_quota_2 = build_redis_quota("a2, &scoping);
233 let redis_quotas = [
234 build_redis_quota("a1, &scoping),
235 redis_quota_2.clone(),
237 redis_quota_2,
238 build_redis_quota("a3, &scoping),
239 ]
240 .iter()
241 .map(|q| q.build_owned())
242 .collect();
243
244 let check_rate_limited = CheckRateLimited {
245 global_quotas: redis_quotas,
246 quantity,
247 };
248
249 let rate_limited_quotas = tx.send(check_rate_limited).await.unwrap().unwrap();
250
251 assert_eq!(
253 BTreeSet::from([100, 150, 150]),
254 rate_limited_quotas
255 .iter()
256 .map(|quota| quota.build_ref().limit())
257 .collect()
258 );
259 }
260}