1use std::io;
2use std::task::Poll;
3
4use axum::extract::FromRequest;
5use axum::extract::Request;
6use bytes::{Bytes, BytesMut};
7use futures::{StreamExt, TryStreamExt};
8use multer::{Field, Multipart};
9use relay_config::Config;
10use serde::{Deserialize, Serialize};
11
12use crate::envelope::{AttachmentType, ContentType, Item, ItemType, Items};
13use crate::extractors::Remote;
14use crate::service::ServiceState;
15
16type Len = u32;
18
19fn write_string<W>(mut writer: W, string: &str) -> io::Result<()>
21where
22 W: io::Write,
23{
24 writer.write_all(&(string.len() as Len).to_le_bytes())?;
25 writer.write_all(string.as_bytes())?;
26
27 Ok(())
28}
29
30fn split_front<'a>(data: &mut &'a [u8], len: usize) -> Option<&'a [u8]> {
32 if data.len() < len {
33 *data = &[];
34 return None;
35 }
36
37 let (slice, rest) = data.split_at(len);
38 *data = rest;
39 Some(slice)
40}
41
42fn consume_len(data: &mut &[u8]) -> Option<usize> {
44 let len = std::mem::size_of::<Len>();
45 let slice = split_front(data, len)?;
46 let bytes = slice.try_into().ok();
47 bytes.map(|b| Len::from_le_bytes(b) as usize)
48}
49
50fn consume_string<'a>(data: &mut &'a [u8]) -> Option<&'a str> {
52 let len = consume_len(data)?;
53 let bytes = split_front(data, len)?;
54 std::str::from_utf8(bytes).ok()
55}
56
57#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
59pub struct FormDataEntry<'a>(&'a str, &'a str);
60
61impl<'a> FormDataEntry<'a> {
62 pub fn new(key: &'a str, value: &'a str) -> Self {
63 Self(key, value)
64 }
65
66 pub fn key(&self) -> &'a str {
67 self.0
68 }
69
70 pub fn value(&self) -> &'a str {
71 self.1
72 }
73
74 fn to_writer<W: io::Write>(&self, mut writer: W) {
75 write_string(&mut writer, self.key()).ok();
76 write_string(&mut writer, self.value()).ok();
77 }
78
79 fn read(data: &mut &'a [u8]) -> Option<Self> {
80 let key = consume_string(data)?;
81 let value = consume_string(data)?;
82 Some(Self::new(key, value))
83 }
84}
85
86struct FormDataWriter {
91 data: Vec<u8>,
92}
93
94impl FormDataWriter {
95 pub fn new() -> Self {
96 Self { data: Vec::new() }
97 }
98
99 pub fn append(&mut self, key: &str, value: &str) {
100 let entry = FormDataEntry::new(key, value);
101 entry.to_writer(&mut self.data);
102 }
103
104 pub fn into_inner(self) -> Vec<u8> {
105 self.data
106 }
107}
108
109pub struct FormDataIter<'a> {
111 data: &'a [u8],
112}
113
114impl<'a> FormDataIter<'a> {
115 pub fn new(data: &'a [u8]) -> Self {
116 Self { data }
117 }
118}
119
120impl<'a> Iterator for FormDataIter<'a> {
121 type Item = FormDataEntry<'a>;
122
123 fn next(&mut self) -> Option<Self::Item> {
124 while !self.data.is_empty() {
125 match FormDataEntry::read(&mut self.data) {
126 Some(entry) => return Some(entry),
127 None => relay_log::error!("form data deserialization failed"),
128 }
129 }
130
131 None
132 }
133}
134
135pub fn get_multipart_boundary(data: &[u8]) -> Option<&str> {
149 data.split(|&byte| byte == b'\r' || byte == b'\n')
150 .find(|slice| !slice.is_empty())
152 .filter(|slice| slice.len() > 2 && slice.starts_with(b"--"))
154 .and_then(|slice| std::str::from_utf8(&slice[2..]).ok())
156}
157
158async fn multipart_items<F>(
159 mut multipart: Multipart<'_>,
160 mut infer_type: F,
161 config: &Config,
162 ignore_large_fields: bool,
163) -> Result<Items, multer::Error>
164where
165 F: FnMut(Option<&str>, &str) -> AttachmentType,
166{
167 let mut items = Items::new();
168 let mut form_data = FormDataWriter::new();
169 let mut attachments_size = 0;
170
171 while let Some(field) = multipart.next_field().await? {
172 if let Some(file_name) = field.file_name() {
173 let mut item = Item::new(ItemType::Attachment);
174 item.set_attachment_type(infer_type(field.name(), file_name));
175 item.set_filename(file_name);
176
177 let content_type = field.content_type().cloned();
178 let field = LimitedField::new(field, config.max_attachment_size());
179 match field.bytes().await {
180 Err(multer::Error::FieldSizeExceeded { .. }) if ignore_large_fields => continue,
181 Err(err) => return Err(err),
182 Ok(bytes) => {
183 attachments_size += bytes.len();
184
185 if attachments_size > config.max_attachments_size() {
186 return Err(multer::Error::StreamSizeExceeded {
187 limit: config.max_attachments_size() as u64,
188 });
189 }
190
191 if let Some(content_type) = content_type {
192 item.set_payload(content_type.as_ref().into(), bytes);
193 } else {
194 item.set_payload_without_content_type(bytes);
195 }
196 }
197 }
198
199 items.push(item);
200 } else if let Some(field_name) = field.name().map(str::to_owned) {
201 let string = field.text().await?;
204 form_data.append(&field_name, &string);
205 } else {
206 relay_log::trace!("multipart content without name or file_name");
207 }
208 }
209
210 let form_data = form_data.into_inner();
211 if !form_data.is_empty() {
212 let mut item = Item::new(ItemType::FormData);
213 item.set_payload(ContentType::Text, form_data);
216 items.push(item);
217 }
218
219 Ok(items)
220}
221
222struct LimitedField<'a> {
227 field: Field<'a>,
228 consumed_size: usize,
229 size_limit: usize,
230 inner_finished: bool,
231}
232
233impl<'a> LimitedField<'a> {
234 fn new(field: Field<'a>, limit: usize) -> Self {
235 LimitedField {
236 field,
237 consumed_size: 0,
238 size_limit: limit,
239 inner_finished: false,
240 }
241 }
242
243 async fn bytes(self) -> Result<Bytes, multer::Error> {
244 self.try_fold(BytesMut::new(), |mut acc, x| async move {
245 acc.extend_from_slice(&x);
246 Ok(acc)
247 })
248 .await
249 .map(|x| x.freeze())
250 }
251}
252
253impl futures::Stream for LimitedField<'_> {
254 type Item = Result<Bytes, multer::Error>;
255
256 fn poll_next(
257 mut self: std::pin::Pin<&mut Self>,
258 cx: &mut std::task::Context<'_>,
259 ) -> std::task::Poll<Option<Self::Item>> {
260 if self.inner_finished {
261 return Poll::Ready(None);
262 }
263
264 match self.field.poll_next_unpin(cx) {
265 err @ Poll::Ready(Some(Err(_))) => err,
266 Poll::Ready(Some(Ok(t))) => {
267 self.consumed_size += t.len();
268 match self.consumed_size <= self.size_limit {
269 true => Poll::Ready(Some(Ok(t))),
270 false => {
271 cx.waker().wake_by_ref();
272 Poll::Pending
273 }
274 }
275 }
276 Poll::Ready(None) if self.consumed_size > self.size_limit => {
277 self.inner_finished = true;
278 Poll::Ready(Some(Err(multer::Error::FieldSizeExceeded {
279 limit: self.size_limit as u64,
280 field_name: self.field.name().map(Into::into),
281 })))
282 }
283 Poll::Ready(None) => {
284 self.inner_finished = true;
285 Poll::Ready(None)
286 }
287 Poll::Pending => Poll::Pending,
288 }
289 }
290}
291
292pub struct ConstrainedMultipart(pub Multipart<'static>);
296
297impl FromRequest<ServiceState> for ConstrainedMultipart {
298 type Rejection = Remote<multer::Error>;
299
300 async fn from_request(request: Request, state: &ServiceState) -> Result<Self, Self::Rejection> {
301 let limits =
303 multer::SizeLimit::new().whole_stream(state.config().max_attachments_size() as u64);
304
305 multipart_from_request(request, multer::Constraints::new().size_limit(limits))
306 .map(Self)
307 .map_err(Remote)
308 }
309}
310
311impl ConstrainedMultipart {
312 pub async fn items<F>(self, infer_type: F, config: &Config) -> Result<Items, multer::Error>
313 where
314 F: FnMut(Option<&str>, &str) -> AttachmentType,
315 {
316 multipart_items(self.0, infer_type, config, false).await
317 }
318}
319
320#[allow(dead_code)]
324pub struct UnconstrainedMultipart(pub Multipart<'static>);
325
326impl FromRequest<ServiceState> for UnconstrainedMultipart {
327 type Rejection = Remote<multer::Error>;
328
329 async fn from_request(
330 request: Request,
331 _state: &ServiceState,
332 ) -> Result<Self, Self::Rejection> {
333 multipart_from_request(request, multer::Constraints::new())
334 .map(Self)
335 .map_err(Remote)
336 }
337}
338
339#[cfg_attr(not(any(test, sentry)), expect(dead_code))]
340impl UnconstrainedMultipart {
341 pub async fn items<F>(self, infer_type: F, config: &Config) -> Result<Items, multer::Error>
342 where
343 F: FnMut(Option<&str>, &str) -> AttachmentType,
344 {
345 multipart_items(self.0, infer_type, config, true).await
346 }
347}
348
349pub fn multipart_from_request(
350 request: Request,
351 constraints: multer::Constraints,
352) -> Result<Multipart<'static>, multer::Error> {
353 let content_type = request
354 .headers()
355 .get("content-type")
356 .and_then(|v| v.to_str().ok())
357 .unwrap_or("");
358 let boundary = multer::parse_boundary(content_type)?;
359
360 Ok(Multipart::with_constraints(
361 request.into_body().into_data_stream(),
362 boundary,
363 constraints,
364 ))
365}
366
367#[cfg(test)]
368mod tests {
369 use std::convert::Infallible;
370
371 use super::*;
372
373 #[test]
374 fn test_get_boundary() {
375 let examples: &[(&[u8], Option<&str>)] = &[
376 (b"--some_val", Some("some_val")),
377 (b"--\nsecond line", None),
378 (b"\n\r--some_val", Some("some_val")),
379 (b"\n\r--some_val\nadfa", Some("some_val")),
380 (b"\n\r--some_val\rfasdf", Some("some_val")),
381 (b"\n\r--some_val\r\nfasdf", Some("some_val")),
382 (b"\n\rsome_val", None),
383 (b"", None),
384 (b"--", None),
385 ];
386
387 for (input, expected) in examples {
388 let boundary = get_multipart_boundary(input);
389 assert_eq!(*expected, boundary);
390 }
391 }
392
393 #[test]
394 fn test_formdata() {
395 let mut writer = FormDataWriter::new();
396 writer.append("foo", "foo");
397 writer.append("bar", "");
398 writer.append("blub", "blub");
399
400 let payload = writer.into_inner();
401 let iter = FormDataIter::new(&payload);
402 let entries: Vec<_> = iter.collect();
403
404 assert_eq!(
405 entries,
406 vec![
407 FormDataEntry::new("foo", "foo"),
408 FormDataEntry::new("bar", ""),
409 FormDataEntry::new("blub", "blub"),
410 ]
411 );
412 }
413
414 #[test]
415 fn test_empty_formdata() {
416 let writer = FormDataWriter::new();
417 let payload = writer.into_inner();
418
419 let iter = FormDataIter::new(&payload);
420 let entries: Vec<_> = iter.collect();
421
422 assert_eq!(entries, vec![]);
423 }
424
425 #[tokio::test]
427 async fn missing_trailing_newline() -> anyhow::Result<()> {
428 let data = "--X-BOUNDARY\r\nContent-Disposition: form-data; \
429 name=\"my_text_field\"\r\n\r\nabcd\r\n--X-BOUNDARY--"; let stream = futures::stream::once(async { Ok::<_, Infallible>(data) });
432 let mut multipart = Multipart::new(stream, "X-BOUNDARY");
433
434 assert!(multipart.next_field().await?.is_some());
435 assert!(multipart.next_field().await?.is_none());
436
437 Ok(())
438 }
439
440 #[tokio::test]
441 async fn test_individual_size_limit_exceeded() -> anyhow::Result<()> {
442 let data = "--X-BOUNDARY\r\n\
443 Content-Disposition: form-data; name=\"file\"; filename=\"large.txt\"\r\n\
444 Content-Type: text/plain\r\n\
445 \r\n\
446 content too large for limit\r\n\
447 --X-BOUNDARY\r\n\
448 Content-Disposition: form-data; name=\"small_file\"; filename=\"small.txt\"\r\n\
449 Content-Type: text/plain\r\n\
450 \r\n\
451 ok\r\n\
452 --X-BOUNDARY--\r\n";
453
454 let stream = futures::stream::once(async move { Ok::<_, Infallible>(data) });
455 let multipart = Multipart::new(stream, "X-BOUNDARY");
456
457 let config = Config::from_json_value(serde_json::json!({
458 "limits": {
459 "max_attachment_size": 5
460 }
461 }))?;
462
463 let items = UnconstrainedMultipart(multipart)
464 .items(|_, _| AttachmentType::Attachment, &config)
465 .await?;
466
467 assert_eq!(items.len(), 1);
469 let item = &items[0];
470 assert_eq!(item.filename(), Some("small.txt"));
471 assert_eq!(item.payload(), Bytes::from("ok"));
472
473 Ok(())
474 }
475
476 #[tokio::test]
477 async fn test_collective_size_limit_exceeded() -> anyhow::Result<()> {
478 let data = "--X-BOUNDARY\r\n\
479 Content-Disposition: form-data; name=\"file\"; filename=\"large.txt\"\r\n\
480 Content-Type: text/plain\r\n\
481 \r\n\
482 content too large for limit\r\n\
483 --X-BOUNDARY\r\n\
484 Content-Disposition: form-data; name=\"small_file\"; filename=\"small.txt\"\r\n\
485 Content-Type: text/plain\r\n\
486 \r\n\
487 ok\r\n\
488 --X-BOUNDARY--\r\n";
489
490 let stream = futures::stream::once(async move { Ok::<_, Infallible>(data) });
491
492 let config = Config::from_json_value(serde_json::json!({
493 "limits": {
494 "max_attachments_size": 5
495 }
496 }))?;
497
498 let multipart = Multipart::new(stream, "X-BOUNDARY");
499
500 let result = UnconstrainedMultipart(multipart)
501 .items(|_, _| AttachmentType::Attachment, &config)
502 .await;
503
504 assert!(result.is_err_and(|x| matches!(x, multer::Error::StreamSizeExceeded { limit: _ })));
506
507 Ok(())
508 }
509}