1use std::collections::BTreeMap;
4use std::fmt;
5use std::num::ParseIntError;
6use std::ops::ControlFlow;
7use std::sync::{Arc, Mutex};
8
9use chrono::{DateTime, Utc};
10use rand::Rng;
11use rand::distributions::Uniform;
12use rand_pcg::Pcg32;
13#[cfg(feature = "redis")]
14use relay_base_schema::organization::OrganizationId;
15use relay_protocol::Getter;
16#[cfg(feature = "redis")]
17use relay_redis::AsyncRedisClient;
18use serde::Serialize;
19use uuid::Uuid;
20
21use crate::config::{RuleId, SamplingRule, SamplingValue};
22#[cfg(feature = "redis")]
23use crate::redis_sampling::{self, ReservoirRuleKey};
24
25fn pseudo_random_from_seed(seed: Uuid) -> f64 {
29 let seed_number = seed.as_u128();
30 let mut generator = Pcg32::new((seed_number >> 64) as u64, seed_number as u64);
31 let dist = Uniform::new(0f64, 1f64);
32 generator.sample(dist)
33}
34
35pub type ReservoirCounters = Arc<Mutex<BTreeMap<RuleId, i64>>>;
37
38#[derive(Debug)]
51pub struct ReservoirEvaluator<'a> {
52 counters: ReservoirCounters,
53 #[cfg(feature = "redis")]
54 org_id_and_client: Option<(OrganizationId, &'a AsyncRedisClient)>,
55 _phantom: std::marker::PhantomData<&'a ()>,
57}
58
59impl ReservoirEvaluator<'_> {
60 pub fn new(counters: ReservoirCounters) -> Self {
62 Self {
63 counters,
64 #[cfg(feature = "redis")]
65 org_id_and_client: None,
66 _phantom: std::marker::PhantomData,
67 }
68 }
69
70 pub fn counters(&self) -> ReservoirCounters {
72 Arc::clone(&self.counters)
73 }
74
75 #[cfg(feature = "redis")]
76 async fn redis_incr(
77 &self,
78 key: &ReservoirRuleKey,
79 client: &AsyncRedisClient,
80 rule_expiry: Option<&DateTime<Utc>>,
81 ) -> anyhow::Result<i64> {
82 let mut connection = client.get_connection().await?;
83
84 let val = redis_sampling::increment_redis_reservoir_count(&mut connection, key).await?;
85 redis_sampling::set_redis_expiry(&mut connection, key, rule_expiry).await?;
86
87 Ok(val)
88 }
89
90 pub fn incr_local(&self, rule: RuleId, limit: i64) -> bool {
92 let Ok(mut map_guard) = self.counters.lock() else {
93 relay_log::error!("failed to lock reservoir counter mutex");
94 return false;
95 };
96
97 let counter_value = map_guard.entry(rule).or_insert(0);
98
99 if *counter_value < limit {
100 *counter_value += 1;
101 true
102 } else {
103 false
104 }
105 }
106
107 pub async fn evaluate(
109 &self,
110 rule: RuleId,
111 limit: i64,
112 _rule_expiry: Option<&DateTime<Utc>>,
113 ) -> bool {
114 #[cfg(feature = "redis")]
115 if let Some((org_id, client)) = self.org_id_and_client {
116 if let Ok(guard) = self.counters.lock() {
117 if *guard.get(&rule).unwrap_or(&0) > limit {
118 return false;
119 }
120 }
121
122 let key = ReservoirRuleKey::new(org_id, rule);
123 let redis_count = match self.redis_incr(&key, client, _rule_expiry).await {
124 Ok(redis_count) => redis_count,
125 Err(e) => {
126 relay_log::error!(error = &*e, "failed to increment reservoir rule");
127 return false;
128 }
129 };
130
131 if let Ok(mut map_guard) = self.counters.lock() {
132 if let Some(value) = map_guard.get_mut(&rule) {
135 *value = redis_count.max(*value);
136 }
137 }
138 return redis_count <= limit;
139 }
140
141 self.incr_local(rule, limit)
142 }
143}
144
145#[cfg(feature = "redis")]
146impl<'a> ReservoirEvaluator<'a> {
147 pub fn set_redis(&mut self, org_id: OrganizationId, client: &'a AsyncRedisClient) {
151 self.org_id_and_client = Some((org_id, client));
152 }
153}
154
155#[derive(Debug)]
157pub struct SamplingEvaluator<'a> {
158 now: DateTime<Utc>,
159 rule_ids: Vec<RuleId>,
160 factor: f64,
161 minimum_sample_rate: Option<f64>,
162 reservoir: Option<&'a ReservoirEvaluator<'a>>,
163}
164
165impl<'a> SamplingEvaluator<'a> {
166 pub fn new_with_reservoir(now: DateTime<Utc>, reservoir: &'a ReservoirEvaluator<'a>) -> Self {
168 Self {
169 now,
170 rule_ids: vec![],
171 factor: 1.0,
172 minimum_sample_rate: None,
173 reservoir: Some(reservoir),
174 }
175 }
176
177 pub fn new(now: DateTime<Utc>) -> Self {
179 Self {
180 now,
181 rule_ids: vec![],
182 factor: 1.0,
183 minimum_sample_rate: None,
184 reservoir: None,
185 }
186 }
187
188 pub async fn match_rules<'b, I, G>(
200 mut self,
201 seed: Uuid,
202 instance: &G,
203 rules: I,
204 ) -> ControlFlow<SamplingMatch, Self>
205 where
206 G: Getter,
207 I: Iterator<Item = &'b SamplingRule>,
208 {
209 for rule in rules {
210 if !rule.time_range.contains(self.now) || !rule.condition.matches(instance) {
211 continue;
212 };
213
214 if let Some(sample_rate) = self.try_compute_sample_rate(rule).await {
215 return ControlFlow::Break(SamplingMatch::new(sample_rate, seed, self.rule_ids));
216 };
217 }
218
219 ControlFlow::Continue(self)
220 }
221
222 async fn try_compute_sample_rate(&mut self, rule: &SamplingRule) -> Option<f64> {
230 match rule.sampling_value {
231 SamplingValue::Factor { value } => {
232 self.factor *= rule.apply_decaying_fn(value, self.now)?;
233 self.rule_ids.push(rule.id);
234 None
235 }
236 SamplingValue::SampleRate { value } => {
237 let sample_rate = rule.apply_decaying_fn(value, self.now)?;
238 let minimum_sample_rate = self.minimum_sample_rate.unwrap_or(0.0);
239 let adjusted = (sample_rate.max(minimum_sample_rate) * self.factor).clamp(0.0, 1.0);
240
241 self.rule_ids.push(rule.id);
242 Some(adjusted)
243 }
244 SamplingValue::Reservoir { limit } => {
245 let reservoir = self.reservoir?;
246 if !reservoir
247 .evaluate(rule.id, limit, rule.time_range.end.as_ref())
248 .await
249 {
250 return None;
251 }
252
253 self.rule_ids.clear();
255 self.rule_ids.push(rule.id);
256 Some(1.0)
258 }
259 SamplingValue::MinimumSampleRate { value } => {
260 if self.minimum_sample_rate.is_none() {
261 self.minimum_sample_rate = Some(rule.apply_decaying_fn(value, self.now)?);
262 self.rule_ids.push(rule.id);
263 }
264 None
265 }
266 }
267 }
268}
269
270fn sampling_match(sample_rate: f64, seed: Uuid) -> SamplingDecision {
271 if sample_rate <= 0.0 {
272 return SamplingDecision::Drop;
273 } else if sample_rate >= 1.0 {
274 return SamplingDecision::Keep;
275 }
276
277 let random_number = pseudo_random_from_seed(seed);
278 relay_log::trace!(
279 sample_rate,
280 random_number,
281 "applying dynamic sampling to matching event"
282 );
283
284 if random_number >= sample_rate {
285 relay_log::trace!("dropping event that matched the configuration");
286 SamplingDecision::Drop
287 } else {
288 relay_log::trace!("keeping event that matched the configuration");
289 SamplingDecision::Keep
290 }
291}
292
293#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
295pub enum SamplingDecision {
296 Keep,
298 Drop,
300}
301
302impl SamplingDecision {
303 pub fn is_keep(self) -> bool {
305 matches!(self, Self::Keep)
306 }
307
308 pub fn is_drop(self) -> bool {
310 matches!(self, Self::Drop)
311 }
312
313 pub fn as_str(self) -> &'static str {
315 match self {
316 Self::Keep => "keep",
317 Self::Drop => "drop",
318 }
319 }
320}
321
322impl fmt::Display for SamplingDecision {
323 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
324 write!(f, "{}", self.as_str())
325 }
326}
327
328#[derive(Clone, Debug, PartialEq)]
330pub struct SamplingMatch {
331 sample_rate: f64,
333 seed: Uuid,
339 matched_rules: MatchedRuleIds,
341 decision: SamplingDecision,
345}
346
347impl SamplingMatch {
348 fn new(sample_rate: f64, seed: Uuid, matched_rules: Vec<RuleId>) -> Self {
349 let matched_rules = MatchedRuleIds(matched_rules);
350 let decision = sampling_match(sample_rate, seed);
351
352 Self {
353 sample_rate,
354 seed,
355 matched_rules,
356 decision,
357 }
358 }
359
360 pub fn sample_rate(&self) -> f64 {
362 self.sample_rate
363 }
364
365 pub fn into_matched_rules(self) -> MatchedRuleIds {
370 self.matched_rules
371 }
372
373 pub fn decision(&self) -> SamplingDecision {
375 self.decision
376 }
377}
378
379#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize)]
381pub struct MatchedRuleIds(pub Vec<RuleId>);
382
383impl MatchedRuleIds {
384 pub fn parse(value: &str) -> Result<MatchedRuleIds, ParseIntError> {
392 let mut rule_ids = vec![];
393
394 for rule_id in value.split(',') {
395 rule_ids.push(RuleId(rule_id.parse()?));
396 }
397
398 Ok(MatchedRuleIds(rule_ids))
399 }
400}
401
402impl fmt::Display for MatchedRuleIds {
403 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
404 for (i, rule_id) in self.0.iter().enumerate() {
405 if i > 0 {
406 write!(f, ",")?;
407 }
408 write!(f, "{rule_id}")?;
409 }
410
411 Ok(())
412 }
413}
414
415#[cfg(test)]
416mod tests {
417 use chrono::TimeZone;
418 use relay_protocol::RuleCondition;
419 use similar_asserts::assert_eq;
420 use std::str::FromStr;
421 use uuid::Uuid;
422
423 use crate::DynamicSamplingContext;
424 use crate::config::{DecayingFunction, RuleType, TimeRange};
425 use crate::dsc::TraceUserContext;
426
427 use super::*;
428
429 fn mock_reservoir_evaluator(vals: Vec<(u32, i64)>) -> ReservoirEvaluator<'static> {
430 let mut map = BTreeMap::default();
431
432 for (rule_id, count) in vals {
433 map.insert(RuleId(rule_id), count);
434 }
435
436 let map = Arc::new(Mutex::new(map));
437
438 ReservoirEvaluator::new(map)
439 }
440
441 async fn get_sampling_match(rules: &[SamplingRule], instance: &impl Getter) -> SamplingMatch {
443 match SamplingEvaluator::new(Utc::now())
444 .match_rules(Uuid::default(), instance, rules.iter())
445 .await
446 {
447 ControlFlow::Break(sampling_match) => sampling_match,
448 ControlFlow::Continue(_) => panic!("no match found"),
449 }
450 }
451
452 fn evaluation_is_match(res: ControlFlow<SamplingMatch, SamplingEvaluator>) -> bool {
453 matches!(res, ControlFlow::Break(_))
454 }
455
456 async fn matches_rule_ids(
458 rule_ids: &[u32],
459 rules: &[SamplingRule],
460 instance: &impl Getter,
461 ) -> bool {
462 let matched_rule_ids = MatchedRuleIds(rule_ids.iter().map(|num| RuleId(*num)).collect());
463 let sampling_match = get_sampling_match(rules, instance).await;
464 matched_rule_ids == sampling_match.matched_rules
465 }
466
467 fn get_matched_rules(
469 sampling_evaluator: &ControlFlow<SamplingMatch, SamplingEvaluator>,
470 ) -> Vec<u32> {
471 match sampling_evaluator {
472 ControlFlow::Continue(_) => panic!("expected a sampling match"),
473 ControlFlow::Break(m) => m.matched_rules.0.iter().map(|rule_id| rule_id.0).collect(),
474 }
475 }
476
477 fn mocked_dsc_with_getter_values(
479 paths_and_values: Vec<(&str, &str)>,
480 ) -> DynamicSamplingContext {
481 let mut dsc = DynamicSamplingContext {
482 trace_id: "67e5504410b1426f9247bb680e5fe0c8".parse().unwrap(),
483 public_key: "12345678123456781234567812345678".parse().unwrap(),
484 release: None,
485 environment: None,
486 transaction: None,
487 sample_rate: None,
488 user: TraceUserContext::default(),
489 replay_id: None,
490 sampled: None,
491 other: Default::default(),
492 };
493
494 for (path, value) in paths_and_values {
495 match path {
496 "trace.release" => dsc.release = Some(value.to_owned()),
497 "trace.environment" => dsc.environment = Some(value.to_owned()),
498 "trace.user.id" => value.clone_into(&mut dsc.user.user_id),
499 "trace.user.segment" => value.clone_into(&mut dsc.user.user_segment),
500 "trace.transaction" => dsc.transaction = Some(value.to_owned()),
501 "trace.replay_id" => dsc.replay_id = Some(Uuid::from_str(value).unwrap()),
502 _ => panic!("invalid path"),
503 }
504 }
505
506 dsc
507 }
508
509 async fn is_match(
510 now: DateTime<Utc>,
511 rule: &SamplingRule,
512 dsc: &DynamicSamplingContext,
513 ) -> bool {
514 SamplingEvaluator::new(now)
515 .match_rules(Uuid::default(), dsc, std::iter::once(rule))
516 .await
517 .is_break()
518 }
519
520 #[tokio::test]
521 async fn test_reservoir_evaluator_limit() {
522 let evaluator = mock_reservoir_evaluator(vec![(1, 0)]);
523
524 let rule = RuleId(1);
525 let limit = 3;
526
527 assert!(evaluator.evaluate(rule, limit, None).await);
528 assert!(evaluator.evaluate(rule, limit, None).await);
529 assert!(evaluator.evaluate(rule, limit, None).await);
530 assert!(!evaluator.evaluate(rule, limit, None).await);
532 assert!(!evaluator.evaluate(rule, limit, None).await);
533 }
534
535 #[tokio::test]
536 async fn test_sample_rate_compounding() {
537 let rules = simple_sampling_rules(vec![
538 (RuleCondition::all(), SamplingValue::Factor { value: 0.8 }),
539 (RuleCondition::all(), SamplingValue::Factor { value: 0.5 }),
540 (
541 RuleCondition::all(),
542 SamplingValue::SampleRate { value: 0.25 },
543 ),
544 ]);
545 let dsc = mocked_dsc_with_getter_values(vec![]);
546
547 assert_eq!(get_sampling_match(&rules, &dsc).await.sample_rate(), 0.1);
549 }
550
551 #[tokio::test]
552 async fn test_minimum_sample_rate() {
553 let rules = simple_sampling_rules(vec![
554 (RuleCondition::all(), SamplingValue::Factor { value: 1.5 }),
555 (
556 RuleCondition::all(),
557 SamplingValue::MinimumSampleRate { value: 0.5 },
558 ),
559 (
561 RuleCondition::all(),
562 SamplingValue::MinimumSampleRate { value: 1.0 },
563 ),
564 (
565 RuleCondition::all(),
566 SamplingValue::SampleRate { value: 0.05 },
567 ),
568 ]);
569 let dsc = mocked_dsc_with_getter_values(vec![]);
570
571 assert_eq!(get_sampling_match(&rules, &dsc).await.sample_rate(), 0.75);
573 }
574
575 fn mocked_sampling_rule() -> SamplingRule {
576 SamplingRule {
577 condition: RuleCondition::all(),
578 sampling_value: SamplingValue::SampleRate { value: 1.0 },
579 ty: RuleType::Trace,
580 id: RuleId(0),
581 time_range: Default::default(),
582 decaying_fn: Default::default(),
583 }
584 }
585
586 fn simple_sampling_rules(vals: Vec<(RuleCondition, SamplingValue)>) -> Vec<SamplingRule> {
589 let mut vec = vec![];
590
591 for (i, val) in vals.into_iter().enumerate() {
592 let (condition, sampling_value) = val;
593 vec.push(SamplingRule {
594 condition,
595 sampling_value,
596 ty: RuleType::Trace,
597 id: RuleId(i as u32),
598 time_range: Default::default(),
599 decaying_fn: Default::default(),
600 });
601 }
602 vec
603 }
604
605 #[tokio::test]
613 async fn test_reservoir_override() {
614 let dsc = mocked_dsc_with_getter_values(vec![]);
615 let rules = simple_sampling_rules(vec![
616 (RuleCondition::all(), SamplingValue::Factor { value: 0.5 }),
617 (RuleCondition::all(), SamplingValue::Reservoir { limit: 2 }),
620 (
621 RuleCondition::all(),
622 SamplingValue::SampleRate { value: 0.5 },
623 ),
624 ]);
625
626 let reservoir = mock_reservoir_evaluator(vec![]);
629
630 let evaluator = SamplingEvaluator::new_with_reservoir(Utc::now(), &reservoir);
631 let matched_rules = get_matched_rules(
632 &evaluator
633 .match_rules(Uuid::default(), &dsc, rules.iter())
634 .await,
635 );
636 assert_eq!(&matched_rules, &[1]);
638
639 let evaluator = SamplingEvaluator::new_with_reservoir(Utc::now(), &reservoir);
640 let matched_rules = get_matched_rules(
641 &evaluator
642 .match_rules(Uuid::default(), &dsc, rules.iter())
643 .await,
644 );
645 assert_eq!(&matched_rules, &[1]);
647
648 let evaluator = SamplingEvaluator::new_with_reservoir(Utc::now(), &reservoir);
649 let matched_rules = get_matched_rules(
650 &evaluator
651 .match_rules(Uuid::default(), &dsc, rules.iter())
652 .await,
653 );
654 assert_eq!(&matched_rules, &[0, 2]);
656 }
657
658 #[tokio::test]
660 async fn test_expired_rules() {
661 let rule = SamplingRule {
662 condition: RuleCondition::all(),
663 sampling_value: SamplingValue::SampleRate { value: 1.0 },
664 ty: RuleType::Trace,
665 id: RuleId(0),
666 time_range: TimeRange {
667 start: Some(Utc.with_ymd_and_hms(1970, 10, 10, 0, 0, 0).unwrap()),
668 end: Some(Utc.with_ymd_and_hms(1970, 10, 12, 0, 0, 0).unwrap()),
669 },
670 decaying_fn: Default::default(),
671 };
672
673 let dsc = mocked_dsc_with_getter_values(vec![]);
674
675 let within_timerange = Utc.with_ymd_and_hms(1970, 10, 11, 0, 0, 0).unwrap();
677 let res = SamplingEvaluator::new(within_timerange)
678 .match_rules(Uuid::default(), &dsc, [rule.clone()].iter())
679 .await;
680 assert!(evaluation_is_match(res));
681
682 let before_timerange = Utc.with_ymd_and_hms(1969, 1, 1, 0, 0, 0).unwrap();
683 let res = SamplingEvaluator::new(before_timerange)
684 .match_rules(Uuid::default(), &dsc, [rule.clone()].iter())
685 .await;
686 assert!(!evaluation_is_match(res));
687
688 let after_timerange = Utc.with_ymd_and_hms(1971, 1, 1, 0, 0, 0).unwrap();
689 let res = SamplingEvaluator::new(after_timerange)
690 .match_rules(Uuid::default(), &dsc, [rule].iter())
691 .await;
692 assert!(!evaluation_is_match(res));
693 }
694
695 #[tokio::test]
697 async fn test_condition_matching() {
698 let rules = simple_sampling_rules(vec![
699 (
700 RuleCondition::glob("trace.transaction", "*healthcheck*"),
701 SamplingValue::SampleRate { value: 1.0 },
702 ),
703 (
704 RuleCondition::glob("trace.environment", "*dev*"),
705 SamplingValue::SampleRate { value: 1.0 },
706 ),
707 (
708 RuleCondition::eq_ignore_case("trace.transaction", "raboof"),
709 SamplingValue::Factor { value: 1.0 },
710 ),
711 (
712 RuleCondition::glob("trace.release", "1.1.1")
713 & RuleCondition::eq_ignore_case("trace.user.segment", "vip"),
714 SamplingValue::SampleRate { value: 1.0 },
715 ),
716 (
717 RuleCondition::eq_ignore_case("trace.release", "1.1.1")
718 & RuleCondition::eq_ignore_case("trace.environment", "prod"),
719 SamplingValue::Factor { value: 1.0 },
720 ),
721 (
722 RuleCondition::all(),
723 SamplingValue::SampleRate { value: 1.0 },
724 ),
725 ]);
726
727 let dsc = mocked_dsc_with_getter_values(vec![("trace.transaction", "foohealthcheckbar")]);
729 assert!(matches_rule_ids(&[0], &rules, &dsc).await);
730
731 let dsc = mocked_dsc_with_getter_values(vec![("trace.environment", "dev")]);
733 assert!(matches_rule_ids(&[1], &rules, &dsc).await);
734
735 let dsc = mocked_dsc_with_getter_values(vec![("trace.transaction", "raboof")]);
737 assert!(matches_rule_ids(&[2, 5], &rules, &dsc).await);
738
739 let dsc = mocked_dsc_with_getter_values(vec![
741 ("trace.transaction", "raboof"),
742 ("trace.release", "1.1.1"),
743 ("trace.user.segment", "vip"),
744 ]);
745 assert!(matches_rule_ids(&[2, 3], &rules, &dsc).await);
746
747 let dsc = mocked_dsc_with_getter_values(vec![
749 ("trace.transaction", "raboof"),
750 ("trace.release", "1.1.1"),
751 ("trace.environment", "prod"),
752 ]);
753 assert!(matches_rule_ids(&[2, 4, 5], &rules, &dsc).await);
754
755 let dsc = mocked_dsc_with_getter_values(vec![
757 ("trace.release", "1.1.1"),
758 ("trace.environment", "prod"),
759 ]);
760 assert!(matches_rule_ids(&[4, 5], &rules, &dsc).await);
761 }
762
763 #[test]
764 fn test_repeatable_seed() {
766 let val1 = pseudo_random_from_seed(Uuid::default());
767 let val2 = pseudo_random_from_seed(Uuid::default());
768 assert!(val1 + f64::EPSILON > val2 && val2 + f64::EPSILON > val1);
769 }
770
771 #[test]
772 fn matched_rule_ids_display() {
774 let matched_rule_ids = MatchedRuleIds(vec![RuleId(123), RuleId(456)]);
775 assert_eq!(matched_rule_ids.to_string(), "123,456");
776
777 let matched_rule_ids = MatchedRuleIds(vec![RuleId(123)]);
778 assert_eq!(matched_rule_ids.to_string(), "123");
779
780 let matched_rule_ids = MatchedRuleIds(vec![]);
781 assert_eq!(matched_rule_ids.to_string(), "")
782 }
783
784 #[test]
785 fn matched_rule_ids_parse() {
787 assert_eq!(
788 MatchedRuleIds::parse("123,456"),
789 Ok(MatchedRuleIds(vec![RuleId(123), RuleId(456)]))
790 );
791
792 assert_eq!(
793 MatchedRuleIds::parse("123"),
794 Ok(MatchedRuleIds(vec![RuleId(123)]))
795 );
796
797 assert!(MatchedRuleIds::parse("").is_err());
798
799 assert!(MatchedRuleIds::parse(",").is_err());
800
801 assert!(MatchedRuleIds::parse("123.456").is_err());
802
803 assert!(MatchedRuleIds::parse("a,b").is_err());
804 }
805
806 #[tokio::test]
807 async fn test_get_sampling_match_result_with_no_match() {
809 let dsc = mocked_dsc_with_getter_values(vec![]);
810
811 let res = SamplingEvaluator::new(Utc::now())
812 .match_rules(Uuid::default(), &dsc, [].iter())
813 .await;
814
815 assert!(!evaluation_is_match(res));
816 }
817
818 #[tokio::test]
823 async fn test_sample_rate_valid_time_range() {
824 let dsc = mocked_dsc_with_getter_values(vec![]);
825 let time_range = TimeRange {
826 start: Some(Utc.with_ymd_and_hms(1970, 1, 1, 0, 0, 0).unwrap()),
827 end: Some(Utc.with_ymd_and_hms(1980, 1, 1, 0, 0, 0).unwrap()),
828 };
829
830 let before_time_range = Utc.with_ymd_and_hms(1969, 1, 1, 0, 0, 0).unwrap();
831 let during_time_range = Utc.with_ymd_and_hms(1975, 1, 1, 0, 0, 0).unwrap();
832 let after_time_range = Utc.with_ymd_and_hms(1981, 1, 1, 0, 0, 0).unwrap();
833
834 let rule = SamplingRule {
835 condition: RuleCondition::all(),
836 sampling_value: SamplingValue::SampleRate { value: 1.0 },
837 ty: RuleType::Trace,
838 id: RuleId(0),
839 time_range,
840 decaying_fn: DecayingFunction::Constant,
841 };
842
843 assert!(!is_match(before_time_range, &rule, &dsc).await);
845 assert!(is_match(during_time_range, &rule, &dsc).await);
846 assert!(!is_match(after_time_range, &rule, &dsc).await);
847
848 let mut rule_without_end = rule.clone();
850 rule_without_end.time_range.end = None;
851 assert!(!is_match(before_time_range, &rule_without_end, &dsc).await);
852 assert!(is_match(during_time_range, &rule_without_end, &dsc).await);
853 assert!(is_match(after_time_range, &rule_without_end, &dsc).await);
854
855 let mut rule_without_start = rule.clone();
857 rule_without_start.time_range.start = None;
858 assert!(is_match(before_time_range, &rule_without_start, &dsc).await);
859 assert!(is_match(during_time_range, &rule_without_start, &dsc).await);
860 assert!(!is_match(after_time_range, &rule_without_start, &dsc).await);
861
862 let mut rule_without_range = rule.clone();
864 rule_without_range.time_range = TimeRange::default();
865 assert!(is_match(before_time_range, &rule_without_range, &dsc).await);
866 assert!(is_match(during_time_range, &rule_without_range, &dsc).await);
867 assert!(is_match(after_time_range, &rule_without_range, &dsc).await);
868 }
869
870 #[tokio::test]
872 async fn test_validate_match() {
873 let mut rule = mocked_sampling_rule();
874
875 let reservoir = ReservoirEvaluator::new(ReservoirCounters::default());
876 let mut eval = SamplingEvaluator::new_with_reservoir(Utc::now(), &reservoir);
877
878 rule.sampling_value = SamplingValue::SampleRate { value: 1.0 };
879 assert_eq!(eval.try_compute_sample_rate(&rule).await, Some(1.0));
880
881 rule.sampling_value = SamplingValue::Factor { value: 1.0 };
882 assert_eq!(eval.try_compute_sample_rate(&rule).await, None);
883
884 rule.sampling_value = SamplingValue::Reservoir { limit: 1 };
885 assert_eq!(eval.try_compute_sample_rate(&rule).await, Some(1.0));
886 }
887}