relay_event_normalization/normalize/span/
ai.rs1use crate::normalize::AiOperationTypeMap;
4use crate::{ModelCostV2, ModelCosts};
5use relay_event_schema::protocol::{Event, Span, SpanData};
6use relay_protocol::{Annotated, Getter, Value};
7
8#[derive(Debug, Copy, Clone)]
10pub struct UsedTokens {
11 pub input_tokens: f64,
13 pub input_cached_tokens: f64,
17 pub output_tokens: f64,
19 pub output_reasoning_tokens: f64,
23}
24
25impl UsedTokens {
26 pub fn from_span_data(data: &SpanData) -> Self {
28 macro_rules! get_value {
29 ($e:expr) => {
30 $e.value().and_then(Value::as_f64).unwrap_or(0.0)
31 };
32 }
33
34 Self {
35 input_tokens: get_value!(data.gen_ai_usage_input_tokens),
36 output_tokens: get_value!(data.gen_ai_usage_output_tokens),
37 output_reasoning_tokens: get_value!(data.gen_ai_usage_output_tokens_reasoning),
38 input_cached_tokens: get_value!(data.gen_ai_usage_input_tokens_cached),
39 }
40 }
41
42 pub fn has_usage(&self) -> bool {
44 self.input_tokens > 0.0 || self.output_tokens > 0.0
45 }
46
47 pub fn raw_input_tokens(&self) -> f64 {
51 self.input_tokens - self.input_cached_tokens
52 }
53
54 pub fn raw_output_tokens(&self) -> f64 {
58 self.output_tokens - self.output_reasoning_tokens
59 }
60}
61
62#[derive(Debug, Copy, Clone)]
64pub struct CalculatedCost {
65 pub input: f64,
67 pub output: f64,
69}
70
71impl CalculatedCost {
72 pub fn total(&self) -> f64 {
74 self.input + self.output
75 }
76}
77
78pub fn calculate_costs(model_cost: &ModelCostV2, tokens: UsedTokens) -> Option<CalculatedCost> {
82 if !tokens.has_usage() {
83 return None;
84 }
85
86 let input = (tokens.raw_input_tokens() * model_cost.input_per_token)
87 + (tokens.input_cached_tokens * model_cost.input_cached_per_token);
88
89 let reasoning_cost = match model_cost.output_reasoning_per_token {
92 reasoning_cost if reasoning_cost > 0.0 => reasoning_cost,
93 _ => model_cost.output_per_token,
94 };
95
96 let output = (tokens.raw_output_tokens() * model_cost.output_per_token)
97 + (tokens.output_reasoning_tokens * reasoning_cost);
98
99 Some(CalculatedCost { input, output })
100}
101
102fn extract_ai_model_cost_data(model_cost: Option<&ModelCostV2>, data: &mut SpanData) {
105 let Some(model_cost) = model_cost else { return };
106
107 let used_tokens = UsedTokens::from_span_data(&*data);
108 let Some(costs) = calculate_costs(model_cost, used_tokens) else {
109 return;
110 };
111
112 data.gen_ai_usage_total_cost
115 .set_value(Value::F64(costs.total()).into());
116 data.gen_ai_cost_total_tokens
117 .set_value(Value::F64(costs.total()).into());
118
119 data.gen_ai_cost_input_tokens
121 .set_value(Value::F64(costs.input).into());
122 data.gen_ai_cost_output_tokens
123 .set_value(Value::F64(costs.output).into());
124}
125
126fn map_ai_measurements_to_data(span: &mut Span) {
128 let measurements = span.measurements.value();
129 let data = span.data.get_or_insert_with(SpanData::default);
130
131 let set_field_from_measurement = |target_field: &mut Annotated<Value>,
132 measurement_key: &str| {
133 if let Some(measurements) = measurements
134 && target_field.value().is_none()
135 && let Some(value) = measurements.get_value(measurement_key)
136 {
137 target_field.set_value(Value::F64(value.to_f64()).into());
138 }
139 };
140
141 set_field_from_measurement(&mut data.gen_ai_usage_total_tokens, "ai_total_tokens_used");
142 set_field_from_measurement(&mut data.gen_ai_usage_input_tokens, "ai_prompt_tokens_used");
143 set_field_from_measurement(
144 &mut data.gen_ai_usage_output_tokens,
145 "ai_completion_tokens_used",
146 );
147}
148
149fn set_total_tokens(span: &mut Span) {
150 let data = span.data.get_or_insert_with(SpanData::default);
151
152 if data.gen_ai_usage_total_tokens.value().is_none() {
154 let input_tokens = data
155 .gen_ai_usage_input_tokens
156 .value()
157 .and_then(Value::as_f64);
158 let output_tokens = data
159 .gen_ai_usage_output_tokens
160 .value()
161 .and_then(Value::as_f64);
162
163 if input_tokens.is_none() && output_tokens.is_none() {
164 return;
166 }
167
168 data.gen_ai_usage_total_tokens.set_value(
169 Value::F64(input_tokens.unwrap_or(0.0) + output_tokens.unwrap_or(0.0)).into(),
170 );
171 }
172}
173
174fn extract_ai_data(span: &mut Span, ai_model_costs: &ModelCosts) {
176 let duration = span
177 .get_value("span.duration")
178 .and_then(|v| v.as_f64())
179 .unwrap_or(0.0);
180
181 let data = span.data.get_or_insert_with(SpanData::default);
182
183 if data.gen_ai_response_tokens_per_second.value().is_none()
185 && duration > 0.0
186 && let Some(output_tokens) = data
187 .gen_ai_usage_output_tokens
188 .value()
189 .and_then(Value::as_f64)
190 {
191 data.gen_ai_response_tokens_per_second
192 .set_value(Value::F64(output_tokens / (duration / 1000.0)).into());
193 }
194
195 if let Some(model_id) = data
197 .gen_ai_request_model
198 .value()
199 .and_then(|val| val.as_str())
200 .or_else(|| {
201 data.gen_ai_response_model
202 .value()
203 .and_then(|val| val.as_str())
204 })
205 {
206 extract_ai_model_cost_data(ai_model_costs.cost_per_token(model_id), data)
207 }
208}
209
210pub fn enrich_ai_span_data(
212 span: &mut Span,
213 model_costs: Option<&ModelCosts>,
214 operation_type_map: Option<&AiOperationTypeMap>,
215) {
216 if !is_ai_span(span) {
217 return;
218 }
219
220 map_ai_measurements_to_data(span);
221 set_total_tokens(span);
222
223 if let Some(model_costs) = model_costs {
224 extract_ai_data(span, model_costs);
225 }
226 if let Some(operation_type_map) = operation_type_map {
227 infer_ai_operation_type(span, operation_type_map);
228 }
229}
230
231pub fn enrich_ai_event_data(
233 event: &mut Event,
234 model_costs: Option<&ModelCosts>,
235 operation_type_map: Option<&AiOperationTypeMap>,
236) {
237 let spans = event.spans.value_mut().iter_mut().flatten();
238 let spans = spans.filter_map(|span| span.value_mut().as_mut());
239
240 for span in spans {
241 enrich_ai_span_data(span, model_costs, operation_type_map);
242 }
243}
244
245fn infer_ai_operation_type(span: &mut Span, operation_type_map: &AiOperationTypeMap) {
250 let data = span.data.get_or_insert_with(SpanData::default);
251
252 if let Some(op) = span.op.value()
253 && let Some(operation_type) = operation_type_map.get_operation_type(op)
254 {
255 data.gen_ai_operation_type
256 .set_value(Some(operation_type.to_owned()));
257 }
258}
259
260fn is_ai_span(span: &Span) -> bool {
263 span.op
264 .value()
265 .is_some_and(|op| op.starts_with("ai.") || op.starts_with("gen_ai."))
266}
267
268#[cfg(test)]
269mod tests {
270 use super::*;
271
272 #[test]
273 fn test_calculate_cost_no_tokens() {
274 let cost = calculate_costs(
275 &ModelCostV2 {
276 input_per_token: 1.0,
277 output_per_token: 1.0,
278 output_reasoning_per_token: 1.0,
279 input_cached_per_token: 1.0,
280 },
281 UsedTokens::from_span_data(&SpanData::default()),
282 );
283 assert!(cost.is_none());
284 }
285
286 #[test]
287 fn test_calculate_cost_full() {
288 let cost = calculate_costs(
289 &ModelCostV2 {
290 input_per_token: 1.0,
291 output_per_token: 2.0,
292 output_reasoning_per_token: 3.0,
293 input_cached_per_token: 0.5,
294 },
295 UsedTokens {
296 input_tokens: 8.0,
297 input_cached_tokens: 5.0,
298 output_tokens: 15.0,
299 output_reasoning_tokens: 9.0,
300 },
301 )
302 .unwrap();
303
304 insta::assert_debug_snapshot!(cost, @r"
305 CalculatedCost {
306 input: 5.5,
307 output: 39.0,
308 }
309 ");
310 }
311
312 #[test]
313 fn test_calculate_cost_no_reasoning_cost() {
314 let cost = calculate_costs(
315 &ModelCostV2 {
316 input_per_token: 1.0,
317 output_per_token: 2.0,
318 output_reasoning_per_token: 0.0,
320 input_cached_per_token: 0.5,
321 },
322 UsedTokens {
323 input_tokens: 8.0,
324 input_cached_tokens: 5.0,
325 output_tokens: 15.0,
326 output_reasoning_tokens: 9.0,
327 },
328 )
329 .unwrap();
330
331 insta::assert_debug_snapshot!(cost, @r"
332 CalculatedCost {
333 input: 5.5,
334 output: 30.0,
335 }
336 ");
337 }
338
339 #[test]
343 fn test_calculate_cost_negative() {
344 let cost = calculate_costs(
345 &ModelCostV2 {
346 input_per_token: 2.0,
347 output_per_token: 2.0,
348 output_reasoning_per_token: 1.0,
349 input_cached_per_token: 1.0,
350 },
351 UsedTokens {
352 input_tokens: 1.0,
353 input_cached_tokens: 11.0,
354 output_tokens: 1.0,
355 output_reasoning_tokens: 9.0,
356 },
357 )
358 .unwrap();
359
360 insta::assert_debug_snapshot!(cost, @r"
361 CalculatedCost {
362 input: -9.0,
363 output: -7.0,
364 }
365 ");
366 }
367}