use std::fmt::{self, Debug};
use relay_common::time::UnixTimestamp;
use relay_log::protocol::value;
use relay_redis::redis::Script;
use relay_redis::{RedisError, RedisPool, RedisScripts};
use thiserror::Error;
use crate::global::GlobalRateLimits;
use crate::quota::{ItemScoping, Quota, QuotaScope};
use crate::rate_limit::{RateLimit, RateLimits, RetryAfter};
use crate::REJECT_ALL_SECS;
const GRACE: u64 = 60;
#[derive(Debug, Error)]
pub enum RateLimitingError {
#[error("failed to communicate with redis")]
Redis(#[source] RedisError),
}
fn get_refunded_quota_key(counter_key: &str) -> String {
format!("r:{counter_key}")
}
struct OptionalDisplay<T>(Option<T>);
impl<T> fmt::Display for OptionalDisplay<T>
where
T: fmt::Display,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self.0 {
Some(ref value) => write!(f, "{value}"),
None => Ok(()),
}
}
}
#[derive(Debug, Clone)]
pub(crate) struct RedisQuota<'a> {
quota: &'a Quota,
scoping: ItemScoping<'a>,
prefix: &'a str,
window: u64,
timestamp: UnixTimestamp,
}
impl<'a> RedisQuota<'a> {
pub(crate) fn new(
quota: &'a Quota,
scoping: ItemScoping<'a>,
timestamp: UnixTimestamp,
) -> Option<Self> {
let prefix = quota.id.as_deref()?;
let window = quota.window?;
Some(Self {
quota,
scoping,
prefix,
window,
timestamp,
})
}
pub fn window(&self) -> u64 {
self.window
}
pub fn prefix(&self) -> &'a str {
self.prefix
}
pub fn limit(&self) -> i64 {
self.limit
.and_then(|limit| limit.try_into().ok())
.unwrap_or(-1)
}
fn shift(&self) -> u64 {
if self.quota.scope == QuotaScope::Global {
0
} else {
self.scoping.organization_id.value() % self.window
}
}
pub fn slot(&self) -> u64 {
(self.timestamp.as_secs() - self.shift()) / self.window
}
pub fn expiry(&self) -> UnixTimestamp {
let next_slot = self.slot() + 1;
let next_start = next_slot * self.window + self.shift();
UnixTimestamp::from_secs(next_start)
}
pub fn key_expiry(&self) -> u64 {
self.expiry().as_secs() + GRACE
}
pub fn key(&self) -> String {
let subscope = match self.quota.scope {
QuotaScope::Global => None,
QuotaScope::Organization => None,
scope => self.scoping.scope_id(scope),
};
let org = self.scoping.organization_id;
format!(
"quota:{id}{{{org}}}{subscope}{namespace}:{slot}",
id = self.prefix,
org = org,
subscope = OptionalDisplay(subscope),
namespace = OptionalDisplay(self.namespace),
slot = self.slot(),
)
}
}
impl std::ops::Deref for RedisQuota<'_> {
type Target = Quota;
fn deref(&self) -> &Self::Target {
self.quota
}
}
pub struct RedisRateLimiter {
pool: RedisPool,
script: &'static Script,
max_limit: Option<u64>,
global_limits: GlobalRateLimits,
}
impl RedisRateLimiter {
pub fn new(pool: RedisPool) -> Self {
RedisRateLimiter {
pool,
script: RedisScripts::load_is_rate_limited(),
max_limit: None,
global_limits: GlobalRateLimits::default(),
}
}
pub fn max_limit(mut self, max_limit: Option<u64>) -> Self {
self.max_limit = max_limit;
self
}
pub fn is_rate_limited<'a>(
&self,
quotas: impl IntoIterator<Item = &'a Quota>,
item_scoping: ItemScoping<'_>,
quantity: usize,
over_accept_once: bool,
) -> Result<RateLimits, RateLimitingError> {
let mut client = self.pool.client().map_err(RateLimitingError::Redis)?;
let timestamp = UnixTimestamp::now();
let mut invocation = self.script.prepare_invoke();
let mut tracked_quotas = Vec::new();
let mut rate_limits = RateLimits::new();
let mut global_quotas = vec![];
for quota in quotas {
if !quota.matches(item_scoping) {
} else if quota.limit == Some(0) {
let retry_after = self.retry_after(REJECT_ALL_SECS);
rate_limits.add(RateLimit::from_quota(quota, &item_scoping, retry_after));
} else if let Some(quota) = RedisQuota::new(quota, item_scoping, timestamp) {
if quota.scope == QuotaScope::Global {
global_quotas.push(quota);
} else {
let key = quota.key();
let refund_key = get_refunded_quota_key(&key);
invocation.key(key);
invocation.key(refund_key);
invocation.arg(quota.limit());
invocation.arg(quota.key_expiry());
invocation.arg(quantity);
invocation.arg(over_accept_once);
tracked_quotas.push(quota);
}
} else {
relay_log::with_scope(
|scope| scope.set_extra("quota", value::to_value(quota).unwrap()),
|| relay_log::warn!("skipping unsupported quota"),
)
}
}
let rate_limited_global_quotas = self
.global_limits
.filter_rate_limited(&mut client, &global_quotas, quantity)
.map_err(RateLimitingError::Redis)?;
for quota in rate_limited_global_quotas {
let retry_after = self.retry_after((quota.expiry() - timestamp).as_secs());
rate_limits.add(RateLimit::from_quota(quota, &item_scoping, retry_after));
}
if tracked_quotas.is_empty() || rate_limits.is_limited() {
return Ok(rate_limits);
}
let rejections: Vec<bool> = invocation
.invoke(&mut client.connection().map_err(RateLimitingError::Redis)?)
.map_err(RedisError::Redis)
.map_err(RateLimitingError::Redis)?;
for (quota, is_rejected) in tracked_quotas.iter().zip(rejections) {
if is_rejected {
let retry_after = self.retry_after((quota.expiry() - timestamp).as_secs());
rate_limits.add(RateLimit::from_quota(quota, &item_scoping, retry_after));
}
}
Ok(rate_limits)
}
fn retry_after(&self, mut seconds: u64) -> RetryAfter {
if let Some(max_limit) = self.max_limit {
seconds = std::cmp::min(seconds, max_limit);
}
RetryAfter::from_secs(seconds)
}
}
#[cfg(test)]
mod tests {
use std::time::{SystemTime, UNIX_EPOCH};
use relay_base_schema::metrics::MetricNamespace;
use relay_base_schema::organization::OrganizationId;
use relay_base_schema::project::{ProjectId, ProjectKey};
use relay_redis::redis::Commands;
use relay_redis::RedisConfigOptions;
use smallvec::smallvec;
use super::*;
use crate::quota::{DataCategories, DataCategory, ReasonCode, Scoping};
use crate::rate_limit::RateLimitScope;
use crate::MetricNamespaceScoping;
fn build_rate_limiter() -> RedisRateLimiter {
let url = std::env::var("RELAY_REDIS_URL")
.unwrap_or_else(|_| "redis://127.0.0.1:6379".to_owned());
RedisRateLimiter {
pool: RedisPool::single(&url, RedisConfigOptions::default()).unwrap(),
script: RedisScripts::load_is_rate_limited(),
max_limit: None,
global_limits: GlobalRateLimits::default(),
}
}
#[test]
fn test_zero_size_quotas() {
let quotas = &[
Quota {
id: None,
categories: DataCategories::new(),
scope: QuotaScope::Organization,
scope_id: None,
limit: Some(0),
window: None,
reason_code: Some(ReasonCode::new("get_lost")),
namespace: None,
},
Quota {
id: Some("42".to_owned()),
categories: DataCategories::new(),
scope: QuotaScope::Organization,
scope_id: None,
limit: None,
window: Some(42),
reason_code: Some(ReasonCode::new("unlimited")),
namespace: None,
},
];
let scoping = ItemScoping {
category: DataCategory::Error,
scoping: &Scoping {
organization_id: OrganizationId::new(42),
project_id: ProjectId::new(43),
project_key: ProjectKey::parse("a94ae32be2584e0bbd7a4cbb95971fee").unwrap(),
key_id: Some(44),
},
namespace: MetricNamespaceScoping::None,
};
let rate_limits: Vec<RateLimit> = build_rate_limiter()
.is_rate_limited(quotas, scoping, 1, false)
.expect("rate limiting failed")
.into_iter()
.collect();
assert_eq!(
rate_limits,
vec![RateLimit {
categories: DataCategories::new(),
scope: RateLimitScope::Organization(OrganizationId::new(42)),
reason_code: Some(ReasonCode::new("get_lost")),
retry_after: rate_limits[0].retry_after,
namespaces: smallvec![],
}]
);
}
#[test]
fn test_non_global_namespace_quota() {
let quota_limit = 5;
let get_quota = |namespace: Option<MetricNamespace>| -> Quota {
Quota {
id: Some(format!("test_simple_quota_{}", uuid::Uuid::new_v4())),
categories: DataCategories::new(),
scope: QuotaScope::Organization,
scope_id: None,
limit: Some(quota_limit),
window: Some(600),
reason_code: Some(ReasonCode::new(format!("ns: {:?}", namespace))),
namespace,
}
};
let quotas = &[get_quota(None)];
let quota_with_namespace = &[get_quota(Some(MetricNamespace::Transactions))];
let scoping = ItemScoping {
category: DataCategory::Error,
scoping: &Scoping {
organization_id: OrganizationId::new(42),
project_id: ProjectId::new(43),
project_key: ProjectKey::parse("a94ae32be2584e0bbd7a4cbb95971fee").unwrap(),
key_id: Some(44),
},
namespace: MetricNamespaceScoping::Some(MetricNamespace::Transactions),
};
let rate_limiter = build_rate_limiter();
for i in 0..10 {
let rate_limits: Vec<RateLimit> = rate_limiter
.is_rate_limited(quotas, scoping, 1, false)
.expect("rate limiting failed")
.into_iter()
.collect();
if i < quota_limit {
assert_eq!(rate_limits, vec![]);
} else {
assert_eq!(
rate_limits[0].reason_code,
Some(ReasonCode::new("ns: None"))
);
}
}
for i in 0..10 {
let rate_limits: Vec<RateLimit> = rate_limiter
.is_rate_limited(quota_with_namespace, scoping, 1, false)
.expect("rate limiting failed")
.into_iter()
.collect();
if i < quota_limit {
assert_eq!(rate_limits, vec![]);
} else {
assert_eq!(
rate_limits[0].reason_code,
Some(ReasonCode::new("ns: Some(Transactions)"))
);
}
}
}
#[test]
fn test_simple_quota() {
let quotas = &[Quota {
id: Some(format!("test_simple_quota_{}", uuid::Uuid::new_v4())),
categories: DataCategories::new(),
scope: QuotaScope::Organization,
scope_id: None,
limit: Some(5),
window: Some(60),
reason_code: Some(ReasonCode::new("get_lost")),
namespace: None,
}];
let scoping = ItemScoping {
category: DataCategory::Error,
scoping: &Scoping {
organization_id: OrganizationId::new(42),
project_id: ProjectId::new(43),
project_key: ProjectKey::parse("a94ae32be2584e0bbd7a4cbb95971fee").unwrap(),
key_id: Some(44),
},
namespace: MetricNamespaceScoping::None,
};
let rate_limiter = build_rate_limiter();
for i in 0..10 {
let rate_limits: Vec<RateLimit> = rate_limiter
.is_rate_limited(quotas, scoping, 1, false)
.expect("rate limiting failed")
.into_iter()
.collect();
if i >= 5 {
assert_eq!(
rate_limits,
vec![RateLimit {
categories: DataCategories::new(),
scope: RateLimitScope::Organization(OrganizationId::new(42)),
reason_code: Some(ReasonCode::new("get_lost")),
retry_after: rate_limits[0].retry_after,
namespaces: smallvec![],
}]
);
} else {
assert_eq!(rate_limits, vec![]);
}
}
}
#[test]
fn test_simple_global_quota() {
let quotas = &[Quota {
id: Some(format!("test_simple_global_quota_{}", uuid::Uuid::new_v4())),
categories: DataCategories::new(),
scope: QuotaScope::Global,
scope_id: None,
limit: Some(5),
window: Some(60),
reason_code: Some(ReasonCode::new("get_lost")),
namespace: None,
}];
let scoping = ItemScoping {
category: DataCategory::Error,
scoping: &Scoping {
organization_id: OrganizationId::new(42),
project_id: ProjectId::new(43),
project_key: ProjectKey::parse("a94ae32be2584e0bbd7a4cbb95971fee").unwrap(),
key_id: Some(44),
},
namespace: MetricNamespaceScoping::None,
};
let rate_limiter = build_rate_limiter();
for i in 0..10 {
let rate_limits: Vec<RateLimit> = rate_limiter
.is_rate_limited(quotas, scoping, 1, false)
.expect("rate limiting failed")
.into_iter()
.collect();
if i >= 5 {
assert_eq!(
rate_limits,
vec![RateLimit {
categories: DataCategories::new(),
scope: RateLimitScope::Global,
reason_code: Some(ReasonCode::new("get_lost")),
retry_after: rate_limits[0].retry_after,
namespaces: smallvec![],
}]
);
} else {
assert_eq!(rate_limits, vec![]);
}
}
}
#[test]
fn test_quantity_0() {
let quotas = &[Quota {
id: Some(format!("test_quantity_0_{}", uuid::Uuid::new_v4())),
categories: DataCategories::new(),
scope: QuotaScope::Organization,
scope_id: None,
limit: Some(1),
window: Some(60),
reason_code: Some(ReasonCode::new("get_lost")),
namespace: None,
}];
let scoping = ItemScoping {
category: DataCategory::Error,
scoping: &Scoping {
organization_id: OrganizationId::new(42),
project_id: ProjectId::new(43),
project_key: ProjectKey::parse("a94ae32be2584e0bbd7a4cbb95971fee").unwrap(),
key_id: Some(44),
},
namespace: MetricNamespaceScoping::None,
};
let rate_limiter = build_rate_limiter();
assert!(!rate_limiter
.is_rate_limited(quotas, scoping, 1, false)
.unwrap()
.is_limited());
assert!(rate_limiter
.is_rate_limited(quotas, scoping, 1, false)
.unwrap()
.is_limited());
assert!(rate_limiter
.is_rate_limited(quotas, scoping, 0, false)
.unwrap()
.is_limited());
assert!(rate_limiter
.is_rate_limited(quotas, scoping, 1, false)
.unwrap()
.is_limited());
}
#[test]
fn test_quota_go_over() {
let quotas = &[Quota {
id: Some(format!("test_quota_go_over{}", uuid::Uuid::new_v4())),
categories: DataCategories::new(),
scope: QuotaScope::Organization,
scope_id: None,
limit: Some(2),
window: Some(60),
reason_code: Some(ReasonCode::new("get_lost")),
namespace: None,
}];
let scoping = ItemScoping {
category: DataCategory::Error,
scoping: &Scoping {
organization_id: OrganizationId::new(42),
project_id: ProjectId::new(43),
project_key: ProjectKey::parse("a94ae32be2584e0bbd7a4cbb95971fee").unwrap(),
key_id: Some(44),
},
namespace: MetricNamespaceScoping::None,
};
let rate_limiter = build_rate_limiter();
let is_limited = rate_limiter
.is_rate_limited(quotas, scoping, 1, true)
.unwrap()
.is_limited();
assert!(!is_limited);
let is_limited = rate_limiter
.is_rate_limited(quotas, scoping, 2, true)
.unwrap()
.is_limited();
assert!(!is_limited);
let is_limited = rate_limiter
.is_rate_limited(quotas, scoping, 0, true)
.unwrap()
.is_limited();
assert!(is_limited);
let is_limited = rate_limiter
.is_rate_limited(quotas, scoping, 1, true)
.unwrap()
.is_limited();
assert!(is_limited);
}
#[test]
fn test_bails_immediately_without_any_quota() {
let scoping = ItemScoping {
category: DataCategory::Error,
scoping: &Scoping {
organization_id: OrganizationId::new(42),
project_id: ProjectId::new(43),
project_key: ProjectKey::parse("a94ae32be2584e0bbd7a4cbb95971fee").unwrap(),
key_id: Some(44),
},
namespace: MetricNamespaceScoping::None,
};
let rate_limits: Vec<RateLimit> = build_rate_limiter()
.is_rate_limited(&[], scoping, 1, false)
.expect("rate limiting failed")
.into_iter()
.collect();
assert_eq!(rate_limits, vec![]);
}
#[test]
fn test_limited_with_unlimited_quota() {
let quotas = &[
Quota {
id: Some("q0".to_string()),
categories: DataCategories::new(),
scope: QuotaScope::Organization,
scope_id: None,
limit: None,
window: Some(1),
reason_code: Some(ReasonCode::new("project_quota0")),
namespace: None,
},
Quota {
id: Some("q1".to_string()),
categories: DataCategories::new(),
scope: QuotaScope::Organization,
scope_id: None,
limit: Some(1),
window: Some(1),
reason_code: Some(ReasonCode::new("project_quota1")),
namespace: None,
},
];
let scoping = ItemScoping {
category: DataCategory::Error,
scoping: &Scoping {
organization_id: OrganizationId::new(42),
project_id: ProjectId::new(43),
project_key: ProjectKey::parse("a94ae32be2584e0bbd7a4cbb95971fee").unwrap(),
key_id: Some(44),
},
namespace: MetricNamespaceScoping::None,
};
let rate_limiter = build_rate_limiter();
for i in 0..1 {
let rate_limits: Vec<RateLimit> = rate_limiter
.is_rate_limited(quotas, scoping, 1, false)
.expect("rate limiting failed")
.into_iter()
.collect();
if i == 0 {
assert_eq!(rate_limits, &[]);
} else {
assert_eq!(
rate_limits,
vec![RateLimit {
categories: DataCategories::new(),
scope: RateLimitScope::Organization(OrganizationId::new(42)),
reason_code: Some(ReasonCode::new("project_quota1")),
retry_after: rate_limits[0].retry_after,
namespaces: smallvec![],
}]
);
}
}
}
#[test]
fn test_quota_with_quantity() {
let quotas = &[Quota {
id: Some(format!("test_quantity_quota_{}", uuid::Uuid::new_v4())),
categories: DataCategories::new(),
scope: QuotaScope::Organization,
scope_id: None,
limit: Some(500),
window: Some(60),
reason_code: Some(ReasonCode::new("get_lost")),
namespace: None,
}];
let scoping = ItemScoping {
category: DataCategory::Error,
scoping: &Scoping {
organization_id: OrganizationId::new(42),
project_id: ProjectId::new(43),
project_key: ProjectKey::parse("a94ae32be2584e0bbd7a4cbb95971fee").unwrap(),
key_id: Some(44),
},
namespace: MetricNamespaceScoping::None,
};
let rate_limiter = build_rate_limiter();
for i in 0..10 {
let rate_limits: Vec<RateLimit> = rate_limiter
.is_rate_limited(quotas, scoping, 100, false)
.expect("rate limiting failed")
.into_iter()
.collect();
if i >= 5 {
assert_eq!(
rate_limits,
vec![RateLimit {
categories: DataCategories::new(),
scope: RateLimitScope::Organization(OrganizationId::new(42)),
reason_code: Some(ReasonCode::new("get_lost")),
retry_after: rate_limits[0].retry_after,
namespaces: smallvec![],
}]
);
} else {
assert_eq!(rate_limits, vec![]);
}
}
}
#[test]
fn test_get_redis_key_scoped() {
let quota = Quota {
id: Some("foo".to_owned()),
categories: DataCategories::new(),
scope: QuotaScope::Project,
scope_id: Some("42".to_owned()),
window: Some(2),
limit: Some(0),
reason_code: None,
namespace: None,
};
let scoping = ItemScoping {
category: DataCategory::Error,
scoping: &Scoping {
organization_id: OrganizationId::new(69420),
project_id: ProjectId::new(42),
project_key: ProjectKey::parse("a94ae32be2584e0bbd7a4cbb95971fee").unwrap(),
key_id: Some(4711),
},
namespace: MetricNamespaceScoping::None,
};
let timestamp = UnixTimestamp::from_secs(123_123_123);
let redis_quota = RedisQuota::new("a, scoping, timestamp).unwrap();
assert_eq!(redis_quota.key(), "quota:foo{69420}42:61561561");
}
#[test]
fn test_get_redis_key_unscoped() {
let quota = Quota {
id: Some("foo".to_owned()),
categories: DataCategories::new(),
scope: QuotaScope::Organization,
scope_id: None,
window: Some(10),
limit: Some(0),
reason_code: None,
namespace: None,
};
let scoping = ItemScoping {
category: DataCategory::Error,
scoping: &Scoping {
organization_id: OrganizationId::new(69420),
project_id: ProjectId::new(42),
project_key: ProjectKey::parse("a94ae32be2584e0bbd7a4cbb95971fee").unwrap(),
key_id: Some(4711),
},
namespace: MetricNamespaceScoping::None,
};
let timestamp = UnixTimestamp::from_secs(234_531);
let redis_quota = RedisQuota::new("a, scoping, timestamp).unwrap();
assert_eq!(redis_quota.key(), "quota:foo{69420}:23453");
}
#[test]
fn test_large_redis_limit_large() {
let quota = Quota {
id: Some("foo".to_owned()),
categories: DataCategories::new(),
scope: QuotaScope::Organization,
scope_id: None,
window: Some(10),
limit: Some(9223372036854775808), reason_code: None,
namespace: None,
};
let scoping = ItemScoping {
category: DataCategory::Error,
scoping: &Scoping {
organization_id: OrganizationId::new(69420),
project_id: ProjectId::new(42),
project_key: ProjectKey::parse("a94ae32be2584e0bbd7a4cbb95971fee").unwrap(),
key_id: Some(4711),
},
namespace: MetricNamespaceScoping::None,
};
let timestamp = UnixTimestamp::from_secs(234_531);
let redis_quota = RedisQuota::new("a, scoping, timestamp).unwrap();
assert_eq!(redis_quota.limit(), -1);
}
#[test]
#[allow(clippy::disallowed_names, clippy::let_unit_value)]
fn test_is_rate_limited_script() {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|duration| duration.as_secs())
.unwrap();
let rate_limiter = build_rate_limiter();
let mut client = rate_limiter.pool.client().expect("get client");
let mut conn = client.connection().expect("Redis connection");
let foo = format!("foo___{now}");
let r_foo = format!("r:foo___{now}");
let bar = format!("bar___{now}");
let r_bar = format!("r:bar___{now}");
let apple = format!("apple___{now}");
let orange = format!("orange___{now}");
let baz = format!("baz___{now}");
let script = RedisScripts::load_is_rate_limited();
let mut invocation = script.prepare_invoke();
invocation
.key(&foo) .key(&r_foo) .key(&bar) .key(&r_bar) .arg(1) .arg(now + 60) .arg(1) .arg(false) .arg(2) .arg(now + 120) .arg(1) .arg(false); assert_eq!(
invocation.invoke::<Vec<bool>>(&mut conn).unwrap(),
vec![false, false]
);
assert_eq!(
invocation.invoke::<Vec<bool>>(&mut conn).unwrap(),
vec![true, false]
);
assert_eq!(
invocation.invoke::<Vec<bool>>(&mut conn).unwrap(),
vec![true, false]
);
assert_eq!(conn.get::<_, String>(&foo).unwrap(), "1");
let ttl: u64 = conn.ttl(&foo).unwrap();
assert!(ttl >= 59);
assert!(ttl <= 60);
assert_eq!(conn.get::<_, String>(&bar).unwrap(), "1");
let ttl: u64 = conn.ttl(&bar).unwrap();
assert!(ttl >= 119);
assert!(ttl <= 120);
let () = conn.get(r_foo).unwrap();
let () = conn.get(r_bar).unwrap();
let () = conn.set(&apple, 5).unwrap();
let mut invocation = script.prepare_invoke();
invocation
.key(&orange) .key(&baz) .arg(1) .arg(now + 60) .arg(1) .arg(false);
assert_eq!(
invocation.invoke::<Vec<bool>>(&mut conn).unwrap(),
vec![false]
);
assert_eq!(
invocation.invoke::<Vec<bool>>(&mut conn).unwrap(),
vec![true]
);
let mut invocation = script.prepare_invoke();
invocation
.key(&orange) .key(&apple) .arg(1) .arg(now + 60) .arg(1) .arg(false);
assert_eq!(
invocation.invoke::<Vec<bool>>(&mut conn).unwrap(),
vec![false]
);
}
}