1use std::fmt;
4
5use chrono::{DateTime, Utc};
6use serde::{Deserialize, Serialize};
7
8use relay_protocol::RuleCondition;
9
10const SAMPLING_CONFIG_VERSION: u16 = 2;
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
21#[serde(rename_all = "camelCase")]
22pub struct SamplingConfig {
23 #[serde(default = "SamplingConfig::legacy_version")]
27 pub version: u16,
28
29 #[serde(default)]
31 pub rules: Vec<SamplingRule>,
32
33 #[serde(default, skip_serializing)]
40 pub rules_v2: Vec<SamplingRule>,
41}
42
43impl SamplingConfig {
44 pub fn new() -> Self {
46 Self::default()
47 }
48
49 pub fn unsupported(&self) -> bool {
51 debug_assert!(self.version > 1, "SamplingConfig not normalized");
52 self.version > SAMPLING_CONFIG_VERSION || !self.rules.iter().all(SamplingRule::supported)
53 }
54
55 pub fn filter_rules(&self, rule_type: RuleType) -> impl Iterator<Item = &SamplingRule> {
57 self.rules.iter().filter(move |rule| rule.ty == rule_type)
58 }
59
60 pub fn normalize(&mut self) {
62 if self.version == Self::legacy_version() {
63 self.rules.append(&mut self.rules_v2);
64 self.version = SAMPLING_CONFIG_VERSION;
65 }
66 }
67
68 const fn legacy_version() -> u16 {
69 1
70 }
71}
72
73impl Default for SamplingConfig {
74 fn default() -> Self {
75 Self {
76 version: SAMPLING_CONFIG_VERSION,
77 rules: vec![],
78 rules_v2: vec![],
79 }
80 }
81}
82
83#[derive(Debug, Clone, Serialize, Deserialize)]
85#[serde(rename_all = "camelCase")]
86pub struct SamplingRule {
87 pub condition: RuleCondition,
91
92 pub sampling_value: SamplingValue,
94
95 #[serde(rename = "type")]
97 pub ty: RuleType,
98
99 pub id: RuleId,
101
102 #[serde(default, skip_serializing_if = "TimeRange::is_empty")]
107 pub time_range: TimeRange,
108
109 #[serde(default, skip_serializing_if = "is_default")]
111 pub decaying_fn: DecayingFunction,
112}
113
114impl SamplingRule {
115 fn supported(&self) -> bool {
116 self.condition.supported() && self.ty != RuleType::Unsupported
117 }
118
119 pub fn apply_decaying_fn(&self, sample_rate: f64, now: DateTime<Utc>) -> Option<f64> {
121 self.decaying_fn
122 .adjust_sample_rate(sample_rate, now, self.time_range)
123 }
124}
125
126fn is_default<T: Default + PartialEq>(t: &T) -> bool {
128 *t == T::default()
129}
130
131#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
135#[serde(rename_all = "camelCase")]
136#[serde(tag = "type")]
137pub enum SamplingValue {
138 SampleRate {
143 value: f64,
145 },
146
147 Factor {
153 value: f64,
155 },
156
157 MinimumSampleRate {
164 value: f64,
166 },
167}
168
169#[derive(Debug, Copy, Clone, Serialize, Deserialize, Eq, PartialEq)]
171#[serde(rename_all = "camelCase")]
172pub enum RuleType {
173 Trace,
176 Transaction,
180 Project,
184 #[serde(other)]
188 Unsupported,
189}
190
191#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash, PartialOrd, Ord)]
196pub struct RuleId(pub u32);
197
198impl fmt::Display for RuleId {
199 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
200 write!(f, "{}", self.0)
201 }
202}
203
204#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize)]
211pub struct TimeRange {
212 pub start: Option<DateTime<Utc>>,
214
215 pub end: Option<DateTime<Utc>>,
217}
218
219impl TimeRange {
220 pub fn is_empty(&self) -> bool {
222 self.start.is_none() && self.end.is_none()
223 }
224
225 pub fn contains(&self, time: DateTime<Utc>) -> bool {
234 self.start.is_none_or(|s| s <= time) && self.end.is_none_or(|e| time < e)
235 }
236}
237
238#[derive(Default, Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
240#[serde(rename_all = "camelCase")]
241#[serde(tag = "type")]
242pub enum DecayingFunction {
243 #[serde(rename_all = "camelCase")]
248 Linear {
249 decayed_value: f64,
251 },
252
253 #[default]
255 Constant,
256}
257
258impl DecayingFunction {
259 pub fn adjust_sample_rate(
261 &self,
262 sample_rate: f64,
263 now: DateTime<Utc>,
264 time_range: TimeRange,
265 ) -> Option<f64> {
266 match self {
267 DecayingFunction::Linear { decayed_value } => {
268 let (Some(start), Some(end)) = (time_range.start, time_range.end) else {
269 return None;
270 };
271
272 if sample_rate < *decayed_value {
273 return None;
274 }
275
276 let now = now.timestamp() as f64;
277 let start = start.timestamp() as f64;
278 let end = end.timestamp() as f64;
279
280 let progress_ratio = ((now - start) / (end - start)).clamp(0.0, 1.0);
281
282 let interval = decayed_value - sample_rate;
284 Some(sample_rate + (interval * progress_ratio))
285 }
286 DecayingFunction::Constant => Some(sample_rate),
287 }
288 }
289}
290
291#[cfg(test)]
292mod tests {
293 use chrono::TimeZone;
294
295 use super::*;
296
297 #[test]
298 fn config_deserialize() {
299 let json = include_str!("../tests/fixtures/sampling_config.json");
300 serde_json::from_str::<SamplingConfig>(json).unwrap();
301 }
302
303 #[test]
304 fn test_supported() {
305 let rule: SamplingRule = serde_json::from_value(serde_json::json!({
306 "id": 1,
307 "type": "trace",
308 "samplingValue": {"type": "sampleRate", "value": 1.0},
309 "condition": {"op": "and", "inner": []}
310 }))
311 .unwrap();
312 assert!(rule.supported());
313 }
314
315 #[test]
316 fn test_unsupported_rule_type() {
317 let rule: SamplingRule = serde_json::from_value(serde_json::json!({
318 "id": 1,
319 "type": "new_rule_type_unknown_to_this_relay",
320 "samplingValue": {"type": "sampleRate", "value": 1.0},
321 "condition": {"op": "and", "inner": []}
322 }))
323 .unwrap();
324 assert!(!rule.supported());
325 }
326
327 #[test]
328 fn test_non_decaying_sampling_rule_deserialization() {
329 let serialized_rule = r#"{
330 "condition":{
331 "op":"and",
332 "inner": [
333 { "op" : "glob", "name": "releases", "value":["1.1.1", "1.1.2"]}
334 ]
335 },
336 "samplingValue": {"type": "sampleRate", "value": 0.7},
337 "type": "trace",
338 "id": 1
339 }"#;
340
341 let rule: SamplingRule = serde_json::from_str(serialized_rule).unwrap();
342 assert_eq!(
343 rule.sampling_value,
344 SamplingValue::SampleRate { value: 0.7f64 }
345 );
346 assert_eq!(rule.ty, RuleType::Trace);
347 }
348
349 #[test]
350 fn test_non_decaying_sampling_rule_deserialization_with_factor() {
351 let serialized_rule = r#"{
352 "condition":{
353 "op":"and",
354 "inner": [
355 { "op" : "glob", "name": "releases", "value":["1.1.1", "1.1.2"]}
356 ]
357 },
358 "samplingValue": {"type": "factor", "value": 5.0},
359 "type": "trace",
360 "id": 1
361 }"#;
362
363 let rule: SamplingRule = serde_json::from_str(serialized_rule).unwrap();
364 assert_eq!(rule.sampling_value, SamplingValue::Factor { value: 5.0 });
365 assert_eq!(rule.ty, RuleType::Trace);
366 }
367
368 #[test]
369 fn test_sampling_rule_with_constant_decaying_function_deserialization() {
370 let serialized_rule = r#"{
371 "condition":{
372 "op":"and",
373 "inner": [
374 { "op" : "glob", "name": "releases", "value":["1.1.1", "1.1.2"]}
375 ]
376 },
377 "samplingValue": {"type": "factor", "value": 5.0},
378 "type": "trace",
379 "id": 1,
380 "timeRange": {
381 "start": "2022-10-10T00:00:00.000000Z",
382 "end": "2022-10-20T00:00:00.000000Z"
383 }
384 }"#;
385 let rule: Result<SamplingRule, _> = serde_json::from_str(serialized_rule);
386 let rule = rule.unwrap();
387 let time_range = rule.time_range;
388 let decaying_function = rule.decaying_fn;
389
390 assert_eq!(
391 time_range.start,
392 Some(Utc.with_ymd_and_hms(2022, 10, 10, 0, 0, 0).unwrap())
393 );
394 assert_eq!(
395 time_range.end,
396 Some(Utc.with_ymd_and_hms(2022, 10, 20, 0, 0, 0).unwrap())
397 );
398 assert_eq!(decaying_function, DecayingFunction::Constant);
399 }
400
401 #[test]
402 fn test_sampling_rule_with_linear_decaying_function_deserialization() {
403 let serialized_rule = r#"{
404 "condition":{
405 "op":"and",
406 "inner": [
407 { "op" : "glob", "name": "releases", "value":["1.1.1", "1.1.2"]}
408 ]
409 },
410 "samplingValue": {"type": "sampleRate", "value": 1.0},
411 "type": "trace",
412 "id": 1,
413 "timeRange": {
414 "start": "2022-10-10T00:00:00.000000Z",
415 "end": "2022-10-20T00:00:00.000000Z"
416 },
417 "decayingFn": {
418 "type": "linear",
419 "decayedValue": 0.9
420 }
421 }"#;
422 let rule: Result<SamplingRule, _> = serde_json::from_str(serialized_rule);
423 let rule = rule.unwrap();
424 let decaying_function = rule.decaying_fn;
425
426 assert_eq!(
427 decaying_function,
428 DecayingFunction::Linear { decayed_value: 0.9 }
429 );
430 }
431
432 #[test]
433 fn test_legacy_deserialization() {
434 let serialized_rule = r#"{
435 "rules": [],
436 "rulesV2": [
437 {
438 "samplingValue":{
439 "type": "sampleRate",
440 "value": 0.5
441 },
442 "type": "trace",
443 "active": true,
444 "condition": {
445 "op": "and",
446 "inner": []
447 },
448 "id": 1000
449 }
450 ],
451 "mode": "received"
452 }"#;
453 let mut config: SamplingConfig = serde_json::from_str(serialized_rule).unwrap();
454 config.normalize();
455
456 assert_eq!(config.version, SAMPLING_CONFIG_VERSION);
459 assert_eq!(
460 config.rules[0].sampling_value,
461 SamplingValue::SampleRate { value: 0.5 }
462 );
463 assert!(config.rules_v2.is_empty());
464 }
465
466 #[test]
467 fn test_sampling_config_with_rules_and_rules_v2_serialization() {
468 let config = SamplingConfig {
469 rules: vec![SamplingRule {
470 condition: RuleCondition::all(),
471 sampling_value: SamplingValue::Factor { value: 2.0 },
472 ty: RuleType::Transaction,
473 id: RuleId(1),
474 time_range: Default::default(),
475 decaying_fn: Default::default(),
476 }],
477 ..SamplingConfig::new()
478 };
479
480 let serialized_config = serde_json::to_string_pretty(&config).unwrap();
481 let expected_serialized_config = r#"{
482 "version": 2,
483 "rules": [
484 {
485 "condition": {
486 "op": "and",
487 "inner": []
488 },
489 "samplingValue": {
490 "type": "factor",
491 "value": 2.0
492 },
493 "type": "transaction",
494 "id": 1
495 }
496 ]
497}"#;
498
499 assert_eq!(serialized_config, expected_serialized_config)
500 }
501
502 #[test]
504 fn test_decay_fn_constant() {
505 let sample_rate = 0.5;
506
507 assert_eq!(
508 DecayingFunction::Constant.adjust_sample_rate(
509 sample_rate,
510 Utc::now(),
511 TimeRange::default()
512 ),
513 Some(sample_rate)
514 );
515 }
516
517 #[test]
519 fn test_decay_fn_linear() {
520 let decaying_fn = DecayingFunction::Linear { decayed_value: 0.5 };
521 let time_range = TimeRange {
522 start: Some(Utc.with_ymd_and_hms(1970, 10, 10, 0, 0, 0).unwrap()),
523 end: Some(Utc.with_ymd_and_hms(1970, 10, 12, 0, 0, 0).unwrap()),
524 };
525
526 let start = Utc.with_ymd_and_hms(1970, 10, 10, 0, 0, 0).unwrap();
527 let halfway = Utc.with_ymd_and_hms(1970, 10, 11, 0, 0, 0).unwrap();
528 let end = Utc.with_ymd_and_hms(1970, 10, 11, 23, 59, 59).unwrap();
529
530 assert_eq!(
532 decaying_fn.adjust_sample_rate(1.0, start, time_range),
533 Some(1.0)
534 );
535
536 assert_eq!(
538 decaying_fn.adjust_sample_rate(1.0, halfway, time_range),
539 Some(0.75)
540 );
541
542 assert_eq!(
544 decaying_fn.adjust_sample_rate(1.0, end, time_range),
545 Some(0.5000028935185186)
547 );
548
549 let mut time_range_without_start = time_range;
551 time_range_without_start.start = None;
552
553 assert!(
554 decaying_fn
555 .adjust_sample_rate(1.0, halfway, time_range_without_start)
556 .is_none()
557 );
558
559 let mut time_range_without_end = time_range;
560 time_range_without_end.end = None;
561
562 assert!(
563 decaying_fn
564 .adjust_sample_rate(1.0, halfway, time_range_without_end)
565 .is_none()
566 );
567 }
568}