objectstore_server/extractors/
batch.rs1use std::fmt::Debug;
7
8use axum::extract::{
9 FromRequest, Multipart, Request,
10 multipart::{Field, MultipartError, MultipartRejection},
11};
12use futures::{StreamExt, stream::BoxStream};
13use objectstore_service::streaming::{Delete, Get, Insert, Operation};
14use objectstore_types::metadata::Metadata;
15use thiserror::Error;
16
17use crate::batch::{HEADER_BATCH_OPERATION_KEY, HEADER_BATCH_OPERATION_KIND};
18
19#[derive(Debug, Error)]
21pub enum BatchError {
22 #[error("bad request: {0}")]
24 BadRequest(String),
25
26 #[error("multipart error: {0}")]
28 Multipart(#[from] MultipartError),
29
30 #[error("metadata error: {0}")]
32 Metadata(#[from] objectstore_types::metadata::Error),
33
34 #[error("batch limit exceeded: {0}")]
36 LimitExceeded(String),
37
38 #[error("rate limited")]
40 RateLimited,
41
42 #[error("response part serialization error: {context}")]
44 ResponseSerialization {
45 context: String,
47 #[source]
49 cause: Box<dyn std::error::Error + Send + Sync>,
50 },
51}
52
53async fn try_operation_from_field(field: Field<'_>) -> Result<Operation, BatchError> {
54 let kind = field
55 .headers()
56 .get(HEADER_BATCH_OPERATION_KIND)
57 .ok_or_else(|| {
58 BatchError::BadRequest(format!("missing {HEADER_BATCH_OPERATION_KIND} header"))
59 })?;
60 let kind = kind
61 .to_str()
62 .map_err(|_| {
63 BatchError::BadRequest(format!(
64 "unable to convert {HEADER_BATCH_OPERATION_KIND} header value to string"
65 ))
66 })?
67 .to_lowercase();
68
69 let key = field
70 .headers()
71 .get(HEADER_BATCH_OPERATION_KEY)
72 .map(|v| {
73 let s = v.to_str().map_err(|_| {
74 BatchError::BadRequest(format!(
75 "unable to convert {HEADER_BATCH_OPERATION_KEY} header value to string"
76 ))
77 })?;
78 percent_encoding::percent_decode_str(s)
79 .decode_utf8()
80 .map(|decoded| decoded.into_owned())
81 .map_err(|_| {
82 BatchError::BadRequest(format!(
83 "unable to percent-decode {HEADER_BATCH_OPERATION_KEY} header value"
84 ))
85 })
86 })
87 .transpose()?;
88
89 let operation = match kind.as_str() {
90 "get" => Operation::Get(Get {
91 key: key.ok_or_else(|| {
92 BatchError::BadRequest(format!(
93 "missing {HEADER_BATCH_OPERATION_KEY} header for {kind} operation"
94 ))
95 })?,
96 }),
97 "delete" => Operation::Delete(Delete {
98 key: key.ok_or_else(|| {
99 BatchError::BadRequest(format!(
100 "missing {HEADER_BATCH_OPERATION_KEY} header for {kind} operation"
101 ))
102 })?,
103 }),
104 "insert" => {
105 let metadata = Metadata::from_headers(field.headers(), "")?;
106 let payload = field.bytes().await?;
107 if payload.len() > MAX_FIELD_SIZE {
108 return Err(BatchError::LimitExceeded(format!(
109 "individual request in batch exceeds body size limit of {MAX_FIELD_SIZE} bytes"
110 )));
111 }
112 Operation::Insert(Insert {
113 key,
114 metadata,
115 payload,
116 })
117 }
118 _ => {
119 return Err(BatchError::BadRequest(format!(
120 "invalid operation kind: {kind}"
121 )));
122 }
123 };
124 Ok(operation)
125}
126
127pub struct BatchOperationStream(pub BoxStream<'static, Result<Operation, BatchError>>);
129
130impl Debug for BatchOperationStream {
131 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
132 f.debug_struct("BatchOperationStream").finish()
133 }
134}
135
136const MAX_FIELD_SIZE: usize = 1024 * 1024; const MAX_OPERATIONS: usize = 1000;
138
139impl<S> FromRequest<S> for BatchOperationStream
140where
141 S: Send + Sync,
142{
143 type Rejection = MultipartRejection;
144
145 async fn from_request(request: Request, state: &S) -> Result<Self, Self::Rejection> {
146 let mut multipart = Multipart::from_request(request, state).await?;
147
148 let requests = async_stream::stream! {
149 let mut count = 0;
150 loop {
151 let field = match multipart.next_field().await {
152 Ok(Some(field)) => field,
153 Ok(None) => break,
154 Err(e) => {
155 yield Err(BatchError::from(e));
156 continue;
157 }
158 };
159 if count >= MAX_OPERATIONS {
160 yield Err(BatchError::LimitExceeded(format!(
161 "exceeded {MAX_OPERATIONS} operations per batch request"
162 )));
163 continue;
164 }
165 count += 1;
166 yield try_operation_from_field(field).await;
167 }
168 }
169 .boxed();
170
171 Ok(Self(requests))
172 }
173}
174
175#[cfg(test)]
176mod tests {
177 use std::time::Duration;
178
179 use super::*;
180 use axum::body::Body;
181 use axum::http::{Request, header::CONTENT_TYPE};
182 use futures::StreamExt;
183 use objectstore_service::streaming::Operation;
184 use objectstore_types::metadata::{ExpirationPolicy, HEADER_EXPIRATION, HEADER_ORIGIN};
185 use percent_encoding::NON_ALPHANUMERIC;
186
187 #[tokio::test]
188 async fn test_valid_request_works() {
189 let insert1_data = b"first blob data";
190 let insert2_data = b"second blob data";
191 let expiration = ExpirationPolicy::TimeToLive(Duration::from_hours(1));
192 let body = format!(
193 "--boundary\r\n\
194 {HEADER_BATCH_OPERATION_KEY}: {key0}\r\n\
195 {HEADER_BATCH_OPERATION_KIND}: get\r\n\
196 \r\n\
197 \r\n\
198 --boundary\r\n\
199 {HEADER_BATCH_OPERATION_KEY}: {key1}\r\n\
200 {HEADER_BATCH_OPERATION_KIND}: insert\r\n\
201 Content-Type: application/octet-stream\r\n\
202 \r\n\
203 {insert1}\r\n\
204 --boundary\r\n\
205 {HEADER_BATCH_OPERATION_KEY}: {key2}\r\n\
206 {HEADER_BATCH_OPERATION_KIND}: insert\r\n\
207 {HEADER_EXPIRATION}: {expiration}\r\n\
208 {HEADER_ORIGIN}: 203.0.113.42\r\n\
209 Content-Type: text/plain\r\n\
210 \r\n\
211 {insert2}\r\n\
212 --boundary\r\n\
213 {HEADER_BATCH_OPERATION_KEY}: {key3}\r\n\
214 {HEADER_BATCH_OPERATION_KIND}: delete\r\n\
215 \r\n\
216 \r\n\
217 --boundary--\r\n",
218 key0 = "test%2F0", key1 = "test1",
220 key2 = "test2",
221 key3 = "test3",
222 insert1 = String::from_utf8_lossy(insert1_data),
223 insert2 = String::from_utf8_lossy(insert2_data),
224 );
225
226 let request = Request::builder()
227 .header(CONTENT_TYPE, "multipart/form-data; boundary=boundary")
228 .body(Body::from(body))
229 .unwrap();
230
231 let batch_request = BatchOperationStream::from_request(request, &())
232 .await
233 .unwrap();
234
235 let operations: Vec<_> = batch_request.0.collect().await;
236 assert_eq!(operations.len(), 4);
237
238 let Operation::Get(get_op) = &operations[0].as_ref().unwrap() else {
239 panic!("expected get operation");
240 };
241 assert_eq!(get_op.key, "test/0");
242
243 let Operation::Insert(insert_op1) = &operations[1].as_ref().unwrap() else {
244 panic!("expected insert operation");
245 };
246 assert_eq!(insert_op1.key.as_deref(), Some("test1"));
247 assert_eq!(insert_op1.metadata.content_type, "application/octet-stream");
248 assert_eq!(insert_op1.metadata.origin, None);
249 assert_eq!(insert_op1.payload.as_ref(), insert1_data);
250
251 let Operation::Insert(insert_op2) = &operations[2].as_ref().unwrap() else {
252 panic!("expected insert operation");
253 };
254 assert_eq!(insert_op2.key.as_deref(), Some("test2"));
255 assert_eq!(insert_op2.metadata.content_type, "text/plain");
256 assert_eq!(insert_op2.metadata.expiration_policy, expiration);
257 assert_eq!(insert_op2.metadata.origin.as_deref(), Some("203.0.113.42"));
258 assert_eq!(insert_op2.payload.as_ref(), insert2_data);
259
260 let Operation::Delete(delete_op) = &operations[3].as_ref().unwrap() else {
261 panic!("expected delete operation");
262 };
263 assert_eq!(delete_op.key, "test3");
264 }
265
266 #[tokio::test]
267 async fn test_insert_without_key_header() {
268 let body = format!(
269 "--boundary\r\n\
270 {HEADER_BATCH_OPERATION_KIND}: insert\r\n\
271 Content-Type: application/octet-stream\r\n\
272 \r\n\
273 keyless payload\r\n\
274 --boundary--\r\n",
275 );
276
277 let request = Request::builder()
278 .header(CONTENT_TYPE, "multipart/form-data; boundary=boundary")
279 .body(Body::from(body))
280 .unwrap();
281
282 let batch_request = BatchOperationStream::from_request(request, &())
283 .await
284 .unwrap();
285
286 let operations: Vec<_> = batch_request.0.collect().await;
287 assert_eq!(operations.len(), 1);
288
289 let Operation::Insert(insert_op) = &operations[0].as_ref().unwrap() else {
290 panic!("expected insert operation");
291 };
292 assert!(insert_op.key.is_none());
293 assert_eq!(insert_op.payload.as_ref(), b"keyless payload");
294 }
295
296 #[tokio::test]
297 async fn test_individual_errors_with_isolation() {
298 let large_payload = "x".repeat(MAX_FIELD_SIZE + 1);
299 let valid_key = percent_encoding::percent_encode(b"valid", NON_ALPHANUMERIC);
300 let body = format!(
301 "--boundary\r\n\
302 {HEADER_BATCH_OPERATION_KIND}: get\r\n\
303 \r\n\
304 \r\n\
305 --boundary\r\n\
306 {HEADER_BATCH_OPERATION_KEY}: {valid_key}\r\n\
307 {HEADER_BATCH_OPERATION_KIND}: get\r\n\
308 \r\n\
309 \r\n\
310 --boundary\r\n\
311 {HEADER_BATCH_OPERATION_KIND}: delete\r\n\
312 \r\n\
313 \r\n\
314 --boundary\r\n\
315 {HEADER_BATCH_OPERATION_KEY}: {valid_key}\r\n\
316 {HEADER_BATCH_OPERATION_KIND}: insert\r\n\
317 Content-Type: application/octet-stream\r\n\
318 \r\n\
319 {large_payload}\r\n\
320 --boundary\r\n\
321 {HEADER_BATCH_OPERATION_KEY}: {valid_key}\r\n\
322 {HEADER_BATCH_OPERATION_KIND}: delete\r\n\
323 \r\n\
324 \r\n\
325 --boundary--\r\n",
326 );
327
328 let request = Request::builder()
329 .header(CONTENT_TYPE, "multipart/form-data; boundary=boundary")
330 .body(Body::from(body))
331 .unwrap();
332
333 let batch_request = BatchOperationStream::from_request(request, &())
334 .await
335 .unwrap();
336
337 let operations: Vec<_> = batch_request.0.collect().await;
338 assert_eq!(operations.len(), 5);
339
340 assert!(matches!(&operations[0], Err(BatchError::BadRequest(_))));
342 assert!(matches!(
344 &operations[1].as_ref().unwrap(),
345 Operation::Get(g) if g.key == "valid"
346 ));
347 assert!(matches!(&operations[2], Err(BatchError::BadRequest(_))));
349 assert!(matches!(&operations[3], Err(BatchError::LimitExceeded(_))));
351 assert!(matches!(
353 &operations[4].as_ref().unwrap(),
354 Operation::Delete(d) if d.key == "valid"
355 ));
356 }
357
358 #[tokio::test]
359 async fn test_max_operations_limit_enforced() {
360 let mut body = String::new();
361 for i in 0..(MAX_OPERATIONS + 1) {
362 let key =
363 percent_encoding::percent_encode(format!("test{i}").as_bytes(), NON_ALPHANUMERIC)
364 .to_string();
365 body.push_str(&format!(
366 "--boundary\r\n\
367 {HEADER_BATCH_OPERATION_KEY}: {key}\r\n\
368 {HEADER_BATCH_OPERATION_KIND}: get\r\n\
369 \r\n\
370 \r\n"
371 ));
372 }
373 body.push_str("--boundary--\r\n");
374
375 let request = Request::builder()
376 .header(CONTENT_TYPE, "multipart/form-data; boundary=boundary")
377 .body(Body::from(body))
378 .unwrap();
379
380 let batch_request = BatchOperationStream::from_request(request, &())
381 .await
382 .unwrap();
383 let operations: Vec<_> = batch_request.0.collect().await;
384
385 assert_eq!(operations.len(), MAX_OPERATIONS + 1);
386 assert!(matches!(
387 &operations[MAX_OPERATIONS],
388 Err(BatchError::LimitExceeded(_))
389 ));
390 }
391}