objectstore_server/extractors/
batch.rs1use std::fmt::Debug;
7
8use axum::extract::{
9 FromRequest, Multipart, Request,
10 multipart::{Field, MultipartError, MultipartRejection},
11};
12use bytes::BytesMut;
13use futures::{StreamExt, stream::BoxStream};
14use objectstore_service::streaming::{Delete, Get, Head, Insert, Operation};
15use objectstore_types::metadata::Metadata;
16use thiserror::Error;
17
18use crate::batch::{HEADER_BATCH_OPERATION_KEY, HEADER_BATCH_OPERATION_KIND};
19
20#[derive(Debug, Error)]
22pub enum BatchError {
23 #[error("bad request: {0}")]
25 BadRequest(String),
26
27 #[error("multipart error: {0}")]
29 Multipart(#[from] MultipartError),
30
31 #[error("metadata error: {0}")]
33 Metadata(#[from] objectstore_types::metadata::Error),
34
35 #[error("batch limit exceeded: {0}")]
37 LimitExceeded(String),
38
39 #[error("rate limited")]
41 RateLimited,
42
43 #[error("response part serialization error: {context}")]
45 ResponseSerialization {
46 context: String,
48 #[source]
50 cause: Box<dyn std::error::Error + Send + Sync>,
51 },
52}
53
54async fn try_operation_from_field(mut field: Field<'_>) -> Result<Operation, BatchError> {
55 let kind = field
56 .headers()
57 .get(HEADER_BATCH_OPERATION_KIND)
58 .ok_or_else(|| {
59 BatchError::BadRequest(format!("missing {HEADER_BATCH_OPERATION_KIND} header"))
60 })?;
61 let kind = kind
62 .to_str()
63 .map_err(|_| {
64 BatchError::BadRequest(format!(
65 "unable to convert {HEADER_BATCH_OPERATION_KIND} header value to string"
66 ))
67 })?
68 .to_lowercase();
69
70 let key = field
71 .headers()
72 .get(HEADER_BATCH_OPERATION_KEY)
73 .map(|v| {
74 let s = v.to_str().map_err(|_| {
75 BatchError::BadRequest(format!(
76 "unable to convert {HEADER_BATCH_OPERATION_KEY} header value to string"
77 ))
78 })?;
79 percent_encoding::percent_decode_str(s)
80 .decode_utf8()
81 .map(|decoded| decoded.into_owned())
82 .map_err(|_| {
83 BatchError::BadRequest(format!(
84 "unable to percent-decode {HEADER_BATCH_OPERATION_KEY} header value"
85 ))
86 })
87 })
88 .transpose()?;
89
90 let operation = match kind.as_str() {
91 "get" => Operation::Get(Get {
92 key: key.ok_or_else(|| {
93 BatchError::BadRequest(format!(
94 "missing {HEADER_BATCH_OPERATION_KEY} header for {kind} operation"
95 ))
96 })?,
97 }),
98 "delete" => Operation::Delete(Delete {
99 key: key.ok_or_else(|| {
100 BatchError::BadRequest(format!(
101 "missing {HEADER_BATCH_OPERATION_KEY} header for {kind} operation"
102 ))
103 })?,
104 }),
105 "head" => Operation::Head(Head {
106 key: key.ok_or_else(|| {
107 BatchError::BadRequest(format!(
108 "missing {HEADER_BATCH_OPERATION_KEY} header for {kind} operation"
109 ))
110 })?,
111 }),
112 "insert" => {
113 let metadata = Metadata::from_headers(field.headers(), "")?;
114 let mut payload = BytesMut::new();
115 while let Some(chunk) = field.chunk().await? {
116 if payload.len() + chunk.len() > MAX_FIELD_SIZE {
117 return Err(BatchError::LimitExceeded(format!(
118 "individual request in batch exceeds body size limit of {MAX_FIELD_SIZE} bytes"
119 )));
120 }
121 payload.extend_from_slice(&chunk);
122 }
123 Operation::Insert(Insert {
124 key,
125 metadata,
126 payload: payload.freeze(),
127 })
128 }
129 _ => {
130 return Err(BatchError::BadRequest(format!(
131 "invalid operation kind: {kind}"
132 )));
133 }
134 };
135 Ok(operation)
136}
137
138pub struct BatchOperationStream(pub BoxStream<'static, Result<Operation, BatchError>>);
140
141impl Debug for BatchOperationStream {
142 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
143 f.debug_struct("BatchOperationStream").finish()
144 }
145}
146
147const MAX_FIELD_SIZE: usize = 1024 * 1024; const MAX_OPERATIONS: usize = 1000;
149
150impl<S> FromRequest<S> for BatchOperationStream
151where
152 S: Send + Sync,
153{
154 type Rejection = MultipartRejection;
155
156 async fn from_request(request: Request, state: &S) -> Result<Self, Self::Rejection> {
157 let mut multipart = Multipart::from_request(request, state).await?;
158
159 let requests = async_stream::stream! {
160 let mut count = 0;
161 loop {
162 let field = match multipart.next_field().await {
163 Ok(Some(field)) => field,
164 Ok(None) => break,
165 Err(e) => {
166 yield Err(BatchError::from(e));
167 continue;
168 }
169 };
170 if count >= MAX_OPERATIONS {
171 yield Err(BatchError::LimitExceeded(format!(
172 "exceeded {MAX_OPERATIONS} operations per batch request"
173 )));
174 continue;
175 }
176 count += 1;
177 yield try_operation_from_field(field).await;
178 }
179 }
180 .boxed();
181
182 Ok(Self(requests))
183 }
184}
185
186#[cfg(test)]
187mod tests {
188 use std::time::Duration;
189
190 use super::*;
191 use axum::body::Body;
192 use axum::http::{Request, header::CONTENT_TYPE};
193 use futures::StreamExt;
194 use objectstore_service::streaming::Operation;
195 use objectstore_types::metadata::{ExpirationPolicy, HEADER_EXPIRATION, HEADER_ORIGIN};
196 use percent_encoding::NON_ALPHANUMERIC;
197
198 #[tokio::test]
199 async fn test_valid_request_works() {
200 let insert1_data = b"first blob data";
201 let insert2_data = b"second blob data";
202 let expiration = ExpirationPolicy::TimeToLive(Duration::from_hours(1));
203 let body = format!(
204 "--boundary\r\n\
205 {HEADER_BATCH_OPERATION_KEY}: {key0}\r\n\
206 {HEADER_BATCH_OPERATION_KIND}: get\r\n\
207 \r\n\
208 \r\n\
209 --boundary\r\n\
210 {HEADER_BATCH_OPERATION_KEY}: {key1}\r\n\
211 {HEADER_BATCH_OPERATION_KIND}: insert\r\n\
212 Content-Type: application/octet-stream\r\n\
213 \r\n\
214 {insert1}\r\n\
215 --boundary\r\n\
216 {HEADER_BATCH_OPERATION_KEY}: {key2}\r\n\
217 {HEADER_BATCH_OPERATION_KIND}: insert\r\n\
218 {HEADER_EXPIRATION}: {expiration}\r\n\
219 {HEADER_ORIGIN}: 203.0.113.42\r\n\
220 Content-Type: text/plain\r\n\
221 \r\n\
222 {insert2}\r\n\
223 --boundary\r\n\
224 {HEADER_BATCH_OPERATION_KEY}: {key3}\r\n\
225 {HEADER_BATCH_OPERATION_KIND}: delete\r\n\
226 \r\n\
227 \r\n\
228 --boundary--\r\n",
229 key0 = "test%2F0", key1 = "test1",
231 key2 = "test2",
232 key3 = "test3",
233 insert1 = String::from_utf8_lossy(insert1_data),
234 insert2 = String::from_utf8_lossy(insert2_data),
235 );
236
237 let request = Request::builder()
238 .header(CONTENT_TYPE, "multipart/form-data; boundary=boundary")
239 .body(Body::from(body))
240 .unwrap();
241
242 let batch_request = BatchOperationStream::from_request(request, &())
243 .await
244 .unwrap();
245
246 let operations: Vec<_> = batch_request.0.collect().await;
247 assert_eq!(operations.len(), 4);
248
249 let Operation::Get(get_op) = &operations[0].as_ref().unwrap() else {
250 panic!("expected get operation");
251 };
252 assert_eq!(get_op.key, "test/0");
253
254 let Operation::Insert(insert_op1) = &operations[1].as_ref().unwrap() else {
255 panic!("expected insert operation");
256 };
257 assert_eq!(insert_op1.key.as_deref(), Some("test1"));
258 assert_eq!(insert_op1.metadata.content_type, "application/octet-stream");
259 assert_eq!(insert_op1.metadata.origin, None);
260 assert_eq!(insert_op1.payload.as_ref(), insert1_data);
261
262 let Operation::Insert(insert_op2) = &operations[2].as_ref().unwrap() else {
263 panic!("expected insert operation");
264 };
265 assert_eq!(insert_op2.key.as_deref(), Some("test2"));
266 assert_eq!(insert_op2.metadata.content_type, "text/plain");
267 assert_eq!(insert_op2.metadata.expiration_policy, expiration);
268 assert_eq!(insert_op2.metadata.origin.as_deref(), Some("203.0.113.42"));
269 assert_eq!(insert_op2.payload.as_ref(), insert2_data);
270
271 let Operation::Delete(delete_op) = &operations[3].as_ref().unwrap() else {
272 panic!("expected delete operation");
273 };
274 assert_eq!(delete_op.key, "test3");
275 }
276
277 #[tokio::test]
278 async fn test_insert_without_key_header() {
279 let body = format!(
280 "--boundary\r\n\
281 {HEADER_BATCH_OPERATION_KIND}: insert\r\n\
282 Content-Type: application/octet-stream\r\n\
283 \r\n\
284 keyless payload\r\n\
285 --boundary--\r\n",
286 );
287
288 let request = Request::builder()
289 .header(CONTENT_TYPE, "multipart/form-data; boundary=boundary")
290 .body(Body::from(body))
291 .unwrap();
292
293 let batch_request = BatchOperationStream::from_request(request, &())
294 .await
295 .unwrap();
296
297 let operations: Vec<_> = batch_request.0.collect().await;
298 assert_eq!(operations.len(), 1);
299
300 let Operation::Insert(insert_op) = &operations[0].as_ref().unwrap() else {
301 panic!("expected insert operation");
302 };
303 assert!(insert_op.key.is_none());
304 assert_eq!(insert_op.payload.as_ref(), b"keyless payload");
305 }
306
307 #[tokio::test]
308 async fn test_individual_errors_with_isolation() {
309 let large_payload = "x".repeat(MAX_FIELD_SIZE + 1);
310 let valid_key = percent_encoding::percent_encode(b"valid", NON_ALPHANUMERIC);
311 let body = format!(
312 "--boundary\r\n\
313 {HEADER_BATCH_OPERATION_KIND}: get\r\n\
314 \r\n\
315 \r\n\
316 --boundary\r\n\
317 {HEADER_BATCH_OPERATION_KEY}: {valid_key}\r\n\
318 {HEADER_BATCH_OPERATION_KIND}: get\r\n\
319 \r\n\
320 \r\n\
321 --boundary\r\n\
322 {HEADER_BATCH_OPERATION_KIND}: delete\r\n\
323 \r\n\
324 \r\n\
325 --boundary\r\n\
326 {HEADER_BATCH_OPERATION_KEY}: {valid_key}\r\n\
327 {HEADER_BATCH_OPERATION_KIND}: insert\r\n\
328 Content-Type: application/octet-stream\r\n\
329 \r\n\
330 {large_payload}\r\n\
331 --boundary\r\n\
332 {HEADER_BATCH_OPERATION_KEY}: {valid_key}\r\n\
333 {HEADER_BATCH_OPERATION_KIND}: delete\r\n\
334 \r\n\
335 \r\n\
336 --boundary--\r\n",
337 );
338
339 let request = Request::builder()
340 .header(CONTENT_TYPE, "multipart/form-data; boundary=boundary")
341 .body(Body::from(body))
342 .unwrap();
343
344 let batch_request = BatchOperationStream::from_request(request, &())
345 .await
346 .unwrap();
347
348 let operations: Vec<_> = batch_request.0.collect().await;
349 assert_eq!(operations.len(), 5);
350
351 assert!(matches!(&operations[0], Err(BatchError::BadRequest(_))));
353 assert!(matches!(
355 &operations[1].as_ref().unwrap(),
356 Operation::Get(g) if g.key == "valid"
357 ));
358 assert!(matches!(&operations[2], Err(BatchError::BadRequest(_))));
360 assert!(matches!(&operations[3], Err(BatchError::LimitExceeded(_))));
362 assert!(matches!(
364 &operations[4].as_ref().unwrap(),
365 Operation::Delete(d) if d.key == "valid"
366 ));
367 }
368
369 #[tokio::test]
370 async fn test_max_operations_limit_enforced() {
371 let mut body = String::new();
372 for i in 0..(MAX_OPERATIONS + 1) {
373 let key =
374 percent_encoding::percent_encode(format!("test{i}").as_bytes(), NON_ALPHANUMERIC)
375 .to_string();
376 body.push_str(&format!(
377 "--boundary\r\n\
378 {HEADER_BATCH_OPERATION_KEY}: {key}\r\n\
379 {HEADER_BATCH_OPERATION_KIND}: get\r\n\
380 \r\n\
381 \r\n"
382 ));
383 }
384 body.push_str("--boundary--\r\n");
385
386 let request = Request::builder()
387 .header(CONTENT_TYPE, "multipart/form-data; boundary=boundary")
388 .body(Body::from(body))
389 .unwrap();
390
391 let batch_request = BatchOperationStream::from_request(request, &())
392 .await
393 .unwrap();
394 let operations: Vec<_> = batch_request.0.collect().await;
395
396 assert_eq!(operations.len(), MAX_OPERATIONS + 1);
397 assert!(matches!(
398 &operations[MAX_OPERATIONS],
399 Err(BatchError::LimitExceeded(_))
400 ));
401 }
402
403 #[tokio::test]
404 async fn test_head_operation() {
405 let key = percent_encoding::percent_encode(b"head-key", NON_ALPHANUMERIC);
406 let body = format!(
407 "--boundary\r\n\
408 {HEADER_BATCH_OPERATION_KEY}: {key}\r\n\
409 {HEADER_BATCH_OPERATION_KIND}: head\r\n\
410 \r\n\
411 \r\n\
412 --boundary--\r\n",
413 );
414
415 let request = Request::builder()
416 .header(CONTENT_TYPE, "multipart/form-data; boundary=boundary")
417 .body(Body::from(body))
418 .unwrap();
419
420 let batch_request = BatchOperationStream::from_request(request, &())
421 .await
422 .unwrap();
423 let operations: Vec<_> = batch_request.0.collect().await;
424 assert_eq!(operations.len(), 1);
425
426 let Operation::Head(head_op) = &operations[0].as_ref().unwrap() else {
427 panic!("expected head operation");
428 };
429 assert_eq!(head_op.key, "head-key");
430 }
431
432 #[tokio::test]
433 async fn test_head_without_key_is_error() {
434 let body = format!(
435 "--boundary\r\n\
436 {HEADER_BATCH_OPERATION_KIND}: head\r\n\
437 \r\n\
438 \r\n\
439 --boundary--\r\n",
440 );
441
442 let request = Request::builder()
443 .header(CONTENT_TYPE, "multipart/form-data; boundary=boundary")
444 .body(Body::from(body))
445 .unwrap();
446
447 let batch_request = BatchOperationStream::from_request(request, &())
448 .await
449 .unwrap();
450 let operations: Vec<_> = batch_request.0.collect().await;
451 assert_eq!(operations.len(), 1);
452 assert!(matches!(&operations[0], Err(BatchError::BadRequest(_))));
453 }
454}