relay_event_normalization/normalize/span/
ai.rs

1//! AI cost calculation.
2
3use crate::normalize::AiOperationTypeMap;
4use crate::{ModelCostV2, ModelCosts};
5use relay_event_schema::protocol::{Event, Span, SpanData};
6use relay_protocol::{Annotated, Getter, Value};
7
8/// Amount of used tokens for a model call.
9#[derive(Debug, Copy, Clone)]
10pub struct UsedTokens {
11    /// Total amount of input tokens used.
12    pub input_tokens: f64,
13    /// Amount of cached tokens used.
14    ///
15    /// This is a subset of [`Self::input_tokens`].
16    pub input_cached_tokens: f64,
17    /// Total amount of output tokens.
18    pub output_tokens: f64,
19    /// Total amount of reasoning tokens.
20    ///
21    /// This is a subset of [`Self::output_tokens`].
22    pub output_reasoning_tokens: f64,
23}
24
25impl UsedTokens {
26    /// Extracts [`UsedTokens`] from [`SpanData`] attributes.
27    pub fn from_span_data(data: &SpanData) -> Self {
28        macro_rules! get_value {
29            ($e:expr) => {
30                $e.value().and_then(Value::as_f64).unwrap_or(0.0)
31            };
32        }
33
34        Self {
35            input_tokens: get_value!(data.gen_ai_usage_input_tokens),
36            output_tokens: get_value!(data.gen_ai_usage_output_tokens),
37            output_reasoning_tokens: get_value!(data.gen_ai_usage_output_tokens_reasoning),
38            input_cached_tokens: get_value!(data.gen_ai_usage_input_tokens_cached),
39        }
40    }
41
42    /// Returns `true` if any tokens were used.
43    pub fn has_usage(&self) -> bool {
44        self.input_tokens > 0.0 || self.output_tokens > 0.0
45    }
46
47    /// Calculates the total amount of uncached input tokens.
48    ///
49    /// Subtracts cached tokens from the total token count.
50    pub fn raw_input_tokens(&self) -> f64 {
51        self.input_tokens - self.input_cached_tokens
52    }
53
54    /// Calculates the total amount of raw, non-reasoning output tokens.
55    ///
56    /// Subtracts reasoning tokens from the total token count.
57    pub fn raw_output_tokens(&self) -> f64 {
58        self.output_tokens - self.output_reasoning_tokens
59    }
60}
61
62/// Calculated model call costs.
63#[derive(Debug, Copy, Clone)]
64pub struct CalculatedCost {
65    /// The cost of input tokens used.
66    pub input: f64,
67    /// The cost of output tokens used.
68    pub output: f64,
69}
70
71impl CalculatedCost {
72    /// The total, input and output, cost.
73    pub fn total(&self) -> f64 {
74        self.input + self.output
75    }
76}
77
78/// Calculates the total cost for a model call.
79///
80/// Returns `None` if no tokens were used.
81pub fn calculate_costs(model_cost: &ModelCostV2, tokens: UsedTokens) -> Option<CalculatedCost> {
82    if !tokens.has_usage() {
83        return None;
84    }
85
86    let input = (tokens.raw_input_tokens() * model_cost.input_per_token)
87        + (tokens.input_cached_tokens * model_cost.input_cached_per_token);
88
89    // For now most of the models do not differentiate between reasoning and output token cost,
90    // it costs the same.
91    let reasoning_cost = match model_cost.output_reasoning_per_token {
92        reasoning_cost if reasoning_cost > 0.0 => reasoning_cost,
93        _ => model_cost.output_per_token,
94    };
95
96    let output = (tokens.raw_output_tokens() * model_cost.output_per_token)
97        + (tokens.output_reasoning_tokens * reasoning_cost);
98
99    Some(CalculatedCost { input, output })
100}
101
102/// Calculates the cost of an AI model based on the model cost and the tokens used.
103/// Calculated cost is in US dollars.
104fn extract_ai_model_cost_data(model_cost: Option<&ModelCostV2>, data: &mut SpanData) {
105    let Some(model_cost) = model_cost else { return };
106
107    let used_tokens = UsedTokens::from_span_data(&*data);
108    let Some(costs) = calculate_costs(model_cost, used_tokens) else {
109        return;
110    };
111
112    data.gen_ai_cost_total_tokens
113        .set_value(Value::F64(costs.total()).into());
114
115    // Set individual cost components
116    data.gen_ai_cost_input_tokens
117        .set_value(Value::F64(costs.input).into());
118    data.gen_ai_cost_output_tokens
119        .set_value(Value::F64(costs.output).into());
120}
121
122/// Maps AI-related measurements (legacy) to span data.
123fn map_ai_measurements_to_data(span: &mut Span) {
124    let measurements = span.measurements.value();
125    let data = span.data.get_or_insert_with(SpanData::default);
126
127    let set_field_from_measurement = |target_field: &mut Annotated<Value>,
128                                      measurement_key: &str| {
129        if let Some(measurements) = measurements
130            && target_field.value().is_none()
131            && let Some(value) = measurements.get_value(measurement_key)
132        {
133            target_field.set_value(Value::F64(value.to_f64()).into());
134        }
135    };
136
137    set_field_from_measurement(&mut data.gen_ai_usage_total_tokens, "ai_total_tokens_used");
138    set_field_from_measurement(&mut data.gen_ai_usage_input_tokens, "ai_prompt_tokens_used");
139    set_field_from_measurement(
140        &mut data.gen_ai_usage_output_tokens,
141        "ai_completion_tokens_used",
142    );
143}
144
145fn set_total_tokens(span: &mut Span) {
146    let data = span.data.get_or_insert_with(SpanData::default);
147
148    // It might be that 'total_tokens' is not set in which case we need to calculate it
149    if data.gen_ai_usage_total_tokens.value().is_none() {
150        let input_tokens = data
151            .gen_ai_usage_input_tokens
152            .value()
153            .and_then(Value::as_f64);
154        let output_tokens = data
155            .gen_ai_usage_output_tokens
156            .value()
157            .and_then(Value::as_f64);
158
159        if input_tokens.is_none() && output_tokens.is_none() {
160            // don't set total_tokens if there are no input nor output tokens
161            return;
162        }
163
164        data.gen_ai_usage_total_tokens.set_value(
165            Value::F64(input_tokens.unwrap_or(0.0) + output_tokens.unwrap_or(0.0)).into(),
166        );
167    }
168}
169
170/// Extract the additional data into the span
171fn extract_ai_data(span: &mut Span, ai_model_costs: &ModelCosts) {
172    let duration = span
173        .get_value("span.duration")
174        .and_then(|v| v.as_f64())
175        .unwrap_or(0.0);
176
177    let data = span.data.get_or_insert_with(SpanData::default);
178
179    // Extracts the response tokens per second
180    if data.gen_ai_response_tokens_per_second.value().is_none()
181        && duration > 0.0
182        && let Some(output_tokens) = data
183            .gen_ai_usage_output_tokens
184            .value()
185            .and_then(Value::as_f64)
186    {
187        data.gen_ai_response_tokens_per_second
188            .set_value(Value::F64(output_tokens / (duration / 1000.0)).into());
189    }
190
191    // Extracts the total cost of the AI model used
192    if let Some(model_id) = data
193        .gen_ai_request_model
194        .value()
195        .and_then(|val| val.as_str())
196        .or_else(|| {
197            data.gen_ai_response_model
198                .value()
199                .and_then(|val| val.as_str())
200        })
201    {
202        extract_ai_model_cost_data(ai_model_costs.cost_per_token(model_id), data)
203    }
204}
205
206/// Enrich the AI span data
207pub fn enrich_ai_span_data(
208    span: &mut Span,
209    model_costs: Option<&ModelCosts>,
210    operation_type_map: Option<&AiOperationTypeMap>,
211) {
212    if !is_ai_span(span) {
213        return;
214    }
215
216    map_ai_measurements_to_data(span);
217    set_total_tokens(span);
218
219    if let Some(model_costs) = model_costs {
220        extract_ai_data(span, model_costs);
221    }
222    if let Some(operation_type_map) = operation_type_map {
223        infer_ai_operation_type(span, operation_type_map);
224    }
225}
226
227/// Extract the ai data from all of an event's spans
228pub fn enrich_ai_event_data(
229    event: &mut Event,
230    model_costs: Option<&ModelCosts>,
231    operation_type_map: Option<&AiOperationTypeMap>,
232) {
233    let spans = event.spans.value_mut().iter_mut().flatten();
234    let spans = spans.filter_map(|span| span.value_mut().as_mut());
235
236    for span in spans {
237        enrich_ai_span_data(span, model_costs, operation_type_map);
238    }
239}
240
241///  Infer AI operation type mapping to a span.
242///
243/// This function sets the gen_ai.operation.type attribute based on the value of either
244/// gen_ai.operation.name or span.op based on the provided operation type map configuration.
245fn infer_ai_operation_type(span: &mut Span, operation_type_map: &AiOperationTypeMap) {
246    let data = span.data.get_or_insert_with(SpanData::default);
247    let op_type = data
248        .gen_ai_operation_name
249        .value()
250        .or(span.op.value())
251        .and_then(|op| operation_type_map.get_operation_type(op));
252
253    if let Some(operation_type) = op_type {
254        data.gen_ai_operation_type
255            .set_value(Some(operation_type.to_owned()));
256    }
257}
258
259/// Returns true if the span is an AI span.
260/// AI spans are spans with either a gen_ai.operation.name attribute or op starting with "ai."
261/// (legacy) or "gen_ai." (new).
262fn is_ai_span(span: &Span) -> bool {
263    let has_ai_op = span
264        .data
265        .value()
266        .and_then(|data| data.gen_ai_operation_name.value())
267        .is_some();
268
269    let is_ai_span_op = span
270        .op
271        .value()
272        .is_some_and(|op| op.starts_with("ai.") || op.starts_with("gen_ai."));
273
274    has_ai_op || is_ai_span_op
275}
276
277#[cfg(test)]
278mod tests {
279    use std::collections::HashMap;
280
281    use relay_pattern::Pattern;
282    use relay_protocol::get_value;
283
284    use super::*;
285
286    #[test]
287    fn test_calculate_cost_no_tokens() {
288        let cost = calculate_costs(
289            &ModelCostV2 {
290                input_per_token: 1.0,
291                output_per_token: 1.0,
292                output_reasoning_per_token: 1.0,
293                input_cached_per_token: 1.0,
294            },
295            UsedTokens::from_span_data(&SpanData::default()),
296        );
297        assert!(cost.is_none());
298    }
299
300    #[test]
301    fn test_calculate_cost_full() {
302        let cost = calculate_costs(
303            &ModelCostV2 {
304                input_per_token: 1.0,
305                output_per_token: 2.0,
306                output_reasoning_per_token: 3.0,
307                input_cached_per_token: 0.5,
308            },
309            UsedTokens {
310                input_tokens: 8.0,
311                input_cached_tokens: 5.0,
312                output_tokens: 15.0,
313                output_reasoning_tokens: 9.0,
314            },
315        )
316        .unwrap();
317
318        insta::assert_debug_snapshot!(cost, @r"
319        CalculatedCost {
320            input: 5.5,
321            output: 39.0,
322        }
323        ");
324    }
325
326    #[test]
327    fn test_calculate_cost_no_reasoning_cost() {
328        let cost = calculate_costs(
329            &ModelCostV2 {
330                input_per_token: 1.0,
331                output_per_token: 2.0,
332                // Should fallback to output token cost for reasoning.
333                output_reasoning_per_token: 0.0,
334                input_cached_per_token: 0.5,
335            },
336            UsedTokens {
337                input_tokens: 8.0,
338                input_cached_tokens: 5.0,
339                output_tokens: 15.0,
340                output_reasoning_tokens: 9.0,
341            },
342        )
343        .unwrap();
344
345        insta::assert_debug_snapshot!(cost, @r"
346        CalculatedCost {
347            input: 5.5,
348            output: 30.0,
349        }
350        ");
351    }
352
353    /// This test shows it is possible to produce negative costs if tokens are not aligned properly.
354    ///
355    /// The behaviour was desired when initially implemented.
356    #[test]
357    fn test_calculate_cost_negative() {
358        let cost = calculate_costs(
359            &ModelCostV2 {
360                input_per_token: 2.0,
361                output_per_token: 2.0,
362                output_reasoning_per_token: 1.0,
363                input_cached_per_token: 1.0,
364            },
365            UsedTokens {
366                input_tokens: 1.0,
367                input_cached_tokens: 11.0,
368                output_tokens: 1.0,
369                output_reasoning_tokens: 9.0,
370            },
371        )
372        .unwrap();
373
374        insta::assert_debug_snapshot!(cost, @r"
375        CalculatedCost {
376            input: -9.0,
377            output: -7.0,
378        }
379        ");
380    }
381
382    /// Test that the AI operation type is inferred from a gen_ai.operation.name attribute.
383    #[test]
384    fn test_infer_ai_operation_type_from_gen_ai_operation_name() {
385        let operation_types = HashMap::from([
386            (Pattern::new("*").unwrap(), "ai_client".to_owned()),
387            (Pattern::new("invoke_agent").unwrap(), "agent".to_owned()),
388            (
389                Pattern::new("gen_ai.invoke_agent").unwrap(),
390                "agent".to_owned(),
391            ),
392        ]);
393
394        let operation_type_map = AiOperationTypeMap {
395            version: 1,
396            operation_types,
397        };
398
399        let span = r#"{
400            "data": {
401                "gen_ai.operation.name": "invoke_agent"
402            }
403        }"#;
404        let mut span = Annotated::from_json(span).unwrap();
405        infer_ai_operation_type(span.value_mut().as_mut().unwrap(), &operation_type_map);
406        assert_eq!(
407            get_value!(span.data.gen_ai_operation_type!).as_str(),
408            "agent"
409        );
410    }
411
412    /// Test that the AI operation type is inferred from a span.op attribute.
413    #[test]
414    fn test_infer_ai_operation_type_from_span_op() {
415        let operation_types = HashMap::from([
416            (Pattern::new("*").unwrap(), "ai_client".to_owned()),
417            (Pattern::new("invoke_agent").unwrap(), "agent".to_owned()),
418            (
419                Pattern::new("gen_ai.invoke_agent").unwrap(),
420                "agent".to_owned(),
421            ),
422        ]);
423        let operation_type_map = AiOperationTypeMap {
424            version: 1,
425            operation_types,
426        };
427
428        let span = r#"{
429            "op": "gen_ai.invoke_agent"
430        }"#;
431        let mut span = Annotated::from_json(span).unwrap();
432        infer_ai_operation_type(span.value_mut().as_mut().unwrap(), &operation_type_map);
433        assert_eq!(
434            get_value!(span.data.gen_ai_operation_type!).as_str(),
435            "agent"
436        );
437    }
438
439    /// Test that the AI operation type is inferred from a fallback.
440    #[test]
441    fn test_infer_ai_operation_type_from_fallback() {
442        let operation_types = HashMap::from([
443            (Pattern::new("*").unwrap(), "ai_client".to_owned()),
444            (Pattern::new("invoke_agent").unwrap(), "agent".to_owned()),
445            (
446                Pattern::new("gen_ai.invoke_agent").unwrap(),
447                "agent".to_owned(),
448            ),
449        ]);
450
451        let operation_type_map = AiOperationTypeMap {
452            version: 1,
453            operation_types,
454        };
455
456        let span = r#"{
457            "data": {
458                "gen_ai.operation.name": "embeddings"
459            }
460        }"#;
461        let mut span = Annotated::from_json(span).unwrap();
462        infer_ai_operation_type(span.value_mut().as_mut().unwrap(), &operation_type_map);
463        assert_eq!(
464            get_value!(span.data.gen_ai_operation_type!).as_str(),
465            "ai_client"
466        );
467    }
468
469    /// Test that an AI span is detected from a gen_ai.operation.name attribute.
470    #[test]
471    fn test_is_ai_span_from_gen_ai_operation_name() {
472        let span = r#"{
473            "data": {
474                "gen_ai.operation.name": "chat"
475            }
476        }"#;
477        let span: Span = Annotated::from_json(span).unwrap().into_value().unwrap();
478        assert!(is_ai_span(&span));
479    }
480
481    /// Test that an AI span is detected from a span.op starting with "ai.".
482    #[test]
483    fn test_is_ai_span_from_span_op_ai() {
484        let span = r#"{
485            "op": "ai.chat"
486        }"#;
487        let span: Span = Annotated::from_json(span).unwrap().into_value().unwrap();
488        assert!(is_ai_span(&span));
489    }
490
491    /// Test that an AI span is detected from a span.op starting with "gen_ai.".
492    #[test]
493    fn test_is_ai_span_from_span_op_gen_ai() {
494        let span = r#"{
495            "op": "gen_ai.chat"
496        }"#;
497        let span: Span = Annotated::from_json(span).unwrap().into_value().unwrap();
498        assert!(is_ai_span(&span));
499    }
500
501    /// Test that a non-AI span is detected.
502    #[test]
503    fn test_is_ai_span_negative() {
504        let span = r#"{
505        }"#;
506        let span: Span = Annotated::from_json(span).unwrap().into_value().unwrap();
507        assert!(!is_ai_span(&span));
508    }
509}