objectstore_server/extractors/
batch.rs

1use std::fmt::Debug;
2
3use axum::extract::{
4    FromRequest, Multipart, Request,
5    multipart::{Field, MultipartError, MultipartRejection},
6};
7use base64::Engine;
8use base64::prelude::BASE64_STANDARD;
9use bytes::Bytes;
10use futures::{StreamExt, stream::BoxStream};
11use objectstore_service::id::ObjectKey;
12use objectstore_types::Metadata;
13use thiserror::Error;
14
15use crate::batch::{HEADER_BATCH_OPERATION_KEY, HEADER_BATCH_OPERATION_KIND};
16
17/// Errors that can occur when processing or executing batch operations.
18#[derive(Debug, Error)]
19pub enum BatchError {
20    /// Malformed request.
21    #[error("bad request: {0}")]
22    BadRequest(String),
23
24    /// Errors in parsing or reading a multipart request body.
25    #[error("multipart error: {0}")]
26    Multipart(#[from] MultipartError),
27
28    /// Errors related to de/serialization and parsing of object metadata.
29    #[error("metadata error: {0}")]
30    Metadata(#[from] objectstore_types::Error),
31
32    /// Size or cardinality limit exceeded.
33    #[error("batch limit exceeded: {0}")]
34    LimitExceeded(String),
35
36    /// Operation rejected due to rate limiting.
37    #[error("rate limited")]
38    RateLimited,
39
40    /// Errors encountered when serializing batch response parts.
41    #[error("response part serialization error: {context}")]
42    ResponseSerialization {
43        /// Context describing what was being serialized.
44        context: String,
45        /// The underlying error.
46        #[source]
47        cause: Box<dyn std::error::Error + Send + Sync>,
48    },
49}
50
51#[derive(Debug)]
52pub struct GetOperation {
53    pub key: ObjectKey,
54}
55
56#[derive(Debug)]
57pub struct InsertOperation {
58    pub key: ObjectKey,
59    pub metadata: Metadata,
60    pub payload: Bytes,
61}
62
63#[derive(Debug)]
64pub struct DeleteOperation {
65    pub key: ObjectKey,
66}
67
68#[derive(Debug)]
69pub enum Operation {
70    Get(GetOperation),
71    Insert(InsertOperation),
72    Delete(DeleteOperation),
73}
74
75impl Operation {
76    async fn try_from_field(field: Field<'_>) -> Result<Self, BatchError> {
77        let kind = field
78            .headers()
79            .get(HEADER_BATCH_OPERATION_KIND)
80            .ok_or_else(|| {
81                BatchError::BadRequest(format!("missing {HEADER_BATCH_OPERATION_KIND} header"))
82            })?;
83        let kind = kind
84            .to_str()
85            .map_err(|_| {
86                BatchError::BadRequest(format!(
87                    "unable to convert {HEADER_BATCH_OPERATION_KIND} header value to string"
88                ))
89            })?
90            .to_lowercase();
91
92        let key_header = field
93            .headers()
94            .get(HEADER_BATCH_OPERATION_KEY)
95            .ok_or_else(|| {
96                BatchError::BadRequest(format!("missing {HEADER_BATCH_OPERATION_KEY} header"))
97            })?
98            .to_str()
99            .map_err(|_| {
100                BatchError::BadRequest(format!(
101                    "unable to convert {HEADER_BATCH_OPERATION_KEY} header value to string"
102                ))
103            })?;
104        let key_bytes = BASE64_STANDARD.decode(key_header).map_err(|_| {
105            BatchError::BadRequest(format!(
106                "unable to base64 decode {HEADER_BATCH_OPERATION_KEY} header value"
107            ))
108        })?;
109        let key = String::from_utf8(key_bytes).map_err(|_| {
110            BatchError::BadRequest(format!(
111                "{HEADER_BATCH_OPERATION_KEY} header value is not valid UTF-8"
112            ))
113        })?;
114
115        let operation = match kind.as_str() {
116            "get" => Operation::Get(GetOperation { key }),
117            "insert" => {
118                let metadata = Metadata::from_headers(field.headers(), "")?;
119                let payload = field.bytes().await?;
120                if payload.len() > MAX_FIELD_SIZE {
121                    return Err(BatchError::LimitExceeded(format!(
122                        "individual request in batch exceeds body size limit of {MAX_FIELD_SIZE} bytes"
123                    )));
124                }
125                Operation::Insert(InsertOperation {
126                    key,
127                    metadata,
128                    payload,
129                })
130            }
131            "delete" => Operation::Delete(DeleteOperation { key }),
132            _ => {
133                return Err(BatchError::BadRequest(format!(
134                    "invalid operation kind: {kind}"
135                )));
136            }
137        };
138        Ok(operation)
139    }
140
141    pub fn key(&self) -> &ObjectKey {
142        match self {
143            Operation::Get(op) => &op.key,
144            Operation::Insert(op) => &op.key,
145            Operation::Delete(op) => &op.key,
146        }
147    }
148}
149
150pub struct BatchOperationStream(pub BoxStream<'static, Result<Operation, BatchError>>);
151
152impl Debug for BatchOperationStream {
153    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
154        f.debug_struct("BatchOperationStream").finish()
155    }
156}
157
158const MAX_FIELD_SIZE: usize = 1024 * 1024; // 1 MB
159const MAX_OPERATIONS: usize = 1000;
160
161impl<S> FromRequest<S> for BatchOperationStream
162where
163    S: Send + Sync,
164{
165    type Rejection = MultipartRejection;
166
167    async fn from_request(request: Request, state: &S) -> Result<Self, Self::Rejection> {
168        let mut multipart = Multipart::from_request(request, state).await?;
169
170        let requests = async_stream::try_stream! {
171            let mut count = 0;
172            while let Some(field) = multipart.next_field().await? {
173                if count >= MAX_OPERATIONS {
174                    Err(BatchError::LimitExceeded(format!(
175                        "exceeded {MAX_OPERATIONS} operations per batch request"
176                    )))?;
177                }
178                count += 1;
179                yield Operation::try_from_field(field).await?;
180            }
181        }
182        .boxed();
183
184        Ok(Self(requests))
185    }
186}
187
188#[cfg(test)]
189mod tests {
190    use std::time::Duration;
191
192    use super::*;
193    use axum::body::Body;
194    use axum::http::{Request, header::CONTENT_TYPE};
195    use futures::StreamExt;
196    use objectstore_types::{ExpirationPolicy, HEADER_EXPIRATION};
197
198    #[tokio::test]
199    async fn test_valid_request_works() {
200        let insert1_data = b"first blob data";
201        let insert2_data = b"second blob data";
202        let expiration = ExpirationPolicy::TimeToLive(Duration::from_hours(1));
203        let body = format!(
204            "--boundary\r\n\
205             {HEADER_BATCH_OPERATION_KEY}: {key0}\r\n\
206             {HEADER_BATCH_OPERATION_KIND}: get\r\n\
207             \r\n\
208             \r\n\
209             --boundary\r\n\
210             {HEADER_BATCH_OPERATION_KEY}: {key1}\r\n\
211             {HEADER_BATCH_OPERATION_KIND}: insert\r\n\
212             Content-Type: application/octet-stream\r\n\
213             \r\n\
214             {insert1}\r\n\
215             --boundary\r\n\
216             {HEADER_BATCH_OPERATION_KEY}: {key2}\r\n\
217             {HEADER_BATCH_OPERATION_KIND}: insert\r\n\
218             {HEADER_EXPIRATION}: {expiration}\r\n\
219             Content-Type: text/plain\r\n\
220             \r\n\
221             {insert2}\r\n\
222             --boundary\r\n\
223             {HEADER_BATCH_OPERATION_KEY}: {key3}\r\n\
224             {HEADER_BATCH_OPERATION_KIND}: delete\r\n\
225             \r\n\
226             \r\n\
227             --boundary--\r\n",
228            key0 = BASE64_STANDARD.encode("test0"),
229            key1 = BASE64_STANDARD.encode("test1"),
230            key2 = BASE64_STANDARD.encode("test2"),
231            key3 = BASE64_STANDARD.encode("test3"),
232            insert1 = String::from_utf8_lossy(insert1_data),
233            insert2 = String::from_utf8_lossy(insert2_data),
234        );
235
236        let request = Request::builder()
237            .header(CONTENT_TYPE, "multipart/form-data; boundary=boundary")
238            .body(Body::from(body))
239            .unwrap();
240
241        let batch_request = BatchOperationStream::from_request(request, &())
242            .await
243            .unwrap();
244
245        let operations: Vec<_> = batch_request.0.collect().await;
246        assert_eq!(operations.len(), 4);
247
248        let Operation::Get(get_op) = &operations[0].as_ref().unwrap() else {
249            panic!("expected get operation");
250        };
251        assert_eq!(get_op.key, "test0");
252
253        let Operation::Insert(insert_op1) = &operations[1].as_ref().unwrap() else {
254            panic!("expected insert operation");
255        };
256        assert_eq!(insert_op1.key, "test1");
257        assert_eq!(insert_op1.metadata.content_type, "application/octet-stream");
258        assert_eq!(insert_op1.payload.as_ref(), insert1_data);
259
260        let Operation::Insert(insert_op2) = &operations[2].as_ref().unwrap() else {
261            panic!("expected insert operation");
262        };
263        assert_eq!(insert_op2.key, "test2");
264        assert_eq!(insert_op2.metadata.content_type, "text/plain");
265        assert_eq!(insert_op2.metadata.expiration_policy, expiration);
266        assert_eq!(insert_op2.payload.as_ref(), insert2_data);
267
268        let Operation::Delete(delete_op) = &operations[3].as_ref().unwrap() else {
269            panic!("expected delete operation");
270        };
271        assert_eq!(delete_op.key, "test3");
272    }
273
274    #[tokio::test]
275    async fn test_max_operations_limit_enforced() {
276        let mut body = String::new();
277        for i in 0..(MAX_OPERATIONS + 1) {
278            let key = BASE64_STANDARD.encode(format!("test{i}"));
279            body.push_str(&format!(
280                "--boundary\r\n\
281                 {HEADER_BATCH_OPERATION_KEY}: {key}\r\n\
282                 {HEADER_BATCH_OPERATION_KIND}: get\r\n\
283                 \r\n\
284                 \r\n"
285            ));
286        }
287        body.push_str("--boundary--\r\n");
288
289        let request = Request::builder()
290            .header(CONTENT_TYPE, "multipart/form-data; boundary=boundary")
291            .body(Body::from(body))
292            .unwrap();
293
294        let batch_request = BatchOperationStream::from_request(request, &())
295            .await
296            .unwrap();
297        let operations: Vec<_> = batch_request.0.collect().await;
298
299        assert_eq!(operations.len(), MAX_OPERATIONS + 1);
300        matches!(
301            &operations[MAX_OPERATIONS],
302            Err(BatchError::LimitExceeded(_))
303        );
304    }
305
306    #[tokio::test]
307    async fn test_operation_body_size_limit_enforced() {
308        let large_payload = "x".repeat(MAX_FIELD_SIZE + 1);
309        let key = BASE64_STANDARD.encode("test");
310        let body = format!(
311            "--boundary\r\n\
312             {HEADER_BATCH_OPERATION_KEY}: {key}\r\n\
313             {HEADER_BATCH_OPERATION_KIND}: insert\r\n\
314             Content-Type: application/octet-stream\r\n\
315             \r\n\
316             {large_payload}\r\n\
317             --boundary--\r\n",
318        );
319
320        let request = Request::builder()
321            .header(CONTENT_TYPE, "multipart/form-data; boundary=boundary")
322            .body(Body::from(body))
323            .unwrap();
324
325        let batch_request = BatchOperationStream::from_request(request, &())
326            .await
327            .unwrap();
328        let operations: Vec<_> = batch_request.0.collect().await;
329
330        assert_eq!(operations.len(), 1);
331        assert!(matches!(&operations[0], Err(BatchError::LimitExceeded(_))));
332    }
333}