relay_event_normalization/normalize/span/
ai.rs

1//! AI cost calculation.
2
3use crate::statsd::{Counters, map_origin_to_integration, platform_tag};
4use crate::{ModelCostV2, ModelCosts};
5use relay_event_schema::protocol::{
6    Event, Measurements, OperationType, Span, SpanData, TraceContext,
7};
8use relay_protocol::{Annotated, Getter, Value};
9
10/// Amount of used tokens for a model call.
11#[derive(Debug, Copy, Clone)]
12pub struct UsedTokens {
13    /// Total amount of input tokens used.
14    pub input_tokens: f64,
15    /// Amount of cached tokens used.
16    ///
17    /// This is a subset of [`Self::input_tokens`].
18    pub input_cached_tokens: f64,
19    /// Amount of cache write tokens used.
20    ///
21    /// This is a subset of [`Self::input_tokens`].
22    pub input_cache_write_tokens: f64,
23    /// Total amount of output tokens.
24    pub output_tokens: f64,
25    /// Total amount of reasoning tokens.
26    ///
27    /// This is a subset of [`Self::output_tokens`].
28    pub output_reasoning_tokens: f64,
29}
30
31impl UsedTokens {
32    /// Extracts [`UsedTokens`] from [`SpanData`] attributes.
33    pub fn from_span_data(data: &SpanData) -> Self {
34        macro_rules! get_value {
35            ($e:expr) => {
36                $e.value().and_then(Value::as_f64).unwrap_or(0.0)
37            };
38        }
39
40        Self {
41            input_tokens: get_value!(data.gen_ai_usage_input_tokens),
42            output_tokens: get_value!(data.gen_ai_usage_output_tokens),
43            output_reasoning_tokens: get_value!(data.gen_ai_usage_output_tokens_reasoning),
44            input_cached_tokens: get_value!(data.gen_ai_usage_input_tokens_cached),
45            input_cache_write_tokens: get_value!(data.gen_ai_usage_input_tokens_cache_write),
46        }
47    }
48
49    /// Returns `true` if any tokens were used.
50    pub fn has_usage(&self) -> bool {
51        self.input_tokens > 0.0 || self.output_tokens > 0.0
52    }
53
54    /// Calculates the total amount of uncached input tokens.
55    ///
56    /// Subtracts cached tokens from the total token count.
57    pub fn raw_input_tokens(&self) -> f64 {
58        self.input_tokens - self.input_cached_tokens
59    }
60
61    /// Calculates the total amount of raw, non-reasoning output tokens.
62    ///
63    /// Subtracts reasoning tokens from the total token count.
64    pub fn raw_output_tokens(&self) -> f64 {
65        self.output_tokens - self.output_reasoning_tokens
66    }
67}
68
69/// Calculated model call costs.
70#[derive(Debug, Copy, Clone)]
71pub struct CalculatedCost {
72    /// The cost of input tokens used.
73    pub input: f64,
74    /// The cost of output tokens used.
75    pub output: f64,
76}
77
78impl CalculatedCost {
79    /// The total, input and output, cost.
80    pub fn total(&self) -> f64 {
81        self.input + self.output
82    }
83}
84
85/// Calculates the total cost for a model call.
86///
87/// Returns `None` if no tokens were used.
88pub fn calculate_costs(
89    model_cost: &ModelCostV2,
90    tokens: UsedTokens,
91    integration: &str,
92    platform: &str,
93) -> Option<CalculatedCost> {
94    if !tokens.has_usage() {
95        relay_statsd::metric!(
96            counter(Counters::GenAiCostCalculationResult) += 1,
97            result = "calculation_no_tokens",
98            integration = integration,
99            platform = platform,
100        );
101        return None;
102    }
103
104    let input = (tokens.raw_input_tokens() * model_cost.input_per_token)
105        + (tokens.input_cached_tokens * model_cost.input_cached_per_token)
106        + (tokens.input_cache_write_tokens * model_cost.input_cache_write_per_token);
107
108    // For now most of the models do not differentiate between reasoning and output token cost,
109    // it costs the same.
110    let reasoning_cost = match model_cost.output_reasoning_per_token {
111        reasoning_cost if reasoning_cost > 0.0 => reasoning_cost,
112        _ => model_cost.output_per_token,
113    };
114
115    let output = (tokens.raw_output_tokens() * model_cost.output_per_token)
116        + (tokens.output_reasoning_tokens * reasoning_cost);
117
118    let metric_label = match (input, output) {
119        (x, y) if x < 0.0 || y < 0.0 => "calculation_negative",
120        (0.0, 0.0) => "calculation_zero",
121        _ => "calculation_positive",
122    };
123
124    relay_statsd::metric!(
125        counter(Counters::GenAiCostCalculationResult) += 1,
126        result = metric_label,
127        integration = integration,
128        platform = platform,
129    );
130
131    Some(CalculatedCost { input, output })
132}
133
134/// Default AI operation stored in [`GEN_AI_OPERATION_TYPE`](relay_conventions::GEN_AI_OPERATION_TYPE)
135/// for AI spans without a well known AI span op.
136///
137/// See also: [`infer_ai_operation_type`].
138pub const DEFAULT_AI_OPERATION: &str = "ai_client";
139
140/// Infers the AI operation from an AI operation name.
141///
142/// The operation name is usually inferred from the
143/// [`GEN_AI_OPERATION_NAME`](relay_conventions::GEN_AI_OPERATION_NAME) span attribute and the span
144/// operation.
145///
146/// Sentry expects the operation type in the
147/// [`GEN_AI_OPERATION_TYPE`](relay_conventions::GEN_AI_OPERATION_TYPE) attribute.
148///
149/// The function returns `None` when the op is not a well known AI operation, callers likely want to default
150/// the value to [`DEFAULT_AI_OPERATION`] for AI spans.
151pub fn infer_ai_operation_type(op_name: &str) -> Option<&'static str> {
152    let ai_op = match op_name {
153        // Full matches:
154        "ai.run.generateText"
155        | "ai.run.generateObject"
156        | "gen_ai.invoke_agent"
157        | "ai.pipeline.generate_text"
158        | "ai.pipeline.generate_object"
159        | "ai.pipeline.stream_text"
160        | "ai.pipeline.stream_object"
161        | "gen_ai.create_agent"
162        | "invoke_agent"
163        | "create_agent" => "agent",
164        "gen_ai.execute_tool" | "execute_tool" => "tool",
165        "gen_ai.handoff" | "handoff" => "handoff",
166        "ai.processor" | "processor_run" => "other",
167        // Prefix matches:
168        op if op.starts_with("ai.streamText.doStream") => "ai_client",
169        op if op.starts_with("ai.streamText") => "agent",
170
171        op if op.starts_with("ai.generateText.doGenerate") => "ai_client",
172        op if op.starts_with("ai.generateText") => "agent",
173
174        op if op.starts_with("ai.generateObject.doGenerate") => "ai_client",
175        op if op.starts_with("ai.generateObject") => "agent",
176
177        op if op.starts_with("ai.toolCall") => "tool",
178        // No match:
179        _ => return None,
180    };
181
182    Some(ai_op)
183}
184
185/// Calculates the cost of an AI model based on the model cost and the tokens used.
186/// Calculated cost is in US dollars.
187fn extract_ai_model_cost_data(
188    model_cost: Option<&ModelCostV2>,
189    data: &mut SpanData,
190    origin: Option<&str>,
191    platform: Option<&str>,
192) {
193    let integration = map_origin_to_integration(origin);
194    let platform = platform_tag(platform);
195
196    let Some(model_cost) = model_cost else {
197        relay_statsd::metric!(
198            counter(Counters::GenAiCostCalculationResult) += 1,
199            result = "calculation_no_model_cost_available",
200            integration = integration,
201            platform = platform,
202        );
203        return;
204    };
205
206    let used_tokens = UsedTokens::from_span_data(&*data);
207    let Some(costs) = calculate_costs(model_cost, used_tokens, integration, platform) else {
208        return;
209    };
210
211    data.gen_ai_cost_total_tokens
212        .set_value(Value::F64(costs.total()).into());
213
214    // Set individual cost components
215    data.gen_ai_cost_input_tokens
216        .set_value(Value::F64(costs.input).into());
217    data.gen_ai_cost_output_tokens
218        .set_value(Value::F64(costs.output).into());
219}
220
221/// Maps AI-related measurements (legacy) to span data.
222fn map_ai_measurements_to_data(data: &mut SpanData, measurements: Option<&Measurements>) {
223    let set_field_from_measurement = |target_field: &mut Annotated<Value>,
224                                      measurement_key: &str| {
225        if let Some(measurements) = measurements
226            && target_field.value().is_none()
227            && let Some(value) = measurements.get_value(measurement_key)
228        {
229            target_field.set_value(Value::F64(value.to_f64()).into());
230        }
231    };
232
233    set_field_from_measurement(&mut data.gen_ai_usage_total_tokens, "ai_total_tokens_used");
234    set_field_from_measurement(&mut data.gen_ai_usage_input_tokens, "ai_prompt_tokens_used");
235    set_field_from_measurement(
236        &mut data.gen_ai_usage_output_tokens,
237        "ai_completion_tokens_used",
238    );
239}
240
241fn set_total_tokens(data: &mut SpanData) {
242    // It might be that 'total_tokens' is not set in which case we need to calculate it
243    if data.gen_ai_usage_total_tokens.value().is_none() {
244        let input_tokens = data
245            .gen_ai_usage_input_tokens
246            .value()
247            .and_then(Value::as_f64);
248        let output_tokens = data
249            .gen_ai_usage_output_tokens
250            .value()
251            .and_then(Value::as_f64);
252
253        if input_tokens.is_none() && output_tokens.is_none() {
254            // don't set total_tokens if there are no input nor output tokens
255            return;
256        }
257
258        data.gen_ai_usage_total_tokens.set_value(
259            Value::F64(input_tokens.unwrap_or(0.0) + output_tokens.unwrap_or(0.0)).into(),
260        );
261    }
262}
263
264/// Extract the additional data into the span
265fn extract_ai_data(
266    data: &mut SpanData,
267    duration: f64,
268    ai_model_costs: &ModelCosts,
269    origin: Option<&str>,
270    platform: Option<&str>,
271) {
272    // Extracts the response tokens per second
273    if data.gen_ai_response_tokens_per_second.value().is_none()
274        && duration > 0.0
275        && let Some(output_tokens) = data
276            .gen_ai_usage_output_tokens
277            .value()
278            .and_then(Value::as_f64)
279    {
280        data.gen_ai_response_tokens_per_second
281            .set_value(Value::F64(output_tokens / (duration / 1000.0)).into());
282    }
283
284    // Extracts the total cost of the AI model used
285    if let Some(model_id) = data
286        .gen_ai_response_model
287        .value()
288        .and_then(|val| val.as_str())
289    {
290        extract_ai_model_cost_data(
291            ai_model_costs.cost_per_token(model_id),
292            data,
293            origin,
294            platform,
295        )
296    } else {
297        relay_statsd::metric!(
298            counter(Counters::GenAiCostCalculationResult) += 1,
299            result = "calculation_no_model_id_available",
300            integration = map_origin_to_integration(origin),
301            platform = platform_tag(platform),
302        );
303    }
304}
305
306/// Enrich the AI span data
307fn enrich_ai_span_data(
308    span_data: &mut Annotated<SpanData>,
309    span_op: &Annotated<OperationType>,
310    measurements: &Annotated<Measurements>,
311    duration: f64,
312    model_costs: Option<&ModelCosts>,
313    origin: Option<&str>,
314    platform: Option<&str>,
315) {
316    if !is_ai_span(span_data, span_op.value()) {
317        return;
318    }
319
320    let data = span_data.get_or_insert_with(SpanData::default);
321
322    map_ai_measurements_to_data(data, measurements.value());
323
324    set_total_tokens(data);
325
326    // Default response model to request model if not set.
327    if data.gen_ai_response_model.value().is_none()
328        && let Some(request_model) = data.gen_ai_request_model.value().cloned()
329    {
330        data.gen_ai_response_model.set_value(Some(request_model));
331    }
332
333    if let Some(model_costs) = model_costs {
334        extract_ai_data(data, duration, model_costs, origin, platform);
335    } else {
336        relay_statsd::metric!(
337            counter(Counters::GenAiCostCalculationResult) += 1,
338            result = "calculation_no_model_cost_available",
339            integration = map_origin_to_integration(origin),
340            platform = platform_tag(platform),
341        );
342    }
343
344    let ai_op_type = data
345        .gen_ai_operation_name
346        .value()
347        .or(span_op.value())
348        .and_then(|op| infer_ai_operation_type(op))
349        .unwrap_or(DEFAULT_AI_OPERATION);
350
351    data.gen_ai_operation_type
352        .set_value(Some(ai_op_type.to_owned()));
353}
354
355/// Enrich the AI span data
356pub fn enrich_ai_span(span: &mut Span, model_costs: Option<&ModelCosts>) {
357    let duration = span
358        .get_value("span.duration")
359        .and_then(|v| v.as_f64())
360        .unwrap_or(0.0);
361
362    enrich_ai_span_data(
363        &mut span.data,
364        &span.op,
365        &span.measurements,
366        duration,
367        model_costs,
368        span.origin.as_str(),
369        span.platform.as_str(),
370    );
371}
372
373/// Extract the ai data from all of an event's spans
374pub fn enrich_ai_event_data(event: &mut Event, model_costs: Option<&ModelCosts>) {
375    let event_duration = event
376        .get_value("event.duration")
377        .and_then(|v| v.as_f64())
378        .unwrap_or(0.0);
379
380    if let Some(trace_context) = event
381        .contexts
382        .value_mut()
383        .as_mut()
384        .and_then(|c| c.get_mut::<TraceContext>())
385    {
386        enrich_ai_span_data(
387            &mut trace_context.data,
388            &trace_context.op,
389            &event.measurements,
390            event_duration,
391            model_costs,
392            trace_context.origin.as_str(),
393            event.platform.as_str(),
394        );
395    }
396    let spans = event.spans.value_mut().iter_mut().flatten();
397    let spans = spans.filter_map(|span| span.value_mut().as_mut());
398
399    for span in spans {
400        let span_duration = span
401            .get_value("span.duration")
402            .and_then(|v| v.as_f64())
403            .unwrap_or(0.0);
404        let span_platform = span.platform.as_str().or_else(|| event.platform.as_str());
405
406        enrich_ai_span_data(
407            &mut span.data,
408            &span.op,
409            &span.measurements,
410            span_duration,
411            model_costs,
412            span.origin.as_str(),
413            span_platform,
414        );
415    }
416}
417
418/// Returns true if the span is an AI span.
419/// AI spans are spans with either a gen_ai.operation.name attribute or op starting with "ai."
420/// (legacy) or "gen_ai." (new).
421fn is_ai_span(span_data: &Annotated<SpanData>, span_op: Option<&OperationType>) -> bool {
422    let has_ai_op = span_data
423        .value()
424        .and_then(|data| data.gen_ai_operation_name.value())
425        .is_some();
426
427    let is_ai_span_op =
428        span_op.is_some_and(|op| op.starts_with("ai.") || op.starts_with("gen_ai."));
429
430    has_ai_op || is_ai_span_op
431}
432
433#[cfg(test)]
434mod tests {
435    use relay_protocol::{FromValue, assert_annotated_snapshot};
436    use serde_json::json;
437
438    use super::*;
439
440    fn ai_span_with_data(data: serde_json::Value) -> Span {
441        Span {
442            op: "gen_ai.test".to_owned().into(),
443            data: SpanData::from_value(data.into()),
444            ..Default::default()
445        }
446    }
447
448    #[test]
449    fn test_calculate_cost_no_tokens() {
450        let cost = calculate_costs(
451            &ModelCostV2 {
452                input_per_token: 1.0,
453                output_per_token: 1.0,
454                output_reasoning_per_token: 1.0,
455                input_cached_per_token: 1.0,
456                input_cache_write_per_token: 1.0,
457            },
458            UsedTokens::from_span_data(&SpanData::default()),
459            "test",
460            "test",
461        );
462        assert!(cost.is_none());
463    }
464
465    #[test]
466    fn test_calculate_cost_full() {
467        let cost = calculate_costs(
468            &ModelCostV2 {
469                input_per_token: 1.0,
470                output_per_token: 2.0,
471                output_reasoning_per_token: 3.0,
472                input_cached_per_token: 0.5,
473                input_cache_write_per_token: 0.75,
474            },
475            UsedTokens {
476                input_tokens: 8.0,
477                input_cached_tokens: 5.0,
478                input_cache_write_tokens: 0.0,
479                output_tokens: 15.0,
480                output_reasoning_tokens: 9.0,
481            },
482            "test",
483            "test",
484        )
485        .unwrap();
486
487        insta::assert_debug_snapshot!(cost, @r"
488        CalculatedCost {
489            input: 5.5,
490            output: 39.0,
491        }
492        ");
493    }
494
495    #[test]
496    fn test_calculate_cost_no_reasoning_cost() {
497        let cost = calculate_costs(
498            &ModelCostV2 {
499                input_per_token: 1.0,
500                output_per_token: 2.0,
501                // Should fallback to output token cost for reasoning.
502                output_reasoning_per_token: 0.0,
503                input_cached_per_token: 0.5,
504                input_cache_write_per_token: 0.0,
505            },
506            UsedTokens {
507                input_tokens: 8.0,
508                input_cached_tokens: 5.0,
509                input_cache_write_tokens: 0.0,
510                output_tokens: 15.0,
511                output_reasoning_tokens: 9.0,
512            },
513            "test",
514            "test",
515        )
516        .unwrap();
517
518        insta::assert_debug_snapshot!(cost, @r"
519        CalculatedCost {
520            input: 5.5,
521            output: 30.0,
522        }
523        ");
524    }
525
526    /// This test shows it is possible to produce negative costs if tokens are not aligned properly.
527    ///
528    /// The behaviour was desired when initially implemented.
529    #[test]
530    fn test_calculate_cost_negative() {
531        let cost = calculate_costs(
532            &ModelCostV2 {
533                input_per_token: 2.0,
534                output_per_token: 2.0,
535                output_reasoning_per_token: 1.0,
536                input_cached_per_token: 1.0,
537                input_cache_write_per_token: 1.5,
538            },
539            UsedTokens {
540                input_tokens: 1.0,
541                input_cached_tokens: 11.0,
542                input_cache_write_tokens: 0.0,
543                output_tokens: 1.0,
544                output_reasoning_tokens: 9.0,
545            },
546            "test",
547            "test",
548        )
549        .unwrap();
550
551        insta::assert_debug_snapshot!(cost, @r"
552        CalculatedCost {
553            input: -9.0,
554            output: -7.0,
555        }
556        ");
557    }
558
559    #[test]
560    fn test_calculate_cost_with_cache_writes() {
561        let cost = calculate_costs(
562            &ModelCostV2 {
563                input_per_token: 1.0,
564                output_per_token: 2.0,
565                output_reasoning_per_token: 3.0,
566                input_cached_per_token: 0.5,
567                input_cache_write_per_token: 0.75,
568            },
569            UsedTokens {
570                input_tokens: 100.0,
571                input_cached_tokens: 20.0,
572                input_cache_write_tokens: 30.0,
573                output_tokens: 50.0,
574                output_reasoning_tokens: 10.0,
575            },
576            "test",
577            "test",
578        )
579        .unwrap();
580
581        insta::assert_debug_snapshot!(cost, @r"
582        CalculatedCost {
583            input: 112.5,
584            output: 110.0,
585        }
586        ");
587    }
588
589    #[test]
590    fn test_calculate_cost_backward_compatibility_no_cache_write() {
591        // Test that cost calculation works when cache_write field is missing (backward compatibility)
592        let span_data = SpanData {
593            gen_ai_usage_input_tokens: Annotated::new(100.0.into()),
594            gen_ai_usage_input_tokens_cached: Annotated::new(20.0.into()),
595            gen_ai_usage_output_tokens: Annotated::new(50.0.into()),
596            // Note: gen_ai_usage_input_tokens_cache_write is NOT set (simulating old data)
597            ..Default::default()
598        };
599
600        let tokens = UsedTokens::from_span_data(&span_data);
601
602        // Verify cache_write_tokens defaults to 0.0
603        assert_eq!(tokens.input_cache_write_tokens, 0.0);
604
605        let cost = calculate_costs(
606            &ModelCostV2 {
607                input_per_token: 1.0,
608                output_per_token: 2.0,
609                output_reasoning_per_token: 0.0,
610                input_cached_per_token: 0.5,
611                input_cache_write_per_token: 0.75,
612            },
613            tokens,
614            "test",
615            "test",
616        )
617        .unwrap();
618
619        // Cost should be calculated without cache_write_tokens
620        // input: (100 - 20) * 1.0 + 20 * 0.5 + 0 * 0.75 = 80 + 10 + 0 = 90
621        // output: 50 * 2.0 = 100
622        insta::assert_debug_snapshot!(cost, @r"
623        CalculatedCost {
624            input: 90.0,
625            output: 100.0,
626        }
627        ");
628    }
629
630    /// Test that the AI operation type is inferred from a gen_ai.operation.name attribute.
631    #[test]
632    fn test_infer_ai_operation_type_from_gen_ai_operation_name() {
633        let mut span = ai_span_with_data(json!({
634            "gen_ai.operation.name": "invoke_agent"
635        }));
636
637        enrich_ai_span(&mut span, None);
638
639        assert_annotated_snapshot!(&span.data, @r#"
640        {
641          "gen_ai.operation.name": "invoke_agent",
642          "gen_ai.operation.type": "agent"
643        }
644        "#);
645    }
646
647    /// Test that the AI operation type is inferred from a span.op attribute.
648    #[test]
649    fn test_infer_ai_operation_type_from_span_op() {
650        let mut span = Span {
651            op: "gen_ai.invoke_agent".to_owned().into(),
652            ..Default::default()
653        };
654
655        enrich_ai_span(&mut span, None);
656
657        assert_annotated_snapshot!(span.data, @r#"
658        {
659          "gen_ai.operation.type": "agent"
660        }
661        "#);
662    }
663
664    /// Test that the AI operation type is inferred from a fallback.
665    #[test]
666    fn test_infer_ai_operation_type_from_fallback() {
667        let mut span = ai_span_with_data(json!({
668            "gen_ai.operation.name": "embeddings"
669        }));
670
671        enrich_ai_span(&mut span, None);
672
673        assert_annotated_snapshot!(&span.data, @r#"
674        {
675          "gen_ai.operation.name": "embeddings",
676          "gen_ai.operation.type": "ai_client"
677        }
678        "#);
679    }
680
681    /// Test that the response model is defaulted to the request model if not set.
682    #[test]
683    fn test_default_response_model_from_request_model() {
684        let mut span = ai_span_with_data(json!({
685            "gen_ai.request.model": "gpt-4",
686        }));
687
688        enrich_ai_span(&mut span, None);
689
690        assert_annotated_snapshot!(&span.data, @r#"
691        {
692          "gen_ai.response.model": "gpt-4",
693          "gen_ai.request.model": "gpt-4",
694          "gen_ai.operation.type": "ai_client"
695        }
696        "#);
697    }
698
699    /// Test that the response model is defaulted to the request model if not set.
700    #[test]
701    fn test_default_response_model_not_overridden() {
702        let mut span = ai_span_with_data(json!({
703            "gen_ai.request.model": "gpt-4",
704            "gen_ai.response.model": "gpt-4-abcd",
705        }));
706
707        enrich_ai_span(&mut span, None);
708
709        assert_annotated_snapshot!(&span.data, @r#"
710        {
711          "gen_ai.response.model": "gpt-4-abcd",
712          "gen_ai.request.model": "gpt-4",
713          "gen_ai.operation.type": "ai_client"
714        }
715        "#);
716    }
717
718    /// Test that an AI span is detected from a gen_ai.operation.name attribute.
719    #[test]
720    fn test_is_ai_span_from_gen_ai_operation_name() {
721        let mut span_data = Annotated::default();
722        span_data
723            .get_or_insert_with(SpanData::default)
724            .gen_ai_operation_name
725            .set_value(Some("chat".into()));
726        assert!(is_ai_span(&span_data, None));
727    }
728
729    /// Test that an AI span is detected from a span.op starting with "ai.".
730    #[test]
731    fn test_is_ai_span_from_span_op_ai() {
732        let span_op: OperationType = "ai.chat".into();
733        assert!(is_ai_span(&Annotated::default(), Some(&span_op)));
734    }
735
736    /// Test that an AI span is detected from a span.op starting with "gen_ai.".
737    #[test]
738    fn test_is_ai_span_from_span_op_gen_ai() {
739        let span_op: OperationType = "gen_ai.chat".into();
740        assert!(is_ai_span(&Annotated::default(), Some(&span_op)));
741    }
742
743    /// Test that a non-AI span is detected.
744    #[test]
745    fn test_is_ai_span_negative() {
746        assert!(!is_ai_span(&Annotated::default(), None));
747    }
748
749    /// Test enrich_ai_event_data with invoke_agent in trace context and a chat child span.
750    #[test]
751    fn test_enrich_ai_event_data_invoke_agent_trace_with_chat_span() {
752        let event_json = r#"{
753            "type": "transaction",
754            "timestamp": 1234567892.0,
755            "start_timestamp": 1234567889.0,
756            "contexts": {
757                "trace": {
758                    "op": "gen_ai.invoke_agent",
759                    "trace_id": "12345678901234567890123456789012",
760                    "span_id": "1234567890123456",
761                    "data": {
762                        "gen_ai.operation.name": "gen_ai.invoke_agent",
763                        "gen_ai.usage.input_tokens": 500,
764                        "gen_ai.usage.output_tokens": 200
765                    }
766                }
767            },
768            "spans": [
769                {
770                    "op": "gen_ai.chat.completions",
771                    "span_id": "1234567890123457",
772                    "start_timestamp": 1234567889.5,
773                    "timestamp": 1234567890.5,
774                    "data": {
775                        "gen_ai.operation.name": "chat",
776                        "gen_ai.usage.input_tokens": 100,
777                        "gen_ai.usage.output_tokens": 50
778                    }
779                }
780            ]
781        }"#;
782
783        let mut annotated_event: Annotated<Event> = Annotated::from_json(event_json).unwrap();
784        let event = annotated_event.value_mut().as_mut().unwrap();
785
786        enrich_ai_event_data(event, None);
787
788        assert_annotated_snapshot!(&annotated_event, @r#"
789        {
790          "type": "transaction",
791          "timestamp": 1234567892.0,
792          "start_timestamp": 1234567889.0,
793          "contexts": {
794            "trace": {
795              "trace_id": "12345678901234567890123456789012",
796              "span_id": "1234567890123456",
797              "op": "gen_ai.invoke_agent",
798              "data": {
799                "gen_ai.usage.total_tokens": 700.0,
800                "gen_ai.usage.input_tokens": 500,
801                "gen_ai.usage.output_tokens": 200,
802                "gen_ai.operation.name": "gen_ai.invoke_agent",
803                "gen_ai.operation.type": "agent"
804              },
805              "type": "trace"
806            }
807          },
808          "spans": [
809            {
810              "timestamp": 1234567890.5,
811              "start_timestamp": 1234567889.5,
812              "op": "gen_ai.chat.completions",
813              "span_id": "1234567890123457",
814              "data": {
815                "gen_ai.usage.total_tokens": 150.0,
816                "gen_ai.usage.input_tokens": 100,
817                "gen_ai.usage.output_tokens": 50,
818                "gen_ai.operation.name": "chat",
819                "gen_ai.operation.type": "ai_client"
820              }
821            }
822          ]
823        }
824        "#);
825    }
826
827    /// Test enrich_ai_event_data with non-AI trace context, invoke_agent parent span, and chat child span.
828    #[test]
829    fn test_enrich_ai_event_data_nested_agent_and_chat_spans() {
830        let event_json = r#"{
831            "type": "transaction",
832            "timestamp": 1234567892.0,
833            "start_timestamp": 1234567889.0,
834            "contexts": {
835                "trace": {
836                    "op": "http.server",
837                    "trace_id": "12345678901234567890123456789012",
838                    "span_id": "1234567890123456"
839                }
840            },
841            "spans": [
842                {
843                    "op": "gen_ai.invoke_agent",
844                    "span_id": "1234567890123457",
845                    "parent_span_id": "1234567890123456",
846                    "start_timestamp": 1234567889.5,
847                    "timestamp": 1234567891.5,
848                    "data": {
849                        "gen_ai.operation.name": "invoke_agent",
850                        "gen_ai.usage.input_tokens": 500,
851                        "gen_ai.usage.output_tokens": 200
852                    }
853                },
854                {
855                    "op": "gen_ai.chat.completions",
856                    "span_id": "1234567890123458",
857                    "parent_span_id": "1234567890123457",
858                    "start_timestamp": 1234567890.0,
859                    "timestamp": 1234567891.0,
860                    "data": {
861                        "gen_ai.operation.name": "chat",
862                        "gen_ai.usage.input_tokens": 100,
863                        "gen_ai.usage.output_tokens": 50
864                    }
865                }
866            ]
867        }"#;
868
869        let mut annotated_event: Annotated<Event> = Annotated::from_json(event_json).unwrap();
870        let event = annotated_event.value_mut().as_mut().unwrap();
871
872        enrich_ai_event_data(event, None);
873
874        assert_annotated_snapshot!(&annotated_event, @r#"
875        {
876          "type": "transaction",
877          "timestamp": 1234567892.0,
878          "start_timestamp": 1234567889.0,
879          "contexts": {
880            "trace": {
881              "trace_id": "12345678901234567890123456789012",
882              "span_id": "1234567890123456",
883              "op": "http.server",
884              "type": "trace"
885            }
886          },
887          "spans": [
888            {
889              "timestamp": 1234567891.5,
890              "start_timestamp": 1234567889.5,
891              "op": "gen_ai.invoke_agent",
892              "span_id": "1234567890123457",
893              "parent_span_id": "1234567890123456",
894              "data": {
895                "gen_ai.usage.total_tokens": 700.0,
896                "gen_ai.usage.input_tokens": 500,
897                "gen_ai.usage.output_tokens": 200,
898                "gen_ai.operation.name": "invoke_agent",
899                "gen_ai.operation.type": "agent"
900              }
901            },
902            {
903              "timestamp": 1234567891.0,
904              "start_timestamp": 1234567890.0,
905              "op": "gen_ai.chat.completions",
906              "span_id": "1234567890123458",
907              "parent_span_id": "1234567890123457",
908              "data": {
909                "gen_ai.usage.total_tokens": 150.0,
910                "gen_ai.usage.input_tokens": 100,
911                "gen_ai.usage.output_tokens": 50,
912                "gen_ai.operation.name": "chat",
913                "gen_ai.operation.type": "ai_client"
914              }
915            }
916          ]
917        }
918        "#);
919    }
920
921    /// Test enrich_ai_event_data with legacy measurements and span op for operation type.
922    #[test]
923    fn test_enrich_ai_event_data_legacy_measurements_and_span_op() {
924        let event_json = r#"{
925            "type": "transaction",
926            "timestamp": 1234567892.0,
927            "start_timestamp": 1234567889.0,
928            "contexts": {
929                "trace": {
930                    "op": "http.server",
931                    "trace_id": "12345678901234567890123456789012",
932                    "span_id": "1234567890123456"
933                }
934            },
935            "spans": [
936                {
937                    "op": "gen_ai.invoke_agent",
938                    "span_id": "1234567890123457",
939                    "parent_span_id": "1234567890123456",
940                    "start_timestamp": 1234567889.5,
941                    "timestamp": 1234567891.5,
942                    "measurements": {
943                        "ai_prompt_tokens_used": {"value": 500.0},
944                        "ai_completion_tokens_used": {"value": 200.0}
945                    }
946                },
947                {
948                    "op": "ai.chat_completions.create.langchain.ChatOpenAI",
949                    "span_id": "1234567890123458",
950                    "parent_span_id": "1234567890123457",
951                    "start_timestamp": 1234567890.0,
952                    "timestamp": 1234567891.0,
953                    "measurements": {
954                        "ai_prompt_tokens_used": {"value": 100.0},
955                        "ai_completion_tokens_used": {"value": 50.0}
956                    }
957                }
958            ]
959        }"#;
960
961        let mut annotated_event: Annotated<Event> = Annotated::from_json(event_json).unwrap();
962        let event = annotated_event.value_mut().as_mut().unwrap();
963
964        enrich_ai_event_data(event, None);
965
966        assert_annotated_snapshot!(&annotated_event, @r#"
967        {
968          "type": "transaction",
969          "timestamp": 1234567892.0,
970          "start_timestamp": 1234567889.0,
971          "contexts": {
972            "trace": {
973              "trace_id": "12345678901234567890123456789012",
974              "span_id": "1234567890123456",
975              "op": "http.server",
976              "type": "trace"
977            }
978          },
979          "spans": [
980            {
981              "timestamp": 1234567891.5,
982              "start_timestamp": 1234567889.5,
983              "op": "gen_ai.invoke_agent",
984              "span_id": "1234567890123457",
985              "parent_span_id": "1234567890123456",
986              "data": {
987                "gen_ai.usage.total_tokens": 700.0,
988                "gen_ai.usage.input_tokens": 500.0,
989                "gen_ai.usage.output_tokens": 200.0,
990                "gen_ai.operation.type": "agent"
991              },
992              "measurements": {
993                "ai_completion_tokens_used": {
994                  "value": 200.0
995                },
996                "ai_prompt_tokens_used": {
997                  "value": 500.0
998                }
999              }
1000            },
1001            {
1002              "timestamp": 1234567891.0,
1003              "start_timestamp": 1234567890.0,
1004              "op": "ai.chat_completions.create.langchain.ChatOpenAI",
1005              "span_id": "1234567890123458",
1006              "parent_span_id": "1234567890123457",
1007              "data": {
1008                "gen_ai.usage.total_tokens": 150.0,
1009                "gen_ai.usage.input_tokens": 100.0,
1010                "gen_ai.usage.output_tokens": 50.0,
1011                "gen_ai.operation.type": "ai_client"
1012              },
1013              "measurements": {
1014                "ai_completion_tokens_used": {
1015                  "value": 50.0
1016                },
1017                "ai_prompt_tokens_used": {
1018                  "value": 100.0
1019                }
1020              }
1021            }
1022          ]
1023        }
1024        "#);
1025    }
1026}