objectstore_server/
multipart.rs1use axum::body::Body;
7use axum::response::IntoResponse as _;
8use axum::response::Response;
9use bytes::{BufMut, Bytes, BytesMut};
10use futures::Stream;
11use futures::StreamExt;
12use futures::stream::BoxStream;
13use http::HeaderMap;
14use http::HeaderValue;
15use http::header::{CONTENT_DISPOSITION, CONTENT_TYPE};
16
17#[derive(Debug)]
19pub struct Part {
20 headers: HeaderMap,
21 body: Bytes,
22}
23
24impl Part {
25 pub fn new(body: Bytes, mut headers: HeaderMap, content_type: Option<HeaderValue>) -> Self {
28 headers.insert(
29 CONTENT_DISPOSITION,
30 HeaderValue::from_static("form-data; name=part"),
31 );
32 if let Some(content_type) = content_type {
33 headers.insert(CONTENT_TYPE, content_type);
34 }
35 Self { headers, body }
36 }
37}
38
39pub trait IntoMultipartResponse {
44 fn into_multipart_response(self, boundary: u128) -> Response;
50}
51
52impl<S, T> IntoMultipartResponse for S
53where
54 S: Stream<Item = T> + Send + 'static,
55 T: Into<Part> + Send,
56{
57 fn into_multipart_response(self, boundary: u128) -> Response {
58 let boundary_str = format!("os-boundary-{:032x}", boundary);
59 let boundary = {
60 let mut bytes = BytesMut::with_capacity(boundary_str.len() + 4);
61 bytes.put(&b"--"[..]);
62 bytes.put(boundary_str.as_bytes());
63 bytes.put(&b"\r\n"[..]);
64 bytes.freeze()
65 };
66
67 let mut headers = HeaderMap::new();
68 headers.insert(
69 CONTENT_TYPE,
70 format!("multipart/form-data; boundary=\"{}\"", &boundary_str)
71 .parse()
72 .expect("valid header value, as we just defined it as \"os-boundary-X\" where X are hex digits"),
73 );
74
75 let body: BoxStream<Result<bytes::Bytes, std::convert::Infallible>> =
76 async_stream::try_stream! {
77 let items = self;
78 futures::pin_mut!(items);
79 while let Some(item) = items.next().await {
80 yield boundary.clone();
81 let part = item.into();
82 yield serialize_headers(part.headers);
83 yield serialize_body(part.body);
84 }
85
86 let mut closing = BytesMut::with_capacity(boundary.len());
87 closing.put(boundary.slice(..boundary.len() - 2)); closing.put(&b"--"[..]);
89 yield closing.freeze();
90 }
91 .boxed();
92
93 (headers, Body::from_stream(body)).into_response()
94 }
95}
96
97fn serialize_headers(headers: HeaderMap) -> Bytes {
98 let mut res = BytesMut::with_capacity(30 + 30 * headers.len());
100 for (name, value) in &headers {
101 res.put(name.as_str().as_bytes());
102 res.put(&b": "[..]);
103 res.put(value.as_bytes());
104 res.put(&b"\r\n"[..]);
105 }
106 res.put(&b"\r\n"[..]);
107 res.freeze()
108}
109
110fn serialize_body(body: Bytes) -> Bytes {
111 let mut res = BytesMut::with_capacity(body.len() + 2);
112 res.put(body);
113 res.put(&b"\r\n"[..]);
114 res.freeze()
115}
116
117#[cfg(test)]
118mod tests {
119 use super::*;
120 use axum::body::{Body, to_bytes};
121 use axum::extract::{FromRequest, Multipart};
122 use axum::http::Request;
123
124 #[tokio::test]
127 async fn test_multipart_response() {
128 let mut extra_headers = HeaderMap::new();
129 extra_headers.insert("X-Custom-Header", "custom-value".parse().unwrap());
130 extra_headers.insert("X-File-Id", "12345".parse().unwrap());
131 let parts = vec![
132 Part::new(
133 Bytes::from(r#"{"key":"value"}"#),
134 HeaderMap::new(),
135 Some(HeaderValue::from_static("application/json")),
136 ),
137 Part::new(
138 Bytes::from(vec![0x00, 0x01, 0x02, 0xff, 0xfe]),
139 extra_headers,
140 Some(HeaderValue::from_static("application/octet-stream")),
141 ),
142 ];
143 let boundary: u128 = 0xdeadbeef;
144 let response = futures::stream::iter(parts).into_multipart_response(boundary);
145
146 let boundary = format!("os-boundary-{:032x}", boundary);
147 let content_type_str = format!("multipart/form-data; boundary=\"{}\"", boundary);
148 assert_eq!(
149 response
150 .headers()
151 .get(CONTENT_TYPE)
152 .unwrap()
153 .to_str()
154 .unwrap(),
155 &content_type_str
156 );
157
158 let body = to_bytes(response.into_body(), usize::MAX).await.unwrap();
159 let request = Request::builder()
160 .header(CONTENT_TYPE, &content_type_str)
161 .body(Body::from(body))
162 .unwrap();
163 let mut multipart = Multipart::from_request(request, &()).await.unwrap();
164
165 let field = multipart.next_field().await.unwrap().unwrap();
166 assert_eq!(field.name(), Some("part"));
167 assert_eq!(field.file_name(), None);
168 assert_eq!(field.content_type(), Some("application/json"));
169 assert_eq!(field.headers().len(), 2);
170 assert_eq!(field.bytes().await.unwrap(), r#"{"key":"value"}"#);
171
172 let field = multipart.next_field().await.unwrap().unwrap();
173 assert_eq!(field.name(), Some("part"));
174 assert_eq!(field.file_name(), None);
175 assert_eq!(field.content_type(), Some("application/octet-stream"));
176 assert_eq!(field.headers().len(), 4);
177 assert_eq!(
178 field.headers().get("X-Custom-Header").unwrap(),
179 "custom-value"
180 );
181 assert_eq!(field.headers().get("X-File-Id").unwrap(), "12345");
182 assert!(field.headers().get("content-disposition").is_some());
183 assert!(field.headers().get("content-type").is_some());
184 assert_eq!(
185 field.bytes().await.unwrap(),
186 vec![0x00, 0x01, 0x02, 0xff, 0xfe]
187 );
188
189 assert!(multipart.next_field().await.unwrap().is_none());
190 }
191}