relay_event_normalization/normalize/span/
ai.rs

1//! AI cost calculation.
2
3use crate::normalize::AiOperationTypeMap;
4use crate::{ModelCostV2, ModelCosts};
5use relay_event_schema::protocol::{Event, Span, SpanData};
6use relay_protocol::{Annotated, Getter, Value};
7
8/// Amount of used tokens for a model call.
9#[derive(Debug, Copy, Clone)]
10pub struct UsedTokens {
11    /// Total amount of input tokens used.
12    pub input_tokens: f64,
13    /// Amount of cached tokens used.
14    ///
15    /// This is a subset of [`Self::input_tokens`].
16    pub input_cached_tokens: f64,
17    /// Total amount of output tokens.
18    pub output_tokens: f64,
19    /// Total amount of reasoning tokens.
20    ///
21    /// This is a subset of [`Self::output_tokens`].
22    pub output_reasoning_tokens: f64,
23}
24
25impl UsedTokens {
26    /// Extracts [`UsedTokens`] from [`SpanData`] attributes.
27    pub fn from_span_data(data: &SpanData) -> Self {
28        macro_rules! get_value {
29            ($e:expr) => {
30                $e.value().and_then(Value::as_f64).unwrap_or(0.0)
31            };
32        }
33
34        Self {
35            input_tokens: get_value!(data.gen_ai_usage_input_tokens),
36            output_tokens: get_value!(data.gen_ai_usage_output_tokens),
37            output_reasoning_tokens: get_value!(data.gen_ai_usage_output_tokens_reasoning),
38            input_cached_tokens: get_value!(data.gen_ai_usage_input_tokens_cached),
39        }
40    }
41
42    /// Returns `true` if any tokens were used.
43    pub fn has_usage(&self) -> bool {
44        self.input_tokens > 0.0 || self.output_tokens > 0.0
45    }
46
47    /// Calculates the total amount of uncached input tokens.
48    ///
49    /// Subtracts cached tokens from the total token count.
50    pub fn raw_input_tokens(&self) -> f64 {
51        self.input_tokens - self.input_cached_tokens
52    }
53
54    /// Calculates the total amount of raw, non-reasoning output tokens.
55    ///
56    /// Subtracts reasoning tokens from the total token count.
57    pub fn raw_output_tokens(&self) -> f64 {
58        self.output_tokens - self.output_reasoning_tokens
59    }
60}
61
62/// Calculated model call costs.
63#[derive(Debug, Copy, Clone)]
64pub struct CalculatedCost {
65    /// The cost of input tokens used.
66    pub input: f64,
67    /// The cost of output tokens used.
68    pub output: f64,
69}
70
71impl CalculatedCost {
72    /// The total, input and output, cost.
73    pub fn total(&self) -> f64 {
74        self.input + self.output
75    }
76}
77
78/// Calculates the total cost for a model call.
79///
80/// Returns `None` if no tokens were used.
81pub fn calculate_costs(model_cost: &ModelCostV2, tokens: UsedTokens) -> Option<CalculatedCost> {
82    if !tokens.has_usage() {
83        return None;
84    }
85
86    let input = (tokens.raw_input_tokens() * model_cost.input_per_token)
87        + (tokens.input_cached_tokens * model_cost.input_cached_per_token);
88
89    // For now most of the models do not differentiate between reasoning and output token cost,
90    // it costs the same.
91    let reasoning_cost = match model_cost.output_reasoning_per_token {
92        reasoning_cost if reasoning_cost > 0.0 => reasoning_cost,
93        _ => model_cost.output_per_token,
94    };
95
96    let output = (tokens.raw_output_tokens() * model_cost.output_per_token)
97        + (tokens.output_reasoning_tokens * reasoning_cost);
98
99    Some(CalculatedCost { input, output })
100}
101
102/// Calculates the cost of an AI model based on the model cost and the tokens used.
103/// Calculated cost is in US dollars.
104fn extract_ai_model_cost_data(model_cost: Option<&ModelCostV2>, data: &mut SpanData) {
105    let Some(model_cost) = model_cost else { return };
106
107    let used_tokens = UsedTokens::from_span_data(&*data);
108    let Some(costs) = calculate_costs(model_cost, used_tokens) else {
109        return;
110    };
111
112    // double write during migration period
113    // 'gen_ai_usage_total_cost' is deprecated and will be removed in the future
114    data.gen_ai_usage_total_cost
115        .set_value(Value::F64(costs.total()).into());
116    data.gen_ai_cost_total_tokens
117        .set_value(Value::F64(costs.total()).into());
118
119    // Set individual cost components
120    data.gen_ai_cost_input_tokens
121        .set_value(Value::F64(costs.input).into());
122    data.gen_ai_cost_output_tokens
123        .set_value(Value::F64(costs.output).into());
124}
125
126/// Maps AI-related measurements (legacy) to span data.
127fn map_ai_measurements_to_data(span: &mut Span) {
128    let measurements = span.measurements.value();
129    let data = span.data.get_or_insert_with(SpanData::default);
130
131    let set_field_from_measurement = |target_field: &mut Annotated<Value>,
132                                      measurement_key: &str| {
133        if let Some(measurements) = measurements
134            && target_field.value().is_none()
135            && let Some(value) = measurements.get_value(measurement_key)
136        {
137            target_field.set_value(Value::F64(value.to_f64()).into());
138        }
139    };
140
141    set_field_from_measurement(&mut data.gen_ai_usage_total_tokens, "ai_total_tokens_used");
142    set_field_from_measurement(&mut data.gen_ai_usage_input_tokens, "ai_prompt_tokens_used");
143    set_field_from_measurement(
144        &mut data.gen_ai_usage_output_tokens,
145        "ai_completion_tokens_used",
146    );
147}
148
149fn set_total_tokens(span: &mut Span) {
150    let data = span.data.get_or_insert_with(SpanData::default);
151
152    // It might be that 'total_tokens' is not set in which case we need to calculate it
153    if data.gen_ai_usage_total_tokens.value().is_none() {
154        let input_tokens = data
155            .gen_ai_usage_input_tokens
156            .value()
157            .and_then(Value::as_f64);
158        let output_tokens = data
159            .gen_ai_usage_output_tokens
160            .value()
161            .and_then(Value::as_f64);
162
163        if input_tokens.is_none() && output_tokens.is_none() {
164            // don't set total_tokens if there are no input nor output tokens
165            return;
166        }
167
168        data.gen_ai_usage_total_tokens.set_value(
169            Value::F64(input_tokens.unwrap_or(0.0) + output_tokens.unwrap_or(0.0)).into(),
170        );
171    }
172}
173
174/// Extract the additional data into the span
175fn extract_ai_data(span: &mut Span, ai_model_costs: &ModelCosts) {
176    let duration = span
177        .get_value("span.duration")
178        .and_then(|v| v.as_f64())
179        .unwrap_or(0.0);
180
181    let data = span.data.get_or_insert_with(SpanData::default);
182
183    // Extracts the response tokens per second
184    if data.gen_ai_response_tokens_per_second.value().is_none()
185        && duration > 0.0
186        && let Some(output_tokens) = data
187            .gen_ai_usage_output_tokens
188            .value()
189            .and_then(Value::as_f64)
190    {
191        data.gen_ai_response_tokens_per_second
192            .set_value(Value::F64(output_tokens / (duration / 1000.0)).into());
193    }
194
195    // Extracts the total cost of the AI model used
196    if let Some(model_id) = data
197        .gen_ai_request_model
198        .value()
199        .and_then(|val| val.as_str())
200        .or_else(|| {
201            data.gen_ai_response_model
202                .value()
203                .and_then(|val| val.as_str())
204        })
205    {
206        extract_ai_model_cost_data(ai_model_costs.cost_per_token(model_id), data)
207    }
208}
209
210/// Enrich the AI span data
211pub fn enrich_ai_span_data(
212    span: &mut Span,
213    model_costs: Option<&ModelCosts>,
214    operation_type_map: Option<&AiOperationTypeMap>,
215) {
216    if !is_ai_span(span) {
217        return;
218    }
219
220    map_ai_measurements_to_data(span);
221    set_total_tokens(span);
222
223    if let Some(model_costs) = model_costs {
224        extract_ai_data(span, model_costs);
225    }
226    if let Some(operation_type_map) = operation_type_map {
227        infer_ai_operation_type(span, operation_type_map);
228    }
229}
230
231/// Extract the ai data from all of an event's spans
232pub fn enrich_ai_event_data(
233    event: &mut Event,
234    model_costs: Option<&ModelCosts>,
235    operation_type_map: Option<&AiOperationTypeMap>,
236) {
237    let spans = event.spans.value_mut().iter_mut().flatten();
238    let spans = spans.filter_map(|span| span.value_mut().as_mut());
239
240    for span in spans {
241        enrich_ai_span_data(span, model_costs, operation_type_map);
242    }
243}
244
245///  Infer AI operation type mapping to a span.
246///
247/// This function maps span.op values to gen_ai.operation.type based on the provided
248/// operation type map configuration.
249fn infer_ai_operation_type(span: &mut Span, operation_type_map: &AiOperationTypeMap) {
250    let data = span.data.get_or_insert_with(SpanData::default);
251
252    if let Some(op) = span.op.value()
253        && let Some(operation_type) = operation_type_map.get_operation_type(op)
254    {
255        data.gen_ai_operation_type
256            .set_value(Some(operation_type.to_owned()));
257    }
258}
259
260/// Returns true if the span is an AI span.
261/// AI spans are spans with op starting with "ai." (legacy) or "gen_ai." (new).
262fn is_ai_span(span: &Span) -> bool {
263    span.op
264        .value()
265        .is_some_and(|op| op.starts_with("ai.") || op.starts_with("gen_ai."))
266}
267
268#[cfg(test)]
269mod tests {
270    use super::*;
271
272    #[test]
273    fn test_calculate_cost_no_tokens() {
274        let cost = calculate_costs(
275            &ModelCostV2 {
276                input_per_token: 1.0,
277                output_per_token: 1.0,
278                output_reasoning_per_token: 1.0,
279                input_cached_per_token: 1.0,
280            },
281            UsedTokens::from_span_data(&SpanData::default()),
282        );
283        assert!(cost.is_none());
284    }
285
286    #[test]
287    fn test_calculate_cost_full() {
288        let cost = calculate_costs(
289            &ModelCostV2 {
290                input_per_token: 1.0,
291                output_per_token: 2.0,
292                output_reasoning_per_token: 3.0,
293                input_cached_per_token: 0.5,
294            },
295            UsedTokens {
296                input_tokens: 8.0,
297                input_cached_tokens: 5.0,
298                output_tokens: 15.0,
299                output_reasoning_tokens: 9.0,
300            },
301        )
302        .unwrap();
303
304        insta::assert_debug_snapshot!(cost, @r"
305        CalculatedCost {
306            input: 5.5,
307            output: 39.0,
308        }
309        ");
310    }
311
312    #[test]
313    fn test_calculate_cost_no_reasoning_cost() {
314        let cost = calculate_costs(
315            &ModelCostV2 {
316                input_per_token: 1.0,
317                output_per_token: 2.0,
318                // Should fallback to output token cost for reasoning.
319                output_reasoning_per_token: 0.0,
320                input_cached_per_token: 0.5,
321            },
322            UsedTokens {
323                input_tokens: 8.0,
324                input_cached_tokens: 5.0,
325                output_tokens: 15.0,
326                output_reasoning_tokens: 9.0,
327            },
328        )
329        .unwrap();
330
331        insta::assert_debug_snapshot!(cost, @r"
332        CalculatedCost {
333            input: 5.5,
334            output: 30.0,
335        }
336        ");
337    }
338
339    /// This test shows it is possible to produce negative costs if tokens are not aligned properly.
340    ///
341    /// The behaviour was desired when initially implemented.
342    #[test]
343    fn test_calculate_cost_negative() {
344        let cost = calculate_costs(
345            &ModelCostV2 {
346                input_per_token: 2.0,
347                output_per_token: 2.0,
348                output_reasoning_per_token: 1.0,
349                input_cached_per_token: 1.0,
350            },
351            UsedTokens {
352                input_tokens: 1.0,
353                input_cached_tokens: 11.0,
354                output_tokens: 1.0,
355                output_reasoning_tokens: 9.0,
356            },
357        )
358        .unwrap();
359
360        insta::assert_debug_snapshot!(cost, @r"
361        CalculatedCost {
362            input: -9.0,
363            output: -7.0,
364        }
365        ");
366    }
367}