1use std::fmt;
4
5use chrono::{DateTime, Utc};
6use serde::{Deserialize, Serialize};
7
8use relay_protocol::RuleCondition;
9
10const SAMPLING_CONFIG_VERSION: u16 = 2;
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
21#[serde(rename_all = "camelCase")]
22pub struct SamplingConfig {
23 #[serde(default = "SamplingConfig::legacy_version")]
27 pub version: u16,
28
29 #[serde(default)]
31 pub rules: Vec<SamplingRule>,
32
33 #[serde(default, skip_serializing)]
40 pub rules_v2: Vec<SamplingRule>,
41}
42
43impl SamplingConfig {
44 pub fn new() -> Self {
46 Self::default()
47 }
48
49 pub fn unsupported(&self) -> bool {
51 debug_assert!(self.version > 1, "SamplingConfig not normalized");
52 self.version > SAMPLING_CONFIG_VERSION || !self.rules.iter().all(SamplingRule::supported)
53 }
54
55 pub fn filter_rules(&self, rule_type: RuleType) -> impl Iterator<Item = &SamplingRule> {
57 self.rules.iter().filter(move |rule| rule.ty == rule_type)
58 }
59
60 pub fn normalize(&mut self) {
62 if self.version == Self::legacy_version() {
63 self.rules.append(&mut self.rules_v2);
64 self.version = SAMPLING_CONFIG_VERSION;
65 }
66 }
67
68 const fn legacy_version() -> u16 {
69 1
70 }
71}
72
73impl Default for SamplingConfig {
74 fn default() -> Self {
75 Self {
76 version: SAMPLING_CONFIG_VERSION,
77 rules: vec![],
78 rules_v2: vec![],
79 }
80 }
81}
82
83#[derive(Debug, Clone, Serialize, Deserialize)]
85#[serde(rename_all = "camelCase")]
86pub struct SamplingRule {
87 pub condition: RuleCondition,
91
92 pub sampling_value: SamplingValue,
94
95 #[serde(rename = "type")]
97 pub ty: RuleType,
98
99 pub id: RuleId,
101
102 #[serde(default, skip_serializing_if = "TimeRange::is_empty")]
107 pub time_range: TimeRange,
108
109 #[serde(default, skip_serializing_if = "is_default")]
111 pub decaying_fn: DecayingFunction,
112}
113
114impl SamplingRule {
115 fn supported(&self) -> bool {
116 self.condition.supported() && self.ty != RuleType::Unsupported
117 }
118
119 pub fn apply_decaying_fn(&self, sample_rate: f64, now: DateTime<Utc>) -> Option<f64> {
121 self.decaying_fn
122 .adjust_sample_rate(sample_rate, now, self.time_range)
123 }
124}
125
126fn is_default<T: Default + PartialEq>(t: &T) -> bool {
128 *t == T::default()
129}
130
131#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
135#[serde(rename_all = "camelCase")]
136#[serde(tag = "type")]
137pub enum SamplingValue {
138 SampleRate {
143 value: f64,
145 },
146
147 Factor {
153 value: f64,
155 },
156
157 Reservoir {
162 limit: i64,
164 },
165
166 MinimumSampleRate {
173 value: f64,
175 },
176}
177
178#[derive(Debug, Copy, Clone, Serialize, Deserialize, Eq, PartialEq)]
180#[serde(rename_all = "camelCase")]
181pub enum RuleType {
182 Trace,
185 Transaction,
189 Project,
193 #[serde(other)]
197 Unsupported,
198}
199
200#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash, PartialOrd, Ord)]
205pub struct RuleId(pub u32);
206
207impl fmt::Display for RuleId {
208 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
209 write!(f, "{}", self.0)
210 }
211}
212
213#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize)]
220pub struct TimeRange {
221 pub start: Option<DateTime<Utc>>,
223
224 pub end: Option<DateTime<Utc>>,
226}
227
228impl TimeRange {
229 pub fn is_empty(&self) -> bool {
231 self.start.is_none() && self.end.is_none()
232 }
233
234 pub fn contains(&self, time: DateTime<Utc>) -> bool {
243 self.start.is_none_or(|s| s <= time) && self.end.is_none_or(|e| time < e)
244 }
245}
246
247#[derive(Default, Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
249#[serde(rename_all = "camelCase")]
250#[serde(tag = "type")]
251pub enum DecayingFunction {
252 #[serde(rename_all = "camelCase")]
257 Linear {
258 decayed_value: f64,
260 },
261
262 #[default]
264 Constant,
265}
266
267impl DecayingFunction {
268 pub fn adjust_sample_rate(
270 &self,
271 sample_rate: f64,
272 now: DateTime<Utc>,
273 time_range: TimeRange,
274 ) -> Option<f64> {
275 match self {
276 DecayingFunction::Linear { decayed_value } => {
277 let (Some(start), Some(end)) = (time_range.start, time_range.end) else {
278 return None;
279 };
280
281 if sample_rate < *decayed_value {
282 return None;
283 }
284
285 let now = now.timestamp() as f64;
286 let start = start.timestamp() as f64;
287 let end = end.timestamp() as f64;
288
289 let progress_ratio = ((now - start) / (end - start)).clamp(0.0, 1.0);
290
291 let interval = decayed_value - sample_rate;
293 Some(sample_rate + (interval * progress_ratio))
294 }
295 DecayingFunction::Constant => Some(sample_rate),
296 }
297 }
298}
299
300#[cfg(test)]
301mod tests {
302 use chrono::TimeZone;
303
304 use super::*;
305
306 #[test]
307 fn config_deserialize() {
308 let json = include_str!("../tests/fixtures/sampling_config.json");
309 serde_json::from_str::<SamplingConfig>(json).unwrap();
310 }
311
312 #[test]
313 fn test_supported() {
314 let rule: SamplingRule = serde_json::from_value(serde_json::json!({
315 "id": 1,
316 "type": "trace",
317 "samplingValue": {"type": "sampleRate", "value": 1.0},
318 "condition": {"op": "and", "inner": []}
319 }))
320 .unwrap();
321 assert!(rule.supported());
322 }
323
324 #[test]
325 fn test_unsupported_rule_type() {
326 let rule: SamplingRule = serde_json::from_value(serde_json::json!({
327 "id": 1,
328 "type": "new_rule_type_unknown_to_this_relay",
329 "samplingValue": {"type": "sampleRate", "value": 1.0},
330 "condition": {"op": "and", "inner": []}
331 }))
332 .unwrap();
333 assert!(!rule.supported());
334 }
335
336 #[test]
337 fn test_non_decaying_sampling_rule_deserialization() {
338 let serialized_rule = r#"{
339 "condition":{
340 "op":"and",
341 "inner": [
342 { "op" : "glob", "name": "releases", "value":["1.1.1", "1.1.2"]}
343 ]
344 },
345 "samplingValue": {"type": "sampleRate", "value": 0.7},
346 "type": "trace",
347 "id": 1
348 }"#;
349
350 let rule: SamplingRule = serde_json::from_str(serialized_rule).unwrap();
351 assert_eq!(
352 rule.sampling_value,
353 SamplingValue::SampleRate { value: 0.7f64 }
354 );
355 assert_eq!(rule.ty, RuleType::Trace);
356 }
357
358 #[test]
359 fn test_non_decaying_sampling_rule_deserialization_with_factor() {
360 let serialized_rule = r#"{
361 "condition":{
362 "op":"and",
363 "inner": [
364 { "op" : "glob", "name": "releases", "value":["1.1.1", "1.1.2"]}
365 ]
366 },
367 "samplingValue": {"type": "factor", "value": 5.0},
368 "type": "trace",
369 "id": 1
370 }"#;
371
372 let rule: SamplingRule = serde_json::from_str(serialized_rule).unwrap();
373 assert_eq!(rule.sampling_value, SamplingValue::Factor { value: 5.0 });
374 assert_eq!(rule.ty, RuleType::Trace);
375 }
376
377 #[test]
378 fn test_sampling_rule_with_constant_decaying_function_deserialization() {
379 let serialized_rule = r#"{
380 "condition":{
381 "op":"and",
382 "inner": [
383 { "op" : "glob", "name": "releases", "value":["1.1.1", "1.1.2"]}
384 ]
385 },
386 "samplingValue": {"type": "factor", "value": 5.0},
387 "type": "trace",
388 "id": 1,
389 "timeRange": {
390 "start": "2022-10-10T00:00:00.000000Z",
391 "end": "2022-10-20T00:00:00.000000Z"
392 }
393 }"#;
394 let rule: Result<SamplingRule, _> = serde_json::from_str(serialized_rule);
395 let rule = rule.unwrap();
396 let time_range = rule.time_range;
397 let decaying_function = rule.decaying_fn;
398
399 assert_eq!(
400 time_range.start,
401 Some(Utc.with_ymd_and_hms(2022, 10, 10, 0, 0, 0).unwrap())
402 );
403 assert_eq!(
404 time_range.end,
405 Some(Utc.with_ymd_and_hms(2022, 10, 20, 0, 0, 0).unwrap())
406 );
407 assert_eq!(decaying_function, DecayingFunction::Constant);
408 }
409
410 #[test]
411 fn test_sampling_rule_with_linear_decaying_function_deserialization() {
412 let serialized_rule = r#"{
413 "condition":{
414 "op":"and",
415 "inner": [
416 { "op" : "glob", "name": "releases", "value":["1.1.1", "1.1.2"]}
417 ]
418 },
419 "samplingValue": {"type": "sampleRate", "value": 1.0},
420 "type": "trace",
421 "id": 1,
422 "timeRange": {
423 "start": "2022-10-10T00:00:00.000000Z",
424 "end": "2022-10-20T00:00:00.000000Z"
425 },
426 "decayingFn": {
427 "type": "linear",
428 "decayedValue": 0.9
429 }
430 }"#;
431 let rule: Result<SamplingRule, _> = serde_json::from_str(serialized_rule);
432 let rule = rule.unwrap();
433 let decaying_function = rule.decaying_fn;
434
435 assert_eq!(
436 decaying_function,
437 DecayingFunction::Linear { decayed_value: 0.9 }
438 );
439 }
440
441 #[test]
442 fn test_legacy_deserialization() {
443 let serialized_rule = r#"{
444 "rules": [],
445 "rulesV2": [
446 {
447 "samplingValue":{
448 "type": "sampleRate",
449 "value": 0.5
450 },
451 "type": "trace",
452 "active": true,
453 "condition": {
454 "op": "and",
455 "inner": []
456 },
457 "id": 1000
458 }
459 ],
460 "mode": "received"
461 }"#;
462 let mut config: SamplingConfig = serde_json::from_str(serialized_rule).unwrap();
463 config.normalize();
464
465 assert_eq!(config.version, SAMPLING_CONFIG_VERSION);
468 assert_eq!(
469 config.rules[0].sampling_value,
470 SamplingValue::SampleRate { value: 0.5 }
471 );
472 assert!(config.rules_v2.is_empty());
473 }
474
475 #[test]
476 fn test_sampling_config_with_rules_and_rules_v2_serialization() {
477 let config = SamplingConfig {
478 rules: vec![SamplingRule {
479 condition: RuleCondition::all(),
480 sampling_value: SamplingValue::Factor { value: 2.0 },
481 ty: RuleType::Transaction,
482 id: RuleId(1),
483 time_range: Default::default(),
484 decaying_fn: Default::default(),
485 }],
486 ..SamplingConfig::new()
487 };
488
489 let serialized_config = serde_json::to_string_pretty(&config).unwrap();
490 let expected_serialized_config = r#"{
491 "version": 2,
492 "rules": [
493 {
494 "condition": {
495 "op": "and",
496 "inner": []
497 },
498 "samplingValue": {
499 "type": "factor",
500 "value": 2.0
501 },
502 "type": "transaction",
503 "id": 1
504 }
505 ]
506}"#;
507
508 assert_eq!(serialized_config, expected_serialized_config)
509 }
510
511 #[test]
513 fn test_decay_fn_constant() {
514 let sample_rate = 0.5;
515
516 assert_eq!(
517 DecayingFunction::Constant.adjust_sample_rate(
518 sample_rate,
519 Utc::now(),
520 TimeRange::default()
521 ),
522 Some(sample_rate)
523 );
524 }
525
526 #[test]
528 fn test_decay_fn_linear() {
529 let decaying_fn = DecayingFunction::Linear { decayed_value: 0.5 };
530 let time_range = TimeRange {
531 start: Some(Utc.with_ymd_and_hms(1970, 10, 10, 0, 0, 0).unwrap()),
532 end: Some(Utc.with_ymd_and_hms(1970, 10, 12, 0, 0, 0).unwrap()),
533 };
534
535 let start = Utc.with_ymd_and_hms(1970, 10, 10, 0, 0, 0).unwrap();
536 let halfway = Utc.with_ymd_and_hms(1970, 10, 11, 0, 0, 0).unwrap();
537 let end = Utc.with_ymd_and_hms(1970, 10, 11, 23, 59, 59).unwrap();
538
539 assert_eq!(
541 decaying_fn.adjust_sample_rate(1.0, start, time_range),
542 Some(1.0)
543 );
544
545 assert_eq!(
547 decaying_fn.adjust_sample_rate(1.0, halfway, time_range),
548 Some(0.75)
549 );
550
551 assert_eq!(
553 decaying_fn.adjust_sample_rate(1.0, end, time_range),
554 Some(0.5000028935185186)
556 );
557
558 let mut time_range_without_start = time_range;
560 time_range_without_start.start = None;
561
562 assert!(
563 decaying_fn
564 .adjust_sample_rate(1.0, halfway, time_range_without_start)
565 .is_none()
566 );
567
568 let mut time_range_without_end = time_range;
569 time_range_without_end.end = None;
570
571 assert!(
572 decaying_fn
573 .adjust_sample_rate(1.0, halfway, time_range_without_end)
574 .is_none()
575 );
576 }
577}