1use std::pin::Pin;
2use std::sync::Arc;
3use std::sync::Mutex;
4use std::sync::atomic::AtomicU64;
5use std::task::{Context, Poll};
6use std::time::{Duration, Instant};
7
8use bytes::Bytes;
9use futures_util::Stream;
10use objectstore_service::PayloadStream;
11use objectstore_service::id::ObjectContext;
12use objectstore_types::scope::Scopes;
13use serde::{Deserialize, Serialize};
14
15#[derive(Clone, Debug, Default, Deserialize, Serialize, PartialEq)]
17pub struct RateLimits {
18 pub throughput: ThroughputLimits,
20 pub bandwidth: BandwidthLimits,
22}
23
24#[derive(Clone, Debug, Default, Deserialize, Serialize, PartialEq)]
25pub struct ThroughputLimits {
26 pub global_rps: Option<u32>,
30
31 pub burst: u32,
36
37 pub usecase_pct: Option<u8>,
41
42 pub scope_pct: Option<u8>,
50
51 pub rules: Vec<ThroughputRule>,
53}
54
55#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
56pub struct ThroughputRule {
57 pub usecase: Option<String>,
61
62 pub scopes: Vec<(String, String)>,
67
68 pub rps: Option<u32>,
73
74 pub pct: Option<u8>,
79}
80
81impl ThroughputRule {
82 pub fn matches(&self, context: &ObjectContext) -> bool {
84 if let Some(ref rule_usecase) = self.usecase
85 && rule_usecase != &context.usecase
86 {
87 return false;
88 }
89
90 for (scope_name, scope_value) in &self.scopes {
91 match context.scopes.get_value(scope_name) {
92 Some(value) if value == scope_value => (),
93 _ => return false,
94 }
95 }
96
97 true
98 }
99}
100
101#[derive(Clone, Debug, Default, Deserialize, Serialize, PartialEq)]
102pub struct BandwidthLimits {
103 pub global_bps: Option<u64>,
107}
108
109#[derive(Debug)]
110pub struct RateLimiter {
111 bandwidth: BandwidthRateLimiter,
112 throughput: ThroughputRateLimiter,
113}
114
115impl RateLimiter {
116 pub fn new(config: RateLimits) -> Self {
117 Self {
118 bandwidth: BandwidthRateLimiter::new(config.bandwidth),
119 throughput: ThroughputRateLimiter::new(config.throughput),
120 }
121 }
122
123 pub fn check(&self, context: &ObjectContext) -> bool {
127 self.throughput.check(context) && self.bandwidth.check()
128 }
129
130 pub fn bytes_accumulator(&self) -> Arc<AtomicU64> {
132 Arc::clone(&self.bandwidth.accumulator)
133 }
134}
135
136#[derive(Debug)]
137struct BandwidthRateLimiter {
138 config: BandwidthLimits,
139 accumulator: Arc<AtomicU64>,
141 estimate: Arc<AtomicU64>,
143}
144
145impl BandwidthRateLimiter {
146 fn new(config: BandwidthLimits) -> Self {
147 let accumulator = Arc::new(AtomicU64::new(0));
148 let estimate = Arc::new(AtomicU64::new(0));
149
150 let accumulator_clone = Arc::clone(&accumulator);
151 let estimate_clone = Arc::clone(&estimate);
152 tokio::task::spawn(async move {
153 Self::estimator(accumulator_clone, estimate_clone).await;
154 });
155
156 Self {
157 config,
158 accumulator,
159 estimate,
160 }
161 }
162
163 async fn estimator(accumulator: Arc<AtomicU64>, estimate: Arc<AtomicU64>) {
168 const TICK: Duration = Duration::from_millis(50); const ALPHA: f64 = 0.2; let mut interval = tokio::time::interval(TICK);
172 let to_bps = 1.0 / TICK.as_secs_f64(); let mut ewma: f64 = 0.0;
174 loop {
175 interval.tick().await;
176 let current = accumulator.swap(0, std::sync::atomic::Ordering::Relaxed);
177 let bps = (current as f64) * to_bps;
178 ewma = ALPHA * bps + (1.0 - ALPHA) * ewma;
179
180 let ewma_int = ewma.floor() as u64;
181 estimate.store(ewma_int, std::sync::atomic::Ordering::Relaxed);
182 merni::gauge!("server.bandwidth.ewma"@b: ewma_int);
183 }
184 }
185
186 fn check(&self) -> bool {
187 let Some(bps) = self.config.global_bps else {
188 return true;
189 };
190 self.estimate.load(std::sync::atomic::Ordering::Relaxed) <= bps
191 }
192}
193
194#[derive(Debug)]
195struct ThroughputRateLimiter {
196 config: ThroughputLimits,
197 global: Option<Mutex<TokenBucket>>,
198 usecases: papaya::HashMap<String, Mutex<TokenBucket>>,
201 scopes: papaya::HashMap<Scopes, Mutex<TokenBucket>>,
202 rules: papaya::HashMap<usize, Mutex<TokenBucket>>,
203}
204
205impl ThroughputRateLimiter {
206 fn new(config: ThroughputLimits) -> Self {
207 let global = config
208 .global_rps
209 .map(|rps| Mutex::new(TokenBucket::new(rps, config.burst)));
210
211 Self {
212 config,
213 global,
214 usecases: papaya::HashMap::new(),
215 scopes: papaya::HashMap::new(),
216 rules: papaya::HashMap::new(),
217 }
218 }
219
220 fn check(&self, context: &ObjectContext) -> bool {
221 if let Some(ref global) = self.global
225 && !global.lock().unwrap().try_acquire()
226 {
227 return false;
228 }
229
230 if let Some(usecase_rps) = self.usecase_rps() {
232 let guard = self.usecases.pin();
233 let bucket = guard
234 .get_or_insert_with(context.usecase.clone(), || self.create_bucket(usecase_rps));
235 if !bucket.lock().unwrap().try_acquire() {
236 return false;
237 }
238 }
239
240 if let Some(scope_rps) = self.scope_rps() {
242 let guard = self.scopes.pin();
243 let bucket =
244 guard.get_or_insert_with(context.scopes.clone(), || self.create_bucket(scope_rps));
245 if !bucket.lock().unwrap().try_acquire() {
246 return false;
247 }
248 }
249
250 for (idx, rule) in self.config.rules.iter().enumerate() {
252 if !rule.matches(context) {
253 continue;
254 }
255 let Some(rule_rps) = self.rule_rps(rule) else {
256 continue;
257 };
258 let guard = self.rules.pin();
259 let bucket = guard.get_or_insert_with(idx, || self.create_bucket(rule_rps));
260 if !bucket.lock().unwrap().try_acquire() {
261 return false;
262 }
263 }
264
265 true
266 }
267
268 fn create_bucket(&self, rps: u32) -> Mutex<TokenBucket> {
269 Mutex::new(TokenBucket::new(rps, self.config.burst))
270 }
271
272 fn usecase_rps(&self) -> Option<u32> {
274 let global_rps = self.config.global_rps?;
275 let pct = self.config.usecase_pct?;
276 Some(((global_rps as f64) * (pct as f64 / 100.0)) as u32)
277 }
278
279 fn scope_rps(&self) -> Option<u32> {
281 let global_rps = self.config.global_rps?;
282 let pct = self.config.scope_pct?;
283 Some(((global_rps as f64) * (pct as f64 / 100.0)) as u32)
284 }
285
286 fn rule_rps(&self, rule: &ThroughputRule) -> Option<u32> {
288 let pct_limit = rule.pct.and_then(|p| {
289 self.config
290 .global_rps
291 .map(|g| ((g as f64) * (p as f64 / 100.0)) as u32)
292 });
293
294 match (rule.rps, pct_limit) {
295 (Some(r), Some(p)) => Some(r.min(p)),
296 (Some(r), None) => Some(r),
297 (None, Some(p)) => Some(p),
298 (None, None) => None,
299 }
300 }
301}
302
303#[derive(Debug)]
310struct TokenBucket {
311 refill_rate: f64,
312 capacity: f64,
313 tokens: f64,
314 last_update: Instant,
315}
316
317impl TokenBucket {
318 pub fn new(rps: u32, burst: u32) -> Self {
323 Self {
324 refill_rate: rps as f64,
325 capacity: (rps + burst) as f64,
326 tokens: (rps + burst) as f64,
327 last_update: Instant::now(),
328 }
329 }
330
331 pub fn try_acquire(&mut self) -> bool {
335 let now = Instant::now();
336 let refill = now.duration_since(self.last_update).as_secs_f64() * self.refill_rate;
337 let refilled = (self.tokens + refill).min(self.capacity);
338
339 if refilled.floor() > self.tokens.floor() {
341 self.last_update = now;
342 self.tokens = refilled;
343 }
344
345 if self.tokens >= 1.0 {
347 self.tokens -= 1.0;
348 true
349 } else {
350 false
351 }
352 }
353}
354
355pub(crate) struct MeteredPayloadStream {
360 inner: PayloadStream,
361 accumulator: Arc<AtomicU64>,
362}
363
364impl MeteredPayloadStream {
365 pub fn from(inner: PayloadStream, accumulator: Arc<AtomicU64>) -> Self {
366 Self { inner, accumulator }
367 }
368}
369
370impl std::fmt::Debug for MeteredPayloadStream {
371 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
372 f.debug_struct("MeteredPayloadStream")
373 .field("accumulator", &self.accumulator)
374 .finish()
375 }
376}
377
378impl Stream for MeteredPayloadStream {
379 type Item = std::io::Result<Bytes>;
380
381 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
382 let this = self.get_mut();
383 let res = this.inner.as_mut().poll_next(cx);
384 if let Poll::Ready(Some(Ok(ref bytes))) = res {
385 this.accumulator
386 .fetch_add(bytes.len() as u64, std::sync::atomic::Ordering::Relaxed);
387 }
388 res
389 }
390}