1use std::fmt;
4
5use chrono::{DateTime, Utc};
6use serde::{Deserialize, Serialize};
7
8use relay_protocol::RuleCondition;
9
10const SAMPLING_CONFIG_VERSION: u16 = 2;
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
21#[serde(rename_all = "camelCase")]
22pub struct SamplingConfig {
23 #[serde(default = "SamplingConfig::legacy_version")]
27 pub version: u16,
28
29 #[serde(default)]
31 pub rules: Vec<SamplingRule>,
32
33 #[serde(default, skip_serializing)]
40 pub rules_v2: Vec<SamplingRule>,
41}
42
43impl SamplingConfig {
44 pub fn new() -> Self {
46 Self::default()
47 }
48
49 pub fn unsupported(&self) -> bool {
51 debug_assert!(self.version > 1, "SamplingConfig not normalized");
52 self.version > SAMPLING_CONFIG_VERSION || !self.rules.iter().all(SamplingRule::supported)
53 }
54
55 pub fn filter_rules(&self, rule_type: RuleType) -> impl Iterator<Item = &SamplingRule> {
57 self.rules.iter().filter(move |rule| rule.ty == rule_type)
58 }
59
60 pub fn normalize(&mut self) {
62 if self.version == Self::legacy_version() {
63 self.rules.append(&mut self.rules_v2);
64 self.version = SAMPLING_CONFIG_VERSION;
65 }
66 }
67
68 const fn legacy_version() -> u16 {
69 1
70 }
71}
72
73impl Default for SamplingConfig {
74 fn default() -> Self {
75 Self {
76 version: SAMPLING_CONFIG_VERSION,
77 rules: vec![],
78 rules_v2: vec![],
79 }
80 }
81}
82
83#[derive(Debug, Clone, Serialize, Deserialize)]
85#[serde(rename_all = "camelCase")]
86pub struct SamplingRule {
87 pub condition: RuleCondition,
91
92 pub sampling_value: SamplingValue,
94
95 #[serde(rename = "type")]
97 pub ty: RuleType,
98
99 pub id: RuleId,
101
102 #[serde(default, skip_serializing_if = "TimeRange::is_empty")]
107 pub time_range: TimeRange,
108
109 #[serde(default, skip_serializing_if = "is_default")]
111 pub decaying_fn: DecayingFunction,
112}
113
114impl SamplingRule {
115 fn supported(&self) -> bool {
116 self.condition.supported() && self.ty != RuleType::Unsupported
117 }
118
119 pub fn apply_decaying_fn(&self, sample_rate: f64, now: DateTime<Utc>) -> Option<f64> {
121 self.decaying_fn
122 .adjust_sample_rate(sample_rate, now, self.time_range)
123 }
124}
125
126fn is_default<T: Default + PartialEq>(t: &T) -> bool {
128 *t == T::default()
129}
130
131#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
135#[serde(rename_all = "camelCase")]
136#[serde(tag = "type")]
137pub enum SamplingValue {
138 SampleRate {
143 value: f64,
145 },
146
147 Factor {
153 value: f64,
155 },
156
157 Reservoir {
162 limit: i64,
164 },
165}
166
167#[derive(Debug, Copy, Clone, Serialize, Deserialize, Eq, PartialEq)]
169#[serde(rename_all = "camelCase")]
170pub enum RuleType {
171 Trace,
174 Transaction,
176 #[serde(other)]
180 Unsupported,
181}
182
183#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash, PartialOrd, Ord)]
188pub struct RuleId(pub u32);
189
190impl fmt::Display for RuleId {
191 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
192 write!(f, "{}", self.0)
193 }
194}
195
196#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize)]
203pub struct TimeRange {
204 pub start: Option<DateTime<Utc>>,
206
207 pub end: Option<DateTime<Utc>>,
209}
210
211impl TimeRange {
212 pub fn is_empty(&self) -> bool {
214 self.start.is_none() && self.end.is_none()
215 }
216
217 pub fn contains(&self, time: DateTime<Utc>) -> bool {
226 self.start.is_none_or(|s| s <= time) && self.end.is_none_or(|e| time < e)
227 }
228}
229
230#[derive(Default, Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
232#[serde(rename_all = "camelCase")]
233#[serde(tag = "type")]
234pub enum DecayingFunction {
235 #[serde(rename_all = "camelCase")]
240 Linear {
241 decayed_value: f64,
243 },
244
245 #[default]
247 Constant,
248}
249
250impl DecayingFunction {
251 pub fn adjust_sample_rate(
253 &self,
254 sample_rate: f64,
255 now: DateTime<Utc>,
256 time_range: TimeRange,
257 ) -> Option<f64> {
258 match self {
259 DecayingFunction::Linear { decayed_value } => {
260 let (Some(start), Some(end)) = (time_range.start, time_range.end) else {
261 return None;
262 };
263
264 if sample_rate < *decayed_value {
265 return None;
266 }
267
268 let now = now.timestamp() as f64;
269 let start = start.timestamp() as f64;
270 let end = end.timestamp() as f64;
271
272 let progress_ratio = ((now - start) / (end - start)).clamp(0.0, 1.0);
273
274 let interval = decayed_value - sample_rate;
276 Some(sample_rate + (interval * progress_ratio))
277 }
278 DecayingFunction::Constant => Some(sample_rate),
279 }
280 }
281}
282
283#[cfg(test)]
284mod tests {
285 use chrono::TimeZone;
286
287 use super::*;
288
289 #[test]
290 fn config_deserialize() {
291 let json = include_str!("../tests/fixtures/sampling_config.json");
292 serde_json::from_str::<SamplingConfig>(json).unwrap();
293 }
294
295 #[test]
296 fn test_supported() {
297 let rule: SamplingRule = serde_json::from_value(serde_json::json!({
298 "id": 1,
299 "type": "trace",
300 "samplingValue": {"type": "sampleRate", "value": 1.0},
301 "condition": {"op": "and", "inner": []}
302 }))
303 .unwrap();
304 assert!(rule.supported());
305 }
306
307 #[test]
308 fn test_unsupported_rule_type() {
309 let rule: SamplingRule = serde_json::from_value(serde_json::json!({
310 "id": 1,
311 "type": "new_rule_type_unknown_to_this_relay",
312 "samplingValue": {"type": "sampleRate", "value": 1.0},
313 "condition": {"op": "and", "inner": []}
314 }))
315 .unwrap();
316 assert!(!rule.supported());
317 }
318
319 #[test]
320 fn test_non_decaying_sampling_rule_deserialization() {
321 let serialized_rule = r#"{
322 "condition":{
323 "op":"and",
324 "inner": [
325 { "op" : "glob", "name": "releases", "value":["1.1.1", "1.1.2"]}
326 ]
327 },
328 "samplingValue": {"type": "sampleRate", "value": 0.7},
329 "type": "trace",
330 "id": 1
331 }"#;
332
333 let rule: SamplingRule = serde_json::from_str(serialized_rule).unwrap();
334 assert_eq!(
335 rule.sampling_value,
336 SamplingValue::SampleRate { value: 0.7f64 }
337 );
338 assert_eq!(rule.ty, RuleType::Trace);
339 }
340
341 #[test]
342 fn test_non_decaying_sampling_rule_deserialization_with_factor() {
343 let serialized_rule = r#"{
344 "condition":{
345 "op":"and",
346 "inner": [
347 { "op" : "glob", "name": "releases", "value":["1.1.1", "1.1.2"]}
348 ]
349 },
350 "samplingValue": {"type": "factor", "value": 5.0},
351 "type": "trace",
352 "id": 1
353 }"#;
354
355 let rule: SamplingRule = serde_json::from_str(serialized_rule).unwrap();
356 assert_eq!(rule.sampling_value, SamplingValue::Factor { value: 5.0 });
357 assert_eq!(rule.ty, RuleType::Trace);
358 }
359
360 #[test]
361 fn test_sampling_rule_with_constant_decaying_function_deserialization() {
362 let serialized_rule = r#"{
363 "condition":{
364 "op":"and",
365 "inner": [
366 { "op" : "glob", "name": "releases", "value":["1.1.1", "1.1.2"]}
367 ]
368 },
369 "samplingValue": {"type": "factor", "value": 5.0},
370 "type": "trace",
371 "id": 1,
372 "timeRange": {
373 "start": "2022-10-10T00:00:00.000000Z",
374 "end": "2022-10-20T00:00:00.000000Z"
375 }
376 }"#;
377 let rule: Result<SamplingRule, _> = serde_json::from_str(serialized_rule);
378 let rule = rule.unwrap();
379 let time_range = rule.time_range;
380 let decaying_function = rule.decaying_fn;
381
382 assert_eq!(
383 time_range.start,
384 Some(Utc.with_ymd_and_hms(2022, 10, 10, 0, 0, 0).unwrap())
385 );
386 assert_eq!(
387 time_range.end,
388 Some(Utc.with_ymd_and_hms(2022, 10, 20, 0, 0, 0).unwrap())
389 );
390 assert_eq!(decaying_function, DecayingFunction::Constant);
391 }
392
393 #[test]
394 fn test_sampling_rule_with_linear_decaying_function_deserialization() {
395 let serialized_rule = r#"{
396 "condition":{
397 "op":"and",
398 "inner": [
399 { "op" : "glob", "name": "releases", "value":["1.1.1", "1.1.2"]}
400 ]
401 },
402 "samplingValue": {"type": "sampleRate", "value": 1.0},
403 "type": "trace",
404 "id": 1,
405 "timeRange": {
406 "start": "2022-10-10T00:00:00.000000Z",
407 "end": "2022-10-20T00:00:00.000000Z"
408 },
409 "decayingFn": {
410 "type": "linear",
411 "decayedValue": 0.9
412 }
413 }"#;
414 let rule: Result<SamplingRule, _> = serde_json::from_str(serialized_rule);
415 let rule = rule.unwrap();
416 let decaying_function = rule.decaying_fn;
417
418 assert_eq!(
419 decaying_function,
420 DecayingFunction::Linear { decayed_value: 0.9 }
421 );
422 }
423
424 #[test]
425 fn test_legacy_deserialization() {
426 let serialized_rule = r#"{
427 "rules": [],
428 "rulesV2": [
429 {
430 "samplingValue":{
431 "type": "sampleRate",
432 "value": 0.5
433 },
434 "type": "trace",
435 "active": true,
436 "condition": {
437 "op": "and",
438 "inner": []
439 },
440 "id": 1000
441 }
442 ],
443 "mode": "received"
444 }"#;
445 let mut config: SamplingConfig = serde_json::from_str(serialized_rule).unwrap();
446 config.normalize();
447
448 assert_eq!(config.version, SAMPLING_CONFIG_VERSION);
451 assert_eq!(
452 config.rules[0].sampling_value,
453 SamplingValue::SampleRate { value: 0.5 }
454 );
455 assert!(config.rules_v2.is_empty());
456 }
457
458 #[test]
459 fn test_sampling_config_with_rules_and_rules_v2_serialization() {
460 let config = SamplingConfig {
461 rules: vec![SamplingRule {
462 condition: RuleCondition::all(),
463 sampling_value: SamplingValue::Factor { value: 2.0 },
464 ty: RuleType::Transaction,
465 id: RuleId(1),
466 time_range: Default::default(),
467 decaying_fn: Default::default(),
468 }],
469 ..SamplingConfig::new()
470 };
471
472 let serialized_config = serde_json::to_string_pretty(&config).unwrap();
473 let expected_serialized_config = r#"{
474 "version": 2,
475 "rules": [
476 {
477 "condition": {
478 "op": "and",
479 "inner": []
480 },
481 "samplingValue": {
482 "type": "factor",
483 "value": 2.0
484 },
485 "type": "transaction",
486 "id": 1
487 }
488 ]
489}"#;
490
491 assert_eq!(serialized_config, expected_serialized_config)
492 }
493
494 #[test]
496 fn test_decay_fn_constant() {
497 let sample_rate = 0.5;
498
499 assert_eq!(
500 DecayingFunction::Constant.adjust_sample_rate(
501 sample_rate,
502 Utc::now(),
503 TimeRange::default()
504 ),
505 Some(sample_rate)
506 );
507 }
508
509 #[test]
511 fn test_decay_fn_linear() {
512 let decaying_fn = DecayingFunction::Linear { decayed_value: 0.5 };
513 let time_range = TimeRange {
514 start: Some(Utc.with_ymd_and_hms(1970, 10, 10, 0, 0, 0).unwrap()),
515 end: Some(Utc.with_ymd_and_hms(1970, 10, 12, 0, 0, 0).unwrap()),
516 };
517
518 let start = Utc.with_ymd_and_hms(1970, 10, 10, 0, 0, 0).unwrap();
519 let halfway = Utc.with_ymd_and_hms(1970, 10, 11, 0, 0, 0).unwrap();
520 let end = Utc.with_ymd_and_hms(1970, 10, 11, 23, 59, 59).unwrap();
521
522 assert_eq!(
524 decaying_fn.adjust_sample_rate(1.0, start, time_range),
525 Some(1.0)
526 );
527
528 assert_eq!(
530 decaying_fn.adjust_sample_rate(1.0, halfway, time_range),
531 Some(0.75)
532 );
533
534 assert_eq!(
536 decaying_fn.adjust_sample_rate(1.0, end, time_range),
537 Some(0.5000028935185186)
539 );
540
541 let mut time_range_without_start = time_range;
543 time_range_without_start.start = None;
544
545 assert!(decaying_fn
546 .adjust_sample_rate(1.0, halfway, time_range_without_start)
547 .is_none());
548
549 let mut time_range_without_end = time_range;
550 time_range_without_end.end = None;
551
552 assert!(decaying_fn
553 .adjust_sample_rate(1.0, halfway, time_range_without_end)
554 .is_none());
555 }
556}