use relay_quotas::{
GlobalLimiter, GlobalRateLimiter, OwnedRedisQuota, RateLimitingError, RedisQuota,
};
use relay_redis::AsyncRedisClient;
use relay_system::{
Addr, AsyncResponse, FromMessage, Interface, MessageResponse, Receiver, Sender, Service,
};
pub struct CheckRateLimited {
pub global_quotas: Vec<OwnedRedisQuota>,
pub quantity: usize,
}
pub enum GlobalRateLimits {
CheckRateLimited(
CheckRateLimited,
Sender<Result<Vec<OwnedRedisQuota>, RateLimitingError>>,
),
}
impl Interface for GlobalRateLimits {}
impl FromMessage<CheckRateLimited> for GlobalRateLimits {
type Response = AsyncResponse<Result<Vec<OwnedRedisQuota>, RateLimitingError>>;
fn from_message(
message: CheckRateLimited,
sender: <Self::Response as MessageResponse>::Sender,
) -> Self {
Self::CheckRateLimited(message, sender)
}
}
pub struct GlobalRateLimitsServiceHandle {
tx: Addr<GlobalRateLimits>,
}
impl GlobalLimiter for GlobalRateLimitsServiceHandle {
async fn check_global_rate_limits<'a>(
&self,
global_quotas: &'a [RedisQuota<'a>],
quantity: usize,
) -> Result<Vec<&'a RedisQuota<'a>>, RateLimitingError> {
let owned_global_quotas = global_quotas
.iter()
.map(|q| q.build_owned())
.collect::<Vec<_>>();
let rate_limited_owned_global_quotas = self
.tx
.send(CheckRateLimited {
global_quotas: owned_global_quotas,
quantity,
})
.await
.map_err(|_| RateLimitingError::UnreachableGlobalRateLimits)?;
let rate_limited_global_quotas =
rate_limited_owned_global_quotas.map(|owned_global_quotas| {
owned_global_quotas
.iter()
.filter_map(|owned_global_quota| {
let global_quota = owned_global_quota.build_ref();
global_quotas.iter().find(|x| **x == global_quota)
})
.collect::<Vec<_>>()
});
rate_limited_global_quotas
}
}
impl From<Addr<GlobalRateLimits>> for GlobalRateLimitsServiceHandle {
fn from(tx: Addr<GlobalRateLimits>) -> Self {
Self { tx }
}
}
#[derive(Debug)]
pub struct GlobalRateLimitsService {
client: AsyncRedisClient,
limiter: GlobalRateLimiter,
}
impl GlobalRateLimitsService {
pub fn new(client: AsyncRedisClient) -> Self {
Self {
client,
limiter: GlobalRateLimiter::default(),
}
}
async fn handle_message(
client: &AsyncRedisClient,
limiter: &mut GlobalRateLimiter,
message: GlobalRateLimits,
) {
match message {
GlobalRateLimits::CheckRateLimited(check_rate_limited, sender) => {
let result =
Self::handle_check_rate_limited(client, limiter, check_rate_limited).await;
sender.send(result);
}
}
}
async fn handle_check_rate_limited(
client: &AsyncRedisClient,
limiter: &mut GlobalRateLimiter,
check_rate_limited: CheckRateLimited,
) -> Result<Vec<OwnedRedisQuota>, RateLimitingError> {
let quotas = check_rate_limited
.global_quotas
.iter()
.map(|q| q.build_ref())
.collect::<Vec<_>>();
limiter
.filter_rate_limited(client, "as, check_rate_limited.quantity)
.await
.map(|q| q.into_iter().map(|q| q.build_owned()).collect::<Vec<_>>())
}
}
impl Service for GlobalRateLimitsService {
type Interface = GlobalRateLimits;
async fn run(mut self, mut rx: Receiver<Self::Interface>) {
loop {
let Some(message) = rx.recv().await else {
break;
};
Self::handle_message(&self.client, &mut self.limiter, message).await;
}
}
}
#[cfg(test)]
mod tests {
use std::collections::BTreeSet;
use relay_base_schema::data_category::DataCategory;
use relay_base_schema::organization::OrganizationId;
use relay_base_schema::project::{ProjectId, ProjectKey};
use relay_common::time::UnixTimestamp;
use relay_quotas::{DataCategories, Quota, QuotaScope, RedisQuota, Scoping};
use relay_redis::{AsyncRedisClient, RedisConfigOptions};
use relay_system::Service;
use crate::services::global_rate_limits::{CheckRateLimited, GlobalRateLimitsService};
async fn build_redis_client() -> AsyncRedisClient {
let url = std::env::var("RELAY_REDIS_URL")
.unwrap_or_else(|_| "redis://127.0.0.1:6379".to_owned());
AsyncRedisClient::single(&url, &RedisConfigOptions::default())
.await
.unwrap()
}
fn build_quota(window: u64, limit: impl Into<Option<u64>>) -> Quota {
Quota {
id: Some(uuid::Uuid::new_v4().to_string()),
categories: DataCategories::new(),
scope: QuotaScope::Global,
scope_id: None,
window: Some(window),
limit: limit.into(),
reason_code: None,
namespace: None,
}
}
fn build_redis_quota<'a>(quota: &'a Quota, scoping: &'a Scoping) -> RedisQuota<'a> {
let scoping = scoping.item(DataCategory::MetricBucket);
RedisQuota::new(quota, scoping, UnixTimestamp::now()).unwrap()
}
#[tokio::test]
async fn test_global_rate_limits_service() {
let client = build_redis_client().await;
let service = GlobalRateLimitsService::new(client);
let tx = service.start_detached();
let scoping = Scoping {
organization_id: OrganizationId::new(69420),
project_id: ProjectId::new(42),
project_key: ProjectKey::parse("a94ae32be2584e0bbd7a4cbb95971fee").unwrap(),
key_id: Some(4711),
};
let quota1 = build_quota(10, 100);
let quota2 = build_quota(10, 150);
let quota3 = build_quota(10, 200);
let quantity = 175;
let redis_quota_2 = build_redis_quota("a2, &scoping);
let redis_quotas = [
build_redis_quota("a1, &scoping),
redis_quota_2.clone(),
redis_quota_2,
build_redis_quota("a3, &scoping),
]
.iter()
.map(|q| q.build_owned())
.collect();
let check_rate_limited = CheckRateLimited {
global_quotas: redis_quotas,
quantity,
};
let rate_limited_quotas = tx.send(check_rate_limited).await.unwrap().unwrap();
assert_eq!(
BTreeSet::from([100, 150, 150]),
rate_limited_quotas
.iter()
.map(|quota| quota.build_ref().limit())
.collect()
);
}
}