relay_server/utils/
multipart.rs

1use std::io;
2use std::task::Poll;
3
4use axum::extract::FromRequest;
5use axum::extract::Request;
6use bytes::{Bytes, BytesMut};
7use futures::{StreamExt, TryStreamExt};
8use multer::{Field, Multipart};
9use relay_config::Config;
10use serde::{Deserialize, Serialize};
11
12use crate::envelope::{AttachmentType, ContentType, Item, ItemType, Items};
13use crate::extractors::Remote;
14use crate::service::ServiceState;
15
16/// Type used for encoding string lengths.
17type Len = u32;
18
19/// Serializes a Pascal-style string with a 4 byte little-endian length prefix.
20fn write_string<W>(mut writer: W, string: &str) -> io::Result<()>
21where
22    W: io::Write,
23{
24    writer.write_all(&(string.len() as Len).to_le_bytes())?;
25    writer.write_all(string.as_bytes())?;
26
27    Ok(())
28}
29
30/// Safely consumes a slice of the given length.
31fn split_front<'a>(data: &mut &'a [u8], len: usize) -> Option<&'a [u8]> {
32    if data.len() < len {
33        *data = &[];
34        return None;
35    }
36
37    let (slice, rest) = data.split_at(len);
38    *data = rest;
39    Some(slice)
40}
41
42/// Consumes the 4-byte length prefix of a string.
43fn consume_len(data: &mut &[u8]) -> Option<usize> {
44    let len = std::mem::size_of::<Len>();
45    let slice = split_front(data, len)?;
46    let bytes = slice.try_into().ok();
47    bytes.map(|b| Len::from_le_bytes(b) as usize)
48}
49
50/// Consumes a Pascal-style string with a 4 byte little-endian length prefix.
51fn consume_string<'a>(data: &mut &'a [u8]) -> Option<&'a str> {
52    let len = consume_len(data)?;
53    let bytes = split_front(data, len)?;
54    std::str::from_utf8(bytes).ok()
55}
56
57/// An entry in a serialized form data item.
58#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
59pub struct FormDataEntry<'a>(&'a str, &'a str);
60
61impl<'a> FormDataEntry<'a> {
62    pub fn new(key: &'a str, value: &'a str) -> Self {
63        Self(key, value)
64    }
65
66    pub fn key(&self) -> &'a str {
67        self.0
68    }
69
70    pub fn value(&self) -> &'a str {
71        self.1
72    }
73
74    fn to_writer<W: io::Write>(&self, mut writer: W) {
75        write_string(&mut writer, self.key()).ok();
76        write_string(&mut writer, self.value()).ok();
77    }
78
79    fn read(data: &mut &'a [u8]) -> Option<Self> {
80        let key = consume_string(data)?;
81        let value = consume_string(data)?;
82        Some(Self::new(key, value))
83    }
84}
85
86/// A writer for serialized form data.
87///
88/// This writer is used to serialize multiple plain fields from a multipart form data request into a
89/// single envelope item. Use `FormDataIter` to iterate all entries.
90struct FormDataWriter {
91    data: Vec<u8>,
92}
93
94impl FormDataWriter {
95    pub fn new() -> Self {
96        Self { data: Vec::new() }
97    }
98
99    pub fn append(&mut self, key: &str, value: &str) {
100        let entry = FormDataEntry::new(key, value);
101        entry.to_writer(&mut self.data);
102    }
103
104    pub fn into_inner(self) -> Vec<u8> {
105        self.data
106    }
107}
108
109/// Iterates through serialized form data written with `FormDataWriter`.
110pub struct FormDataIter<'a> {
111    data: &'a [u8],
112}
113
114impl<'a> FormDataIter<'a> {
115    pub fn new(data: &'a [u8]) -> Self {
116        Self { data }
117    }
118}
119
120impl<'a> Iterator for FormDataIter<'a> {
121    type Item = FormDataEntry<'a>;
122
123    fn next(&mut self) -> Option<Self::Item> {
124        while !self.data.is_empty() {
125            match FormDataEntry::read(&mut self.data) {
126                Some(entry) => return Some(entry),
127                None => relay_log::error!("form data deserialization failed"),
128            }
129        }
130
131        None
132    }
133}
134
135/// Looks for a multipart boundary at the beginning of the data
136/// and returns it as a `&str` if it is found
137///
138/// A multipart boundary starts at the beginning of the data (possibly
139/// after some blank lines) and it is prefixed by '--' (two dashes)
140///
141/// ```ignore
142/// let boundary = get_multipart_boundary(b"--The boundary\r\n next line");
143/// assert_eq!(Some("The boundary"), boundary);
144///
145/// let invalid_boundary = get_multipart_boundary(b"The boundary\r\n next line");
146/// assert_eq!(None, invalid_boundary);
147/// ```
148pub fn get_multipart_boundary(data: &[u8]) -> Option<&str> {
149    data.split(|&byte| byte == b'\r' || byte == b'\n')
150        // Get the first non-empty line
151        .find(|slice| !slice.is_empty())
152        // Check for the form boundary indicator
153        .filter(|slice| slice.len() > 2 && slice.starts_with(b"--"))
154        // Form boundaries must be valid UTF-8 strings
155        .and_then(|slice| std::str::from_utf8(&slice[2..]).ok())
156}
157
158async fn multipart_items<F>(
159    mut multipart: Multipart<'_>,
160    mut infer_type: F,
161    config: &Config,
162    ignore_large_fields: bool,
163) -> Result<Items, multer::Error>
164where
165    F: FnMut(Option<&str>, &str) -> AttachmentType,
166{
167    let mut items = Items::new();
168    let mut form_data = FormDataWriter::new();
169    let mut attachments_size = 0;
170
171    while let Some(field) = multipart.next_field().await? {
172        if let Some(file_name) = field.file_name() {
173            let mut item = Item::new(ItemType::Attachment);
174            item.set_attachment_type(infer_type(field.name(), file_name));
175            item.set_filename(file_name);
176
177            let content_type = field.content_type().cloned();
178            let field = LimitedField::new(field, config.max_attachment_size());
179            match field.bytes().await {
180                Err(multer::Error::FieldSizeExceeded { .. }) if ignore_large_fields => continue,
181                Err(err) => return Err(err),
182                Ok(bytes) => {
183                    attachments_size += bytes.len();
184
185                    if attachments_size > config.max_attachments_size() {
186                        return Err(multer::Error::StreamSizeExceeded {
187                            limit: config.max_attachments_size() as u64,
188                        });
189                    }
190
191                    if let Some(content_type) = content_type {
192                        item.set_payload(content_type.as_ref().into(), bytes);
193                    } else {
194                        item.set_payload_without_content_type(bytes);
195                    }
196                }
197            }
198
199            items.push(item);
200        } else if let Some(field_name) = field.name().map(str::to_owned) {
201            // Ensure to decode this SAFELY to match Django's POST data behavior. This allows us to
202            // process sentry event payloads even if they contain invalid encoding.
203            let string = field.text().await?;
204            form_data.append(&field_name, &string);
205        } else {
206            relay_log::trace!("multipart content without name or file_name");
207        }
208    }
209
210    let form_data = form_data.into_inner();
211    if !form_data.is_empty() {
212        let mut item = Item::new(ItemType::FormData);
213        // Content type is `Text` (since it is not a json object but multiple
214        // json arrays serialized one after the other).
215        item.set_payload(ContentType::Text, form_data);
216        items.push(item);
217    }
218
219    Ok(items)
220}
221
222/// Wrapper around `multer::Field` which consumes the entire underlying stream even when the
223/// size limit is exceeded.
224///
225/// The idea being that you can process fields in a multi-part form even if one fields is too large.
226struct LimitedField<'a> {
227    field: Field<'a>,
228    consumed_size: usize,
229    size_limit: usize,
230    inner_finished: bool,
231}
232
233impl<'a> LimitedField<'a> {
234    fn new(field: Field<'a>, limit: usize) -> Self {
235        LimitedField {
236            field,
237            consumed_size: 0,
238            size_limit: limit,
239            inner_finished: false,
240        }
241    }
242
243    async fn bytes(self) -> Result<Bytes, multer::Error> {
244        self.try_fold(BytesMut::new(), |mut acc, x| async move {
245            acc.extend_from_slice(&x);
246            Ok(acc)
247        })
248        .await
249        .map(|x| x.freeze())
250    }
251}
252
253impl futures::Stream for LimitedField<'_> {
254    type Item = Result<Bytes, multer::Error>;
255
256    fn poll_next(
257        mut self: std::pin::Pin<&mut Self>,
258        cx: &mut std::task::Context<'_>,
259    ) -> std::task::Poll<Option<Self::Item>> {
260        if self.inner_finished {
261            return Poll::Ready(None);
262        }
263
264        match self.field.poll_next_unpin(cx) {
265            err @ Poll::Ready(Some(Err(_))) => err,
266            Poll::Ready(Some(Ok(t))) => {
267                self.consumed_size += t.len();
268                match self.consumed_size <= self.size_limit {
269                    true => Poll::Ready(Some(Ok(t))),
270                    false => {
271                        cx.waker().wake_by_ref();
272                        Poll::Pending
273                    }
274                }
275            }
276            Poll::Ready(None) if self.consumed_size > self.size_limit => {
277                self.inner_finished = true;
278                Poll::Ready(Some(Err(multer::Error::FieldSizeExceeded {
279                    limit: self.size_limit as u64,
280                    field_name: self.field.name().map(Into::into),
281                })))
282            }
283            Poll::Ready(None) => {
284                self.inner_finished = true;
285                Poll::Ready(None)
286            }
287            Poll::Pending => Poll::Pending,
288        }
289    }
290}
291
292/// Wrapper around [`multer::Multipart`] that checks each field is smaller than
293/// `max_attachment_size` and that the combined size of all fields is smaller than
294/// 'max_attachments_size'.
295pub struct ConstrainedMultipart(pub Multipart<'static>);
296
297impl FromRequest<ServiceState> for ConstrainedMultipart {
298    type Rejection = Remote<multer::Error>;
299
300    async fn from_request(request: Request, state: &ServiceState) -> Result<Self, Self::Rejection> {
301        // Still want to enforce multer limits here so that we avoid parsing large fields.
302        let limits =
303            multer::SizeLimit::new().whole_stream(state.config().max_attachments_size() as u64);
304
305        multipart_from_request(request, multer::Constraints::new().size_limit(limits))
306            .map(Self)
307            .map_err(Remote)
308    }
309}
310
311impl ConstrainedMultipart {
312    pub async fn items<F>(self, infer_type: F, config: &Config) -> Result<Items, multer::Error>
313    where
314        F: FnMut(Option<&str>, &str) -> AttachmentType,
315    {
316        multipart_items(self.0, infer_type, config, false).await
317    }
318}
319
320/// Wrapper around [`multer::Multipart`] that skips over fields which are larger than
321/// `max_attachment_size`. These fields are also not taken into account when checking that the
322/// combined size of all fields is smaller than `max_attachments_size`.
323#[allow(dead_code)]
324pub struct UnconstrainedMultipart(pub Multipart<'static>);
325
326impl FromRequest<ServiceState> for UnconstrainedMultipart {
327    type Rejection = Remote<multer::Error>;
328
329    async fn from_request(
330        request: Request,
331        _state: &ServiceState,
332    ) -> Result<Self, Self::Rejection> {
333        multipart_from_request(request, multer::Constraints::new())
334            .map(Self)
335            .map_err(Remote)
336    }
337}
338
339#[cfg_attr(not(any(test, sentry)), expect(dead_code))]
340impl UnconstrainedMultipart {
341    pub async fn items<F>(self, infer_type: F, config: &Config) -> Result<Items, multer::Error>
342    where
343        F: FnMut(Option<&str>, &str) -> AttachmentType,
344    {
345        multipart_items(self.0, infer_type, config, true).await
346    }
347}
348
349pub fn multipart_from_request(
350    request: Request,
351    constraints: multer::Constraints,
352) -> Result<Multipart<'static>, multer::Error> {
353    let content_type = request
354        .headers()
355        .get("content-type")
356        .and_then(|v| v.to_str().ok())
357        .unwrap_or("");
358    let boundary = multer::parse_boundary(content_type)?;
359
360    Ok(Multipart::with_constraints(
361        request.into_body().into_data_stream(),
362        boundary,
363        constraints,
364    ))
365}
366
367#[cfg(test)]
368mod tests {
369    use std::convert::Infallible;
370
371    use super::*;
372
373    #[test]
374    fn test_get_boundary() {
375        let examples: &[(&[u8], Option<&str>)] = &[
376            (b"--some_val", Some("some_val")),
377            (b"--\nsecond line", None),
378            (b"\n\r--some_val", Some("some_val")),
379            (b"\n\r--some_val\nadfa", Some("some_val")),
380            (b"\n\r--some_val\rfasdf", Some("some_val")),
381            (b"\n\r--some_val\r\nfasdf", Some("some_val")),
382            (b"\n\rsome_val", None),
383            (b"", None),
384            (b"--", None),
385        ];
386
387        for (input, expected) in examples {
388            let boundary = get_multipart_boundary(input);
389            assert_eq!(*expected, boundary);
390        }
391    }
392
393    #[test]
394    fn test_formdata() {
395        let mut writer = FormDataWriter::new();
396        writer.append("foo", "foo");
397        writer.append("bar", "");
398        writer.append("blub", "blub");
399
400        let payload = writer.into_inner();
401        let iter = FormDataIter::new(&payload);
402        let entries: Vec<_> = iter.collect();
403
404        assert_eq!(
405            entries,
406            vec![
407                FormDataEntry::new("foo", "foo"),
408                FormDataEntry::new("bar", ""),
409                FormDataEntry::new("blub", "blub"),
410            ]
411        );
412    }
413
414    #[test]
415    fn test_empty_formdata() {
416        let writer = FormDataWriter::new();
417        let payload = writer.into_inner();
418
419        let iter = FormDataIter::new(&payload);
420        let entries: Vec<_> = iter.collect();
421
422        assert_eq!(entries, vec![]);
423    }
424
425    /// Regression test for multipart payloads without a trailing newline.
426    #[tokio::test]
427    async fn missing_trailing_newline() -> anyhow::Result<()> {
428        let data = "--X-BOUNDARY\r\nContent-Disposition: form-data; \
429        name=\"my_text_field\"\r\n\r\nabcd\r\n--X-BOUNDARY--"; // No trailing newline
430
431        let stream = futures::stream::once(async { Ok::<_, Infallible>(data) });
432        let mut multipart = Multipart::new(stream, "X-BOUNDARY");
433
434        assert!(multipart.next_field().await?.is_some());
435        assert!(multipart.next_field().await?.is_none());
436
437        Ok(())
438    }
439
440    #[tokio::test]
441    async fn test_individual_size_limit_exceeded() -> anyhow::Result<()> {
442        let data = "--X-BOUNDARY\r\n\
443              Content-Disposition: form-data; name=\"file\"; filename=\"large.txt\"\r\n\
444              Content-Type: text/plain\r\n\
445              \r\n\
446              content too large for limit\r\n\
447              --X-BOUNDARY\r\n\
448              Content-Disposition: form-data; name=\"small_file\"; filename=\"small.txt\"\r\n\
449              Content-Type: text/plain\r\n\
450              \r\n\
451              ok\r\n\
452              --X-BOUNDARY--\r\n";
453
454        let stream = futures::stream::once(async move { Ok::<_, Infallible>(data) });
455        let multipart = Multipart::new(stream, "X-BOUNDARY");
456
457        let config = Config::from_json_value(serde_json::json!({
458            "limits": {
459                "max_attachment_size": 5
460            }
461        }))?;
462
463        let items = UnconstrainedMultipart(multipart)
464            .items(|_, _| AttachmentType::Attachment, &config)
465            .await?;
466
467        // The large field is skipped so only the small one should make it through.
468        assert_eq!(items.len(), 1);
469        let item = &items[0];
470        assert_eq!(item.filename(), Some("small.txt"));
471        assert_eq!(item.payload(), Bytes::from("ok"));
472
473        Ok(())
474    }
475
476    #[tokio::test]
477    async fn test_collective_size_limit_exceeded() -> anyhow::Result<()> {
478        let data = "--X-BOUNDARY\r\n\
479              Content-Disposition: form-data; name=\"file\"; filename=\"large.txt\"\r\n\
480              Content-Type: text/plain\r\n\
481              \r\n\
482              content too large for limit\r\n\
483              --X-BOUNDARY\r\n\
484              Content-Disposition: form-data; name=\"small_file\"; filename=\"small.txt\"\r\n\
485              Content-Type: text/plain\r\n\
486              \r\n\
487              ok\r\n\
488              --X-BOUNDARY--\r\n";
489
490        let stream = futures::stream::once(async move { Ok::<_, Infallible>(data) });
491
492        let config = Config::from_json_value(serde_json::json!({
493            "limits": {
494                "max_attachments_size": 5
495            }
496        }))?;
497
498        let multipart = Multipart::new(stream, "X-BOUNDARY");
499
500        let result = UnconstrainedMultipart(multipart)
501            .items(|_, _| AttachmentType::Attachment, &config)
502            .await;
503
504        // Should be warned if the overall stream limit is being breached.
505        assert!(result.is_err_and(|x| matches!(x, multer::Error::StreamSizeExceeded { limit: _ })));
506
507        Ok(())
508    }
509}