objectstore_server/extractors/
batch.rs

1//! Axum extractor for batch operation streams.
2//!
3//! Provides [`BatchOperationStream`], which parses a multipart request body into a
4//! lazy stream of [`Operation`]s.
5
6use std::fmt::Debug;
7
8use axum::extract::{
9    FromRequest, Multipart, Request,
10    multipart::{Field, MultipartError, MultipartRejection},
11};
12use futures::{StreamExt, stream::BoxStream};
13use objectstore_service::streaming::{Delete, Get, Insert, Operation};
14use objectstore_types::metadata::Metadata;
15use thiserror::Error;
16
17use crate::batch::{HEADER_BATCH_OPERATION_KEY, HEADER_BATCH_OPERATION_KIND};
18
19/// Errors that can occur when processing or executing batch operations.
20#[derive(Debug, Error)]
21pub enum BatchError {
22    /// Malformed request.
23    #[error("bad request: {0}")]
24    BadRequest(String),
25
26    /// Errors in parsing or reading a multipart request body.
27    #[error("multipart error: {0}")]
28    Multipart(#[from] MultipartError),
29
30    /// Errors related to de/serialization and parsing of object metadata.
31    #[error("metadata error: {0}")]
32    Metadata(#[from] objectstore_types::metadata::Error),
33
34    /// Size or cardinality limit exceeded.
35    #[error("batch limit exceeded: {0}")]
36    LimitExceeded(String),
37
38    /// Operation rejected due to rate limiting.
39    #[error("rate limited")]
40    RateLimited,
41
42    /// Errors encountered when serializing batch response parts.
43    #[error("response part serialization error: {context}")]
44    ResponseSerialization {
45        /// Context describing what was being serialized.
46        context: String,
47        /// The underlying error.
48        #[source]
49        cause: Box<dyn std::error::Error + Send + Sync>,
50    },
51}
52
53async fn try_operation_from_field(field: Field<'_>) -> Result<Operation, BatchError> {
54    let kind = field
55        .headers()
56        .get(HEADER_BATCH_OPERATION_KIND)
57        .ok_or_else(|| {
58            BatchError::BadRequest(format!("missing {HEADER_BATCH_OPERATION_KIND} header"))
59        })?;
60    let kind = kind
61        .to_str()
62        .map_err(|_| {
63            BatchError::BadRequest(format!(
64                "unable to convert {HEADER_BATCH_OPERATION_KIND} header value to string"
65            ))
66        })?
67        .to_lowercase();
68
69    let key = field
70        .headers()
71        .get(HEADER_BATCH_OPERATION_KEY)
72        .map(|v| {
73            let s = v.to_str().map_err(|_| {
74                BatchError::BadRequest(format!(
75                    "unable to convert {HEADER_BATCH_OPERATION_KEY} header value to string"
76                ))
77            })?;
78            percent_encoding::percent_decode_str(s)
79                .decode_utf8()
80                .map(|decoded| decoded.into_owned())
81                .map_err(|_| {
82                    BatchError::BadRequest(format!(
83                        "unable to percent-decode {HEADER_BATCH_OPERATION_KEY} header value"
84                    ))
85                })
86        })
87        .transpose()?;
88
89    let operation = match kind.as_str() {
90        "get" => Operation::Get(Get {
91            key: key.ok_or_else(|| {
92                BatchError::BadRequest(format!(
93                    "missing {HEADER_BATCH_OPERATION_KEY} header for {kind} operation"
94                ))
95            })?,
96        }),
97        "delete" => Operation::Delete(Delete {
98            key: key.ok_or_else(|| {
99                BatchError::BadRequest(format!(
100                    "missing {HEADER_BATCH_OPERATION_KEY} header for {kind} operation"
101                ))
102            })?,
103        }),
104        "insert" => {
105            let metadata = Metadata::from_headers(field.headers(), "")?;
106            let payload = field.bytes().await?;
107            if payload.len() > MAX_FIELD_SIZE {
108                return Err(BatchError::LimitExceeded(format!(
109                    "individual request in batch exceeds body size limit of {MAX_FIELD_SIZE} bytes"
110                )));
111            }
112            Operation::Insert(Insert {
113                key,
114                metadata,
115                payload,
116            })
117        }
118        _ => {
119            return Err(BatchError::BadRequest(format!(
120                "invalid operation kind: {kind}"
121            )));
122        }
123    };
124    Ok(operation)
125}
126
127/// A lazily-parsed stream of batch operations extracted from a multipart request body.
128pub struct BatchOperationStream(pub BoxStream<'static, Result<Operation, BatchError>>);
129
130impl Debug for BatchOperationStream {
131    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
132        f.debug_struct("BatchOperationStream").finish()
133    }
134}
135
136const MAX_FIELD_SIZE: usize = 1024 * 1024; // 1 MB
137const MAX_OPERATIONS: usize = 1000;
138
139impl<S> FromRequest<S> for BatchOperationStream
140where
141    S: Send + Sync,
142{
143    type Rejection = MultipartRejection;
144
145    async fn from_request(request: Request, state: &S) -> Result<Self, Self::Rejection> {
146        let mut multipart = Multipart::from_request(request, state).await?;
147
148        let requests = async_stream::stream! {
149            let mut count = 0;
150            loop {
151                let field = match multipart.next_field().await {
152                    Ok(Some(field)) => field,
153                    Ok(None) => break,
154                    Err(e) => {
155                        yield Err(BatchError::from(e));
156                        continue;
157                    }
158                };
159                if count >= MAX_OPERATIONS {
160                    yield Err(BatchError::LimitExceeded(format!(
161                        "exceeded {MAX_OPERATIONS} operations per batch request"
162                    )));
163                    continue;
164                }
165                count += 1;
166                yield try_operation_from_field(field).await;
167            }
168        }
169        .boxed();
170
171        Ok(Self(requests))
172    }
173}
174
175#[cfg(test)]
176mod tests {
177    use std::time::Duration;
178
179    use super::*;
180    use axum::body::Body;
181    use axum::http::{Request, header::CONTENT_TYPE};
182    use futures::StreamExt;
183    use objectstore_service::streaming::Operation;
184    use objectstore_types::metadata::{ExpirationPolicy, HEADER_EXPIRATION, HEADER_ORIGIN};
185    use percent_encoding::NON_ALPHANUMERIC;
186
187    #[tokio::test]
188    async fn test_valid_request_works() {
189        let insert1_data = b"first blob data";
190        let insert2_data = b"second blob data";
191        let expiration = ExpirationPolicy::TimeToLive(Duration::from_hours(1));
192        let body = format!(
193            "--boundary\r\n\
194             {HEADER_BATCH_OPERATION_KEY}: {key0}\r\n\
195             {HEADER_BATCH_OPERATION_KIND}: get\r\n\
196             \r\n\
197             \r\n\
198             --boundary\r\n\
199             {HEADER_BATCH_OPERATION_KEY}: {key1}\r\n\
200             {HEADER_BATCH_OPERATION_KIND}: insert\r\n\
201             Content-Type: application/octet-stream\r\n\
202             \r\n\
203             {insert1}\r\n\
204             --boundary\r\n\
205             {HEADER_BATCH_OPERATION_KEY}: {key2}\r\n\
206             {HEADER_BATCH_OPERATION_KIND}: insert\r\n\
207             {HEADER_EXPIRATION}: {expiration}\r\n\
208             {HEADER_ORIGIN}: 203.0.113.42\r\n\
209             Content-Type: text/plain\r\n\
210             \r\n\
211             {insert2}\r\n\
212             --boundary\r\n\
213             {HEADER_BATCH_OPERATION_KEY}: {key3}\r\n\
214             {HEADER_BATCH_OPERATION_KIND}: delete\r\n\
215             \r\n\
216             \r\n\
217             --boundary--\r\n",
218            key0 = "test%2F0", // "test/0" percent-encoded
219            key1 = "test1",
220            key2 = "test2",
221            key3 = "test3",
222            insert1 = String::from_utf8_lossy(insert1_data),
223            insert2 = String::from_utf8_lossy(insert2_data),
224        );
225
226        let request = Request::builder()
227            .header(CONTENT_TYPE, "multipart/form-data; boundary=boundary")
228            .body(Body::from(body))
229            .unwrap();
230
231        let batch_request = BatchOperationStream::from_request(request, &())
232            .await
233            .unwrap();
234
235        let operations: Vec<_> = batch_request.0.collect().await;
236        assert_eq!(operations.len(), 4);
237
238        let Operation::Get(get_op) = &operations[0].as_ref().unwrap() else {
239            panic!("expected get operation");
240        };
241        assert_eq!(get_op.key, "test/0");
242
243        let Operation::Insert(insert_op1) = &operations[1].as_ref().unwrap() else {
244            panic!("expected insert operation");
245        };
246        assert_eq!(insert_op1.key.as_deref(), Some("test1"));
247        assert_eq!(insert_op1.metadata.content_type, "application/octet-stream");
248        assert_eq!(insert_op1.metadata.origin, None);
249        assert_eq!(insert_op1.payload.as_ref(), insert1_data);
250
251        let Operation::Insert(insert_op2) = &operations[2].as_ref().unwrap() else {
252            panic!("expected insert operation");
253        };
254        assert_eq!(insert_op2.key.as_deref(), Some("test2"));
255        assert_eq!(insert_op2.metadata.content_type, "text/plain");
256        assert_eq!(insert_op2.metadata.expiration_policy, expiration);
257        assert_eq!(insert_op2.metadata.origin.as_deref(), Some("203.0.113.42"));
258        assert_eq!(insert_op2.payload.as_ref(), insert2_data);
259
260        let Operation::Delete(delete_op) = &operations[3].as_ref().unwrap() else {
261            panic!("expected delete operation");
262        };
263        assert_eq!(delete_op.key, "test3");
264    }
265
266    #[tokio::test]
267    async fn test_insert_without_key_header() {
268        let body = format!(
269            "--boundary\r\n\
270             {HEADER_BATCH_OPERATION_KIND}: insert\r\n\
271             Content-Type: application/octet-stream\r\n\
272             \r\n\
273             keyless payload\r\n\
274             --boundary--\r\n",
275        );
276
277        let request = Request::builder()
278            .header(CONTENT_TYPE, "multipart/form-data; boundary=boundary")
279            .body(Body::from(body))
280            .unwrap();
281
282        let batch_request = BatchOperationStream::from_request(request, &())
283            .await
284            .unwrap();
285
286        let operations: Vec<_> = batch_request.0.collect().await;
287        assert_eq!(operations.len(), 1);
288
289        let Operation::Insert(insert_op) = &operations[0].as_ref().unwrap() else {
290            panic!("expected insert operation");
291        };
292        assert!(insert_op.key.is_none());
293        assert_eq!(insert_op.payload.as_ref(), b"keyless payload");
294    }
295
296    #[tokio::test]
297    async fn test_individual_errors_with_isolation() {
298        let large_payload = "x".repeat(MAX_FIELD_SIZE + 1);
299        let valid_key = percent_encoding::percent_encode(b"valid", NON_ALPHANUMERIC);
300        let body = format!(
301            "--boundary\r\n\
302             {HEADER_BATCH_OPERATION_KIND}: get\r\n\
303             \r\n\
304             \r\n\
305             --boundary\r\n\
306             {HEADER_BATCH_OPERATION_KEY}: {valid_key}\r\n\
307             {HEADER_BATCH_OPERATION_KIND}: get\r\n\
308             \r\n\
309             \r\n\
310             --boundary\r\n\
311             {HEADER_BATCH_OPERATION_KIND}: delete\r\n\
312             \r\n\
313             \r\n\
314             --boundary\r\n\
315             {HEADER_BATCH_OPERATION_KEY}: {valid_key}\r\n\
316             {HEADER_BATCH_OPERATION_KIND}: insert\r\n\
317             Content-Type: application/octet-stream\r\n\
318             \r\n\
319             {large_payload}\r\n\
320             --boundary\r\n\
321             {HEADER_BATCH_OPERATION_KEY}: {valid_key}\r\n\
322             {HEADER_BATCH_OPERATION_KIND}: delete\r\n\
323             \r\n\
324             \r\n\
325             --boundary--\r\n",
326        );
327
328        let request = Request::builder()
329            .header(CONTENT_TYPE, "multipart/form-data; boundary=boundary")
330            .body(Body::from(body))
331            .unwrap();
332
333        let batch_request = BatchOperationStream::from_request(request, &())
334            .await
335            .unwrap();
336
337        let operations: Vec<_> = batch_request.0.collect().await;
338        assert_eq!(operations.len(), 5);
339
340        // get without key → BadRequest
341        assert!(matches!(&operations[0], Err(BatchError::BadRequest(_))));
342        // valid get
343        assert!(matches!(
344            &operations[1].as_ref().unwrap(),
345            Operation::Get(g) if g.key == "valid"
346        ));
347        // delete without key → BadRequest
348        assert!(matches!(&operations[2], Err(BatchError::BadRequest(_))));
349        // oversized insert → LimitExceeded
350        assert!(matches!(&operations[3], Err(BatchError::LimitExceeded(_))));
351        // valid delete still succeeds after prior errors
352        assert!(matches!(
353            &operations[4].as_ref().unwrap(),
354            Operation::Delete(d) if d.key == "valid"
355        ));
356    }
357
358    #[tokio::test]
359    async fn test_max_operations_limit_enforced() {
360        let mut body = String::new();
361        for i in 0..(MAX_OPERATIONS + 1) {
362            let key =
363                percent_encoding::percent_encode(format!("test{i}").as_bytes(), NON_ALPHANUMERIC)
364                    .to_string();
365            body.push_str(&format!(
366                "--boundary\r\n\
367                 {HEADER_BATCH_OPERATION_KEY}: {key}\r\n\
368                 {HEADER_BATCH_OPERATION_KIND}: get\r\n\
369                 \r\n\
370                 \r\n"
371            ));
372        }
373        body.push_str("--boundary--\r\n");
374
375        let request = Request::builder()
376            .header(CONTENT_TYPE, "multipart/form-data; boundary=boundary")
377            .body(Body::from(body))
378            .unwrap();
379
380        let batch_request = BatchOperationStream::from_request(request, &())
381            .await
382            .unwrap();
383        let operations: Vec<_> = batch_request.0.collect().await;
384
385        assert_eq!(operations.len(), MAX_OPERATIONS + 1);
386        assert!(matches!(
387            &operations[MAX_OPERATIONS],
388            Err(BatchError::LimitExceeded(_))
389        ));
390    }
391}