use std::collections::BTreeMap;
use std::fmt;
use std::num::ParseIntError;
use std::ops::ControlFlow;
use std::sync::{Arc, Mutex};
use chrono::{DateTime, Utc};
use rand::distributions::Uniform;
use rand::Rng;
use rand_pcg::Pcg32;
#[cfg(feature = "redis")]
use relay_base_schema::organization::OrganizationId;
use relay_protocol::Getter;
#[cfg(feature = "redis")]
use relay_redis::RedisPool;
use serde::Serialize;
use uuid::Uuid;
use crate::config::{RuleId, SamplingRule, SamplingValue};
#[cfg(feature = "redis")]
use crate::redis_sampling::{self, ReservoirRuleKey};
fn pseudo_random_from_uuid(id: Uuid) -> f64 {
let big_seed = id.as_u128();
let mut generator = Pcg32::new((big_seed >> 64) as u64, big_seed as u64);
let dist = Uniform::new(0f64, 1f64);
generator.sample(dist)
}
pub type ReservoirCounters = Arc<Mutex<BTreeMap<RuleId, i64>>>;
#[derive(Debug)]
pub struct ReservoirEvaluator<'a> {
counters: ReservoirCounters,
#[cfg(feature = "redis")]
org_id_and_redis_pool: Option<(OrganizationId, &'a RedisPool)>,
_phantom: std::marker::PhantomData<&'a ()>,
}
impl<'a> ReservoirEvaluator<'a> {
pub fn new(counters: ReservoirCounters) -> Self {
Self {
counters,
#[cfg(feature = "redis")]
org_id_and_redis_pool: None,
_phantom: std::marker::PhantomData,
}
}
pub fn counters(&self) -> ReservoirCounters {
Arc::clone(&self.counters)
}
#[cfg(feature = "redis")]
pub fn set_redis(&mut self, org_id: OrganizationId, redis_pool: &'a RedisPool) {
self.org_id_and_redis_pool = Some((org_id, redis_pool));
}
#[cfg(feature = "redis")]
fn redis_incr(
&self,
key: &ReservoirRuleKey,
redis_pool: &RedisPool,
rule_expiry: Option<&DateTime<Utc>>,
) -> anyhow::Result<i64> {
let mut redis_client = redis_pool.client()?;
let mut redis_connection = redis_client.connection()?;
let val = redis_sampling::increment_redis_reservoir_count(&mut redis_connection, key)?;
redis_sampling::set_redis_expiry(&mut redis_connection, key, rule_expiry)?;
Ok(val)
}
pub fn incr_local(&self, rule: RuleId, limit: i64) -> bool {
let Ok(mut map_guard) = self.counters.lock() else {
relay_log::error!("failed to lock reservoir counter mutex");
return false;
};
let counter_value = map_guard.entry(rule).or_insert(0);
if *counter_value < limit {
*counter_value += 1;
true
} else {
false
}
}
pub fn evaluate(&self, rule: RuleId, limit: i64, _rule_expiry: Option<&DateTime<Utc>>) -> bool {
#[cfg(feature = "redis")]
if let Some((org_id, redis_pool)) = self.org_id_and_redis_pool {
if let Ok(guard) = self.counters.lock() {
if *guard.get(&rule).unwrap_or(&0) > limit {
return false;
}
}
let key = ReservoirRuleKey::new(org_id, rule);
let redis_count = match self.redis_incr(&key, redis_pool, _rule_expiry) {
Ok(redis_count) => redis_count,
Err(e) => {
relay_log::error!(error = &*e, "failed to increment reservoir rule");
return false;
}
};
if let Ok(mut map_guard) = self.counters.lock() {
if let Some(value) = map_guard.get_mut(&rule) {
*value = redis_count.max(*value);
}
}
return redis_count <= limit;
}
self.incr_local(rule, limit)
}
}
#[derive(Debug)]
pub struct SamplingEvaluator<'a> {
now: DateTime<Utc>,
rule_ids: Vec<RuleId>,
factor: f64,
reservoir: Option<&'a ReservoirEvaluator<'a>>,
}
impl<'a> SamplingEvaluator<'a> {
pub fn new_with_reservoir(now: DateTime<Utc>, reservoir: &'a ReservoirEvaluator<'a>) -> Self {
Self {
now,
rule_ids: vec![],
factor: 1.0,
reservoir: Some(reservoir),
}
}
pub fn new(now: DateTime<Utc>) -> Self {
Self {
now,
rule_ids: vec![],
factor: 1.0,
reservoir: None,
}
}
pub fn match_rules<'b, I, G>(
mut self,
seed: Uuid,
instance: &G,
rules: I,
) -> ControlFlow<SamplingMatch, Self>
where
G: Getter,
I: Iterator<Item = &'b SamplingRule>,
{
for rule in rules {
if !rule.time_range.contains(self.now) || !rule.condition.matches(instance) {
continue;
};
if let Some(sample_rate) = self.try_compute_sample_rate(rule) {
return ControlFlow::Break(SamplingMatch::new(sample_rate, seed, self.rule_ids));
};
}
ControlFlow::Continue(self)
}
fn try_compute_sample_rate(&mut self, rule: &SamplingRule) -> Option<f64> {
match rule.sampling_value {
SamplingValue::Factor { value } => {
self.factor *= rule.apply_decaying_fn(value, self.now)?;
self.rule_ids.push(rule.id);
None
}
SamplingValue::SampleRate { value } => {
let sample_rate = rule.apply_decaying_fn(value, self.now)?;
let adjusted = (sample_rate * self.factor).clamp(0.0, 1.0);
self.rule_ids.push(rule.id);
Some(adjusted)
}
SamplingValue::Reservoir { limit } => {
let reservoir = self.reservoir?;
if !reservoir.evaluate(rule.id, limit, rule.time_range.end.as_ref()) {
return None;
}
self.rule_ids.clear();
self.rule_ids.push(rule.id);
Some(1.0)
}
}
}
}
fn sampling_match(sample_rate: f64, seed: Uuid) -> SamplingDecision {
if sample_rate <= 0.0 {
return SamplingDecision::Drop;
} else if sample_rate >= 1.0 {
return SamplingDecision::Keep;
}
let random_number = pseudo_random_from_uuid(seed);
relay_log::trace!(
sample_rate,
random_number,
"applying dynamic sampling to matching event"
);
if random_number >= sample_rate {
relay_log::trace!("dropping event that matched the configuration");
SamplingDecision::Drop
} else {
relay_log::trace!("keeping event that matched the configuration");
SamplingDecision::Keep
}
}
#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub enum SamplingDecision {
Keep,
Drop,
}
impl SamplingDecision {
pub fn is_keep(self) -> bool {
matches!(self, Self::Keep)
}
pub fn is_drop(self) -> bool {
matches!(self, Self::Drop)
}
pub fn as_str(self) -> &'static str {
match self {
Self::Keep => "keep",
Self::Drop => "drop",
}
}
}
impl fmt::Display for SamplingDecision {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.as_str())
}
}
#[derive(Clone, Debug, PartialEq)]
pub struct SamplingMatch {
sample_rate: f64,
seed: Uuid,
matched_rules: MatchedRuleIds,
decision: SamplingDecision,
}
impl SamplingMatch {
fn new(sample_rate: f64, seed: Uuid, matched_rules: Vec<RuleId>) -> Self {
let matched_rules = MatchedRuleIds(matched_rules);
let decision = sampling_match(sample_rate, seed);
Self {
sample_rate,
seed,
matched_rules,
decision,
}
}
pub fn sample_rate(&self) -> f64 {
self.sample_rate
}
pub fn into_matched_rules(self) -> MatchedRuleIds {
self.matched_rules
}
pub fn decision(&self) -> SamplingDecision {
self.decision
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize)]
pub struct MatchedRuleIds(pub Vec<RuleId>);
impl MatchedRuleIds {
pub fn parse(value: &str) -> Result<MatchedRuleIds, ParseIntError> {
let mut rule_ids = vec![];
for rule_id in value.split(',') {
rule_ids.push(RuleId(rule_id.parse()?));
}
Ok(MatchedRuleIds(rule_ids))
}
}
impl fmt::Display for MatchedRuleIds {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
for (i, rule_id) in self.0.iter().enumerate() {
if i > 0 {
write!(f, ",")?;
}
write!(f, "{rule_id}")?;
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use std::str::FromStr;
use chrono::TimeZone;
use relay_protocol::RuleCondition;
use similar_asserts::assert_eq;
use crate::config::{DecayingFunction, RuleType, TimeRange};
use crate::dsc::TraceUserContext;
use crate::DynamicSamplingContext;
use super::*;
fn mock_reservoir_evaluator(vals: Vec<(u32, i64)>) -> ReservoirEvaluator<'static> {
let mut map = BTreeMap::default();
for (rule_id, count) in vals {
map.insert(RuleId(rule_id), count);
}
let map = Arc::new(Mutex::new(map));
ReservoirEvaluator::new(map)
}
fn get_sampling_match(rules: &[SamplingRule], instance: &impl Getter) -> SamplingMatch {
match SamplingEvaluator::new(Utc::now()).match_rules(
Uuid::default(),
instance,
rules.iter(),
) {
ControlFlow::Break(sampling_match) => sampling_match,
ControlFlow::Continue(_) => panic!("no match found"),
}
}
fn evaluation_is_match(res: ControlFlow<SamplingMatch, SamplingEvaluator>) -> bool {
matches!(res, ControlFlow::Break(_))
}
fn matches_rule_ids(rule_ids: &[u32], rules: &[SamplingRule], instance: &impl Getter) -> bool {
let matched_rule_ids = MatchedRuleIds(rule_ids.iter().map(|num| RuleId(*num)).collect());
let sampling_match = get_sampling_match(rules, instance);
matched_rule_ids == sampling_match.matched_rules
}
fn get_matched_rules(
sampling_evaluator: &ControlFlow<SamplingMatch, SamplingEvaluator>,
) -> Vec<u32> {
match sampling_evaluator {
ControlFlow::Continue(_) => panic!("expected a sampling match"),
ControlFlow::Break(m) => m.matched_rules.0.iter().map(|rule_id| rule_id.0).collect(),
}
}
fn mocked_dsc_with_getter_values(
paths_and_values: Vec<(&str, &str)>,
) -> DynamicSamplingContext {
let mut dsc = DynamicSamplingContext {
trace_id: Uuid::new_v4(),
public_key: "12345678123456781234567812345678".parse().unwrap(),
release: None,
environment: None,
transaction: None,
sample_rate: None,
user: TraceUserContext::default(),
replay_id: None,
sampled: None,
other: Default::default(),
};
for (path, value) in paths_and_values {
match path {
"trace.release" => dsc.release = Some(value.to_owned()),
"trace.environment" => dsc.environment = Some(value.to_owned()),
"trace.user.id" => value.clone_into(&mut dsc.user.user_id),
"trace.user.segment" => value.clone_into(&mut dsc.user.user_segment),
"trace.transaction" => dsc.transaction = Some(value.to_owned()),
"trace.replay_id" => dsc.replay_id = Some(Uuid::from_str(value).unwrap()),
_ => panic!("invalid path"),
}
}
dsc
}
#[test]
fn test_reservoir_evaluator_limit() {
let evaluator = mock_reservoir_evaluator(vec![(1, 0)]);
let rule = RuleId(1);
let limit = 3;
assert!(evaluator.evaluate(rule, limit, None));
assert!(evaluator.evaluate(rule, limit, None));
assert!(evaluator.evaluate(rule, limit, None));
assert!(!evaluator.evaluate(rule, limit, None));
assert!(!evaluator.evaluate(rule, limit, None));
}
#[test]
fn test_sample_rate_compounding() {
let rules = simple_sampling_rules(vec![
(RuleCondition::all(), SamplingValue::Factor { value: 0.8 }),
(RuleCondition::all(), SamplingValue::Factor { value: 0.5 }),
(
RuleCondition::all(),
SamplingValue::SampleRate { value: 0.25 },
),
]);
let dsc = mocked_dsc_with_getter_values(vec![]);
assert_eq!(get_sampling_match(&rules, &dsc).sample_rate(), 0.1);
}
fn mocked_sampling_rule() -> SamplingRule {
SamplingRule {
condition: RuleCondition::all(),
sampling_value: SamplingValue::SampleRate { value: 1.0 },
ty: RuleType::Trace,
id: RuleId(0),
time_range: Default::default(),
decaying_fn: Default::default(),
}
}
fn simple_sampling_rules(vals: Vec<(RuleCondition, SamplingValue)>) -> Vec<SamplingRule> {
let mut vec = vec![];
for (i, val) in vals.into_iter().enumerate() {
let (condition, sampling_value) = val;
vec.push(SamplingRule {
condition,
sampling_value,
ty: RuleType::Trace,
id: RuleId(i as u32),
time_range: Default::default(),
decaying_fn: Default::default(),
});
}
vec
}
#[test]
fn test_reservoir_override() {
let dsc = mocked_dsc_with_getter_values(vec![]);
let rules = simple_sampling_rules(vec![
(RuleCondition::all(), SamplingValue::Factor { value: 0.5 }),
(RuleCondition::all(), SamplingValue::Reservoir { limit: 2 }),
(
RuleCondition::all(),
SamplingValue::SampleRate { value: 0.5 },
),
]);
let reservoir = mock_reservoir_evaluator(vec![]);
let evaluator = SamplingEvaluator::new_with_reservoir(Utc::now(), &reservoir);
let matched_rules =
get_matched_rules(&evaluator.match_rules(Uuid::default(), &dsc, rules.iter()));
assert_eq!(&matched_rules, &[1]);
let evaluator = SamplingEvaluator::new_with_reservoir(Utc::now(), &reservoir);
let matched_rules =
get_matched_rules(&evaluator.match_rules(Uuid::default(), &dsc, rules.iter()));
assert_eq!(&matched_rules, &[1]);
let evaluator = SamplingEvaluator::new_with_reservoir(Utc::now(), &reservoir);
let matched_rules =
get_matched_rules(&evaluator.match_rules(Uuid::default(), &dsc, rules.iter()));
assert_eq!(&matched_rules, &[0, 2]);
}
#[test]
fn test_expired_rules() {
let rule = SamplingRule {
condition: RuleCondition::all(),
sampling_value: SamplingValue::SampleRate { value: 1.0 },
ty: RuleType::Trace,
id: RuleId(0),
time_range: TimeRange {
start: Some(Utc.with_ymd_and_hms(1970, 10, 10, 0, 0, 0).unwrap()),
end: Some(Utc.with_ymd_and_hms(1970, 10, 12, 0, 0, 0).unwrap()),
},
decaying_fn: Default::default(),
};
let dsc = mocked_dsc_with_getter_values(vec![]);
let within_timerange = Utc.with_ymd_and_hms(1970, 10, 11, 0, 0, 0).unwrap();
let res = SamplingEvaluator::new(within_timerange).match_rules(
Uuid::default(),
&dsc,
[rule.clone()].iter(),
);
assert!(evaluation_is_match(res));
let before_timerange = Utc.with_ymd_and_hms(1969, 1, 1, 0, 0, 0).unwrap();
let res = SamplingEvaluator::new(before_timerange).match_rules(
Uuid::default(),
&dsc,
[rule.clone()].iter(),
);
assert!(!evaluation_is_match(res));
let after_timerange = Utc.with_ymd_and_hms(1971, 1, 1, 0, 0, 0).unwrap();
let res = SamplingEvaluator::new(after_timerange).match_rules(
Uuid::default(),
&dsc,
[rule].iter(),
);
assert!(!evaluation_is_match(res));
}
#[test]
fn test_condition_matching() {
let rules = simple_sampling_rules(vec![
(
RuleCondition::glob("trace.transaction", "*healthcheck*"),
SamplingValue::SampleRate { value: 1.0 },
),
(
RuleCondition::glob("trace.environment", "*dev*"),
SamplingValue::SampleRate { value: 1.0 },
),
(
RuleCondition::eq_ignore_case("trace.transaction", "raboof"),
SamplingValue::Factor { value: 1.0 },
),
(
RuleCondition::glob("trace.release", "1.1.1")
& RuleCondition::eq_ignore_case("trace.user.segment", "vip"),
SamplingValue::SampleRate { value: 1.0 },
),
(
RuleCondition::eq_ignore_case("trace.release", "1.1.1")
& RuleCondition::eq_ignore_case("trace.environment", "prod"),
SamplingValue::Factor { value: 1.0 },
),
(
RuleCondition::all(),
SamplingValue::SampleRate { value: 1.0 },
),
]);
let dsc = mocked_dsc_with_getter_values(vec![("trace.transaction", "foohealthcheckbar")]);
assert!(matches_rule_ids(&[0], &rules, &dsc));
let dsc = mocked_dsc_with_getter_values(vec![("trace.environment", "dev")]);
assert!(matches_rule_ids(&[1], &rules, &dsc));
let dsc = mocked_dsc_with_getter_values(vec![("trace.transaction", "raboof")]);
assert!(matches_rule_ids(&[2, 5], &rules, &dsc));
let dsc = mocked_dsc_with_getter_values(vec![
("trace.transaction", "raboof"),
("trace.release", "1.1.1"),
("trace.user.segment", "vip"),
]);
assert!(matches_rule_ids(&[2, 3], &rules, &dsc));
let dsc = mocked_dsc_with_getter_values(vec![
("trace.transaction", "raboof"),
("trace.release", "1.1.1"),
("trace.environment", "prod"),
]);
assert!(matches_rule_ids(&[2, 4, 5], &rules, &dsc));
let dsc = mocked_dsc_with_getter_values(vec![
("trace.release", "1.1.1"),
("trace.environment", "prod"),
]);
assert!(matches_rule_ids(&[4, 5], &rules, &dsc));
}
#[test]
fn test_repeatable_seed() {
let id = "4a106cf6-b151-44eb-9131-ae7db1a157a3".parse().unwrap();
let val1 = pseudo_random_from_uuid(id);
let val2 = pseudo_random_from_uuid(id);
assert!(val1 + f64::EPSILON > val2 && val2 + f64::EPSILON > val1);
}
#[test]
fn matched_rule_ids_display() {
let matched_rule_ids = MatchedRuleIds(vec![RuleId(123), RuleId(456)]);
assert_eq!(matched_rule_ids.to_string(), "123,456");
let matched_rule_ids = MatchedRuleIds(vec![RuleId(123)]);
assert_eq!(matched_rule_ids.to_string(), "123");
let matched_rule_ids = MatchedRuleIds(vec![]);
assert_eq!(matched_rule_ids.to_string(), "")
}
#[test]
fn matched_rule_ids_parse() {
assert_eq!(
MatchedRuleIds::parse("123,456"),
Ok(MatchedRuleIds(vec![RuleId(123), RuleId(456)]))
);
assert_eq!(
MatchedRuleIds::parse("123"),
Ok(MatchedRuleIds(vec![RuleId(123)]))
);
assert!(MatchedRuleIds::parse("").is_err());
assert!(MatchedRuleIds::parse(",").is_err());
assert!(MatchedRuleIds::parse("123.456").is_err());
assert!(MatchedRuleIds::parse("a,b").is_err());
}
#[test]
fn test_get_sampling_match_result_with_no_match() {
let dsc = mocked_dsc_with_getter_values(vec![]);
let res = SamplingEvaluator::new(Utc::now()).match_rules(Uuid::default(), &dsc, [].iter());
assert!(!evaluation_is_match(res));
}
#[test]
fn test_sample_rate_valid_time_range() {
let dsc = mocked_dsc_with_getter_values(vec![]);
let time_range = TimeRange {
start: Some(Utc.with_ymd_and_hms(1970, 1, 1, 0, 0, 0).unwrap()),
end: Some(Utc.with_ymd_and_hms(1980, 1, 1, 0, 0, 0).unwrap()),
};
let before_time_range = Utc.with_ymd_and_hms(1969, 1, 1, 0, 0, 0).unwrap();
let during_time_range = Utc.with_ymd_and_hms(1975, 1, 1, 0, 0, 0).unwrap();
let after_time_range = Utc.with_ymd_and_hms(1981, 1, 1, 0, 0, 0).unwrap();
let rule = SamplingRule {
condition: RuleCondition::all(),
sampling_value: SamplingValue::SampleRate { value: 1.0 },
ty: RuleType::Trace,
id: RuleId(0),
time_range,
decaying_fn: DecayingFunction::Constant,
};
let is_match = |now: DateTime<Utc>, rule: &SamplingRule| -> bool {
SamplingEvaluator::new(now)
.match_rules(Uuid::default(), &dsc, [rule.clone()].iter())
.is_break()
};
assert!(!is_match(before_time_range, &rule));
assert!(is_match(during_time_range, &rule));
assert!(!is_match(after_time_range, &rule));
let mut rule_without_end = rule.clone();
rule_without_end.time_range.end = None;
assert!(!is_match(before_time_range, &rule_without_end));
assert!(is_match(during_time_range, &rule_without_end));
assert!(is_match(after_time_range, &rule_without_end));
let mut rule_without_start = rule.clone();
rule_without_start.time_range.start = None;
assert!(is_match(before_time_range, &rule_without_start));
assert!(is_match(during_time_range, &rule_without_start));
assert!(!is_match(after_time_range, &rule_without_start));
let mut rule_without_range = rule.clone();
rule_without_range.time_range = TimeRange::default();
assert!(is_match(before_time_range, &rule_without_range));
assert!(is_match(during_time_range, &rule_without_range));
assert!(is_match(after_time_range, &rule_without_range));
}
#[test]
fn test_validate_match() {
let mut rule = mocked_sampling_rule();
let reservoir = ReservoirEvaluator::new(ReservoirCounters::default());
let mut eval = SamplingEvaluator::new_with_reservoir(Utc::now(), &reservoir);
rule.sampling_value = SamplingValue::SampleRate { value: 1.0 };
assert_eq!(eval.try_compute_sample_rate(&rule), Some(1.0));
rule.sampling_value = SamplingValue::Factor { value: 1.0 };
assert_eq!(eval.try_compute_sample_rate(&rule), None);
rule.sampling_value = SamplingValue::Reservoir { limit: 1 };
assert_eq!(eval.try_compute_sample_rate(&rule), Some(1.0));
}
}