Skip to main content

relay_event_normalization/normalize/span/
ai.rs

1//! AI cost calculation.
2
3use crate::statsd::{Counters, map_origin_to_integration, platform_tag};
4use crate::{ModelCostV2, ModelMetadata};
5use relay_event_schema::protocol::{
6    Event, Measurements, OperationType, Span, SpanData, TraceContext,
7};
8use relay_protocol::{Annotated, Getter, Value};
9
10/// Amount of used tokens for a model call.
11#[derive(Debug, Copy, Clone)]
12pub struct UsedTokens {
13    /// Total amount of input tokens used.
14    pub input_tokens: f64,
15    /// Amount of cached tokens used.
16    ///
17    /// This is a subset of [`Self::input_tokens`].
18    pub input_cached_tokens: f64,
19    /// Amount of cache write tokens used.
20    ///
21    /// This is a subset of [`Self::input_tokens`].
22    pub input_cache_write_tokens: f64,
23    /// Total amount of output tokens.
24    pub output_tokens: f64,
25    /// Total amount of reasoning tokens.
26    ///
27    /// This is a subset of [`Self::output_tokens`].
28    pub output_reasoning_tokens: f64,
29}
30
31impl UsedTokens {
32    /// Extracts [`UsedTokens`] from [`SpanData`] attributes.
33    pub fn from_span_data(data: &SpanData) -> Self {
34        macro_rules! get_value {
35            ($e:expr) => {
36                $e.value().and_then(Value::as_f64).unwrap_or(0.0)
37            };
38        }
39
40        Self {
41            input_tokens: get_value!(data.gen_ai_usage_input_tokens),
42            output_tokens: get_value!(data.gen_ai_usage_output_tokens),
43            output_reasoning_tokens: get_value!(data.gen_ai_usage_output_tokens_reasoning),
44            input_cached_tokens: get_value!(data.gen_ai_usage_input_tokens_cached),
45            input_cache_write_tokens: get_value!(data.gen_ai_usage_input_tokens_cache_write),
46        }
47    }
48
49    /// Returns `true` if any tokens were used.
50    pub fn has_usage(&self) -> bool {
51        self.input_tokens > 0.0 || self.output_tokens > 0.0
52    }
53
54    /// Calculates the total amount of uncached input tokens.
55    ///
56    /// Subtracts cached tokens from the total token count.
57    pub fn raw_input_tokens(&self) -> f64 {
58        self.input_tokens - self.input_cached_tokens
59    }
60
61    /// Calculates the total amount of raw, non-reasoning output tokens.
62    ///
63    /// Subtracts reasoning tokens from the total token count.
64    pub fn raw_output_tokens(&self) -> f64 {
65        self.output_tokens - self.output_reasoning_tokens
66    }
67}
68
69/// Calculated model call costs.
70#[derive(Debug, Copy, Clone)]
71pub struct CalculatedCost {
72    /// The cost of input tokens used.
73    pub input: f64,
74    /// The cost of output tokens used.
75    pub output: f64,
76}
77
78impl CalculatedCost {
79    /// The total, input and output, cost.
80    pub fn total(&self) -> f64 {
81        self.input + self.output
82    }
83}
84
85/// Calculates the total cost for a model call.
86///
87/// Returns `None` if no tokens were used.
88pub fn calculate_costs(
89    model_cost: &ModelCostV2,
90    tokens: UsedTokens,
91    integration: &str,
92    platform: &str,
93) -> Option<CalculatedCost> {
94    if !tokens.has_usage() {
95        relay_statsd::metric!(
96            counter(Counters::GenAiCostCalculationResult) += 1,
97            result = "calculation_no_tokens",
98            integration = integration,
99            platform = platform,
100        );
101        return None;
102    }
103
104    let input = (tokens.raw_input_tokens() * model_cost.input_per_token)
105        + (tokens.input_cached_tokens * model_cost.input_cached_per_token)
106        + (tokens.input_cache_write_tokens * model_cost.input_cache_write_per_token);
107
108    // For now most of the models do not differentiate between reasoning and output token cost,
109    // it costs the same.
110    let reasoning_cost = match model_cost.output_reasoning_per_token {
111        reasoning_cost if reasoning_cost > 0.0 => reasoning_cost,
112        _ => model_cost.output_per_token,
113    };
114
115    let output = (tokens.raw_output_tokens() * model_cost.output_per_token)
116        + (tokens.output_reasoning_tokens * reasoning_cost);
117
118    let metric_label = match (input, output) {
119        (x, y) if x < 0.0 || y < 0.0 => "calculation_negative",
120        (0.0, 0.0) => "calculation_zero",
121        _ => "calculation_positive",
122    };
123
124    relay_statsd::metric!(
125        counter(Counters::GenAiCostCalculationResult) += 1,
126        result = metric_label,
127        integration = integration,
128        platform = platform,
129    );
130
131    Some(CalculatedCost { input, output })
132}
133
134/// Default AI operation stored in [`GEN_AI_OPERATION_TYPE`](relay_conventions::GEN_AI_OPERATION_TYPE)
135/// for AI spans without a well known AI span op.
136///
137/// See also: [`infer_ai_operation_type`].
138pub const DEFAULT_AI_OPERATION: &str = "ai_client";
139
140/// Infers the AI operation from an AI operation name.
141///
142/// The operation name is usually inferred from the
143/// [`GEN_AI_OPERATION_NAME`](relay_conventions::GEN_AI_OPERATION_NAME) span attribute and the span
144/// operation.
145///
146/// Sentry expects the operation type in the
147/// [`GEN_AI_OPERATION_TYPE`](relay_conventions::GEN_AI_OPERATION_TYPE) attribute.
148///
149/// The function returns `None` when the op is not a well known AI operation, callers likely want to default
150/// the value to [`DEFAULT_AI_OPERATION`] for AI spans.
151pub fn infer_ai_operation_type(op_name: &str) -> Option<&'static str> {
152    let ai_op = match op_name {
153        // Full matches:
154        "ai.run.generateText"
155        | "ai.run.generateObject"
156        | "gen_ai.invoke_agent"
157        | "ai.pipeline.generate_text"
158        | "ai.pipeline.generate_object"
159        | "ai.pipeline.stream_text"
160        | "ai.pipeline.stream_object"
161        | "gen_ai.create_agent"
162        | "invoke_agent"
163        | "create_agent" => "agent",
164        "gen_ai.execute_tool" | "execute_tool" => "tool",
165        "gen_ai.handoff" | "handoff" => "handoff",
166        "ai.processor" | "processor_run" => "other",
167        // Prefix matches:
168        op if op.starts_with("ai.streamText.doStream") => "ai_client",
169        op if op.starts_with("ai.streamText") => "agent",
170
171        op if op.starts_with("ai.generateText.doGenerate") => "ai_client",
172        op if op.starts_with("ai.generateText") => "agent",
173
174        op if op.starts_with("ai.generateObject.doGenerate") => "ai_client",
175        op if op.starts_with("ai.generateObject") => "agent",
176
177        op if op.starts_with("ai.toolCall") => "tool",
178        // No match:
179        _ => return None,
180    };
181
182    Some(ai_op)
183}
184
185/// Calculates the cost of an AI model based on the model cost and the tokens used.
186/// Calculated cost is in US dollars.
187fn extract_ai_model_cost_data(
188    model_cost: Option<&ModelCostV2>,
189    data: &mut SpanData,
190    origin: Option<&str>,
191    platform: Option<&str>,
192) {
193    let integration = map_origin_to_integration(origin);
194    let platform = platform_tag(platform);
195
196    let Some(model_cost) = model_cost else {
197        relay_statsd::metric!(
198            counter(Counters::GenAiCostCalculationResult) += 1,
199            result = "calculation_no_model_cost_available",
200            integration = integration,
201            platform = platform,
202        );
203        return;
204    };
205
206    let used_tokens = UsedTokens::from_span_data(&*data);
207    let Some(costs) = calculate_costs(model_cost, used_tokens, integration, platform) else {
208        return;
209    };
210
211    data.gen_ai_cost_total_tokens
212        .set_value(Value::F64(costs.total()).into());
213
214    // Set individual cost components
215    data.gen_ai_cost_input_tokens
216        .set_value(Value::F64(costs.input).into());
217    data.gen_ai_cost_output_tokens
218        .set_value(Value::F64(costs.output).into());
219}
220
221/// Maps AI-related measurements (legacy) to span data.
222fn map_ai_measurements_to_data(data: &mut SpanData, measurements: Option<&Measurements>) {
223    let set_field_from_measurement = |target_field: &mut Annotated<Value>,
224                                      measurement_key: &str| {
225        if let Some(measurements) = measurements
226            && target_field.value().is_none()
227            && let Some(value) = measurements.get_value(measurement_key)
228        {
229            target_field.set_value(Value::F64(value.to_f64()).into());
230        }
231    };
232
233    set_field_from_measurement(&mut data.gen_ai_usage_total_tokens, "ai_total_tokens_used");
234    set_field_from_measurement(&mut data.gen_ai_usage_input_tokens, "ai_prompt_tokens_used");
235    set_field_from_measurement(
236        &mut data.gen_ai_usage_output_tokens,
237        "ai_completion_tokens_used",
238    );
239}
240
241fn set_total_tokens(data: &mut SpanData) {
242    // It might be that 'total_tokens' is not set in which case we need to calculate it
243    if data.gen_ai_usage_total_tokens.value().is_none() {
244        let input_tokens = data
245            .gen_ai_usage_input_tokens
246            .value()
247            .and_then(Value::as_f64);
248        let output_tokens = data
249            .gen_ai_usage_output_tokens
250            .value()
251            .and_then(Value::as_f64);
252
253        if input_tokens.is_none() && output_tokens.is_none() {
254            // don't set total_tokens if there are no input nor output tokens
255            return;
256        }
257
258        data.gen_ai_usage_total_tokens.set_value(
259            Value::F64(input_tokens.unwrap_or(0.0) + output_tokens.unwrap_or(0.0)).into(),
260        );
261    }
262}
263
264/// Sets the context window size and utilization for the model.
265fn extract_context_utilization(data: &mut SpanData, model_metadata: &ModelMetadata) {
266    let model_id = data
267        .gen_ai_response_model
268        .value()
269        .and_then(|val| val.as_str());
270
271    let context_size = model_id.and_then(|id| model_metadata.context_size(id));
272
273    let Some(context_size) = context_size else {
274        return;
275    };
276
277    data.gen_ai_context_window_size
278        .set_value(Value::U64(context_size).into());
279
280    let total_tokens = data
281        .gen_ai_usage_total_tokens
282        .value()
283        .and_then(Value::as_f64);
284
285    if let Some(total_tokens) = total_tokens {
286        data.gen_ai_context_utilization
287            .set_value(Value::F64(total_tokens / context_size as f64).into());
288    }
289}
290
291/// Extract the additional data into the span
292fn extract_ai_data(
293    data: &mut SpanData,
294    duration: f64,
295    model_metadata: &ModelMetadata,
296    origin: Option<&str>,
297    platform: Option<&str>,
298) {
299    // Extracts the response tokens per second
300    if data.gen_ai_response_tokens_per_second.value().is_none()
301        && duration > 0.0
302        && let Some(output_tokens) = data
303            .gen_ai_usage_output_tokens
304            .value()
305            .and_then(Value::as_f64)
306    {
307        data.gen_ai_response_tokens_per_second
308            .set_value(Value::F64(output_tokens / (duration / 1000.0)).into());
309    }
310
311    extract_context_utilization(data, model_metadata);
312
313    // Extracts the total cost of the AI model used
314    if let Some(model_id) = data
315        .gen_ai_response_model
316        .value()
317        .and_then(|val| val.as_str())
318    {
319        extract_ai_model_cost_data(
320            model_metadata.cost_per_token(model_id),
321            data,
322            origin,
323            platform,
324        )
325    } else {
326        relay_statsd::metric!(
327            counter(Counters::GenAiCostCalculationResult) += 1,
328            result = "calculation_no_model_id_available",
329            integration = map_origin_to_integration(origin),
330            platform = platform_tag(platform),
331        );
332    }
333}
334
335/// Enrich the AI span data
336fn enrich_ai_span_data(
337    span_data: &mut Annotated<SpanData>,
338    span_op: &Annotated<OperationType>,
339    measurements: &Annotated<Measurements>,
340    duration: f64,
341    model_metadata: Option<&ModelMetadata>,
342    origin: Option<&str>,
343    platform: Option<&str>,
344) {
345    if !is_ai_span(span_data, span_op.value()) {
346        return;
347    }
348
349    let data = span_data.get_or_insert_with(SpanData::default);
350
351    map_ai_measurements_to_data(data, measurements.value());
352
353    set_total_tokens(data);
354
355    // Default response model to request model if not set.
356    if data.gen_ai_response_model.value().is_none()
357        && let Some(request_model) = data.gen_ai_request_model.value().cloned()
358    {
359        data.gen_ai_response_model.set_value(Some(request_model));
360    }
361
362    // Default agent name to function_id if not set.
363    if data.gen_ai_agent_name.value().is_none()
364        && let Some(function_id) = data.gen_ai_function_id.value().cloned()
365    {
366        data.gen_ai_agent_name.set_value(Some(function_id));
367    }
368
369    if let Some(model_metadata) = model_metadata {
370        extract_ai_data(data, duration, model_metadata, origin, platform);
371    } else {
372        relay_statsd::metric!(
373            counter(Counters::GenAiCostCalculationResult) += 1,
374            result = "calculation_no_model_cost_available",
375            integration = map_origin_to_integration(origin),
376            platform = platform_tag(platform),
377        );
378    }
379
380    let ai_op_type = data
381        .gen_ai_operation_name
382        .value()
383        .or(span_op.value())
384        .and_then(|op| infer_ai_operation_type(op))
385        .unwrap_or(DEFAULT_AI_OPERATION);
386
387    data.gen_ai_operation_type
388        .set_value(Some(ai_op_type.to_owned()));
389}
390
391/// Enrich the AI span data
392pub fn enrich_ai_span(span: &mut Span, model_metadata: Option<&ModelMetadata>) {
393    let duration = span
394        .get_value("span.duration")
395        .and_then(|v| v.as_f64())
396        .unwrap_or(0.0);
397
398    enrich_ai_span_data(
399        &mut span.data,
400        &span.op,
401        &span.measurements,
402        duration,
403        model_metadata,
404        span.origin.as_str(),
405        span.platform.as_str(),
406    );
407}
408
409/// Extract the ai data from all of an event's spans
410pub fn enrich_ai_event_data(event: &mut Event, model_metadata: Option<&ModelMetadata>) {
411    let event_duration = event
412        .get_value("event.duration")
413        .and_then(|v| v.as_f64())
414        .unwrap_or(0.0);
415
416    if let Some(trace_context) = event
417        .contexts
418        .value_mut()
419        .as_mut()
420        .and_then(|c| c.get_mut::<TraceContext>())
421    {
422        enrich_ai_span_data(
423            &mut trace_context.data,
424            &trace_context.op,
425            &event.measurements,
426            event_duration,
427            model_metadata,
428            trace_context.origin.as_str(),
429            event.platform.as_str(),
430        );
431    }
432    let spans = event.spans.value_mut().iter_mut().flatten();
433    let spans = spans.filter_map(|span| span.value_mut().as_mut());
434
435    for span in spans {
436        let span_duration = span
437            .get_value("span.duration")
438            .and_then(|v| v.as_f64())
439            .unwrap_or(0.0);
440        let span_platform = span.platform.as_str().or_else(|| event.platform.as_str());
441
442        enrich_ai_span_data(
443            &mut span.data,
444            &span.op,
445            &span.measurements,
446            span_duration,
447            model_metadata,
448            span.origin.as_str(),
449            span_platform,
450        );
451    }
452}
453
454/// Returns true if the span is an AI span.
455/// AI spans are spans with either a gen_ai.operation.name attribute or op starting with "ai."
456/// (legacy) or "gen_ai." (new).
457fn is_ai_span(span_data: &Annotated<SpanData>, span_op: Option<&OperationType>) -> bool {
458    let has_ai_op = span_data
459        .value()
460        .and_then(|data| data.gen_ai_operation_name.value())
461        .is_some();
462
463    let is_ai_span_op =
464        span_op.is_some_and(|op| op.starts_with("ai.") || op.starts_with("gen_ai."));
465
466    has_ai_op || is_ai_span_op
467}
468
469#[cfg(test)]
470mod tests {
471    use std::collections::HashMap;
472
473    use relay_pattern::Pattern;
474    use relay_protocol::{FromValue, assert_annotated_snapshot};
475    use serde_json::json;
476
477    use super::*;
478    use crate::ModelMetadataEntry;
479
480    fn ai_span_with_data(data: serde_json::Value) -> Span {
481        Span {
482            op: "gen_ai.test".to_owned().into(),
483            data: SpanData::from_value(data.into()),
484            ..Default::default()
485        }
486    }
487
488    #[test]
489    fn test_calculate_cost_no_tokens() {
490        let cost = calculate_costs(
491            &ModelCostV2 {
492                input_per_token: 1.0,
493                output_per_token: 1.0,
494                output_reasoning_per_token: 1.0,
495                input_cached_per_token: 1.0,
496                input_cache_write_per_token: 1.0,
497            },
498            UsedTokens::from_span_data(&SpanData::default()),
499            "test",
500            "test",
501        );
502        assert!(cost.is_none());
503    }
504
505    #[test]
506    fn test_calculate_cost_full() {
507        let cost = calculate_costs(
508            &ModelCostV2 {
509                input_per_token: 1.0,
510                output_per_token: 2.0,
511                output_reasoning_per_token: 3.0,
512                input_cached_per_token: 0.5,
513                input_cache_write_per_token: 0.75,
514            },
515            UsedTokens {
516                input_tokens: 8.0,
517                input_cached_tokens: 5.0,
518                input_cache_write_tokens: 0.0,
519                output_tokens: 15.0,
520                output_reasoning_tokens: 9.0,
521            },
522            "test",
523            "test",
524        )
525        .unwrap();
526
527        insta::assert_debug_snapshot!(cost, @r"
528        CalculatedCost {
529            input: 5.5,
530            output: 39.0,
531        }
532        ");
533    }
534
535    #[test]
536    fn test_calculate_cost_no_reasoning_cost() {
537        let cost = calculate_costs(
538            &ModelCostV2 {
539                input_per_token: 1.0,
540                output_per_token: 2.0,
541                // Should fallback to output token cost for reasoning.
542                output_reasoning_per_token: 0.0,
543                input_cached_per_token: 0.5,
544                input_cache_write_per_token: 0.0,
545            },
546            UsedTokens {
547                input_tokens: 8.0,
548                input_cached_tokens: 5.0,
549                input_cache_write_tokens: 0.0,
550                output_tokens: 15.0,
551                output_reasoning_tokens: 9.0,
552            },
553            "test",
554            "test",
555        )
556        .unwrap();
557
558        insta::assert_debug_snapshot!(cost, @r"
559        CalculatedCost {
560            input: 5.5,
561            output: 30.0,
562        }
563        ");
564    }
565
566    /// This test shows it is possible to produce negative costs if tokens are not aligned properly.
567    ///
568    /// The behaviour was desired when initially implemented.
569    #[test]
570    fn test_calculate_cost_negative() {
571        let cost = calculate_costs(
572            &ModelCostV2 {
573                input_per_token: 2.0,
574                output_per_token: 2.0,
575                output_reasoning_per_token: 1.0,
576                input_cached_per_token: 1.0,
577                input_cache_write_per_token: 1.5,
578            },
579            UsedTokens {
580                input_tokens: 1.0,
581                input_cached_tokens: 11.0,
582                input_cache_write_tokens: 0.0,
583                output_tokens: 1.0,
584                output_reasoning_tokens: 9.0,
585            },
586            "test",
587            "test",
588        )
589        .unwrap();
590
591        insta::assert_debug_snapshot!(cost, @r"
592        CalculatedCost {
593            input: -9.0,
594            output: -7.0,
595        }
596        ");
597    }
598
599    #[test]
600    fn test_calculate_cost_with_cache_writes() {
601        let cost = calculate_costs(
602            &ModelCostV2 {
603                input_per_token: 1.0,
604                output_per_token: 2.0,
605                output_reasoning_per_token: 3.0,
606                input_cached_per_token: 0.5,
607                input_cache_write_per_token: 0.75,
608            },
609            UsedTokens {
610                input_tokens: 100.0,
611                input_cached_tokens: 20.0,
612                input_cache_write_tokens: 30.0,
613                output_tokens: 50.0,
614                output_reasoning_tokens: 10.0,
615            },
616            "test",
617            "test",
618        )
619        .unwrap();
620
621        insta::assert_debug_snapshot!(cost, @r"
622        CalculatedCost {
623            input: 112.5,
624            output: 110.0,
625        }
626        ");
627    }
628
629    #[test]
630    fn test_calculate_cost_backward_compatibility_no_cache_write() {
631        // Test that cost calculation works when cache_write field is missing (backward compatibility)
632        let span_data = SpanData {
633            gen_ai_usage_input_tokens: Annotated::new(100.0.into()),
634            gen_ai_usage_input_tokens_cached: Annotated::new(20.0.into()),
635            gen_ai_usage_output_tokens: Annotated::new(50.0.into()),
636            // Note: gen_ai_usage_input_tokens_cache_write is NOT set (simulating old data)
637            ..Default::default()
638        };
639
640        let tokens = UsedTokens::from_span_data(&span_data);
641
642        // Verify cache_write_tokens defaults to 0.0
643        assert_eq!(tokens.input_cache_write_tokens, 0.0);
644
645        let cost = calculate_costs(
646            &ModelCostV2 {
647                input_per_token: 1.0,
648                output_per_token: 2.0,
649                output_reasoning_per_token: 0.0,
650                input_cached_per_token: 0.5,
651                input_cache_write_per_token: 0.75,
652            },
653            tokens,
654            "test",
655            "test",
656        )
657        .unwrap();
658
659        // Cost should be calculated without cache_write_tokens
660        // input: (100 - 20) * 1.0 + 20 * 0.5 + 0 * 0.75 = 80 + 10 + 0 = 90
661        // output: 50 * 2.0 = 100
662        insta::assert_debug_snapshot!(cost, @r"
663        CalculatedCost {
664            input: 90.0,
665            output: 100.0,
666        }
667        ");
668    }
669
670    /// Test that the AI operation type is inferred from a gen_ai.operation.name attribute.
671    #[test]
672    fn test_infer_ai_operation_type_from_gen_ai_operation_name() {
673        let mut span = ai_span_with_data(json!({
674            "gen_ai.operation.name": "invoke_agent"
675        }));
676
677        enrich_ai_span(&mut span, None);
678
679        assert_annotated_snapshot!(&span.data, @r#"
680        {
681          "gen_ai.operation.name": "invoke_agent",
682          "gen_ai.operation.type": "agent"
683        }
684        "#);
685    }
686
687    /// Test that the AI operation type is inferred from a span.op attribute.
688    #[test]
689    fn test_infer_ai_operation_type_from_span_op() {
690        let mut span = Span {
691            op: "gen_ai.invoke_agent".to_owned().into(),
692            ..Default::default()
693        };
694
695        enrich_ai_span(&mut span, None);
696
697        assert_annotated_snapshot!(span.data, @r#"
698        {
699          "gen_ai.operation.type": "agent"
700        }
701        "#);
702    }
703
704    /// Test that the AI operation type is inferred from a fallback.
705    #[test]
706    fn test_infer_ai_operation_type_from_fallback() {
707        let mut span = ai_span_with_data(json!({
708            "gen_ai.operation.name": "embeddings"
709        }));
710
711        enrich_ai_span(&mut span, None);
712
713        assert_annotated_snapshot!(&span.data, @r#"
714        {
715          "gen_ai.operation.name": "embeddings",
716          "gen_ai.operation.type": "ai_client"
717        }
718        "#);
719    }
720
721    /// Test that the response model is defaulted to the request model if not set.
722    #[test]
723    fn test_default_response_model_from_request_model() {
724        let mut span = ai_span_with_data(json!({
725            "gen_ai.request.model": "gpt-4",
726        }));
727
728        enrich_ai_span(&mut span, None);
729
730        assert_annotated_snapshot!(&span.data, @r#"
731        {
732          "gen_ai.response.model": "gpt-4",
733          "gen_ai.request.model": "gpt-4",
734          "gen_ai.operation.type": "ai_client"
735        }
736        "#);
737    }
738
739    /// Test that the response model is defaulted to the request model if not set.
740    #[test]
741    fn test_default_response_model_not_overridden() {
742        let mut span = ai_span_with_data(json!({
743            "gen_ai.request.model": "gpt-4",
744            "gen_ai.response.model": "gpt-4-abcd",
745        }));
746
747        enrich_ai_span(&mut span, None);
748
749        assert_annotated_snapshot!(&span.data, @r#"
750        {
751          "gen_ai.response.model": "gpt-4-abcd",
752          "gen_ai.request.model": "gpt-4",
753          "gen_ai.operation.type": "ai_client"
754        }
755        "#);
756    }
757
758    /// Test that gen_ai.agent.name is defaulted from gen_ai.function_id.
759    #[test]
760    fn test_default_agent_name_from_function_id() {
761        let mut span = ai_span_with_data(json!({
762            "gen_ai.function_id": "my-agent",
763        }));
764
765        enrich_ai_span(&mut span, None);
766
767        assert_annotated_snapshot!(&span.data, @r#"
768        {
769          "gen_ai.operation.type": "ai_client",
770          "gen_ai.agent.name": "my-agent",
771          "gen_ai.function_id": "my-agent"
772        }
773        "#);
774    }
775
776    /// Test that gen_ai.agent.name is not overridden when already set.
777    #[test]
778    fn test_default_agent_name_not_overridden() {
779        let mut span = ai_span_with_data(json!({
780            "gen_ai.function_id": "my-function",
781            "gen_ai.agent.name": "my-agent",
782        }));
783
784        enrich_ai_span(&mut span, None);
785
786        assert_annotated_snapshot!(&span.data, @r#"
787        {
788          "gen_ai.operation.type": "ai_client",
789          "gen_ai.agent.name": "my-agent",
790          "gen_ai.function_id": "my-function"
791        }
792        "#);
793    }
794
795    /// Test that an AI span is detected from a gen_ai.operation.name attribute.
796    #[test]
797    fn test_is_ai_span_from_gen_ai_operation_name() {
798        let mut span_data = Annotated::default();
799        span_data
800            .get_or_insert_with(SpanData::default)
801            .gen_ai_operation_name
802            .set_value(Some("chat".into()));
803        assert!(is_ai_span(&span_data, None));
804    }
805
806    /// Test that an AI span is detected from a span.op starting with "ai.".
807    #[test]
808    fn test_is_ai_span_from_span_op_ai() {
809        let span_op: OperationType = "ai.chat".into();
810        assert!(is_ai_span(&Annotated::default(), Some(&span_op)));
811    }
812
813    /// Test that an AI span is detected from a span.op starting with "gen_ai.".
814    #[test]
815    fn test_is_ai_span_from_span_op_gen_ai() {
816        let span_op: OperationType = "gen_ai.chat".into();
817        assert!(is_ai_span(&Annotated::default(), Some(&span_op)));
818    }
819
820    /// Test that a non-AI span is detected.
821    #[test]
822    fn test_is_ai_span_negative() {
823        assert!(!is_ai_span(&Annotated::default(), None));
824    }
825
826    /// Test enrich_ai_event_data with invoke_agent in trace context and a chat child span.
827    #[test]
828    fn test_enrich_ai_event_data_invoke_agent_trace_with_chat_span() {
829        let event_json = r#"{
830            "type": "transaction",
831            "timestamp": 1234567892.0,
832            "start_timestamp": 1234567889.0,
833            "contexts": {
834                "trace": {
835                    "op": "gen_ai.invoke_agent",
836                    "trace_id": "12345678901234567890123456789012",
837                    "span_id": "1234567890123456",
838                    "data": {
839                        "gen_ai.operation.name": "gen_ai.invoke_agent",
840                        "gen_ai.usage.input_tokens": 500,
841                        "gen_ai.usage.output_tokens": 200
842                    }
843                }
844            },
845            "spans": [
846                {
847                    "op": "gen_ai.chat.completions",
848                    "span_id": "1234567890123457",
849                    "start_timestamp": 1234567889.5,
850                    "timestamp": 1234567890.5,
851                    "data": {
852                        "gen_ai.operation.name": "chat",
853                        "gen_ai.usage.input_tokens": 100,
854                        "gen_ai.usage.output_tokens": 50
855                    }
856                }
857            ]
858        }"#;
859
860        let mut annotated_event: Annotated<Event> = Annotated::from_json(event_json).unwrap();
861        let event = annotated_event.value_mut().as_mut().unwrap();
862
863        enrich_ai_event_data(event, None);
864
865        assert_annotated_snapshot!(&annotated_event, @r#"
866        {
867          "type": "transaction",
868          "timestamp": 1234567892.0,
869          "start_timestamp": 1234567889.0,
870          "contexts": {
871            "trace": {
872              "trace_id": "12345678901234567890123456789012",
873              "span_id": "1234567890123456",
874              "op": "gen_ai.invoke_agent",
875              "data": {
876                "gen_ai.usage.total_tokens": 700.0,
877                "gen_ai.usage.input_tokens": 500,
878                "gen_ai.usage.output_tokens": 200,
879                "gen_ai.operation.name": "gen_ai.invoke_agent",
880                "gen_ai.operation.type": "agent"
881              },
882              "type": "trace"
883            }
884          },
885          "spans": [
886            {
887              "timestamp": 1234567890.5,
888              "start_timestamp": 1234567889.5,
889              "op": "gen_ai.chat.completions",
890              "span_id": "1234567890123457",
891              "data": {
892                "gen_ai.usage.total_tokens": 150.0,
893                "gen_ai.usage.input_tokens": 100,
894                "gen_ai.usage.output_tokens": 50,
895                "gen_ai.operation.name": "chat",
896                "gen_ai.operation.type": "ai_client"
897              }
898            }
899          ]
900        }
901        "#);
902    }
903
904    /// Test enrich_ai_event_data with non-AI trace context, invoke_agent parent span, and chat child span.
905    #[test]
906    fn test_enrich_ai_event_data_nested_agent_and_chat_spans() {
907        let event_json = r#"{
908            "type": "transaction",
909            "timestamp": 1234567892.0,
910            "start_timestamp": 1234567889.0,
911            "contexts": {
912                "trace": {
913                    "op": "http.server",
914                    "trace_id": "12345678901234567890123456789012",
915                    "span_id": "1234567890123456"
916                }
917            },
918            "spans": [
919                {
920                    "op": "gen_ai.invoke_agent",
921                    "span_id": "1234567890123457",
922                    "parent_span_id": "1234567890123456",
923                    "start_timestamp": 1234567889.5,
924                    "timestamp": 1234567891.5,
925                    "data": {
926                        "gen_ai.operation.name": "invoke_agent",
927                        "gen_ai.usage.input_tokens": 500,
928                        "gen_ai.usage.output_tokens": 200
929                    }
930                },
931                {
932                    "op": "gen_ai.chat.completions",
933                    "span_id": "1234567890123458",
934                    "parent_span_id": "1234567890123457",
935                    "start_timestamp": 1234567890.0,
936                    "timestamp": 1234567891.0,
937                    "data": {
938                        "gen_ai.operation.name": "chat",
939                        "gen_ai.usage.input_tokens": 100,
940                        "gen_ai.usage.output_tokens": 50
941                    }
942                }
943            ]
944        }"#;
945
946        let mut annotated_event: Annotated<Event> = Annotated::from_json(event_json).unwrap();
947        let event = annotated_event.value_mut().as_mut().unwrap();
948
949        enrich_ai_event_data(event, None);
950
951        assert_annotated_snapshot!(&annotated_event, @r#"
952        {
953          "type": "transaction",
954          "timestamp": 1234567892.0,
955          "start_timestamp": 1234567889.0,
956          "contexts": {
957            "trace": {
958              "trace_id": "12345678901234567890123456789012",
959              "span_id": "1234567890123456",
960              "op": "http.server",
961              "type": "trace"
962            }
963          },
964          "spans": [
965            {
966              "timestamp": 1234567891.5,
967              "start_timestamp": 1234567889.5,
968              "op": "gen_ai.invoke_agent",
969              "span_id": "1234567890123457",
970              "parent_span_id": "1234567890123456",
971              "data": {
972                "gen_ai.usage.total_tokens": 700.0,
973                "gen_ai.usage.input_tokens": 500,
974                "gen_ai.usage.output_tokens": 200,
975                "gen_ai.operation.name": "invoke_agent",
976                "gen_ai.operation.type": "agent"
977              }
978            },
979            {
980              "timestamp": 1234567891.0,
981              "start_timestamp": 1234567890.0,
982              "op": "gen_ai.chat.completions",
983              "span_id": "1234567890123458",
984              "parent_span_id": "1234567890123457",
985              "data": {
986                "gen_ai.usage.total_tokens": 150.0,
987                "gen_ai.usage.input_tokens": 100,
988                "gen_ai.usage.output_tokens": 50,
989                "gen_ai.operation.name": "chat",
990                "gen_ai.operation.type": "ai_client"
991              }
992            }
993          ]
994        }
995        "#);
996    }
997
998    /// Test enrich_ai_event_data with legacy measurements and span op for operation type.
999    #[test]
1000    fn test_enrich_ai_event_data_legacy_measurements_and_span_op() {
1001        let event_json = r#"{
1002            "type": "transaction",
1003            "timestamp": 1234567892.0,
1004            "start_timestamp": 1234567889.0,
1005            "contexts": {
1006                "trace": {
1007                    "op": "http.server",
1008                    "trace_id": "12345678901234567890123456789012",
1009                    "span_id": "1234567890123456"
1010                }
1011            },
1012            "spans": [
1013                {
1014                    "op": "gen_ai.invoke_agent",
1015                    "span_id": "1234567890123457",
1016                    "parent_span_id": "1234567890123456",
1017                    "start_timestamp": 1234567889.5,
1018                    "timestamp": 1234567891.5,
1019                    "measurements": {
1020                        "ai_prompt_tokens_used": {"value": 500.0},
1021                        "ai_completion_tokens_used": {"value": 200.0}
1022                    }
1023                },
1024                {
1025                    "op": "ai.chat_completions.create.langchain.ChatOpenAI",
1026                    "span_id": "1234567890123458",
1027                    "parent_span_id": "1234567890123457",
1028                    "start_timestamp": 1234567890.0,
1029                    "timestamp": 1234567891.0,
1030                    "measurements": {
1031                        "ai_prompt_tokens_used": {"value": 100.0},
1032                        "ai_completion_tokens_used": {"value": 50.0}
1033                    }
1034                }
1035            ]
1036        }"#;
1037
1038        let mut annotated_event: Annotated<Event> = Annotated::from_json(event_json).unwrap();
1039        let event = annotated_event.value_mut().as_mut().unwrap();
1040
1041        enrich_ai_event_data(event, None);
1042
1043        assert_annotated_snapshot!(&annotated_event, @r#"
1044        {
1045          "type": "transaction",
1046          "timestamp": 1234567892.0,
1047          "start_timestamp": 1234567889.0,
1048          "contexts": {
1049            "trace": {
1050              "trace_id": "12345678901234567890123456789012",
1051              "span_id": "1234567890123456",
1052              "op": "http.server",
1053              "type": "trace"
1054            }
1055          },
1056          "spans": [
1057            {
1058              "timestamp": 1234567891.5,
1059              "start_timestamp": 1234567889.5,
1060              "op": "gen_ai.invoke_agent",
1061              "span_id": "1234567890123457",
1062              "parent_span_id": "1234567890123456",
1063              "data": {
1064                "gen_ai.usage.total_tokens": 700.0,
1065                "gen_ai.usage.input_tokens": 500.0,
1066                "gen_ai.usage.output_tokens": 200.0,
1067                "gen_ai.operation.type": "agent"
1068              },
1069              "measurements": {
1070                "ai_completion_tokens_used": {
1071                  "value": 200.0
1072                },
1073                "ai_prompt_tokens_used": {
1074                  "value": 500.0
1075                }
1076              }
1077            },
1078            {
1079              "timestamp": 1234567891.0,
1080              "start_timestamp": 1234567890.0,
1081              "op": "ai.chat_completions.create.langchain.ChatOpenAI",
1082              "span_id": "1234567890123458",
1083              "parent_span_id": "1234567890123457",
1084              "data": {
1085                "gen_ai.usage.total_tokens": 150.0,
1086                "gen_ai.usage.input_tokens": 100.0,
1087                "gen_ai.usage.output_tokens": 50.0,
1088                "gen_ai.operation.type": "ai_client"
1089              },
1090              "measurements": {
1091                "ai_completion_tokens_used": {
1092                  "value": 50.0
1093                },
1094                "ai_prompt_tokens_used": {
1095                  "value": 100.0
1096                }
1097              }
1098            }
1099          ]
1100        }
1101        "#);
1102    }
1103
1104    fn metadata_with_context_size() -> ModelMetadata {
1105        ModelMetadata {
1106            version: 1,
1107            models: HashMap::from([(
1108                Pattern::new("claude-2.1").unwrap(),
1109                ModelMetadataEntry {
1110                    costs: Some(ModelCostV2 {
1111                        input_per_token: 0.01,
1112                        output_per_token: 0.02,
1113                        output_reasoning_per_token: 0.0,
1114                        input_cached_per_token: 0.0,
1115                        input_cache_write_per_token: 0.0,
1116                    }),
1117                    context_size: Some(100_000),
1118                },
1119            )]),
1120        }
1121    }
1122
1123    #[test]
1124    fn test_context_utilization_with_total_tokens() {
1125        let mut span = Span {
1126            op: "gen_ai.test".to_owned().into(),
1127            data: SpanData::from_value(
1128                json!({
1129                    "gen_ai.response.model": "claude-2.1",
1130                    "gen_ai.usage.input_tokens": 30000.0,
1131                    "gen_ai.usage.output_tokens": 12000.0,
1132                    "gen_ai.usage.total_tokens": 42000.0,
1133                })
1134                .into(),
1135            ),
1136            ..Default::default()
1137        };
1138
1139        enrich_ai_span(&mut span, Some(&metadata_with_context_size()));
1140
1141        let data = span.data.value().unwrap();
1142        assert_eq!(
1143            data.gen_ai_context_window_size
1144                .value()
1145                .and_then(Value::as_f64),
1146            Some(100_000.0)
1147        );
1148        assert_eq!(
1149            data.gen_ai_context_utilization
1150                .value()
1151                .and_then(Value::as_f64),
1152            Some(0.42)
1153        );
1154    }
1155
1156    #[test]
1157    fn test_context_utilization_no_context_size() {
1158        let metadata = ModelMetadata {
1159            version: 1,
1160            models: HashMap::from([(
1161                Pattern::new("claude-2.1").unwrap(),
1162                ModelMetadataEntry {
1163                    costs: None,
1164                    context_size: None,
1165                },
1166            )]),
1167        };
1168
1169        let mut span = Span {
1170            op: "gen_ai.test".to_owned().into(),
1171            data: SpanData::from_value(
1172                json!({
1173                    "gen_ai.response.model": "claude-2.1",
1174                    "gen_ai.usage.total_tokens": 1000.0,
1175                })
1176                .into(),
1177            ),
1178            ..Default::default()
1179        };
1180
1181        enrich_ai_span(&mut span, Some(&metadata));
1182
1183        let data = span.data.value().unwrap();
1184        assert!(data.gen_ai_context_window_size.value().is_none());
1185        assert!(data.gen_ai_context_utilization.value().is_none());
1186    }
1187
1188    #[test]
1189    fn test_context_utilization_no_total_tokens() {
1190        let mut span = Span {
1191            op: "gen_ai.test".to_owned().into(),
1192            data: SpanData::from_value(
1193                json!({
1194                    "gen_ai.response.model": "claude-2.1",
1195                })
1196                .into(),
1197            ),
1198            ..Default::default()
1199        };
1200
1201        enrich_ai_span(&mut span, Some(&metadata_with_context_size()));
1202
1203        let data = span.data.value().unwrap();
1204        // window_size should still be set even without tokens.
1205        assert_eq!(
1206            data.gen_ai_context_window_size
1207                .value()
1208                .and_then(Value::as_f64),
1209            Some(100_000.0)
1210        );
1211        // But utilization cannot be computed without total_tokens.
1212        assert!(data.gen_ai_context_utilization.value().is_none());
1213    }
1214
1215    #[test]
1216    fn test_context_utilization_unknown_model() {
1217        let mut span = Span {
1218            op: "gen_ai.test".to_owned().into(),
1219            data: SpanData::from_value(
1220                json!({
1221                    "gen_ai.response.model": "unknown-model",
1222                    "gen_ai.usage.total_tokens": 1000.0,
1223                })
1224                .into(),
1225            ),
1226            ..Default::default()
1227        };
1228
1229        enrich_ai_span(&mut span, Some(&metadata_with_context_size()));
1230
1231        let data = span.data.value().unwrap();
1232        assert!(data.gen_ai_context_window_size.value().is_none());
1233        assert!(data.gen_ai_context_utilization.value().is_none());
1234    }
1235}