1use std::fmt;
4
5use chrono::{DateTime, Utc};
6use serde::{Deserialize, Serialize};
7
8use relay_protocol::RuleCondition;
9
10const SAMPLING_CONFIG_VERSION: u16 = 2;
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
21#[serde(rename_all = "camelCase")]
22pub struct SamplingConfig {
23 #[serde(default = "SamplingConfig::legacy_version")]
27 pub version: u16,
28
29 #[serde(default)]
31 pub rules: Vec<SamplingRule>,
32
33 #[serde(default, skip_serializing)]
40 pub rules_v2: Vec<SamplingRule>,
41}
42
43impl SamplingConfig {
44 pub fn new() -> Self {
46 Self::default()
47 }
48
49 pub fn unsupported(&self) -> bool {
51 debug_assert!(self.version > 1, "SamplingConfig not normalized");
52 self.version > SAMPLING_CONFIG_VERSION || !self.rules.iter().all(SamplingRule::supported)
53 }
54
55 pub fn filter_rules(&self, rule_type: RuleType) -> impl Iterator<Item = &SamplingRule> {
57 self.rules.iter().filter(move |rule| rule.ty == rule_type)
58 }
59
60 pub fn normalize(&mut self) {
62 if self.version == Self::legacy_version() {
63 self.rules.append(&mut self.rules_v2);
64 self.version = SAMPLING_CONFIG_VERSION;
65 }
66 }
67
68 const fn legacy_version() -> u16 {
69 1
70 }
71}
72
73impl Default for SamplingConfig {
74 fn default() -> Self {
75 Self {
76 version: SAMPLING_CONFIG_VERSION,
77 rules: vec![],
78 rules_v2: vec![],
79 }
80 }
81}
82
83#[derive(Debug, Clone, Serialize, Deserialize)]
85#[serde(rename_all = "camelCase")]
86pub struct SamplingRule {
87 pub condition: RuleCondition,
91
92 pub sampling_value: SamplingValue,
94
95 #[serde(rename = "type")]
97 pub ty: RuleType,
98
99 pub id: RuleId,
101
102 #[serde(default, skip_serializing_if = "TimeRange::is_empty")]
107 pub time_range: TimeRange,
108
109 #[serde(default, skip_serializing_if = "is_default")]
111 pub decaying_fn: DecayingFunction,
112}
113
114impl SamplingRule {
115 fn supported(&self) -> bool {
116 self.condition.supported() && self.ty != RuleType::Unsupported
117 }
118
119 pub fn apply_decaying_fn(&self, sample_rate: f64, now: DateTime<Utc>) -> Option<f64> {
121 self.decaying_fn
122 .adjust_sample_rate(sample_rate, now, self.time_range)
123 }
124}
125
126fn is_default<T: Default + PartialEq>(t: &T) -> bool {
128 *t == T::default()
129}
130
131#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
135#[serde(rename_all = "camelCase")]
136#[serde(tag = "type")]
137pub enum SamplingValue {
138 SampleRate {
143 value: f64,
145 },
146
147 Factor {
153 value: f64,
155 },
156
157 Reservoir {
162 limit: i64,
164 },
165
166 MinimumSampleRate {
173 value: f64,
175 },
176}
177
178#[derive(Debug, Copy, Clone, Serialize, Deserialize, Eq, PartialEq)]
180#[serde(rename_all = "camelCase")]
181pub enum RuleType {
182 Trace,
185 Transaction,
187 #[serde(other)]
191 Unsupported,
192}
193
194#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash, PartialOrd, Ord)]
199pub struct RuleId(pub u32);
200
201impl fmt::Display for RuleId {
202 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
203 write!(f, "{}", self.0)
204 }
205}
206
207#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize)]
214pub struct TimeRange {
215 pub start: Option<DateTime<Utc>>,
217
218 pub end: Option<DateTime<Utc>>,
220}
221
222impl TimeRange {
223 pub fn is_empty(&self) -> bool {
225 self.start.is_none() && self.end.is_none()
226 }
227
228 pub fn contains(&self, time: DateTime<Utc>) -> bool {
237 self.start.is_none_or(|s| s <= time) && self.end.is_none_or(|e| time < e)
238 }
239}
240
241#[derive(Default, Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
243#[serde(rename_all = "camelCase")]
244#[serde(tag = "type")]
245pub enum DecayingFunction {
246 #[serde(rename_all = "camelCase")]
251 Linear {
252 decayed_value: f64,
254 },
255
256 #[default]
258 Constant,
259}
260
261impl DecayingFunction {
262 pub fn adjust_sample_rate(
264 &self,
265 sample_rate: f64,
266 now: DateTime<Utc>,
267 time_range: TimeRange,
268 ) -> Option<f64> {
269 match self {
270 DecayingFunction::Linear { decayed_value } => {
271 let (Some(start), Some(end)) = (time_range.start, time_range.end) else {
272 return None;
273 };
274
275 if sample_rate < *decayed_value {
276 return None;
277 }
278
279 let now = now.timestamp() as f64;
280 let start = start.timestamp() as f64;
281 let end = end.timestamp() as f64;
282
283 let progress_ratio = ((now - start) / (end - start)).clamp(0.0, 1.0);
284
285 let interval = decayed_value - sample_rate;
287 Some(sample_rate + (interval * progress_ratio))
288 }
289 DecayingFunction::Constant => Some(sample_rate),
290 }
291 }
292}
293
294#[cfg(test)]
295mod tests {
296 use chrono::TimeZone;
297
298 use super::*;
299
300 #[test]
301 fn config_deserialize() {
302 let json = include_str!("../tests/fixtures/sampling_config.json");
303 serde_json::from_str::<SamplingConfig>(json).unwrap();
304 }
305
306 #[test]
307 fn test_supported() {
308 let rule: SamplingRule = serde_json::from_value(serde_json::json!({
309 "id": 1,
310 "type": "trace",
311 "samplingValue": {"type": "sampleRate", "value": 1.0},
312 "condition": {"op": "and", "inner": []}
313 }))
314 .unwrap();
315 assert!(rule.supported());
316 }
317
318 #[test]
319 fn test_unsupported_rule_type() {
320 let rule: SamplingRule = serde_json::from_value(serde_json::json!({
321 "id": 1,
322 "type": "new_rule_type_unknown_to_this_relay",
323 "samplingValue": {"type": "sampleRate", "value": 1.0},
324 "condition": {"op": "and", "inner": []}
325 }))
326 .unwrap();
327 assert!(!rule.supported());
328 }
329
330 #[test]
331 fn test_non_decaying_sampling_rule_deserialization() {
332 let serialized_rule = r#"{
333 "condition":{
334 "op":"and",
335 "inner": [
336 { "op" : "glob", "name": "releases", "value":["1.1.1", "1.1.2"]}
337 ]
338 },
339 "samplingValue": {"type": "sampleRate", "value": 0.7},
340 "type": "trace",
341 "id": 1
342 }"#;
343
344 let rule: SamplingRule = serde_json::from_str(serialized_rule).unwrap();
345 assert_eq!(
346 rule.sampling_value,
347 SamplingValue::SampleRate { value: 0.7f64 }
348 );
349 assert_eq!(rule.ty, RuleType::Trace);
350 }
351
352 #[test]
353 fn test_non_decaying_sampling_rule_deserialization_with_factor() {
354 let serialized_rule = r#"{
355 "condition":{
356 "op":"and",
357 "inner": [
358 { "op" : "glob", "name": "releases", "value":["1.1.1", "1.1.2"]}
359 ]
360 },
361 "samplingValue": {"type": "factor", "value": 5.0},
362 "type": "trace",
363 "id": 1
364 }"#;
365
366 let rule: SamplingRule = serde_json::from_str(serialized_rule).unwrap();
367 assert_eq!(rule.sampling_value, SamplingValue::Factor { value: 5.0 });
368 assert_eq!(rule.ty, RuleType::Trace);
369 }
370
371 #[test]
372 fn test_sampling_rule_with_constant_decaying_function_deserialization() {
373 let serialized_rule = r#"{
374 "condition":{
375 "op":"and",
376 "inner": [
377 { "op" : "glob", "name": "releases", "value":["1.1.1", "1.1.2"]}
378 ]
379 },
380 "samplingValue": {"type": "factor", "value": 5.0},
381 "type": "trace",
382 "id": 1,
383 "timeRange": {
384 "start": "2022-10-10T00:00:00.000000Z",
385 "end": "2022-10-20T00:00:00.000000Z"
386 }
387 }"#;
388 let rule: Result<SamplingRule, _> = serde_json::from_str(serialized_rule);
389 let rule = rule.unwrap();
390 let time_range = rule.time_range;
391 let decaying_function = rule.decaying_fn;
392
393 assert_eq!(
394 time_range.start,
395 Some(Utc.with_ymd_and_hms(2022, 10, 10, 0, 0, 0).unwrap())
396 );
397 assert_eq!(
398 time_range.end,
399 Some(Utc.with_ymd_and_hms(2022, 10, 20, 0, 0, 0).unwrap())
400 );
401 assert_eq!(decaying_function, DecayingFunction::Constant);
402 }
403
404 #[test]
405 fn test_sampling_rule_with_linear_decaying_function_deserialization() {
406 let serialized_rule = r#"{
407 "condition":{
408 "op":"and",
409 "inner": [
410 { "op" : "glob", "name": "releases", "value":["1.1.1", "1.1.2"]}
411 ]
412 },
413 "samplingValue": {"type": "sampleRate", "value": 1.0},
414 "type": "trace",
415 "id": 1,
416 "timeRange": {
417 "start": "2022-10-10T00:00:00.000000Z",
418 "end": "2022-10-20T00:00:00.000000Z"
419 },
420 "decayingFn": {
421 "type": "linear",
422 "decayedValue": 0.9
423 }
424 }"#;
425 let rule: Result<SamplingRule, _> = serde_json::from_str(serialized_rule);
426 let rule = rule.unwrap();
427 let decaying_function = rule.decaying_fn;
428
429 assert_eq!(
430 decaying_function,
431 DecayingFunction::Linear { decayed_value: 0.9 }
432 );
433 }
434
435 #[test]
436 fn test_legacy_deserialization() {
437 let serialized_rule = r#"{
438 "rules": [],
439 "rulesV2": [
440 {
441 "samplingValue":{
442 "type": "sampleRate",
443 "value": 0.5
444 },
445 "type": "trace",
446 "active": true,
447 "condition": {
448 "op": "and",
449 "inner": []
450 },
451 "id": 1000
452 }
453 ],
454 "mode": "received"
455 }"#;
456 let mut config: SamplingConfig = serde_json::from_str(serialized_rule).unwrap();
457 config.normalize();
458
459 assert_eq!(config.version, SAMPLING_CONFIG_VERSION);
462 assert_eq!(
463 config.rules[0].sampling_value,
464 SamplingValue::SampleRate { value: 0.5 }
465 );
466 assert!(config.rules_v2.is_empty());
467 }
468
469 #[test]
470 fn test_sampling_config_with_rules_and_rules_v2_serialization() {
471 let config = SamplingConfig {
472 rules: vec![SamplingRule {
473 condition: RuleCondition::all(),
474 sampling_value: SamplingValue::Factor { value: 2.0 },
475 ty: RuleType::Transaction,
476 id: RuleId(1),
477 time_range: Default::default(),
478 decaying_fn: Default::default(),
479 }],
480 ..SamplingConfig::new()
481 };
482
483 let serialized_config = serde_json::to_string_pretty(&config).unwrap();
484 let expected_serialized_config = r#"{
485 "version": 2,
486 "rules": [
487 {
488 "condition": {
489 "op": "and",
490 "inner": []
491 },
492 "samplingValue": {
493 "type": "factor",
494 "value": 2.0
495 },
496 "type": "transaction",
497 "id": 1
498 }
499 ]
500}"#;
501
502 assert_eq!(serialized_config, expected_serialized_config)
503 }
504
505 #[test]
507 fn test_decay_fn_constant() {
508 let sample_rate = 0.5;
509
510 assert_eq!(
511 DecayingFunction::Constant.adjust_sample_rate(
512 sample_rate,
513 Utc::now(),
514 TimeRange::default()
515 ),
516 Some(sample_rate)
517 );
518 }
519
520 #[test]
522 fn test_decay_fn_linear() {
523 let decaying_fn = DecayingFunction::Linear { decayed_value: 0.5 };
524 let time_range = TimeRange {
525 start: Some(Utc.with_ymd_and_hms(1970, 10, 10, 0, 0, 0).unwrap()),
526 end: Some(Utc.with_ymd_and_hms(1970, 10, 12, 0, 0, 0).unwrap()),
527 };
528
529 let start = Utc.with_ymd_and_hms(1970, 10, 10, 0, 0, 0).unwrap();
530 let halfway = Utc.with_ymd_and_hms(1970, 10, 11, 0, 0, 0).unwrap();
531 let end = Utc.with_ymd_and_hms(1970, 10, 11, 23, 59, 59).unwrap();
532
533 assert_eq!(
535 decaying_fn.adjust_sample_rate(1.0, start, time_range),
536 Some(1.0)
537 );
538
539 assert_eq!(
541 decaying_fn.adjust_sample_rate(1.0, halfway, time_range),
542 Some(0.75)
543 );
544
545 assert_eq!(
547 decaying_fn.adjust_sample_rate(1.0, end, time_range),
548 Some(0.5000028935185186)
550 );
551
552 let mut time_range_without_start = time_range;
554 time_range_without_start.start = None;
555
556 assert!(
557 decaying_fn
558 .adjust_sample_rate(1.0, halfway, time_range_without_start)
559 .is_none()
560 );
561
562 let mut time_range_without_end = time_range;
563 time_range_without_end.end = None;
564
565 assert!(
566 decaying_fn
567 .adjust_sample_rate(1.0, halfway, time_range_without_end)
568 .is_none()
569 );
570 }
571}