1use relay_quotas::{
2 GlobalLimiter, GlobalRateLimiter, OwnedRedisQuota, RateLimitingError, RedisQuota,
3};
4use relay_redis::AsyncRedisClient;
5use relay_system::{
6 Addr, AsyncResponse, FromMessage, Interface, MessageResponse, Receiver, Sender, Service,
7};
8
9pub struct CheckRateLimited {
14 pub global_quotas: Vec<OwnedRedisQuota>,
15}
16
17pub enum GlobalRateLimits {
22 CheckRateLimited(
24 CheckRateLimited,
25 Sender<Result<Vec<OwnedRedisQuota>, RateLimitingError>>,
26 ),
27}
28
29impl Interface for GlobalRateLimits {}
30
31impl FromMessage<CheckRateLimited> for GlobalRateLimits {
32 type Response = AsyncResponse<Result<Vec<OwnedRedisQuota>, RateLimitingError>>;
33
34 fn from_message(
35 message: CheckRateLimited,
36 sender: <Self::Response as MessageResponse>::Sender,
37 ) -> Self {
38 Self::CheckRateLimited(message, sender)
39 }
40}
41
42#[derive(Clone)]
47pub struct GlobalRateLimitsServiceHandle {
48 tx: Addr<GlobalRateLimits>,
49}
50
51impl GlobalLimiter for GlobalRateLimitsServiceHandle {
52 async fn check_global_rate_limits<'a>(
53 &self,
54 global_quotas: &'a [RedisQuota<'a>],
55 ) -> Result<Vec<&'a RedisQuota<'a>>, RateLimitingError> {
56 let owned_global_quotas = global_quotas
58 .iter()
59 .map(|q| q.build_owned())
60 .collect::<Vec<_>>();
61
62 let rate_limited_owned_global_quotas = self
63 .tx
64 .send(CheckRateLimited {
65 global_quotas: owned_global_quotas,
66 })
67 .await
68 .map_err(|_| RateLimitingError::UnreachableGlobalRateLimits)??;
69
70 let res = rate_limited_owned_global_quotas
83 .iter()
84 .filter_map(|owned_global_quota| {
85 let global_quota = owned_global_quota.build_ref();
86 global_quotas.iter().find(|x| **x == global_quota)
87 })
88 .collect::<Vec<_>>();
89 Ok(res)
90 }
91}
92
93impl From<Addr<GlobalRateLimits>> for GlobalRateLimitsServiceHandle {
94 fn from(tx: Addr<GlobalRateLimits>) -> Self {
95 Self { tx }
96 }
97}
98
99#[derive(Debug)]
104pub struct GlobalRateLimitsService {
105 client: AsyncRedisClient,
106 limiter: GlobalRateLimiter,
107}
108
109impl GlobalRateLimitsService {
110 pub fn new(client: AsyncRedisClient) -> Self {
115 Self {
116 client,
117 limiter: GlobalRateLimiter::default(),
118 }
119 }
120
121 async fn handle_message(
123 client: &AsyncRedisClient,
124 limiter: &mut GlobalRateLimiter,
125 message: GlobalRateLimits,
126 ) {
127 match message {
128 GlobalRateLimits::CheckRateLimited(check_rate_limited, sender) => {
129 let result =
130 Self::handle_check_rate_limited(client, limiter, check_rate_limited).await;
131 sender.send(result);
132 }
133 }
134 }
135
136 async fn handle_check_rate_limited(
141 client: &AsyncRedisClient,
142 limiter: &mut GlobalRateLimiter,
143 check_rate_limited: CheckRateLimited,
144 ) -> Result<Vec<OwnedRedisQuota>, RateLimitingError> {
145 let quotas = check_rate_limited
146 .global_quotas
147 .iter()
148 .map(|q| q.build_ref())
149 .collect::<Vec<_>>();
150
151 limiter
152 .filter_rate_limited(client, "as)
153 .await
154 .map(|q| q.into_iter().map(|q| q.build_owned()).collect::<Vec<_>>())
155 }
156}
157
158impl Service for GlobalRateLimitsService {
159 type Interface = GlobalRateLimits;
160
161 async fn run(mut self, mut rx: Receiver<Self::Interface>) {
162 loop {
163 let Some(message) = rx.recv().await else {
164 break;
165 };
166
167 Self::handle_message(&self.client, &mut self.limiter, message).await;
168 }
169 }
170}
171
172#[cfg(test)]
173mod tests {
174 use std::collections::BTreeSet;
175
176 use relay_base_schema::data_category::DataCategory;
177 use relay_base_schema::organization::OrganizationId;
178 use relay_base_schema::project::{ProjectId, ProjectKey};
179 use relay_common::time::UnixTimestamp;
180 use relay_quotas::{DataCategories, Quota, QuotaScope, RedisQuota, Scoping};
181 use relay_redis::{AsyncRedisClient, RedisConfigOptions};
182 use relay_system::Service;
183
184 use crate::services::global_rate_limits::{CheckRateLimited, GlobalRateLimitsService};
185
186 fn build_redis_client() -> AsyncRedisClient {
187 let url = std::env::var("RELAY_REDIS_URL")
188 .unwrap_or_else(|_| "redis://127.0.0.1:6379".to_owned());
189
190 AsyncRedisClient::single("test", &url, &RedisConfigOptions::default()).unwrap()
191 }
192
193 fn build_quota(window: u64, limit: impl Into<Option<u64>>) -> Quota {
194 Quota {
195 id: Some(uuid::Uuid::new_v4().to_string().into()),
196 categories: DataCategories::new(),
197 scope: QuotaScope::Global,
198 scope_id: None,
199 window: Some(window),
200 limit: limit.into(),
201 reason_code: None,
202 namespace: None,
203 }
204 }
205
206 fn build_redis_quota<'a>(
207 quota: &'a Quota,
208 quantity: u64,
209 scoping: &'a Scoping,
210 ) -> RedisQuota<'a> {
211 let scoping = scoping.item(DataCategory::MetricBucket);
212 RedisQuota::new(quota, quantity, scoping, UnixTimestamp::now()).unwrap()
213 }
214
215 #[tokio::test]
216 async fn test_global_rate_limits_service() {
217 let client = build_redis_client();
218 let service = GlobalRateLimitsService::new(client);
219 let tx = service.start_detached();
220
221 let scoping = Scoping {
222 organization_id: OrganizationId::new(69420),
223 project_id: ProjectId::new(42),
224 project_key: ProjectKey::parse("a94ae32be2584e0bbd7a4cbb95971fee").unwrap(),
225 key_id: Some(4711),
226 };
227
228 let quota1 = build_quota(10, 100);
229 let quota2 = build_quota(10, 150);
230 let quota3 = build_quota(10, 200);
231 let quantity = 175;
232
233 let redis_quota_2 = build_redis_quota("a2, quantity, &scoping);
234 let redis_quotas = [
235 build_redis_quota("a1, quantity, &scoping),
236 redis_quota_2.clone(),
238 redis_quota_2,
239 build_redis_quota("a3, quantity, &scoping),
240 ]
241 .iter()
242 .map(|q| q.build_owned())
243 .collect();
244
245 let check_rate_limited = CheckRateLimited {
246 global_quotas: redis_quotas,
247 };
248
249 let rate_limited_quotas = tx.send(check_rate_limited).await.unwrap().unwrap();
250
251 assert_eq!(
253 BTreeSet::from([100, 150, 150]),
254 rate_limited_quotas
255 .iter()
256 .map(|quota| quota.build_ref().limit())
257 .collect()
258 );
259 }
260}