Skip to main content

objectstore_server/extractors/
batch.rs

1//! Axum extractor for batch operation streams.
2//!
3//! Provides [`BatchOperationStream`], which parses a multipart request body into a
4//! lazy stream of [`Operation`]s.
5
6use std::fmt::Debug;
7
8use axum::extract::{
9    FromRequest, Multipart, Request,
10    multipart::{Field, MultipartError, MultipartRejection},
11};
12use bytes::BytesMut;
13use futures::{StreamExt, stream::BoxStream};
14use objectstore_service::streaming::{Delete, Get, Head, Insert, Operation};
15use objectstore_types::metadata::Metadata;
16use thiserror::Error;
17
18use crate::batch::{HEADER_BATCH_OPERATION_KEY, HEADER_BATCH_OPERATION_KIND};
19
20/// Errors that can occur when processing or executing batch operations.
21#[derive(Debug, Error)]
22pub enum BatchError {
23    /// Malformed request.
24    #[error("bad request: {0}")]
25    BadRequest(String),
26
27    /// Errors in parsing or reading a multipart request body.
28    #[error("multipart error: {0}")]
29    Multipart(#[from] MultipartError),
30
31    /// Errors related to de/serialization and parsing of object metadata.
32    #[error("metadata error: {0}")]
33    Metadata(#[from] objectstore_types::metadata::Error),
34
35    /// Size or cardinality limit exceeded.
36    #[error("batch limit exceeded: {0}")]
37    LimitExceeded(String),
38
39    /// Operation rejected due to rate limiting.
40    #[error("rate limited")]
41    RateLimited,
42
43    /// Errors encountered when serializing batch response parts.
44    #[error("response part serialization error: {context}")]
45    ResponseSerialization {
46        /// Context describing what was being serialized.
47        context: String,
48        /// The underlying error.
49        #[source]
50        cause: Box<dyn std::error::Error + Send + Sync>,
51    },
52}
53
54async fn try_operation_from_field(mut field: Field<'_>) -> Result<Operation, BatchError> {
55    let kind = field
56        .headers()
57        .get(HEADER_BATCH_OPERATION_KIND)
58        .ok_or_else(|| {
59            BatchError::BadRequest(format!("missing {HEADER_BATCH_OPERATION_KIND} header"))
60        })?;
61    let kind = kind
62        .to_str()
63        .map_err(|_| {
64            BatchError::BadRequest(format!(
65                "unable to convert {HEADER_BATCH_OPERATION_KIND} header value to string"
66            ))
67        })?
68        .to_lowercase();
69
70    let key = field
71        .headers()
72        .get(HEADER_BATCH_OPERATION_KEY)
73        .map(|v| {
74            let s = v.to_str().map_err(|_| {
75                BatchError::BadRequest(format!(
76                    "unable to convert {HEADER_BATCH_OPERATION_KEY} header value to string"
77                ))
78            })?;
79            percent_encoding::percent_decode_str(s)
80                .decode_utf8()
81                .map(|decoded| decoded.into_owned())
82                .map_err(|_| {
83                    BatchError::BadRequest(format!(
84                        "unable to percent-decode {HEADER_BATCH_OPERATION_KEY} header value"
85                    ))
86                })
87        })
88        .transpose()?;
89
90    let operation = match kind.as_str() {
91        "get" => Operation::Get(Get {
92            key: key.ok_or_else(|| {
93                BatchError::BadRequest(format!(
94                    "missing {HEADER_BATCH_OPERATION_KEY} header for {kind} operation"
95                ))
96            })?,
97        }),
98        "delete" => Operation::Delete(Delete {
99            key: key.ok_or_else(|| {
100                BatchError::BadRequest(format!(
101                    "missing {HEADER_BATCH_OPERATION_KEY} header for {kind} operation"
102                ))
103            })?,
104        }),
105        "head" => Operation::Head(Head {
106            key: key.ok_or_else(|| {
107                BatchError::BadRequest(format!(
108                    "missing {HEADER_BATCH_OPERATION_KEY} header for {kind} operation"
109                ))
110            })?,
111        }),
112        "insert" => {
113            let metadata = Metadata::from_headers(field.headers(), "")?;
114            let mut payload = BytesMut::new();
115            while let Some(chunk) = field.chunk().await? {
116                if payload.len() + chunk.len() > MAX_FIELD_SIZE {
117                    return Err(BatchError::LimitExceeded(format!(
118                        "individual request in batch exceeds body size limit of {MAX_FIELD_SIZE} bytes"
119                    )));
120                }
121                payload.extend_from_slice(&chunk);
122            }
123            Operation::Insert(Insert {
124                key,
125                metadata,
126                payload: payload.freeze(),
127            })
128        }
129        _ => {
130            return Err(BatchError::BadRequest(format!(
131                "invalid operation kind: {kind}"
132            )));
133        }
134    };
135    Ok(operation)
136}
137
138/// A lazily-parsed stream of batch operations extracted from a multipart request body.
139pub struct BatchOperationStream(pub BoxStream<'static, Result<Operation, BatchError>>);
140
141impl Debug for BatchOperationStream {
142    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
143        f.debug_struct("BatchOperationStream").finish()
144    }
145}
146
147const MAX_FIELD_SIZE: usize = 1024 * 1024; // 1 MB
148const MAX_OPERATIONS: usize = 1000;
149
150impl<S> FromRequest<S> for BatchOperationStream
151where
152    S: Send + Sync,
153{
154    type Rejection = MultipartRejection;
155
156    async fn from_request(request: Request, state: &S) -> Result<Self, Self::Rejection> {
157        let mut multipart = Multipart::from_request(request, state).await?;
158
159        let requests = async_stream::stream! {
160            let mut count = 0;
161            loop {
162                let field = match multipart.next_field().await {
163                    Ok(Some(field)) => field,
164                    Ok(None) => break,
165                    Err(e) => {
166                        yield Err(BatchError::from(e));
167                        continue;
168                    }
169                };
170                if count >= MAX_OPERATIONS {
171                    yield Err(BatchError::LimitExceeded(format!(
172                        "exceeded {MAX_OPERATIONS} operations per batch request"
173                    )));
174                    continue;
175                }
176                count += 1;
177                yield try_operation_from_field(field).await;
178            }
179        }
180        .boxed();
181
182        Ok(Self(requests))
183    }
184}
185
186#[cfg(test)]
187mod tests {
188    use std::time::Duration;
189
190    use super::*;
191    use axum::body::Body;
192    use axum::http::{Request, header::CONTENT_TYPE};
193    use futures::StreamExt;
194    use objectstore_service::streaming::Operation;
195    use objectstore_types::metadata::{ExpirationPolicy, HEADER_EXPIRATION, HEADER_ORIGIN};
196    use percent_encoding::NON_ALPHANUMERIC;
197
198    #[tokio::test]
199    async fn test_valid_request_works() {
200        let insert1_data = b"first blob data";
201        let insert2_data = b"second blob data";
202        let expiration = ExpirationPolicy::TimeToLive(Duration::from_hours(1));
203        let body = format!(
204            "--boundary\r\n\
205             {HEADER_BATCH_OPERATION_KEY}: {key0}\r\n\
206             {HEADER_BATCH_OPERATION_KIND}: get\r\n\
207             \r\n\
208             \r\n\
209             --boundary\r\n\
210             {HEADER_BATCH_OPERATION_KEY}: {key1}\r\n\
211             {HEADER_BATCH_OPERATION_KIND}: insert\r\n\
212             Content-Type: application/octet-stream\r\n\
213             \r\n\
214             {insert1}\r\n\
215             --boundary\r\n\
216             {HEADER_BATCH_OPERATION_KEY}: {key2}\r\n\
217             {HEADER_BATCH_OPERATION_KIND}: insert\r\n\
218             {HEADER_EXPIRATION}: {expiration}\r\n\
219             {HEADER_ORIGIN}: 203.0.113.42\r\n\
220             Content-Type: text/plain\r\n\
221             \r\n\
222             {insert2}\r\n\
223             --boundary\r\n\
224             {HEADER_BATCH_OPERATION_KEY}: {key3}\r\n\
225             {HEADER_BATCH_OPERATION_KIND}: delete\r\n\
226             \r\n\
227             \r\n\
228             --boundary--\r\n",
229            key0 = "test%2F0", // "test/0" percent-encoded
230            key1 = "test1",
231            key2 = "test2",
232            key3 = "test3",
233            insert1 = String::from_utf8_lossy(insert1_data),
234            insert2 = String::from_utf8_lossy(insert2_data),
235        );
236
237        let request = Request::builder()
238            .header(CONTENT_TYPE, "multipart/form-data; boundary=boundary")
239            .body(Body::from(body))
240            .unwrap();
241
242        let batch_request = BatchOperationStream::from_request(request, &())
243            .await
244            .unwrap();
245
246        let operations: Vec<_> = batch_request.0.collect().await;
247        assert_eq!(operations.len(), 4);
248
249        let Operation::Get(get_op) = &operations[0].as_ref().unwrap() else {
250            panic!("expected get operation");
251        };
252        assert_eq!(get_op.key, "test/0");
253
254        let Operation::Insert(insert_op1) = &operations[1].as_ref().unwrap() else {
255            panic!("expected insert operation");
256        };
257        assert_eq!(insert_op1.key.as_deref(), Some("test1"));
258        assert_eq!(insert_op1.metadata.content_type, "application/octet-stream");
259        assert_eq!(insert_op1.metadata.origin, None);
260        assert_eq!(insert_op1.payload.as_ref(), insert1_data);
261
262        let Operation::Insert(insert_op2) = &operations[2].as_ref().unwrap() else {
263            panic!("expected insert operation");
264        };
265        assert_eq!(insert_op2.key.as_deref(), Some("test2"));
266        assert_eq!(insert_op2.metadata.content_type, "text/plain");
267        assert_eq!(insert_op2.metadata.expiration_policy, expiration);
268        assert_eq!(insert_op2.metadata.origin.as_deref(), Some("203.0.113.42"));
269        assert_eq!(insert_op2.payload.as_ref(), insert2_data);
270
271        let Operation::Delete(delete_op) = &operations[3].as_ref().unwrap() else {
272            panic!("expected delete operation");
273        };
274        assert_eq!(delete_op.key, "test3");
275    }
276
277    #[tokio::test]
278    async fn test_insert_without_key_header() {
279        let body = format!(
280            "--boundary\r\n\
281             {HEADER_BATCH_OPERATION_KIND}: insert\r\n\
282             Content-Type: application/octet-stream\r\n\
283             \r\n\
284             keyless payload\r\n\
285             --boundary--\r\n",
286        );
287
288        let request = Request::builder()
289            .header(CONTENT_TYPE, "multipart/form-data; boundary=boundary")
290            .body(Body::from(body))
291            .unwrap();
292
293        let batch_request = BatchOperationStream::from_request(request, &())
294            .await
295            .unwrap();
296
297        let operations: Vec<_> = batch_request.0.collect().await;
298        assert_eq!(operations.len(), 1);
299
300        let Operation::Insert(insert_op) = &operations[0].as_ref().unwrap() else {
301            panic!("expected insert operation");
302        };
303        assert!(insert_op.key.is_none());
304        assert_eq!(insert_op.payload.as_ref(), b"keyless payload");
305    }
306
307    #[tokio::test]
308    async fn test_individual_errors_with_isolation() {
309        let large_payload = "x".repeat(MAX_FIELD_SIZE + 1);
310        let valid_key = percent_encoding::percent_encode(b"valid", NON_ALPHANUMERIC);
311        let body = format!(
312            "--boundary\r\n\
313             {HEADER_BATCH_OPERATION_KIND}: get\r\n\
314             \r\n\
315             \r\n\
316             --boundary\r\n\
317             {HEADER_BATCH_OPERATION_KEY}: {valid_key}\r\n\
318             {HEADER_BATCH_OPERATION_KIND}: get\r\n\
319             \r\n\
320             \r\n\
321             --boundary\r\n\
322             {HEADER_BATCH_OPERATION_KIND}: delete\r\n\
323             \r\n\
324             \r\n\
325             --boundary\r\n\
326             {HEADER_BATCH_OPERATION_KEY}: {valid_key}\r\n\
327             {HEADER_BATCH_OPERATION_KIND}: insert\r\n\
328             Content-Type: application/octet-stream\r\n\
329             \r\n\
330             {large_payload}\r\n\
331             --boundary\r\n\
332             {HEADER_BATCH_OPERATION_KEY}: {valid_key}\r\n\
333             {HEADER_BATCH_OPERATION_KIND}: delete\r\n\
334             \r\n\
335             \r\n\
336             --boundary--\r\n",
337        );
338
339        let request = Request::builder()
340            .header(CONTENT_TYPE, "multipart/form-data; boundary=boundary")
341            .body(Body::from(body))
342            .unwrap();
343
344        let batch_request = BatchOperationStream::from_request(request, &())
345            .await
346            .unwrap();
347
348        let operations: Vec<_> = batch_request.0.collect().await;
349        assert_eq!(operations.len(), 5);
350
351        // get without key → BadRequest
352        assert!(matches!(&operations[0], Err(BatchError::BadRequest(_))));
353        // valid get
354        assert!(matches!(
355            &operations[1].as_ref().unwrap(),
356            Operation::Get(g) if g.key == "valid"
357        ));
358        // delete without key → BadRequest
359        assert!(matches!(&operations[2], Err(BatchError::BadRequest(_))));
360        // oversized insert → LimitExceeded
361        assert!(matches!(&operations[3], Err(BatchError::LimitExceeded(_))));
362        // valid delete still succeeds after prior errors
363        assert!(matches!(
364            &operations[4].as_ref().unwrap(),
365            Operation::Delete(d) if d.key == "valid"
366        ));
367    }
368
369    #[tokio::test]
370    async fn test_max_operations_limit_enforced() {
371        let mut body = String::new();
372        for i in 0..(MAX_OPERATIONS + 1) {
373            let key =
374                percent_encoding::percent_encode(format!("test{i}").as_bytes(), NON_ALPHANUMERIC)
375                    .to_string();
376            body.push_str(&format!(
377                "--boundary\r\n\
378                 {HEADER_BATCH_OPERATION_KEY}: {key}\r\n\
379                 {HEADER_BATCH_OPERATION_KIND}: get\r\n\
380                 \r\n\
381                 \r\n"
382            ));
383        }
384        body.push_str("--boundary--\r\n");
385
386        let request = Request::builder()
387            .header(CONTENT_TYPE, "multipart/form-data; boundary=boundary")
388            .body(Body::from(body))
389            .unwrap();
390
391        let batch_request = BatchOperationStream::from_request(request, &())
392            .await
393            .unwrap();
394        let operations: Vec<_> = batch_request.0.collect().await;
395
396        assert_eq!(operations.len(), MAX_OPERATIONS + 1);
397        assert!(matches!(
398            &operations[MAX_OPERATIONS],
399            Err(BatchError::LimitExceeded(_))
400        ));
401    }
402
403    #[tokio::test]
404    async fn test_head_operation() {
405        let key = percent_encoding::percent_encode(b"head-key", NON_ALPHANUMERIC);
406        let body = format!(
407            "--boundary\r\n\
408             {HEADER_BATCH_OPERATION_KEY}: {key}\r\n\
409             {HEADER_BATCH_OPERATION_KIND}: head\r\n\
410             \r\n\
411             \r\n\
412             --boundary--\r\n",
413        );
414
415        let request = Request::builder()
416            .header(CONTENT_TYPE, "multipart/form-data; boundary=boundary")
417            .body(Body::from(body))
418            .unwrap();
419
420        let batch_request = BatchOperationStream::from_request(request, &())
421            .await
422            .unwrap();
423        let operations: Vec<_> = batch_request.0.collect().await;
424        assert_eq!(operations.len(), 1);
425
426        let Operation::Head(head_op) = &operations[0].as_ref().unwrap() else {
427            panic!("expected head operation");
428        };
429        assert_eq!(head_op.key, "head-key");
430    }
431
432    #[tokio::test]
433    async fn test_head_without_key_is_error() {
434        let body = format!(
435            "--boundary\r\n\
436             {HEADER_BATCH_OPERATION_KIND}: head\r\n\
437             \r\n\
438             \r\n\
439             --boundary--\r\n",
440        );
441
442        let request = Request::builder()
443            .header(CONTENT_TYPE, "multipart/form-data; boundary=boundary")
444            .body(Body::from(body))
445            .unwrap();
446
447        let batch_request = BatchOperationStream::from_request(request, &())
448            .await
449            .unwrap();
450        let operations: Vec<_> = batch_request.0.collect().await;
451        assert_eq!(operations.len(), 1);
452        assert!(matches!(&operations[0], Err(BatchError::BadRequest(_))));
453    }
454}