relay_event_normalization/normalize/span/
ai.rs

1//! AI cost calculation.
2
3use crate::ModelCosts;
4use relay_base_schema::metrics::MetricUnit;
5use relay_event_schema::protocol::{Event, Measurement, Span};
6
7/// Calculated cost is in US dollars.
8fn calculate_ai_model_cost(
9    model_id: &str,
10    prompt_tokens_used: Option<f64>,
11    completion_tokens_used: Option<f64>,
12    total_tokens_used: Option<f64>,
13    ai_model_costs: &ModelCosts,
14) -> Option<f64> {
15    if let Some(prompt_tokens) = prompt_tokens_used {
16        if let Some(completion_tokens) = completion_tokens_used {
17            let mut result = 0.0;
18            if let Some(cost_per_1k) = ai_model_costs.cost_per_1k_tokens(model_id, false) {
19                result += cost_per_1k * (prompt_tokens / 1000.0)
20            }
21            if let Some(cost_per_1k) = ai_model_costs.cost_per_1k_tokens(model_id, true) {
22                result += cost_per_1k * (completion_tokens / 1000.0)
23            }
24            return Some(result);
25        }
26    }
27    if let Some(total_tokens) = total_tokens_used {
28        ai_model_costs
29            .cost_per_1k_tokens(model_id, false)
30            .map(|cost| cost * (total_tokens / 1000.0))
31    } else {
32        None
33    }
34}
35
36/// Extract the ai_total_cost measurement into the span.
37pub fn extract_ai_measurements(span: &mut Span, ai_model_costs: &ModelCosts) {
38    let Some(span_op) = span.op.value() else {
39        return;
40    };
41
42    if !span_op.starts_with("ai.") {
43        return;
44    }
45
46    let Some(measurements) = span.measurements.value() else {
47        return;
48    };
49
50    let total_tokens_used = measurements.get_value("ai_total_tokens_used");
51    let prompt_tokens_used = measurements.get_value("ai_prompt_tokens_used");
52    let completion_tokens_used = measurements.get_value("ai_completion_tokens_used");
53    if let Some(model_id) = span
54        .data
55        .value()
56        .and_then(|d| d.ai_model_id.value())
57        .and_then(|val| val.as_str())
58    {
59        if let Some(total_cost) = calculate_ai_model_cost(
60            model_id,
61            prompt_tokens_used,
62            completion_tokens_used,
63            total_tokens_used,
64            ai_model_costs,
65        ) {
66            span.measurements
67                .get_or_insert_with(Default::default)
68                .insert(
69                    "ai_total_cost".to_owned(),
70                    Measurement {
71                        value: total_cost.into(),
72                        unit: MetricUnit::None.into(),
73                    }
74                    .into(),
75                );
76        }
77    }
78}
79
80/// Extract the ai_total_cost measurements from all of an event's spans
81pub fn normalize_ai_measurements(event: &mut Event, model_costs: Option<&ModelCosts>) {
82    if let Some(model_costs) = model_costs {
83        if let Some(spans) = event.spans.value_mut() {
84            for span in spans {
85                if let Some(mut_span) = span.value_mut() {
86                    extract_ai_measurements(mut_span, model_costs);
87                }
88            }
89        }
90    }
91}