Skip to main content

objectstore_server/extractors/
batch.rs

1//! Axum extractor for batch operation streams.
2//!
3//! Provides [`BatchOperationStream`], which parses a multipart request body into a
4//! lazy stream of [`Operation`]s.
5
6use std::fmt::Debug;
7
8use axum::extract::{
9    FromRequest, Multipart, Request,
10    multipart::{Field, MultipartError, MultipartRejection},
11};
12use bytes::BytesMut;
13use futures::{StreamExt, stream::BoxStream};
14use objectstore_service::streaming::{Delete, Get, Insert, Operation};
15use objectstore_types::metadata::Metadata;
16use thiserror::Error;
17
18use crate::batch::{HEADER_BATCH_OPERATION_KEY, HEADER_BATCH_OPERATION_KIND};
19
20/// Errors that can occur when processing or executing batch operations.
21#[derive(Debug, Error)]
22pub enum BatchError {
23    /// Malformed request.
24    #[error("bad request: {0}")]
25    BadRequest(String),
26
27    /// Errors in parsing or reading a multipart request body.
28    #[error("multipart error: {0}")]
29    Multipart(#[from] MultipartError),
30
31    /// Errors related to de/serialization and parsing of object metadata.
32    #[error("metadata error: {0}")]
33    Metadata(#[from] objectstore_types::metadata::Error),
34
35    /// Size or cardinality limit exceeded.
36    #[error("batch limit exceeded: {0}")]
37    LimitExceeded(String),
38
39    /// Operation rejected due to rate limiting.
40    #[error("rate limited")]
41    RateLimited,
42
43    /// Errors encountered when serializing batch response parts.
44    #[error("response part serialization error: {context}")]
45    ResponseSerialization {
46        /// Context describing what was being serialized.
47        context: String,
48        /// The underlying error.
49        #[source]
50        cause: Box<dyn std::error::Error + Send + Sync>,
51    },
52}
53
54async fn try_operation_from_field(mut field: Field<'_>) -> Result<Operation, BatchError> {
55    let kind = field
56        .headers()
57        .get(HEADER_BATCH_OPERATION_KIND)
58        .ok_or_else(|| {
59            BatchError::BadRequest(format!("missing {HEADER_BATCH_OPERATION_KIND} header"))
60        })?;
61    let kind = kind
62        .to_str()
63        .map_err(|_| {
64            BatchError::BadRequest(format!(
65                "unable to convert {HEADER_BATCH_OPERATION_KIND} header value to string"
66            ))
67        })?
68        .to_lowercase();
69
70    let key = field
71        .headers()
72        .get(HEADER_BATCH_OPERATION_KEY)
73        .map(|v| {
74            let s = v.to_str().map_err(|_| {
75                BatchError::BadRequest(format!(
76                    "unable to convert {HEADER_BATCH_OPERATION_KEY} header value to string"
77                ))
78            })?;
79            percent_encoding::percent_decode_str(s)
80                .decode_utf8()
81                .map(|decoded| decoded.into_owned())
82                .map_err(|_| {
83                    BatchError::BadRequest(format!(
84                        "unable to percent-decode {HEADER_BATCH_OPERATION_KEY} header value"
85                    ))
86                })
87        })
88        .transpose()?;
89
90    let operation = match kind.as_str() {
91        "get" => Operation::Get(Get {
92            key: key.ok_or_else(|| {
93                BatchError::BadRequest(format!(
94                    "missing {HEADER_BATCH_OPERATION_KEY} header for {kind} operation"
95                ))
96            })?,
97        }),
98        "delete" => Operation::Delete(Delete {
99            key: key.ok_or_else(|| {
100                BatchError::BadRequest(format!(
101                    "missing {HEADER_BATCH_OPERATION_KEY} header for {kind} operation"
102                ))
103            })?,
104        }),
105        "insert" => {
106            let metadata = Metadata::from_headers(field.headers(), "")?;
107            let mut payload = BytesMut::new();
108            while let Some(chunk) = field.chunk().await? {
109                if payload.len() + chunk.len() > MAX_FIELD_SIZE {
110                    return Err(BatchError::LimitExceeded(format!(
111                        "individual request in batch exceeds body size limit of {MAX_FIELD_SIZE} bytes"
112                    )));
113                }
114                payload.extend_from_slice(&chunk);
115            }
116            Operation::Insert(Insert {
117                key,
118                metadata,
119                payload: payload.freeze(),
120            })
121        }
122        _ => {
123            return Err(BatchError::BadRequest(format!(
124                "invalid operation kind: {kind}"
125            )));
126        }
127    };
128    Ok(operation)
129}
130
131/// A lazily-parsed stream of batch operations extracted from a multipart request body.
132pub struct BatchOperationStream(pub BoxStream<'static, Result<Operation, BatchError>>);
133
134impl Debug for BatchOperationStream {
135    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
136        f.debug_struct("BatchOperationStream").finish()
137    }
138}
139
140const MAX_FIELD_SIZE: usize = 1024 * 1024; // 1 MB
141const MAX_OPERATIONS: usize = 1000;
142
143impl<S> FromRequest<S> for BatchOperationStream
144where
145    S: Send + Sync,
146{
147    type Rejection = MultipartRejection;
148
149    async fn from_request(request: Request, state: &S) -> Result<Self, Self::Rejection> {
150        let mut multipart = Multipart::from_request(request, state).await?;
151
152        let requests = async_stream::stream! {
153            let mut count = 0;
154            loop {
155                let field = match multipart.next_field().await {
156                    Ok(Some(field)) => field,
157                    Ok(None) => break,
158                    Err(e) => {
159                        yield Err(BatchError::from(e));
160                        continue;
161                    }
162                };
163                if count >= MAX_OPERATIONS {
164                    yield Err(BatchError::LimitExceeded(format!(
165                        "exceeded {MAX_OPERATIONS} operations per batch request"
166                    )));
167                    continue;
168                }
169                count += 1;
170                yield try_operation_from_field(field).await;
171            }
172        }
173        .boxed();
174
175        Ok(Self(requests))
176    }
177}
178
179#[cfg(test)]
180mod tests {
181    use std::time::Duration;
182
183    use super::*;
184    use axum::body::Body;
185    use axum::http::{Request, header::CONTENT_TYPE};
186    use futures::StreamExt;
187    use objectstore_service::streaming::Operation;
188    use objectstore_types::metadata::{ExpirationPolicy, HEADER_EXPIRATION, HEADER_ORIGIN};
189    use percent_encoding::NON_ALPHANUMERIC;
190
191    #[tokio::test]
192    async fn test_valid_request_works() {
193        let insert1_data = b"first blob data";
194        let insert2_data = b"second blob data";
195        let expiration = ExpirationPolicy::TimeToLive(Duration::from_hours(1));
196        let body = format!(
197            "--boundary\r\n\
198             {HEADER_BATCH_OPERATION_KEY}: {key0}\r\n\
199             {HEADER_BATCH_OPERATION_KIND}: get\r\n\
200             \r\n\
201             \r\n\
202             --boundary\r\n\
203             {HEADER_BATCH_OPERATION_KEY}: {key1}\r\n\
204             {HEADER_BATCH_OPERATION_KIND}: insert\r\n\
205             Content-Type: application/octet-stream\r\n\
206             \r\n\
207             {insert1}\r\n\
208             --boundary\r\n\
209             {HEADER_BATCH_OPERATION_KEY}: {key2}\r\n\
210             {HEADER_BATCH_OPERATION_KIND}: insert\r\n\
211             {HEADER_EXPIRATION}: {expiration}\r\n\
212             {HEADER_ORIGIN}: 203.0.113.42\r\n\
213             Content-Type: text/plain\r\n\
214             \r\n\
215             {insert2}\r\n\
216             --boundary\r\n\
217             {HEADER_BATCH_OPERATION_KEY}: {key3}\r\n\
218             {HEADER_BATCH_OPERATION_KIND}: delete\r\n\
219             \r\n\
220             \r\n\
221             --boundary--\r\n",
222            key0 = "test%2F0", // "test/0" percent-encoded
223            key1 = "test1",
224            key2 = "test2",
225            key3 = "test3",
226            insert1 = String::from_utf8_lossy(insert1_data),
227            insert2 = String::from_utf8_lossy(insert2_data),
228        );
229
230        let request = Request::builder()
231            .header(CONTENT_TYPE, "multipart/form-data; boundary=boundary")
232            .body(Body::from(body))
233            .unwrap();
234
235        let batch_request = BatchOperationStream::from_request(request, &())
236            .await
237            .unwrap();
238
239        let operations: Vec<_> = batch_request.0.collect().await;
240        assert_eq!(operations.len(), 4);
241
242        let Operation::Get(get_op) = &operations[0].as_ref().unwrap() else {
243            panic!("expected get operation");
244        };
245        assert_eq!(get_op.key, "test/0");
246
247        let Operation::Insert(insert_op1) = &operations[1].as_ref().unwrap() else {
248            panic!("expected insert operation");
249        };
250        assert_eq!(insert_op1.key.as_deref(), Some("test1"));
251        assert_eq!(insert_op1.metadata.content_type, "application/octet-stream");
252        assert_eq!(insert_op1.metadata.origin, None);
253        assert_eq!(insert_op1.payload.as_ref(), insert1_data);
254
255        let Operation::Insert(insert_op2) = &operations[2].as_ref().unwrap() else {
256            panic!("expected insert operation");
257        };
258        assert_eq!(insert_op2.key.as_deref(), Some("test2"));
259        assert_eq!(insert_op2.metadata.content_type, "text/plain");
260        assert_eq!(insert_op2.metadata.expiration_policy, expiration);
261        assert_eq!(insert_op2.metadata.origin.as_deref(), Some("203.0.113.42"));
262        assert_eq!(insert_op2.payload.as_ref(), insert2_data);
263
264        let Operation::Delete(delete_op) = &operations[3].as_ref().unwrap() else {
265            panic!("expected delete operation");
266        };
267        assert_eq!(delete_op.key, "test3");
268    }
269
270    #[tokio::test]
271    async fn test_insert_without_key_header() {
272        let body = format!(
273            "--boundary\r\n\
274             {HEADER_BATCH_OPERATION_KIND}: insert\r\n\
275             Content-Type: application/octet-stream\r\n\
276             \r\n\
277             keyless payload\r\n\
278             --boundary--\r\n",
279        );
280
281        let request = Request::builder()
282            .header(CONTENT_TYPE, "multipart/form-data; boundary=boundary")
283            .body(Body::from(body))
284            .unwrap();
285
286        let batch_request = BatchOperationStream::from_request(request, &())
287            .await
288            .unwrap();
289
290        let operations: Vec<_> = batch_request.0.collect().await;
291        assert_eq!(operations.len(), 1);
292
293        let Operation::Insert(insert_op) = &operations[0].as_ref().unwrap() else {
294            panic!("expected insert operation");
295        };
296        assert!(insert_op.key.is_none());
297        assert_eq!(insert_op.payload.as_ref(), b"keyless payload");
298    }
299
300    #[tokio::test]
301    async fn test_individual_errors_with_isolation() {
302        let large_payload = "x".repeat(MAX_FIELD_SIZE + 1);
303        let valid_key = percent_encoding::percent_encode(b"valid", NON_ALPHANUMERIC);
304        let body = format!(
305            "--boundary\r\n\
306             {HEADER_BATCH_OPERATION_KIND}: get\r\n\
307             \r\n\
308             \r\n\
309             --boundary\r\n\
310             {HEADER_BATCH_OPERATION_KEY}: {valid_key}\r\n\
311             {HEADER_BATCH_OPERATION_KIND}: get\r\n\
312             \r\n\
313             \r\n\
314             --boundary\r\n\
315             {HEADER_BATCH_OPERATION_KIND}: delete\r\n\
316             \r\n\
317             \r\n\
318             --boundary\r\n\
319             {HEADER_BATCH_OPERATION_KEY}: {valid_key}\r\n\
320             {HEADER_BATCH_OPERATION_KIND}: insert\r\n\
321             Content-Type: application/octet-stream\r\n\
322             \r\n\
323             {large_payload}\r\n\
324             --boundary\r\n\
325             {HEADER_BATCH_OPERATION_KEY}: {valid_key}\r\n\
326             {HEADER_BATCH_OPERATION_KIND}: delete\r\n\
327             \r\n\
328             \r\n\
329             --boundary--\r\n",
330        );
331
332        let request = Request::builder()
333            .header(CONTENT_TYPE, "multipart/form-data; boundary=boundary")
334            .body(Body::from(body))
335            .unwrap();
336
337        let batch_request = BatchOperationStream::from_request(request, &())
338            .await
339            .unwrap();
340
341        let operations: Vec<_> = batch_request.0.collect().await;
342        assert_eq!(operations.len(), 5);
343
344        // get without key → BadRequest
345        assert!(matches!(&operations[0], Err(BatchError::BadRequest(_))));
346        // valid get
347        assert!(matches!(
348            &operations[1].as_ref().unwrap(),
349            Operation::Get(g) if g.key == "valid"
350        ));
351        // delete without key → BadRequest
352        assert!(matches!(&operations[2], Err(BatchError::BadRequest(_))));
353        // oversized insert → LimitExceeded
354        assert!(matches!(&operations[3], Err(BatchError::LimitExceeded(_))));
355        // valid delete still succeeds after prior errors
356        assert!(matches!(
357            &operations[4].as_ref().unwrap(),
358            Operation::Delete(d) if d.key == "valid"
359        ));
360    }
361
362    #[tokio::test]
363    async fn test_max_operations_limit_enforced() {
364        let mut body = String::new();
365        for i in 0..(MAX_OPERATIONS + 1) {
366            let key =
367                percent_encoding::percent_encode(format!("test{i}").as_bytes(), NON_ALPHANUMERIC)
368                    .to_string();
369            body.push_str(&format!(
370                "--boundary\r\n\
371                 {HEADER_BATCH_OPERATION_KEY}: {key}\r\n\
372                 {HEADER_BATCH_OPERATION_KIND}: get\r\n\
373                 \r\n\
374                 \r\n"
375            ));
376        }
377        body.push_str("--boundary--\r\n");
378
379        let request = Request::builder()
380            .header(CONTENT_TYPE, "multipart/form-data; boundary=boundary")
381            .body(Body::from(body))
382            .unwrap();
383
384        let batch_request = BatchOperationStream::from_request(request, &())
385            .await
386            .unwrap();
387        let operations: Vec<_> = batch_request.0.collect().await;
388
389        assert_eq!(operations.len(), MAX_OPERATIONS + 1);
390        assert!(matches!(
391            &operations[MAX_OPERATIONS],
392            Err(BatchError::LimitExceeded(_))
393        ));
394    }
395}