Skip to main content

relay_event_normalization/normalize/span/
ai.rs

1//! AI cost calculation.
2
3use crate::statsd::{Counters, map_origin_to_integration, platform_tag};
4use crate::{ModelCostV2, ModelMetadata};
5use relay_event_schema::protocol::{
6    Event, Measurements, OperationType, Span, SpanData, TraceContext,
7};
8use relay_protocol::{Annotated, Getter, Value};
9
10/// Amount of used tokens for a model call.
11#[derive(Debug, Copy, Clone)]
12pub struct UsedTokens {
13    /// Total amount of input tokens used.
14    pub input_tokens: f64,
15    /// Amount of cached tokens used.
16    ///
17    /// This is a subset of [`Self::input_tokens`].
18    pub input_cached_tokens: f64,
19    /// Amount of cache write tokens used.
20    ///
21    /// This is a subset of [`Self::input_tokens`].
22    pub input_cache_write_tokens: f64,
23    /// Total amount of output tokens.
24    pub output_tokens: f64,
25    /// Total amount of reasoning tokens.
26    ///
27    /// This is a subset of [`Self::output_tokens`].
28    pub output_reasoning_tokens: f64,
29}
30
31impl UsedTokens {
32    /// Extracts [`UsedTokens`] from [`SpanData`] attributes.
33    pub fn from_span_data(data: &SpanData) -> Self {
34        macro_rules! get_value {
35            ($e:expr) => {
36                $e.value().and_then(Value::as_f64).unwrap_or(0.0)
37            };
38        }
39
40        Self {
41            input_tokens: get_value!(data.gen_ai_usage_input_tokens),
42            output_tokens: get_value!(data.gen_ai_usage_output_tokens),
43            output_reasoning_tokens: get_value!(data.gen_ai_usage_output_tokens_reasoning),
44            input_cached_tokens: get_value!(data.gen_ai_usage_input_tokens_cached),
45            input_cache_write_tokens: get_value!(data.gen_ai_usage_input_tokens_cache_write),
46        }
47    }
48
49    /// Returns `true` if any tokens were used.
50    pub fn has_usage(&self) -> bool {
51        self.input_tokens > 0.0 || self.output_tokens > 0.0
52    }
53
54    /// Calculates the total amount of input tokens billed at the standard rate.
55    ///
56    /// Both [`Self::input_cached_tokens`] and [`Self::input_cache_write_tokens`] are
57    /// subsets of [`Self::input_tokens`] and are billed separately at their own
58    /// (cached / cache-write) rates, so both are subtracted here to avoid charging
59    /// them twice.
60    pub fn raw_input_tokens(&self) -> f64 {
61        self.input_tokens - self.input_cached_tokens - self.input_cache_write_tokens
62    }
63
64    /// Calculates the total amount of raw, non-reasoning output tokens.
65    ///
66    /// Subtracts reasoning tokens from the total token count.
67    pub fn raw_output_tokens(&self) -> f64 {
68        self.output_tokens - self.output_reasoning_tokens
69    }
70}
71
72/// Calculated model call costs.
73#[derive(Debug, Copy, Clone)]
74pub struct CalculatedCost {
75    /// The cost of input tokens used.
76    pub input: f64,
77    /// The cost of output tokens used.
78    pub output: f64,
79}
80
81impl CalculatedCost {
82    /// The total, input and output, cost.
83    pub fn total(&self) -> f64 {
84        self.input + self.output
85    }
86}
87
88/// Calculates the total cost for a model call.
89///
90/// Returns `None` if no tokens were used.
91pub fn calculate_costs(
92    model_cost: &ModelCostV2,
93    tokens: UsedTokens,
94    integration: &str,
95    platform: &str,
96) -> Option<CalculatedCost> {
97    if !tokens.has_usage() {
98        relay_statsd::metric!(
99            counter(Counters::GenAiCostCalculationResult) += 1,
100            result = "calculation_no_tokens",
101            integration = integration,
102            platform = platform,
103        );
104        return None;
105    }
106
107    let input = (tokens.raw_input_tokens() * model_cost.input_per_token)
108        + (tokens.input_cached_tokens * model_cost.input_cached_per_token)
109        + (tokens.input_cache_write_tokens * model_cost.input_cache_write_per_token);
110
111    // For now most of the models do not differentiate between reasoning and output token cost,
112    // it costs the same.
113    let reasoning_cost = match model_cost.output_reasoning_per_token {
114        reasoning_cost if reasoning_cost > 0.0 => reasoning_cost,
115        _ => model_cost.output_per_token,
116    };
117
118    let output = (tokens.raw_output_tokens() * model_cost.output_per_token)
119        + (tokens.output_reasoning_tokens * reasoning_cost);
120
121    let metric_label = match (input, output) {
122        (x, y) if x < 0.0 || y < 0.0 => "calculation_negative",
123        (0.0, 0.0) => "calculation_zero",
124        _ => "calculation_positive",
125    };
126
127    relay_statsd::metric!(
128        counter(Counters::GenAiCostCalculationResult) += 1,
129        result = metric_label,
130        integration = integration,
131        platform = platform,
132    );
133
134    Some(CalculatedCost { input, output })
135}
136
137/// Default AI operation stored in [`GEN_AI__OPERATION__TYPE`](relay_conventions::attributes::GEN_AI__OPERATION__TYPE)
138/// for AI spans without a well known AI span op.
139///
140/// See also: [`infer_ai_operation_type`].
141pub const DEFAULT_AI_OPERATION: &str = "ai_client";
142
143/// Infers the AI operation from an AI operation name.
144///
145/// The operation name is usually inferred from the
146/// [`GEN_AI__OPERATION__NAME`](relay_conventions::attributes::GEN_AI__OPERATION__NAME) span attribute and the span
147/// operation.
148///
149/// Sentry expects the operation type in the
150/// [`GEN_AI__OPERATION__TYPE`](relay_conventions::attributes::GEN_AI__OPERATION__TYPE) attribute.
151///
152/// The function returns `None` when the op is not a well known AI operation, callers likely want to default
153/// the value to [`DEFAULT_AI_OPERATION`] for AI spans.
154pub fn infer_ai_operation_type(op_name: &str) -> Option<&'static str> {
155    let ai_op = match op_name {
156        // Full matches:
157        "ai.run.generateText"
158        | "ai.run.generateObject"
159        | "gen_ai.invoke_agent"
160        | "ai.pipeline.generate_text"
161        | "ai.pipeline.generate_object"
162        | "ai.pipeline.stream_text"
163        | "ai.pipeline.stream_object"
164        | "gen_ai.create_agent"
165        | "invoke_agent"
166        | "create_agent" => "agent",
167        "gen_ai.execute_tool" | "execute_tool" => "tool",
168        "gen_ai.handoff" | "handoff" => "handoff",
169        "ai.processor" | "processor_run" => "other",
170        // Prefix matches:
171        op if op.starts_with("ai.streamText.doStream") => "ai_client",
172        op if op.starts_with("ai.streamText") => "agent",
173
174        op if op.starts_with("ai.generateText.doGenerate") => "ai_client",
175        op if op.starts_with("ai.generateText") => "agent",
176
177        op if op.starts_with("ai.generateObject.doGenerate") => "ai_client",
178        op if op.starts_with("ai.generateObject") => "agent",
179
180        op if op.starts_with("ai.toolCall") => "tool",
181        // No match:
182        _ => return None,
183    };
184
185    Some(ai_op)
186}
187
188/// Calculates the cost of an AI model based on the model cost and the tokens used.
189/// Calculated cost is in US dollars.
190fn extract_ai_model_cost_data(
191    model_cost: Option<&ModelCostV2>,
192    data: &mut SpanData,
193    origin: Option<&str>,
194    platform: Option<&str>,
195) {
196    let integration = map_origin_to_integration(origin);
197    let platform = platform_tag(platform);
198
199    let Some(model_cost) = model_cost else {
200        relay_statsd::metric!(
201            counter(Counters::GenAiCostCalculationResult) += 1,
202            result = "calculation_no_model_cost_available",
203            integration = integration,
204            platform = platform,
205        );
206        return;
207    };
208
209    let used_tokens = UsedTokens::from_span_data(&*data);
210    let Some(costs) = calculate_costs(model_cost, used_tokens, integration, platform) else {
211        return;
212    };
213
214    data.gen_ai_cost_total_tokens
215        .set_value(Value::F64(costs.total()).into());
216
217    // Set individual cost components
218    data.gen_ai_cost_input_tokens
219        .set_value(Value::F64(costs.input).into());
220    data.gen_ai_cost_output_tokens
221        .set_value(Value::F64(costs.output).into());
222}
223
224/// Maps AI-related measurements (legacy) to span data.
225fn map_ai_measurements_to_data(data: &mut SpanData, measurements: Option<&Measurements>) {
226    let set_field_from_measurement = |target_field: &mut Annotated<Value>,
227                                      measurement_key: &str| {
228        if let Some(measurements) = measurements
229            && target_field.value().is_none()
230            && let Some(value) = measurements.get_value(measurement_key)
231        {
232            target_field.set_value(Value::F64(value.to_f64()).into());
233        }
234    };
235
236    set_field_from_measurement(&mut data.gen_ai_usage_total_tokens, "ai_total_tokens_used");
237    set_field_from_measurement(&mut data.gen_ai_usage_input_tokens, "ai_prompt_tokens_used");
238    set_field_from_measurement(
239        &mut data.gen_ai_usage_output_tokens,
240        "ai_completion_tokens_used",
241    );
242}
243
244fn set_total_tokens(data: &mut SpanData) {
245    // It might be that 'total_tokens' is not set in which case we need to calculate it
246    if data.gen_ai_usage_total_tokens.value().is_none() {
247        let input_tokens = data
248            .gen_ai_usage_input_tokens
249            .value()
250            .and_then(Value::as_f64);
251        let output_tokens = data
252            .gen_ai_usage_output_tokens
253            .value()
254            .and_then(Value::as_f64);
255
256        if input_tokens.is_none() && output_tokens.is_none() {
257            // don't set total_tokens if there are no input nor output tokens
258            return;
259        }
260
261        data.gen_ai_usage_total_tokens.set_value(
262            Value::F64(input_tokens.unwrap_or(0.0) + output_tokens.unwrap_or(0.0)).into(),
263        );
264    }
265}
266
267/// Sets the context window size and utilization for the model.
268fn extract_context_utilization(data: &mut SpanData, model_metadata: &ModelMetadata) {
269    let model_id = data
270        .gen_ai_response_model
271        .value()
272        .and_then(|val| val.as_str());
273
274    let context_size = model_id.and_then(|id| model_metadata.context_size(id));
275
276    let Some(context_size) = context_size else {
277        return;
278    };
279
280    data.gen_ai_context_window_size
281        .set_value(Value::U64(context_size).into());
282
283    let total_tokens = data
284        .gen_ai_usage_total_tokens
285        .value()
286        .and_then(Value::as_f64);
287
288    if let Some(total_tokens) = total_tokens {
289        data.gen_ai_context_utilization
290            .set_value(Value::F64(total_tokens / context_size as f64).into());
291    }
292}
293
294/// Extract the additional data into the span
295fn extract_ai_data(
296    data: &mut SpanData,
297    duration: f64,
298    model_metadata: &ModelMetadata,
299    origin: Option<&str>,
300    platform: Option<&str>,
301) {
302    // Extracts the response tokens per second
303    if data.gen_ai_response_tokens_per_second.value().is_none()
304        && duration > 0.0
305        && let Some(output_tokens) = data
306            .gen_ai_usage_output_tokens
307            .value()
308            .and_then(Value::as_f64)
309    {
310        data.gen_ai_response_tokens_per_second
311            .set_value(Value::F64(output_tokens / (duration / 1000.0)).into());
312    }
313
314    extract_context_utilization(data, model_metadata);
315
316    // Extracts the total cost of the AI model used
317    if let Some(model_id) = data
318        .gen_ai_response_model
319        .value()
320        .and_then(|val| val.as_str())
321    {
322        extract_ai_model_cost_data(
323            model_metadata.cost_per_token(model_id),
324            data,
325            origin,
326            platform,
327        )
328    } else {
329        relay_statsd::metric!(
330            counter(Counters::GenAiCostCalculationResult) += 1,
331            result = "calculation_no_model_id_available",
332            integration = map_origin_to_integration(origin),
333            platform = platform_tag(platform),
334        );
335    }
336}
337
338/// Enrich the AI span data
339fn enrich_ai_span_data(
340    span_data: &mut Annotated<SpanData>,
341    span_op: &Annotated<OperationType>,
342    measurements: &Annotated<Measurements>,
343    duration: f64,
344    model_metadata: Option<&ModelMetadata>,
345    origin: Option<&str>,
346    platform: Option<&str>,
347) {
348    if !is_ai_span(span_data, span_op.value()) {
349        return;
350    }
351
352    let data = span_data.get_or_insert_with(SpanData::default);
353
354    map_ai_measurements_to_data(data, measurements.value());
355
356    set_total_tokens(data);
357
358    // Default response model to request model if not set.
359    if data.gen_ai_response_model.value().is_none()
360        && let Some(request_model) = data.gen_ai_request_model.value().cloned()
361    {
362        data.gen_ai_response_model.set_value(Some(request_model));
363    }
364
365    // Default agent name to function_id if not set.
366    if data.gen_ai_agent_name.value().is_none()
367        && let Some(function_id) = data.gen_ai_function_id.value().cloned()
368    {
369        data.gen_ai_agent_name.set_value(Some(function_id));
370    }
371
372    if let Some(model_metadata) = model_metadata {
373        extract_ai_data(data, duration, model_metadata, origin, platform);
374    } else {
375        relay_statsd::metric!(
376            counter(Counters::GenAiCostCalculationResult) += 1,
377            result = "calculation_no_model_cost_available",
378            integration = map_origin_to_integration(origin),
379            platform = platform_tag(platform),
380        );
381    }
382
383    let ai_op_type = data
384        .gen_ai_operation_name
385        .value()
386        .or(span_op.value())
387        .and_then(|op| infer_ai_operation_type(op))
388        .unwrap_or(DEFAULT_AI_OPERATION);
389
390    data.gen_ai_operation_type
391        .set_value(Some(ai_op_type.to_owned()));
392}
393
394/// Enrich the AI span data
395pub fn enrich_ai_span(span: &mut Span, model_metadata: Option<&ModelMetadata>) {
396    let duration = span
397        .get_value("span.duration")
398        .and_then(|v| v.as_f64())
399        .unwrap_or(0.0);
400
401    enrich_ai_span_data(
402        &mut span.data,
403        &span.op,
404        &span.measurements,
405        duration,
406        model_metadata,
407        span.origin.as_str(),
408        span.platform.as_str(),
409    );
410}
411
412/// Extract the ai data from all of an event's spans
413pub fn enrich_ai_event_data(event: &mut Event, model_metadata: Option<&ModelMetadata>) {
414    let event_duration = event
415        .get_value("event.duration")
416        .and_then(|v| v.as_f64())
417        .unwrap_or(0.0);
418
419    if let Some(trace_context) = event
420        .contexts
421        .value_mut()
422        .as_mut()
423        .and_then(|c| c.get_mut::<TraceContext>())
424    {
425        enrich_ai_span_data(
426            &mut trace_context.data,
427            &trace_context.op,
428            &event.measurements,
429            event_duration,
430            model_metadata,
431            trace_context.origin.as_str(),
432            event.platform.as_str(),
433        );
434    }
435    let spans = event.spans.value_mut().iter_mut().flatten();
436    let spans = spans.filter_map(|span| span.value_mut().as_mut());
437
438    for span in spans {
439        let span_duration = span
440            .get_value("span.duration")
441            .and_then(|v| v.as_f64())
442            .unwrap_or(0.0);
443        let span_platform = span.platform.as_str().or_else(|| event.platform.as_str());
444
445        enrich_ai_span_data(
446            &mut span.data,
447            &span.op,
448            &span.measurements,
449            span_duration,
450            model_metadata,
451            span.origin.as_str(),
452            span_platform,
453        );
454    }
455}
456
457/// Returns true if the span is an AI span.
458/// AI spans are spans with either a gen_ai.operation.name attribute or op starting with "ai."
459/// (legacy) or "gen_ai." (new).
460fn is_ai_span(span_data: &Annotated<SpanData>, span_op: Option<&OperationType>) -> bool {
461    let has_ai_op = span_data
462        .value()
463        .and_then(|data| data.gen_ai_operation_name.value())
464        .is_some();
465
466    let is_ai_span_op =
467        span_op.is_some_and(|op| op.starts_with("ai.") || op.starts_with("gen_ai."));
468
469    has_ai_op || is_ai_span_op
470}
471
472#[cfg(test)]
473mod tests {
474    use std::collections::HashMap;
475
476    use relay_pattern::Pattern;
477    use relay_protocol::{FromValue, assert_annotated_snapshot};
478    use serde_json::json;
479
480    use super::*;
481    use crate::ModelMetadataEntry;
482
483    fn ai_span_with_data(data: serde_json::Value) -> Span {
484        Span {
485            op: "gen_ai.test".to_owned().into(),
486            data: SpanData::from_value(data.into()),
487            ..Default::default()
488        }
489    }
490
491    #[test]
492    fn test_calculate_cost_no_tokens() {
493        let cost = calculate_costs(
494            &ModelCostV2 {
495                input_per_token: 1.0,
496                output_per_token: 1.0,
497                output_reasoning_per_token: 1.0,
498                input_cached_per_token: 1.0,
499                input_cache_write_per_token: 1.0,
500            },
501            UsedTokens::from_span_data(&SpanData::default()),
502            "test",
503            "test",
504        );
505        assert!(cost.is_none());
506    }
507
508    #[test]
509    fn test_calculate_cost_full() {
510        let cost = calculate_costs(
511            &ModelCostV2 {
512                input_per_token: 1.0,
513                output_per_token: 2.0,
514                output_reasoning_per_token: 3.0,
515                input_cached_per_token: 0.5,
516                input_cache_write_per_token: 0.75,
517            },
518            UsedTokens {
519                input_tokens: 8.0,
520                input_cached_tokens: 5.0,
521                input_cache_write_tokens: 0.0,
522                output_tokens: 15.0,
523                output_reasoning_tokens: 9.0,
524            },
525            "test",
526            "test",
527        )
528        .unwrap();
529
530        insta::assert_debug_snapshot!(cost, @r"
531        CalculatedCost {
532            input: 5.5,
533            output: 39.0,
534        }
535        ");
536    }
537
538    #[test]
539    fn test_calculate_cost_no_reasoning_cost() {
540        let cost = calculate_costs(
541            &ModelCostV2 {
542                input_per_token: 1.0,
543                output_per_token: 2.0,
544                // Should fallback to output token cost for reasoning.
545                output_reasoning_per_token: 0.0,
546                input_cached_per_token: 0.5,
547                input_cache_write_per_token: 0.0,
548            },
549            UsedTokens {
550                input_tokens: 8.0,
551                input_cached_tokens: 5.0,
552                input_cache_write_tokens: 0.0,
553                output_tokens: 15.0,
554                output_reasoning_tokens: 9.0,
555            },
556            "test",
557            "test",
558        )
559        .unwrap();
560
561        insta::assert_debug_snapshot!(cost, @r"
562        CalculatedCost {
563            input: 5.5,
564            output: 30.0,
565        }
566        ");
567    }
568
569    /// This test shows it is possible to produce negative costs if tokens are not aligned properly.
570    ///
571    /// The behaviour was desired when initially implemented.
572    #[test]
573    fn test_calculate_cost_negative() {
574        let cost = calculate_costs(
575            &ModelCostV2 {
576                input_per_token: 2.0,
577                output_per_token: 2.0,
578                output_reasoning_per_token: 1.0,
579                input_cached_per_token: 1.0,
580                input_cache_write_per_token: 1.5,
581            },
582            UsedTokens {
583                input_tokens: 1.0,
584                input_cached_tokens: 11.0,
585                input_cache_write_tokens: 0.0,
586                output_tokens: 1.0,
587                output_reasoning_tokens: 9.0,
588            },
589            "test",
590            "test",
591        )
592        .unwrap();
593
594        insta::assert_debug_snapshot!(cost, @r"
595        CalculatedCost {
596            input: -9.0,
597            output: -7.0,
598        }
599        ");
600    }
601
602    #[test]
603    fn test_calculate_cost_with_cache_writes() {
604        let cost = calculate_costs(
605            &ModelCostV2 {
606                input_per_token: 1.0,
607                output_per_token: 2.0,
608                output_reasoning_per_token: 3.0,
609                input_cached_per_token: 0.5,
610                input_cache_write_per_token: 0.75,
611            },
612            UsedTokens {
613                input_tokens: 100.0,
614                input_cached_tokens: 20.0,
615                input_cache_write_tokens: 30.0,
616                output_tokens: 50.0,
617                output_reasoning_tokens: 10.0,
618            },
619            "test",
620            "test",
621        )
622        .unwrap();
623
624        // input: (100 - 20 - 30) * 1.0 + 20 * 0.5 + 30 * 0.75 = 50 + 10 + 22.5 = 82.5
625        //   (cache-write tokens are billed once at the cache-write rate, not also at
626        //    the standard input rate). output: 40 * 2.0 + 10 * 3.0 = 110.0
627        insta::assert_debug_snapshot!(cost, @r"
628        CalculatedCost {
629            input: 82.5,
630            output: 110.0,
631        }
632        ");
633    }
634
635    #[test]
636    fn test_calculate_cost_backward_compatibility_no_cache_write() {
637        // Test that cost calculation works when cache_write field is missing (backward compatibility)
638        let span_data = SpanData {
639            gen_ai_usage_input_tokens: Annotated::new(100.0.into()),
640            gen_ai_usage_input_tokens_cached: Annotated::new(20.0.into()),
641            gen_ai_usage_output_tokens: Annotated::new(50.0.into()),
642            // Note: gen_ai_usage_input_tokens_cache_write is NOT set (simulating old data)
643            ..Default::default()
644        };
645
646        let tokens = UsedTokens::from_span_data(&span_data);
647
648        // Verify cache_write_tokens defaults to 0.0
649        assert_eq!(tokens.input_cache_write_tokens, 0.0);
650
651        let cost = calculate_costs(
652            &ModelCostV2 {
653                input_per_token: 1.0,
654                output_per_token: 2.0,
655                output_reasoning_per_token: 0.0,
656                input_cached_per_token: 0.5,
657                input_cache_write_per_token: 0.75,
658            },
659            tokens,
660            "test",
661            "test",
662        )
663        .unwrap();
664
665        // Cost should be calculated without cache_write_tokens
666        // input: (100 - 20) * 1.0 + 20 * 0.5 + 0 * 0.75 = 80 + 10 + 0 = 90
667        // output: 50 * 2.0 = 100
668        insta::assert_debug_snapshot!(cost, @r"
669        CalculatedCost {
670            input: 90.0,
671            output: 100.0,
672        }
673        ");
674    }
675
676    /// Test that the AI operation type is inferred from a gen_ai.operation.name attribute.
677    #[test]
678    fn test_infer_ai_operation_type_from_gen_ai_operation_name() {
679        let mut span = ai_span_with_data(json!({
680            "gen_ai.operation.name": "invoke_agent"
681        }));
682
683        enrich_ai_span(&mut span, None);
684
685        assert_annotated_snapshot!(&span.data, @r#"
686        {
687          "gen_ai.operation.name": "invoke_agent",
688          "gen_ai.operation.type": "agent"
689        }
690        "#);
691    }
692
693    /// Test that the AI operation type is inferred from a span.op attribute.
694    #[test]
695    fn test_infer_ai_operation_type_from_span_op() {
696        let mut span = Span {
697            op: "gen_ai.invoke_agent".to_owned().into(),
698            ..Default::default()
699        };
700
701        enrich_ai_span(&mut span, None);
702
703        assert_annotated_snapshot!(span.data, @r#"
704        {
705          "gen_ai.operation.type": "agent"
706        }
707        "#);
708    }
709
710    /// Test that the AI operation type is inferred from a fallback.
711    #[test]
712    fn test_infer_ai_operation_type_from_fallback() {
713        let mut span = ai_span_with_data(json!({
714            "gen_ai.operation.name": "embeddings"
715        }));
716
717        enrich_ai_span(&mut span, None);
718
719        assert_annotated_snapshot!(&span.data, @r#"
720        {
721          "gen_ai.operation.name": "embeddings",
722          "gen_ai.operation.type": "ai_client"
723        }
724        "#);
725    }
726
727    /// Test that the response model is defaulted to the request model if not set.
728    #[test]
729    fn test_default_response_model_from_request_model() {
730        let mut span = ai_span_with_data(json!({
731            "gen_ai.request.model": "gpt-4",
732        }));
733
734        enrich_ai_span(&mut span, None);
735
736        assert_annotated_snapshot!(&span.data, @r#"
737        {
738          "gen_ai.response.model": "gpt-4",
739          "gen_ai.request.model": "gpt-4",
740          "gen_ai.operation.type": "ai_client"
741        }
742        "#);
743    }
744
745    /// Test that the response model is defaulted to the request model if not set.
746    #[test]
747    fn test_default_response_model_not_overridden() {
748        let mut span = ai_span_with_data(json!({
749            "gen_ai.request.model": "gpt-4",
750            "gen_ai.response.model": "gpt-4-abcd",
751        }));
752
753        enrich_ai_span(&mut span, None);
754
755        assert_annotated_snapshot!(&span.data, @r#"
756        {
757          "gen_ai.response.model": "gpt-4-abcd",
758          "gen_ai.request.model": "gpt-4",
759          "gen_ai.operation.type": "ai_client"
760        }
761        "#);
762    }
763
764    /// Test that gen_ai.agent.name is defaulted from gen_ai.function_id.
765    #[test]
766    fn test_default_agent_name_from_function_id() {
767        let mut span = ai_span_with_data(json!({
768            "gen_ai.function_id": "my-agent",
769        }));
770
771        enrich_ai_span(&mut span, None);
772
773        assert_annotated_snapshot!(&span.data, @r#"
774        {
775          "gen_ai.operation.type": "ai_client",
776          "gen_ai.agent.name": "my-agent",
777          "gen_ai.function_id": "my-agent"
778        }
779        "#);
780    }
781
782    /// Test that gen_ai.agent.name is not overridden when already set.
783    #[test]
784    fn test_default_agent_name_not_overridden() {
785        let mut span = ai_span_with_data(json!({
786            "gen_ai.function_id": "my-function",
787            "gen_ai.agent.name": "my-agent",
788        }));
789
790        enrich_ai_span(&mut span, None);
791
792        assert_annotated_snapshot!(&span.data, @r#"
793        {
794          "gen_ai.operation.type": "ai_client",
795          "gen_ai.agent.name": "my-agent",
796          "gen_ai.function_id": "my-function"
797        }
798        "#);
799    }
800
801    /// Test that an AI span is detected from a gen_ai.operation.name attribute.
802    #[test]
803    fn test_is_ai_span_from_gen_ai_operation_name() {
804        let mut span_data = Annotated::default();
805        span_data
806            .get_or_insert_with(SpanData::default)
807            .gen_ai_operation_name
808            .set_value(Some("chat".into()));
809        assert!(is_ai_span(&span_data, None));
810    }
811
812    /// Test that an AI span is detected from a span.op starting with "ai.".
813    #[test]
814    fn test_is_ai_span_from_span_op_ai() {
815        let span_op: OperationType = "ai.chat".into();
816        assert!(is_ai_span(&Annotated::default(), Some(&span_op)));
817    }
818
819    /// Test that an AI span is detected from a span.op starting with "gen_ai.".
820    #[test]
821    fn test_is_ai_span_from_span_op_gen_ai() {
822        let span_op: OperationType = "gen_ai.chat".into();
823        assert!(is_ai_span(&Annotated::default(), Some(&span_op)));
824    }
825
826    /// Test that a non-AI span is detected.
827    #[test]
828    fn test_is_ai_span_negative() {
829        assert!(!is_ai_span(&Annotated::default(), None));
830    }
831
832    /// Test enrich_ai_event_data with invoke_agent in trace context and a chat child span.
833    #[test]
834    fn test_enrich_ai_event_data_invoke_agent_trace_with_chat_span() {
835        let event_json = r#"{
836            "type": "transaction",
837            "timestamp": 1234567892.0,
838            "start_timestamp": 1234567889.0,
839            "contexts": {
840                "trace": {
841                    "op": "gen_ai.invoke_agent",
842                    "trace_id": "12345678901234567890123456789012",
843                    "span_id": "1234567890123456",
844                    "data": {
845                        "gen_ai.operation.name": "gen_ai.invoke_agent",
846                        "gen_ai.usage.input_tokens": 500,
847                        "gen_ai.usage.output_tokens": 200
848                    }
849                }
850            },
851            "spans": [
852                {
853                    "op": "gen_ai.chat.completions",
854                    "span_id": "1234567890123457",
855                    "start_timestamp": 1234567889.5,
856                    "timestamp": 1234567890.5,
857                    "data": {
858                        "gen_ai.operation.name": "chat",
859                        "gen_ai.usage.input_tokens": 100,
860                        "gen_ai.usage.output_tokens": 50
861                    }
862                }
863            ]
864        }"#;
865
866        let mut annotated_event: Annotated<Event> = Annotated::from_json(event_json).unwrap();
867        let event = annotated_event.value_mut().as_mut().unwrap();
868
869        enrich_ai_event_data(event, None);
870
871        assert_annotated_snapshot!(&annotated_event, @r#"
872        {
873          "type": "transaction",
874          "timestamp": 1234567892.0,
875          "start_timestamp": 1234567889.0,
876          "contexts": {
877            "trace": {
878              "trace_id": "12345678901234567890123456789012",
879              "span_id": "1234567890123456",
880              "op": "gen_ai.invoke_agent",
881              "data": {
882                "gen_ai.usage.total_tokens": 700.0,
883                "gen_ai.usage.input_tokens": 500,
884                "gen_ai.usage.output_tokens": 200,
885                "gen_ai.operation.name": "gen_ai.invoke_agent",
886                "gen_ai.operation.type": "agent"
887              },
888              "type": "trace"
889            }
890          },
891          "spans": [
892            {
893              "timestamp": 1234567890.5,
894              "start_timestamp": 1234567889.5,
895              "op": "gen_ai.chat.completions",
896              "span_id": "1234567890123457",
897              "data": {
898                "gen_ai.usage.total_tokens": 150.0,
899                "gen_ai.usage.input_tokens": 100,
900                "gen_ai.usage.output_tokens": 50,
901                "gen_ai.operation.name": "chat",
902                "gen_ai.operation.type": "ai_client"
903              }
904            }
905          ]
906        }
907        "#);
908    }
909
910    /// Test enrich_ai_event_data with non-AI trace context, invoke_agent parent span, and chat child span.
911    #[test]
912    fn test_enrich_ai_event_data_nested_agent_and_chat_spans() {
913        let event_json = r#"{
914            "type": "transaction",
915            "timestamp": 1234567892.0,
916            "start_timestamp": 1234567889.0,
917            "contexts": {
918                "trace": {
919                    "op": "http.server",
920                    "trace_id": "12345678901234567890123456789012",
921                    "span_id": "1234567890123456"
922                }
923            },
924            "spans": [
925                {
926                    "op": "gen_ai.invoke_agent",
927                    "span_id": "1234567890123457",
928                    "parent_span_id": "1234567890123456",
929                    "start_timestamp": 1234567889.5,
930                    "timestamp": 1234567891.5,
931                    "data": {
932                        "gen_ai.operation.name": "invoke_agent",
933                        "gen_ai.usage.input_tokens": 500,
934                        "gen_ai.usage.output_tokens": 200
935                    }
936                },
937                {
938                    "op": "gen_ai.chat.completions",
939                    "span_id": "1234567890123458",
940                    "parent_span_id": "1234567890123457",
941                    "start_timestamp": 1234567890.0,
942                    "timestamp": 1234567891.0,
943                    "data": {
944                        "gen_ai.operation.name": "chat",
945                        "gen_ai.usage.input_tokens": 100,
946                        "gen_ai.usage.output_tokens": 50
947                    }
948                }
949            ]
950        }"#;
951
952        let mut annotated_event: Annotated<Event> = Annotated::from_json(event_json).unwrap();
953        let event = annotated_event.value_mut().as_mut().unwrap();
954
955        enrich_ai_event_data(event, None);
956
957        assert_annotated_snapshot!(&annotated_event, @r#"
958        {
959          "type": "transaction",
960          "timestamp": 1234567892.0,
961          "start_timestamp": 1234567889.0,
962          "contexts": {
963            "trace": {
964              "trace_id": "12345678901234567890123456789012",
965              "span_id": "1234567890123456",
966              "op": "http.server",
967              "type": "trace"
968            }
969          },
970          "spans": [
971            {
972              "timestamp": 1234567891.5,
973              "start_timestamp": 1234567889.5,
974              "op": "gen_ai.invoke_agent",
975              "span_id": "1234567890123457",
976              "parent_span_id": "1234567890123456",
977              "data": {
978                "gen_ai.usage.total_tokens": 700.0,
979                "gen_ai.usage.input_tokens": 500,
980                "gen_ai.usage.output_tokens": 200,
981                "gen_ai.operation.name": "invoke_agent",
982                "gen_ai.operation.type": "agent"
983              }
984            },
985            {
986              "timestamp": 1234567891.0,
987              "start_timestamp": 1234567890.0,
988              "op": "gen_ai.chat.completions",
989              "span_id": "1234567890123458",
990              "parent_span_id": "1234567890123457",
991              "data": {
992                "gen_ai.usage.total_tokens": 150.0,
993                "gen_ai.usage.input_tokens": 100,
994                "gen_ai.usage.output_tokens": 50,
995                "gen_ai.operation.name": "chat",
996                "gen_ai.operation.type": "ai_client"
997              }
998            }
999          ]
1000        }
1001        "#);
1002    }
1003
1004    /// Test enrich_ai_event_data with legacy measurements and span op for operation type.
1005    #[test]
1006    fn test_enrich_ai_event_data_legacy_measurements_and_span_op() {
1007        let event_json = r#"{
1008            "type": "transaction",
1009            "timestamp": 1234567892.0,
1010            "start_timestamp": 1234567889.0,
1011            "contexts": {
1012                "trace": {
1013                    "op": "http.server",
1014                    "trace_id": "12345678901234567890123456789012",
1015                    "span_id": "1234567890123456"
1016                }
1017            },
1018            "spans": [
1019                {
1020                    "op": "gen_ai.invoke_agent",
1021                    "span_id": "1234567890123457",
1022                    "parent_span_id": "1234567890123456",
1023                    "start_timestamp": 1234567889.5,
1024                    "timestamp": 1234567891.5,
1025                    "measurements": {
1026                        "ai_prompt_tokens_used": {"value": 500.0},
1027                        "ai_completion_tokens_used": {"value": 200.0}
1028                    }
1029                },
1030                {
1031                    "op": "ai.chat_completions.create.langchain.ChatOpenAI",
1032                    "span_id": "1234567890123458",
1033                    "parent_span_id": "1234567890123457",
1034                    "start_timestamp": 1234567890.0,
1035                    "timestamp": 1234567891.0,
1036                    "measurements": {
1037                        "ai_prompt_tokens_used": {"value": 100.0},
1038                        "ai_completion_tokens_used": {"value": 50.0}
1039                    }
1040                }
1041            ]
1042        }"#;
1043
1044        let mut annotated_event: Annotated<Event> = Annotated::from_json(event_json).unwrap();
1045        let event = annotated_event.value_mut().as_mut().unwrap();
1046
1047        enrich_ai_event_data(event, None);
1048
1049        assert_annotated_snapshot!(&annotated_event, @r#"
1050        {
1051          "type": "transaction",
1052          "timestamp": 1234567892.0,
1053          "start_timestamp": 1234567889.0,
1054          "contexts": {
1055            "trace": {
1056              "trace_id": "12345678901234567890123456789012",
1057              "span_id": "1234567890123456",
1058              "op": "http.server",
1059              "type": "trace"
1060            }
1061          },
1062          "spans": [
1063            {
1064              "timestamp": 1234567891.5,
1065              "start_timestamp": 1234567889.5,
1066              "op": "gen_ai.invoke_agent",
1067              "span_id": "1234567890123457",
1068              "parent_span_id": "1234567890123456",
1069              "data": {
1070                "gen_ai.usage.total_tokens": 700.0,
1071                "gen_ai.usage.input_tokens": 500.0,
1072                "gen_ai.usage.output_tokens": 200.0,
1073                "gen_ai.operation.type": "agent"
1074              },
1075              "measurements": {
1076                "ai_completion_tokens_used": {
1077                  "value": 200.0
1078                },
1079                "ai_prompt_tokens_used": {
1080                  "value": 500.0
1081                }
1082              }
1083            },
1084            {
1085              "timestamp": 1234567891.0,
1086              "start_timestamp": 1234567890.0,
1087              "op": "ai.chat_completions.create.langchain.ChatOpenAI",
1088              "span_id": "1234567890123458",
1089              "parent_span_id": "1234567890123457",
1090              "data": {
1091                "gen_ai.usage.total_tokens": 150.0,
1092                "gen_ai.usage.input_tokens": 100.0,
1093                "gen_ai.usage.output_tokens": 50.0,
1094                "gen_ai.operation.type": "ai_client"
1095              },
1096              "measurements": {
1097                "ai_completion_tokens_used": {
1098                  "value": 50.0
1099                },
1100                "ai_prompt_tokens_used": {
1101                  "value": 100.0
1102                }
1103              }
1104            }
1105          ]
1106        }
1107        "#);
1108    }
1109
1110    fn metadata_with_context_size() -> ModelMetadata {
1111        ModelMetadata {
1112            version: 1,
1113            models: HashMap::from([(
1114                Pattern::new("claude-2.1").unwrap(),
1115                ModelMetadataEntry {
1116                    costs: Some(ModelCostV2 {
1117                        input_per_token: 0.01,
1118                        output_per_token: 0.02,
1119                        output_reasoning_per_token: 0.0,
1120                        input_cached_per_token: 0.0,
1121                        input_cache_write_per_token: 0.0,
1122                    }),
1123                    context_size: Some(100_000),
1124                },
1125            )]),
1126        }
1127    }
1128
1129    #[test]
1130    fn test_context_utilization_with_total_tokens() {
1131        let mut span = Span {
1132            op: "gen_ai.test".to_owned().into(),
1133            data: SpanData::from_value(
1134                json!({
1135                    "gen_ai.response.model": "claude-2.1",
1136                    "gen_ai.usage.input_tokens": 30000.0,
1137                    "gen_ai.usage.output_tokens": 12000.0,
1138                    "gen_ai.usage.total_tokens": 42000.0,
1139                })
1140                .into(),
1141            ),
1142            ..Default::default()
1143        };
1144
1145        enrich_ai_span(&mut span, Some(&metadata_with_context_size()));
1146
1147        let data = span.data.value().unwrap();
1148        assert_eq!(
1149            data.gen_ai_context_window_size
1150                .value()
1151                .and_then(Value::as_f64),
1152            Some(100_000.0)
1153        );
1154        assert_eq!(
1155            data.gen_ai_context_utilization
1156                .value()
1157                .and_then(Value::as_f64),
1158            Some(0.42)
1159        );
1160    }
1161
1162    #[test]
1163    fn test_context_utilization_no_context_size() {
1164        let metadata = ModelMetadata {
1165            version: 1,
1166            models: HashMap::from([(
1167                Pattern::new("claude-2.1").unwrap(),
1168                ModelMetadataEntry {
1169                    costs: None,
1170                    context_size: None,
1171                },
1172            )]),
1173        };
1174
1175        let mut span = Span {
1176            op: "gen_ai.test".to_owned().into(),
1177            data: SpanData::from_value(
1178                json!({
1179                    "gen_ai.response.model": "claude-2.1",
1180                    "gen_ai.usage.total_tokens": 1000.0,
1181                })
1182                .into(),
1183            ),
1184            ..Default::default()
1185        };
1186
1187        enrich_ai_span(&mut span, Some(&metadata));
1188
1189        let data = span.data.value().unwrap();
1190        assert!(data.gen_ai_context_window_size.value().is_none());
1191        assert!(data.gen_ai_context_utilization.value().is_none());
1192    }
1193
1194    #[test]
1195    fn test_context_utilization_no_total_tokens() {
1196        let mut span = Span {
1197            op: "gen_ai.test".to_owned().into(),
1198            data: SpanData::from_value(
1199                json!({
1200                    "gen_ai.response.model": "claude-2.1",
1201                })
1202                .into(),
1203            ),
1204            ..Default::default()
1205        };
1206
1207        enrich_ai_span(&mut span, Some(&metadata_with_context_size()));
1208
1209        let data = span.data.value().unwrap();
1210        // window_size should still be set even without tokens.
1211        assert_eq!(
1212            data.gen_ai_context_window_size
1213                .value()
1214                .and_then(Value::as_f64),
1215            Some(100_000.0)
1216        );
1217        // But utilization cannot be computed without total_tokens.
1218        assert!(data.gen_ai_context_utilization.value().is_none());
1219    }
1220
1221    #[test]
1222    fn test_context_utilization_unknown_model() {
1223        let mut span = Span {
1224            op: "gen_ai.test".to_owned().into(),
1225            data: SpanData::from_value(
1226                json!({
1227                    "gen_ai.response.model": "unknown-model",
1228                    "gen_ai.usage.total_tokens": 1000.0,
1229                })
1230                .into(),
1231            ),
1232            ..Default::default()
1233        };
1234
1235        enrich_ai_span(&mut span, Some(&metadata_with_context_size()));
1236
1237        let data = span.data.value().unwrap();
1238        assert!(data.gen_ai_context_window_size.value().is_none());
1239        assert!(data.gen_ai_context_utilization.value().is_none());
1240    }
1241}