relay_event_normalization/normalize/span/
ai.rs

1//! AI cost calculation.
2
3use crate::statsd::{Counters, map_origin_to_integration, platform_tag};
4use crate::{ModelCostV2, ModelCosts};
5use relay_event_schema::protocol::{
6    Event, Measurements, OperationType, Span, SpanData, TraceContext,
7};
8use relay_protocol::{Annotated, Getter, Value};
9
10/// Amount of used tokens for a model call.
11#[derive(Debug, Copy, Clone)]
12pub struct UsedTokens {
13    /// Total amount of input tokens used.
14    pub input_tokens: f64,
15    /// Amount of cached tokens used.
16    ///
17    /// This is a subset of [`Self::input_tokens`].
18    pub input_cached_tokens: f64,
19    /// Amount of cache write tokens used.
20    ///
21    /// This is a subset of [`Self::input_tokens`].
22    pub input_cache_write_tokens: f64,
23    /// Total amount of output tokens.
24    pub output_tokens: f64,
25    /// Total amount of reasoning tokens.
26    ///
27    /// This is a subset of [`Self::output_tokens`].
28    pub output_reasoning_tokens: f64,
29}
30
31impl UsedTokens {
32    /// Extracts [`UsedTokens`] from [`SpanData`] attributes.
33    pub fn from_span_data(data: &SpanData) -> Self {
34        macro_rules! get_value {
35            ($e:expr) => {
36                $e.value().and_then(Value::as_f64).unwrap_or(0.0)
37            };
38        }
39
40        Self {
41            input_tokens: get_value!(data.gen_ai_usage_input_tokens),
42            output_tokens: get_value!(data.gen_ai_usage_output_tokens),
43            output_reasoning_tokens: get_value!(data.gen_ai_usage_output_tokens_reasoning),
44            input_cached_tokens: get_value!(data.gen_ai_usage_input_tokens_cached),
45            input_cache_write_tokens: get_value!(data.gen_ai_usage_input_tokens_cache_write),
46        }
47    }
48
49    /// Returns `true` if any tokens were used.
50    pub fn has_usage(&self) -> bool {
51        self.input_tokens > 0.0 || self.output_tokens > 0.0
52    }
53
54    /// Calculates the total amount of uncached input tokens.
55    ///
56    /// Subtracts cached tokens from the total token count.
57    pub fn raw_input_tokens(&self) -> f64 {
58        self.input_tokens - self.input_cached_tokens
59    }
60
61    /// Calculates the total amount of raw, non-reasoning output tokens.
62    ///
63    /// Subtracts reasoning tokens from the total token count.
64    pub fn raw_output_tokens(&self) -> f64 {
65        self.output_tokens - self.output_reasoning_tokens
66    }
67}
68
69/// Calculated model call costs.
70#[derive(Debug, Copy, Clone)]
71pub struct CalculatedCost {
72    /// The cost of input tokens used.
73    pub input: f64,
74    /// The cost of output tokens used.
75    pub output: f64,
76}
77
78impl CalculatedCost {
79    /// The total, input and output, cost.
80    pub fn total(&self) -> f64 {
81        self.input + self.output
82    }
83}
84
85/// Calculates the total cost for a model call.
86///
87/// Returns `None` if no tokens were used.
88pub fn calculate_costs(
89    model_cost: &ModelCostV2,
90    tokens: UsedTokens,
91    integration: &str,
92    platform: &str,
93) -> Option<CalculatedCost> {
94    if !tokens.has_usage() {
95        relay_statsd::metric!(
96            counter(Counters::GenAiCostCalculationResult) += 1,
97            result = "calculation_no_tokens",
98            integration = integration,
99            platform = platform,
100        );
101        return None;
102    }
103
104    let input = (tokens.raw_input_tokens() * model_cost.input_per_token)
105        + (tokens.input_cached_tokens * model_cost.input_cached_per_token)
106        + (tokens.input_cache_write_tokens * model_cost.input_cache_write_per_token);
107
108    // For now most of the models do not differentiate between reasoning and output token cost,
109    // it costs the same.
110    let reasoning_cost = match model_cost.output_reasoning_per_token {
111        reasoning_cost if reasoning_cost > 0.0 => reasoning_cost,
112        _ => model_cost.output_per_token,
113    };
114
115    let output = (tokens.raw_output_tokens() * model_cost.output_per_token)
116        + (tokens.output_reasoning_tokens * reasoning_cost);
117
118    let metric_label = match (input, output) {
119        (x, y) if x < 0.0 || y < 0.0 => "calculation_negative",
120        (0.0, 0.0) => "calculation_zero",
121        _ => "calculation_positive",
122    };
123
124    relay_statsd::metric!(
125        counter(Counters::GenAiCostCalculationResult) += 1,
126        result = metric_label,
127        integration = integration,
128        platform = platform,
129    );
130
131    Some(CalculatedCost { input, output })
132}
133
134/// Default AI operation stored in [`GEN_AI_OPERATION_TYPE`](relay_conventions::GEN_AI_OPERATION_TYPE)
135/// for AI spans without a well known AI span op.
136///
137/// See also: [`infer_ai_operation_type`].
138pub const DEFAULT_AI_OPERATION: &str = "ai_client";
139
140/// Infers the AI operation from an AI operation name.
141///
142/// The operation name is usually inferred from the
143/// [`GEN_AI_OPERATION_NAME`](relay_conventions::GEN_AI_OPERATION_NAME) span attribute and the span
144/// operation.
145///
146/// Sentry expects the operation type in the
147/// [`GEN_AI_OPERATION_TYPE`](relay_conventions::GEN_AI_OPERATION_TYPE) attribute.
148///
149/// The function returns `None` when the op is not a well known AI operation, callers likely want to default
150/// the value to [`DEFAULT_AI_OPERATION`] for AI spans.
151pub fn infer_ai_operation_type(op_name: &str) -> Option<&'static str> {
152    let ai_op = match op_name {
153        // Full matches:
154        "ai.run.generateText"
155        | "ai.run.generateObject"
156        | "gen_ai.invoke_agent"
157        | "ai.pipeline.generate_text"
158        | "ai.pipeline.generate_object"
159        | "ai.pipeline.stream_text"
160        | "ai.pipeline.stream_object"
161        | "gen_ai.create_agent"
162        | "invoke_agent"
163        | "create_agent" => "agent",
164        "gen_ai.execute_tool" | "execute_tool" => "tool",
165        "gen_ai.handoff" | "handoff" => "handoff",
166        "ai.processor" | "processor_run" => "other",
167        // Prefix matches:
168        op if op.starts_with("ai.streamText.doStream") => "ai_client",
169        op if op.starts_with("ai.streamText") => "agent",
170
171        op if op.starts_with("ai.generateText.doGenerate") => "ai_client",
172        op if op.starts_with("ai.generateText") => "agent",
173
174        op if op.starts_with("ai.generateObject.doGenerate") => "ai_client",
175        op if op.starts_with("ai.generateObject") => "agent",
176
177        op if op.starts_with("ai.toolCall") => "tool",
178        // No match:
179        _ => return None,
180    };
181
182    Some(ai_op)
183}
184
185/// Calculates the cost of an AI model based on the model cost and the tokens used.
186/// Calculated cost is in US dollars.
187fn extract_ai_model_cost_data(
188    model_cost: Option<&ModelCostV2>,
189    data: &mut SpanData,
190    origin: Option<&str>,
191    platform: Option<&str>,
192) {
193    let integration = map_origin_to_integration(origin);
194    let platform = platform_tag(platform);
195
196    let Some(model_cost) = model_cost else {
197        relay_statsd::metric!(
198            counter(Counters::GenAiCostCalculationResult) += 1,
199            result = "calculation_no_model_cost_available",
200            integration = integration,
201            platform = platform,
202        );
203        return;
204    };
205
206    let used_tokens = UsedTokens::from_span_data(&*data);
207    let Some(costs) = calculate_costs(model_cost, used_tokens, integration, platform) else {
208        return;
209    };
210
211    data.gen_ai_cost_total_tokens
212        .set_value(Value::F64(costs.total()).into());
213
214    // Set individual cost components
215    data.gen_ai_cost_input_tokens
216        .set_value(Value::F64(costs.input).into());
217    data.gen_ai_cost_output_tokens
218        .set_value(Value::F64(costs.output).into());
219}
220
221/// Maps AI-related measurements (legacy) to span data.
222fn map_ai_measurements_to_data(data: &mut SpanData, measurements: Option<&Measurements>) {
223    let set_field_from_measurement = |target_field: &mut Annotated<Value>,
224                                      measurement_key: &str| {
225        if let Some(measurements) = measurements
226            && target_field.value().is_none()
227            && let Some(value) = measurements.get_value(measurement_key)
228        {
229            target_field.set_value(Value::F64(value.to_f64()).into());
230        }
231    };
232
233    set_field_from_measurement(&mut data.gen_ai_usage_total_tokens, "ai_total_tokens_used");
234    set_field_from_measurement(&mut data.gen_ai_usage_input_tokens, "ai_prompt_tokens_used");
235    set_field_from_measurement(
236        &mut data.gen_ai_usage_output_tokens,
237        "ai_completion_tokens_used",
238    );
239}
240
241fn set_total_tokens(data: &mut SpanData) {
242    // It might be that 'total_tokens' is not set in which case we need to calculate it
243    if data.gen_ai_usage_total_tokens.value().is_none() {
244        let input_tokens = data
245            .gen_ai_usage_input_tokens
246            .value()
247            .and_then(Value::as_f64);
248        let output_tokens = data
249            .gen_ai_usage_output_tokens
250            .value()
251            .and_then(Value::as_f64);
252
253        if input_tokens.is_none() && output_tokens.is_none() {
254            // don't set total_tokens if there are no input nor output tokens
255            return;
256        }
257
258        data.gen_ai_usage_total_tokens.set_value(
259            Value::F64(input_tokens.unwrap_or(0.0) + output_tokens.unwrap_or(0.0)).into(),
260        );
261    }
262}
263
264/// Extract the additional data into the span
265fn extract_ai_data(
266    data: &mut SpanData,
267    duration: f64,
268    ai_model_costs: &ModelCosts,
269    origin: Option<&str>,
270    platform: Option<&str>,
271) {
272    // Extracts the response tokens per second
273    if data.gen_ai_response_tokens_per_second.value().is_none()
274        && duration > 0.0
275        && let Some(output_tokens) = data
276            .gen_ai_usage_output_tokens
277            .value()
278            .and_then(Value::as_f64)
279    {
280        data.gen_ai_response_tokens_per_second
281            .set_value(Value::F64(output_tokens / (duration / 1000.0)).into());
282    }
283
284    // Extracts the total cost of the AI model used
285    if let Some(model_id) = data
286        .gen_ai_response_model
287        .value()
288        .and_then(|val| val.as_str())
289    {
290        extract_ai_model_cost_data(
291            ai_model_costs.cost_per_token(model_id),
292            data,
293            origin,
294            platform,
295        )
296    } else {
297        relay_statsd::metric!(
298            counter(Counters::GenAiCostCalculationResult) += 1,
299            result = "calculation_no_model_id_available",
300            integration = map_origin_to_integration(origin),
301            platform = platform_tag(platform),
302        );
303    }
304}
305
306/// Enrich the AI span data
307fn enrich_ai_span_data(
308    span_data: &mut Annotated<SpanData>,
309    span_op: &Annotated<OperationType>,
310    measurements: &Annotated<Measurements>,
311    duration: f64,
312    model_costs: Option<&ModelCosts>,
313    origin: Option<&str>,
314    platform: Option<&str>,
315) {
316    if !is_ai_span(span_data, span_op.value()) {
317        return;
318    }
319
320    let data = span_data.get_or_insert_with(SpanData::default);
321
322    map_ai_measurements_to_data(data, measurements.value());
323
324    set_total_tokens(data);
325
326    // Default response model to request model if not set.
327    if data.gen_ai_response_model.value().is_none()
328        && let Some(request_model) = data.gen_ai_request_model.value().cloned()
329    {
330        data.gen_ai_response_model.set_value(Some(request_model));
331    }
332
333    // Default agent name to function_id if not set.
334    if data.gen_ai_agent_name.value().is_none()
335        && let Some(function_id) = data.gen_ai_function_id.value().cloned()
336    {
337        data.gen_ai_agent_name.set_value(Some(function_id));
338    }
339
340    if let Some(model_costs) = model_costs {
341        extract_ai_data(data, duration, model_costs, origin, platform);
342    } else {
343        relay_statsd::metric!(
344            counter(Counters::GenAiCostCalculationResult) += 1,
345            result = "calculation_no_model_cost_available",
346            integration = map_origin_to_integration(origin),
347            platform = platform_tag(platform),
348        );
349    }
350
351    let ai_op_type = data
352        .gen_ai_operation_name
353        .value()
354        .or(span_op.value())
355        .and_then(|op| infer_ai_operation_type(op))
356        .unwrap_or(DEFAULT_AI_OPERATION);
357
358    data.gen_ai_operation_type
359        .set_value(Some(ai_op_type.to_owned()));
360}
361
362/// Enrich the AI span data
363pub fn enrich_ai_span(span: &mut Span, model_costs: Option<&ModelCosts>) {
364    let duration = span
365        .get_value("span.duration")
366        .and_then(|v| v.as_f64())
367        .unwrap_or(0.0);
368
369    enrich_ai_span_data(
370        &mut span.data,
371        &span.op,
372        &span.measurements,
373        duration,
374        model_costs,
375        span.origin.as_str(),
376        span.platform.as_str(),
377    );
378}
379
380/// Extract the ai data from all of an event's spans
381pub fn enrich_ai_event_data(event: &mut Event, model_costs: Option<&ModelCosts>) {
382    let event_duration = event
383        .get_value("event.duration")
384        .and_then(|v| v.as_f64())
385        .unwrap_or(0.0);
386
387    if let Some(trace_context) = event
388        .contexts
389        .value_mut()
390        .as_mut()
391        .and_then(|c| c.get_mut::<TraceContext>())
392    {
393        enrich_ai_span_data(
394            &mut trace_context.data,
395            &trace_context.op,
396            &event.measurements,
397            event_duration,
398            model_costs,
399            trace_context.origin.as_str(),
400            event.platform.as_str(),
401        );
402    }
403    let spans = event.spans.value_mut().iter_mut().flatten();
404    let spans = spans.filter_map(|span| span.value_mut().as_mut());
405
406    for span in spans {
407        let span_duration = span
408            .get_value("span.duration")
409            .and_then(|v| v.as_f64())
410            .unwrap_or(0.0);
411        let span_platform = span.platform.as_str().or_else(|| event.platform.as_str());
412
413        enrich_ai_span_data(
414            &mut span.data,
415            &span.op,
416            &span.measurements,
417            span_duration,
418            model_costs,
419            span.origin.as_str(),
420            span_platform,
421        );
422    }
423}
424
425/// Returns true if the span is an AI span.
426/// AI spans are spans with either a gen_ai.operation.name attribute or op starting with "ai."
427/// (legacy) or "gen_ai." (new).
428fn is_ai_span(span_data: &Annotated<SpanData>, span_op: Option<&OperationType>) -> bool {
429    let has_ai_op = span_data
430        .value()
431        .and_then(|data| data.gen_ai_operation_name.value())
432        .is_some();
433
434    let is_ai_span_op =
435        span_op.is_some_and(|op| op.starts_with("ai.") || op.starts_with("gen_ai."));
436
437    has_ai_op || is_ai_span_op
438}
439
440#[cfg(test)]
441mod tests {
442    use relay_protocol::{FromValue, assert_annotated_snapshot};
443    use serde_json::json;
444
445    use super::*;
446
447    fn ai_span_with_data(data: serde_json::Value) -> Span {
448        Span {
449            op: "gen_ai.test".to_owned().into(),
450            data: SpanData::from_value(data.into()),
451            ..Default::default()
452        }
453    }
454
455    #[test]
456    fn test_calculate_cost_no_tokens() {
457        let cost = calculate_costs(
458            &ModelCostV2 {
459                input_per_token: 1.0,
460                output_per_token: 1.0,
461                output_reasoning_per_token: 1.0,
462                input_cached_per_token: 1.0,
463                input_cache_write_per_token: 1.0,
464            },
465            UsedTokens::from_span_data(&SpanData::default()),
466            "test",
467            "test",
468        );
469        assert!(cost.is_none());
470    }
471
472    #[test]
473    fn test_calculate_cost_full() {
474        let cost = calculate_costs(
475            &ModelCostV2 {
476                input_per_token: 1.0,
477                output_per_token: 2.0,
478                output_reasoning_per_token: 3.0,
479                input_cached_per_token: 0.5,
480                input_cache_write_per_token: 0.75,
481            },
482            UsedTokens {
483                input_tokens: 8.0,
484                input_cached_tokens: 5.0,
485                input_cache_write_tokens: 0.0,
486                output_tokens: 15.0,
487                output_reasoning_tokens: 9.0,
488            },
489            "test",
490            "test",
491        )
492        .unwrap();
493
494        insta::assert_debug_snapshot!(cost, @r"
495        CalculatedCost {
496            input: 5.5,
497            output: 39.0,
498        }
499        ");
500    }
501
502    #[test]
503    fn test_calculate_cost_no_reasoning_cost() {
504        let cost = calculate_costs(
505            &ModelCostV2 {
506                input_per_token: 1.0,
507                output_per_token: 2.0,
508                // Should fallback to output token cost for reasoning.
509                output_reasoning_per_token: 0.0,
510                input_cached_per_token: 0.5,
511                input_cache_write_per_token: 0.0,
512            },
513            UsedTokens {
514                input_tokens: 8.0,
515                input_cached_tokens: 5.0,
516                input_cache_write_tokens: 0.0,
517                output_tokens: 15.0,
518                output_reasoning_tokens: 9.0,
519            },
520            "test",
521            "test",
522        )
523        .unwrap();
524
525        insta::assert_debug_snapshot!(cost, @r"
526        CalculatedCost {
527            input: 5.5,
528            output: 30.0,
529        }
530        ");
531    }
532
533    /// This test shows it is possible to produce negative costs if tokens are not aligned properly.
534    ///
535    /// The behaviour was desired when initially implemented.
536    #[test]
537    fn test_calculate_cost_negative() {
538        let cost = calculate_costs(
539            &ModelCostV2 {
540                input_per_token: 2.0,
541                output_per_token: 2.0,
542                output_reasoning_per_token: 1.0,
543                input_cached_per_token: 1.0,
544                input_cache_write_per_token: 1.5,
545            },
546            UsedTokens {
547                input_tokens: 1.0,
548                input_cached_tokens: 11.0,
549                input_cache_write_tokens: 0.0,
550                output_tokens: 1.0,
551                output_reasoning_tokens: 9.0,
552            },
553            "test",
554            "test",
555        )
556        .unwrap();
557
558        insta::assert_debug_snapshot!(cost, @r"
559        CalculatedCost {
560            input: -9.0,
561            output: -7.0,
562        }
563        ");
564    }
565
566    #[test]
567    fn test_calculate_cost_with_cache_writes() {
568        let cost = calculate_costs(
569            &ModelCostV2 {
570                input_per_token: 1.0,
571                output_per_token: 2.0,
572                output_reasoning_per_token: 3.0,
573                input_cached_per_token: 0.5,
574                input_cache_write_per_token: 0.75,
575            },
576            UsedTokens {
577                input_tokens: 100.0,
578                input_cached_tokens: 20.0,
579                input_cache_write_tokens: 30.0,
580                output_tokens: 50.0,
581                output_reasoning_tokens: 10.0,
582            },
583            "test",
584            "test",
585        )
586        .unwrap();
587
588        insta::assert_debug_snapshot!(cost, @r"
589        CalculatedCost {
590            input: 112.5,
591            output: 110.0,
592        }
593        ");
594    }
595
596    #[test]
597    fn test_calculate_cost_backward_compatibility_no_cache_write() {
598        // Test that cost calculation works when cache_write field is missing (backward compatibility)
599        let span_data = SpanData {
600            gen_ai_usage_input_tokens: Annotated::new(100.0.into()),
601            gen_ai_usage_input_tokens_cached: Annotated::new(20.0.into()),
602            gen_ai_usage_output_tokens: Annotated::new(50.0.into()),
603            // Note: gen_ai_usage_input_tokens_cache_write is NOT set (simulating old data)
604            ..Default::default()
605        };
606
607        let tokens = UsedTokens::from_span_data(&span_data);
608
609        // Verify cache_write_tokens defaults to 0.0
610        assert_eq!(tokens.input_cache_write_tokens, 0.0);
611
612        let cost = calculate_costs(
613            &ModelCostV2 {
614                input_per_token: 1.0,
615                output_per_token: 2.0,
616                output_reasoning_per_token: 0.0,
617                input_cached_per_token: 0.5,
618                input_cache_write_per_token: 0.75,
619            },
620            tokens,
621            "test",
622            "test",
623        )
624        .unwrap();
625
626        // Cost should be calculated without cache_write_tokens
627        // input: (100 - 20) * 1.0 + 20 * 0.5 + 0 * 0.75 = 80 + 10 + 0 = 90
628        // output: 50 * 2.0 = 100
629        insta::assert_debug_snapshot!(cost, @r"
630        CalculatedCost {
631            input: 90.0,
632            output: 100.0,
633        }
634        ");
635    }
636
637    /// Test that the AI operation type is inferred from a gen_ai.operation.name attribute.
638    #[test]
639    fn test_infer_ai_operation_type_from_gen_ai_operation_name() {
640        let mut span = ai_span_with_data(json!({
641            "gen_ai.operation.name": "invoke_agent"
642        }));
643
644        enrich_ai_span(&mut span, None);
645
646        assert_annotated_snapshot!(&span.data, @r#"
647        {
648          "gen_ai.operation.name": "invoke_agent",
649          "gen_ai.operation.type": "agent"
650        }
651        "#);
652    }
653
654    /// Test that the AI operation type is inferred from a span.op attribute.
655    #[test]
656    fn test_infer_ai_operation_type_from_span_op() {
657        let mut span = Span {
658            op: "gen_ai.invoke_agent".to_owned().into(),
659            ..Default::default()
660        };
661
662        enrich_ai_span(&mut span, None);
663
664        assert_annotated_snapshot!(span.data, @r#"
665        {
666          "gen_ai.operation.type": "agent"
667        }
668        "#);
669    }
670
671    /// Test that the AI operation type is inferred from a fallback.
672    #[test]
673    fn test_infer_ai_operation_type_from_fallback() {
674        let mut span = ai_span_with_data(json!({
675            "gen_ai.operation.name": "embeddings"
676        }));
677
678        enrich_ai_span(&mut span, None);
679
680        assert_annotated_snapshot!(&span.data, @r#"
681        {
682          "gen_ai.operation.name": "embeddings",
683          "gen_ai.operation.type": "ai_client"
684        }
685        "#);
686    }
687
688    /// Test that the response model is defaulted to the request model if not set.
689    #[test]
690    fn test_default_response_model_from_request_model() {
691        let mut span = ai_span_with_data(json!({
692            "gen_ai.request.model": "gpt-4",
693        }));
694
695        enrich_ai_span(&mut span, None);
696
697        assert_annotated_snapshot!(&span.data, @r#"
698        {
699          "gen_ai.response.model": "gpt-4",
700          "gen_ai.request.model": "gpt-4",
701          "gen_ai.operation.type": "ai_client"
702        }
703        "#);
704    }
705
706    /// Test that the response model is defaulted to the request model if not set.
707    #[test]
708    fn test_default_response_model_not_overridden() {
709        let mut span = ai_span_with_data(json!({
710            "gen_ai.request.model": "gpt-4",
711            "gen_ai.response.model": "gpt-4-abcd",
712        }));
713
714        enrich_ai_span(&mut span, None);
715
716        assert_annotated_snapshot!(&span.data, @r#"
717        {
718          "gen_ai.response.model": "gpt-4-abcd",
719          "gen_ai.request.model": "gpt-4",
720          "gen_ai.operation.type": "ai_client"
721        }
722        "#);
723    }
724
725    /// Test that gen_ai.agent.name is defaulted from gen_ai.function_id.
726    #[test]
727    fn test_default_agent_name_from_function_id() {
728        let mut span = ai_span_with_data(json!({
729            "gen_ai.function_id": "my-agent",
730        }));
731
732        enrich_ai_span(&mut span, None);
733
734        assert_annotated_snapshot!(&span.data, @r#"
735        {
736          "gen_ai.operation.type": "ai_client",
737          "gen_ai.agent.name": "my-agent",
738          "gen_ai.function_id": "my-agent"
739        }
740        "#);
741    }
742
743    /// Test that gen_ai.agent.name is not overridden when already set.
744    #[test]
745    fn test_default_agent_name_not_overridden() {
746        let mut span = ai_span_with_data(json!({
747            "gen_ai.function_id": "my-function",
748            "gen_ai.agent.name": "my-agent",
749        }));
750
751        enrich_ai_span(&mut span, None);
752
753        assert_annotated_snapshot!(&span.data, @r#"
754        {
755          "gen_ai.operation.type": "ai_client",
756          "gen_ai.agent.name": "my-agent",
757          "gen_ai.function_id": "my-function"
758        }
759        "#);
760    }
761
762    /// Test that an AI span is detected from a gen_ai.operation.name attribute.
763    #[test]
764    fn test_is_ai_span_from_gen_ai_operation_name() {
765        let mut span_data = Annotated::default();
766        span_data
767            .get_or_insert_with(SpanData::default)
768            .gen_ai_operation_name
769            .set_value(Some("chat".into()));
770        assert!(is_ai_span(&span_data, None));
771    }
772
773    /// Test that an AI span is detected from a span.op starting with "ai.".
774    #[test]
775    fn test_is_ai_span_from_span_op_ai() {
776        let span_op: OperationType = "ai.chat".into();
777        assert!(is_ai_span(&Annotated::default(), Some(&span_op)));
778    }
779
780    /// Test that an AI span is detected from a span.op starting with "gen_ai.".
781    #[test]
782    fn test_is_ai_span_from_span_op_gen_ai() {
783        let span_op: OperationType = "gen_ai.chat".into();
784        assert!(is_ai_span(&Annotated::default(), Some(&span_op)));
785    }
786
787    /// Test that a non-AI span is detected.
788    #[test]
789    fn test_is_ai_span_negative() {
790        assert!(!is_ai_span(&Annotated::default(), None));
791    }
792
793    /// Test enrich_ai_event_data with invoke_agent in trace context and a chat child span.
794    #[test]
795    fn test_enrich_ai_event_data_invoke_agent_trace_with_chat_span() {
796        let event_json = r#"{
797            "type": "transaction",
798            "timestamp": 1234567892.0,
799            "start_timestamp": 1234567889.0,
800            "contexts": {
801                "trace": {
802                    "op": "gen_ai.invoke_agent",
803                    "trace_id": "12345678901234567890123456789012",
804                    "span_id": "1234567890123456",
805                    "data": {
806                        "gen_ai.operation.name": "gen_ai.invoke_agent",
807                        "gen_ai.usage.input_tokens": 500,
808                        "gen_ai.usage.output_tokens": 200
809                    }
810                }
811            },
812            "spans": [
813                {
814                    "op": "gen_ai.chat.completions",
815                    "span_id": "1234567890123457",
816                    "start_timestamp": 1234567889.5,
817                    "timestamp": 1234567890.5,
818                    "data": {
819                        "gen_ai.operation.name": "chat",
820                        "gen_ai.usage.input_tokens": 100,
821                        "gen_ai.usage.output_tokens": 50
822                    }
823                }
824            ]
825        }"#;
826
827        let mut annotated_event: Annotated<Event> = Annotated::from_json(event_json).unwrap();
828        let event = annotated_event.value_mut().as_mut().unwrap();
829
830        enrich_ai_event_data(event, None);
831
832        assert_annotated_snapshot!(&annotated_event, @r#"
833        {
834          "type": "transaction",
835          "timestamp": 1234567892.0,
836          "start_timestamp": 1234567889.0,
837          "contexts": {
838            "trace": {
839              "trace_id": "12345678901234567890123456789012",
840              "span_id": "1234567890123456",
841              "op": "gen_ai.invoke_agent",
842              "data": {
843                "gen_ai.usage.total_tokens": 700.0,
844                "gen_ai.usage.input_tokens": 500,
845                "gen_ai.usage.output_tokens": 200,
846                "gen_ai.operation.name": "gen_ai.invoke_agent",
847                "gen_ai.operation.type": "agent"
848              },
849              "type": "trace"
850            }
851          },
852          "spans": [
853            {
854              "timestamp": 1234567890.5,
855              "start_timestamp": 1234567889.5,
856              "op": "gen_ai.chat.completions",
857              "span_id": "1234567890123457",
858              "data": {
859                "gen_ai.usage.total_tokens": 150.0,
860                "gen_ai.usage.input_tokens": 100,
861                "gen_ai.usage.output_tokens": 50,
862                "gen_ai.operation.name": "chat",
863                "gen_ai.operation.type": "ai_client"
864              }
865            }
866          ]
867        }
868        "#);
869    }
870
871    /// Test enrich_ai_event_data with non-AI trace context, invoke_agent parent span, and chat child span.
872    #[test]
873    fn test_enrich_ai_event_data_nested_agent_and_chat_spans() {
874        let event_json = r#"{
875            "type": "transaction",
876            "timestamp": 1234567892.0,
877            "start_timestamp": 1234567889.0,
878            "contexts": {
879                "trace": {
880                    "op": "http.server",
881                    "trace_id": "12345678901234567890123456789012",
882                    "span_id": "1234567890123456"
883                }
884            },
885            "spans": [
886                {
887                    "op": "gen_ai.invoke_agent",
888                    "span_id": "1234567890123457",
889                    "parent_span_id": "1234567890123456",
890                    "start_timestamp": 1234567889.5,
891                    "timestamp": 1234567891.5,
892                    "data": {
893                        "gen_ai.operation.name": "invoke_agent",
894                        "gen_ai.usage.input_tokens": 500,
895                        "gen_ai.usage.output_tokens": 200
896                    }
897                },
898                {
899                    "op": "gen_ai.chat.completions",
900                    "span_id": "1234567890123458",
901                    "parent_span_id": "1234567890123457",
902                    "start_timestamp": 1234567890.0,
903                    "timestamp": 1234567891.0,
904                    "data": {
905                        "gen_ai.operation.name": "chat",
906                        "gen_ai.usage.input_tokens": 100,
907                        "gen_ai.usage.output_tokens": 50
908                    }
909                }
910            ]
911        }"#;
912
913        let mut annotated_event: Annotated<Event> = Annotated::from_json(event_json).unwrap();
914        let event = annotated_event.value_mut().as_mut().unwrap();
915
916        enrich_ai_event_data(event, None);
917
918        assert_annotated_snapshot!(&annotated_event, @r#"
919        {
920          "type": "transaction",
921          "timestamp": 1234567892.0,
922          "start_timestamp": 1234567889.0,
923          "contexts": {
924            "trace": {
925              "trace_id": "12345678901234567890123456789012",
926              "span_id": "1234567890123456",
927              "op": "http.server",
928              "type": "trace"
929            }
930          },
931          "spans": [
932            {
933              "timestamp": 1234567891.5,
934              "start_timestamp": 1234567889.5,
935              "op": "gen_ai.invoke_agent",
936              "span_id": "1234567890123457",
937              "parent_span_id": "1234567890123456",
938              "data": {
939                "gen_ai.usage.total_tokens": 700.0,
940                "gen_ai.usage.input_tokens": 500,
941                "gen_ai.usage.output_tokens": 200,
942                "gen_ai.operation.name": "invoke_agent",
943                "gen_ai.operation.type": "agent"
944              }
945            },
946            {
947              "timestamp": 1234567891.0,
948              "start_timestamp": 1234567890.0,
949              "op": "gen_ai.chat.completions",
950              "span_id": "1234567890123458",
951              "parent_span_id": "1234567890123457",
952              "data": {
953                "gen_ai.usage.total_tokens": 150.0,
954                "gen_ai.usage.input_tokens": 100,
955                "gen_ai.usage.output_tokens": 50,
956                "gen_ai.operation.name": "chat",
957                "gen_ai.operation.type": "ai_client"
958              }
959            }
960          ]
961        }
962        "#);
963    }
964
965    /// Test enrich_ai_event_data with legacy measurements and span op for operation type.
966    #[test]
967    fn test_enrich_ai_event_data_legacy_measurements_and_span_op() {
968        let event_json = r#"{
969            "type": "transaction",
970            "timestamp": 1234567892.0,
971            "start_timestamp": 1234567889.0,
972            "contexts": {
973                "trace": {
974                    "op": "http.server",
975                    "trace_id": "12345678901234567890123456789012",
976                    "span_id": "1234567890123456"
977                }
978            },
979            "spans": [
980                {
981                    "op": "gen_ai.invoke_agent",
982                    "span_id": "1234567890123457",
983                    "parent_span_id": "1234567890123456",
984                    "start_timestamp": 1234567889.5,
985                    "timestamp": 1234567891.5,
986                    "measurements": {
987                        "ai_prompt_tokens_used": {"value": 500.0},
988                        "ai_completion_tokens_used": {"value": 200.0}
989                    }
990                },
991                {
992                    "op": "ai.chat_completions.create.langchain.ChatOpenAI",
993                    "span_id": "1234567890123458",
994                    "parent_span_id": "1234567890123457",
995                    "start_timestamp": 1234567890.0,
996                    "timestamp": 1234567891.0,
997                    "measurements": {
998                        "ai_prompt_tokens_used": {"value": 100.0},
999                        "ai_completion_tokens_used": {"value": 50.0}
1000                    }
1001                }
1002            ]
1003        }"#;
1004
1005        let mut annotated_event: Annotated<Event> = Annotated::from_json(event_json).unwrap();
1006        let event = annotated_event.value_mut().as_mut().unwrap();
1007
1008        enrich_ai_event_data(event, None);
1009
1010        assert_annotated_snapshot!(&annotated_event, @r#"
1011        {
1012          "type": "transaction",
1013          "timestamp": 1234567892.0,
1014          "start_timestamp": 1234567889.0,
1015          "contexts": {
1016            "trace": {
1017              "trace_id": "12345678901234567890123456789012",
1018              "span_id": "1234567890123456",
1019              "op": "http.server",
1020              "type": "trace"
1021            }
1022          },
1023          "spans": [
1024            {
1025              "timestamp": 1234567891.5,
1026              "start_timestamp": 1234567889.5,
1027              "op": "gen_ai.invoke_agent",
1028              "span_id": "1234567890123457",
1029              "parent_span_id": "1234567890123456",
1030              "data": {
1031                "gen_ai.usage.total_tokens": 700.0,
1032                "gen_ai.usage.input_tokens": 500.0,
1033                "gen_ai.usage.output_tokens": 200.0,
1034                "gen_ai.operation.type": "agent"
1035              },
1036              "measurements": {
1037                "ai_completion_tokens_used": {
1038                  "value": 200.0
1039                },
1040                "ai_prompt_tokens_used": {
1041                  "value": 500.0
1042                }
1043              }
1044            },
1045            {
1046              "timestamp": 1234567891.0,
1047              "start_timestamp": 1234567890.0,
1048              "op": "ai.chat_completions.create.langchain.ChatOpenAI",
1049              "span_id": "1234567890123458",
1050              "parent_span_id": "1234567890123457",
1051              "data": {
1052                "gen_ai.usage.total_tokens": 150.0,
1053                "gen_ai.usage.input_tokens": 100.0,
1054                "gen_ai.usage.output_tokens": 50.0,
1055                "gen_ai.operation.type": "ai_client"
1056              },
1057              "measurements": {
1058                "ai_completion_tokens_used": {
1059                  "value": 50.0
1060                },
1061                "ai_prompt_tokens_used": {
1062                  "value": 100.0
1063                }
1064              }
1065            }
1066          ]
1067        }
1068        "#);
1069    }
1070}