relay_server/utils/
multipart.rsuse std::io;
use axum::extract::Request;
use multer::Multipart;
use relay_config::Config;
use serde::{Deserialize, Serialize};
use crate::envelope::{AttachmentType, ContentType, Item, ItemType, Items};
type Len = u32;
fn write_string<W>(mut writer: W, string: &str) -> io::Result<()>
where
W: io::Write,
{
writer.write_all(&(string.len() as Len).to_le_bytes())?;
writer.write_all(string.as_bytes())?;
Ok(())
}
fn split_front<'a>(data: &mut &'a [u8], len: usize) -> Option<&'a [u8]> {
if data.len() < len {
*data = &[];
return None;
}
let (slice, rest) = data.split_at(len);
*data = rest;
Some(slice)
}
fn consume_len(data: &mut &[u8]) -> Option<usize> {
let len = std::mem::size_of::<Len>();
let slice = split_front(data, len)?;
let bytes = slice.try_into().ok();
bytes.map(|b| Len::from_le_bytes(b) as usize)
}
fn consume_string<'a>(data: &mut &'a [u8]) -> Option<&'a str> {
let len = consume_len(data)?;
let bytes = split_front(data, len)?;
std::str::from_utf8(bytes).ok()
}
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
pub struct FormDataEntry<'a>(&'a str, &'a str);
impl<'a> FormDataEntry<'a> {
pub fn new(key: &'a str, value: &'a str) -> Self {
Self(key, value)
}
pub fn key(&self) -> &'a str {
self.0
}
pub fn value(&self) -> &'a str {
self.1
}
fn to_writer<W: io::Write>(&self, mut writer: W) {
write_string(&mut writer, self.key()).ok();
write_string(&mut writer, self.value()).ok();
}
fn read(data: &mut &'a [u8]) -> Option<Self> {
let key = consume_string(data)?;
let value = consume_string(data)?;
Some(Self::new(key, value))
}
}
struct FormDataWriter {
data: Vec<u8>,
}
impl FormDataWriter {
pub fn new() -> Self {
Self { data: Vec::new() }
}
pub fn append(&mut self, key: &str, value: &str) {
let entry = FormDataEntry::new(key, value);
entry.to_writer(&mut self.data);
}
pub fn into_inner(self) -> Vec<u8> {
self.data
}
}
pub struct FormDataIter<'a> {
data: &'a [u8],
}
impl<'a> FormDataIter<'a> {
pub fn new(data: &'a [u8]) -> Self {
Self { data }
}
}
impl<'a> Iterator for FormDataIter<'a> {
type Item = FormDataEntry<'a>;
fn next(&mut self) -> Option<Self::Item> {
while !self.data.is_empty() {
match FormDataEntry::read(&mut self.data) {
Some(entry) => return Some(entry),
None => relay_log::error!("form data deserialization failed"),
}
}
None
}
}
pub fn get_multipart_boundary(data: &[u8]) -> Option<&str> {
data.split(|&byte| byte == b'\r' || byte == b'\n')
.find(|slice| !slice.is_empty())
.filter(|slice| slice.len() > 2 && slice.starts_with(b"--"))
.and_then(|slice| std::str::from_utf8(&slice[2..]).ok())
}
pub async fn multipart_items<F>(
mut multipart: Multipart<'_>,
mut infer_type: F,
) -> Result<Items, multer::Error>
where
F: FnMut(Option<&str>) -> AttachmentType,
{
let mut items = Items::new();
let mut form_data = FormDataWriter::new();
while let Some(field) = multipart.next_field().await? {
if let Some(file_name) = field.file_name() {
let mut item = Item::new(ItemType::Attachment);
item.set_attachment_type(infer_type(field.name()));
item.set_filename(file_name);
if let Some(content_type) = field.content_type() {
item.set_payload(content_type.as_ref().into(), field.bytes().await?);
} else {
item.set_payload_without_content_type(field.bytes().await?);
}
items.push(item);
} else if let Some(field_name) = field.name().map(str::to_owned) {
let string = field.text().await?;
form_data.append(&field_name, &string);
} else {
relay_log::trace!("multipart content without name or file_name");
}
}
let form_data = form_data.into_inner();
if !form_data.is_empty() {
let mut item = Item::new(ItemType::FormData);
item.set_payload(ContentType::Text, form_data);
items.push(item);
}
Ok(items)
}
pub fn multipart_from_request(
request: Request,
config: &Config,
) -> Result<Multipart<'static>, multer::Error> {
let content_type = request
.headers()
.get("content-type")
.and_then(|v| v.to_str().ok())
.unwrap_or("");
let boundary = multer::parse_boundary(content_type)?;
let limits = multer::SizeLimit::new()
.whole_stream(config.max_attachments_size() as u64)
.per_field(config.max_attachment_size() as u64);
Ok(Multipart::with_constraints(
request.into_body().into_data_stream(),
boundary,
multer::Constraints::new().size_limit(limits),
))
}
#[cfg(test)]
mod tests {
use std::convert::Infallible;
use super::*;
#[test]
fn test_get_boundary() {
let examples: &[(&[u8], Option<&str>)] = &[
(b"--some_val", Some("some_val")),
(b"--\nsecond line", None),
(b"\n\r--some_val", Some("some_val")),
(b"\n\r--some_val\nadfa", Some("some_val")),
(b"\n\r--some_val\rfasdf", Some("some_val")),
(b"\n\r--some_val\r\nfasdf", Some("some_val")),
(b"\n\rsome_val", None),
(b"", None),
(b"--", None),
];
for (input, expected) in examples {
let boundary = get_multipart_boundary(input);
assert_eq!(*expected, boundary);
}
}
#[test]
fn test_formdata() {
let mut writer = FormDataWriter::new();
writer.append("foo", "foo");
writer.append("bar", "");
writer.append("blub", "blub");
let payload = writer.into_inner();
let iter = FormDataIter::new(&payload);
let entries: Vec<_> = iter.collect();
assert_eq!(
entries,
vec![
FormDataEntry::new("foo", "foo"),
FormDataEntry::new("bar", ""),
FormDataEntry::new("blub", "blub"),
]
);
}
#[test]
fn test_empty_formdata() {
let writer = FormDataWriter::new();
let payload = writer.into_inner();
let iter = FormDataIter::new(&payload);
let entries: Vec<_> = iter.collect();
assert_eq!(entries, vec![]);
}
#[tokio::test]
async fn missing_trailing_newline() -> anyhow::Result<()> {
let data = "--X-BOUNDARY\r\nContent-Disposition: form-data; \
name=\"my_text_field\"\r\n\r\nabcd\r\n--X-BOUNDARY--"; let stream = futures::stream::once(async { Ok::<_, Infallible>(data) });
let mut multipart = Multipart::new(stream, "X-BOUNDARY");
assert!(multipart.next_field().await?.is_some());
assert!(multipart.next_field().await?.is_none());
Ok(())
}
}