objectstore_server/extractors/
batch.rs1use std::fmt::Debug;
7
8use axum::extract::{
9 FromRequest, Multipart, Request,
10 multipart::{Field, MultipartError, MultipartRejection},
11};
12use bytes::BytesMut;
13use futures::{StreamExt, stream::BoxStream};
14use objectstore_service::streaming::{Delete, Get, Insert, Operation};
15use objectstore_types::metadata::Metadata;
16use thiserror::Error;
17
18use crate::batch::{HEADER_BATCH_OPERATION_KEY, HEADER_BATCH_OPERATION_KIND};
19
20#[derive(Debug, Error)]
22pub enum BatchError {
23 #[error("bad request: {0}")]
25 BadRequest(String),
26
27 #[error("multipart error: {0}")]
29 Multipart(#[from] MultipartError),
30
31 #[error("metadata error: {0}")]
33 Metadata(#[from] objectstore_types::metadata::Error),
34
35 #[error("batch limit exceeded: {0}")]
37 LimitExceeded(String),
38
39 #[error("rate limited")]
41 RateLimited,
42
43 #[error("response part serialization error: {context}")]
45 ResponseSerialization {
46 context: String,
48 #[source]
50 cause: Box<dyn std::error::Error + Send + Sync>,
51 },
52}
53
54async fn try_operation_from_field(mut field: Field<'_>) -> Result<Operation, BatchError> {
55 let kind = field
56 .headers()
57 .get(HEADER_BATCH_OPERATION_KIND)
58 .ok_or_else(|| {
59 BatchError::BadRequest(format!("missing {HEADER_BATCH_OPERATION_KIND} header"))
60 })?;
61 let kind = kind
62 .to_str()
63 .map_err(|_| {
64 BatchError::BadRequest(format!(
65 "unable to convert {HEADER_BATCH_OPERATION_KIND} header value to string"
66 ))
67 })?
68 .to_lowercase();
69
70 let key = field
71 .headers()
72 .get(HEADER_BATCH_OPERATION_KEY)
73 .map(|v| {
74 let s = v.to_str().map_err(|_| {
75 BatchError::BadRequest(format!(
76 "unable to convert {HEADER_BATCH_OPERATION_KEY} header value to string"
77 ))
78 })?;
79 percent_encoding::percent_decode_str(s)
80 .decode_utf8()
81 .map(|decoded| decoded.into_owned())
82 .map_err(|_| {
83 BatchError::BadRequest(format!(
84 "unable to percent-decode {HEADER_BATCH_OPERATION_KEY} header value"
85 ))
86 })
87 })
88 .transpose()?;
89
90 let operation = match kind.as_str() {
91 "get" => Operation::Get(Get {
92 key: key.ok_or_else(|| {
93 BatchError::BadRequest(format!(
94 "missing {HEADER_BATCH_OPERATION_KEY} header for {kind} operation"
95 ))
96 })?,
97 }),
98 "delete" => Operation::Delete(Delete {
99 key: key.ok_or_else(|| {
100 BatchError::BadRequest(format!(
101 "missing {HEADER_BATCH_OPERATION_KEY} header for {kind} operation"
102 ))
103 })?,
104 }),
105 "insert" => {
106 let metadata = Metadata::from_headers(field.headers(), "")?;
107 let mut payload = BytesMut::new();
108 while let Some(chunk) = field.chunk().await? {
109 if payload.len() + chunk.len() > MAX_FIELD_SIZE {
110 return Err(BatchError::LimitExceeded(format!(
111 "individual request in batch exceeds body size limit of {MAX_FIELD_SIZE} bytes"
112 )));
113 }
114 payload.extend_from_slice(&chunk);
115 }
116 Operation::Insert(Insert {
117 key,
118 metadata,
119 payload: payload.freeze(),
120 })
121 }
122 _ => {
123 return Err(BatchError::BadRequest(format!(
124 "invalid operation kind: {kind}"
125 )));
126 }
127 };
128 Ok(operation)
129}
130
131pub struct BatchOperationStream(pub BoxStream<'static, Result<Operation, BatchError>>);
133
134impl Debug for BatchOperationStream {
135 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
136 f.debug_struct("BatchOperationStream").finish()
137 }
138}
139
140const MAX_FIELD_SIZE: usize = 1024 * 1024; const MAX_OPERATIONS: usize = 1000;
142
143impl<S> FromRequest<S> for BatchOperationStream
144where
145 S: Send + Sync,
146{
147 type Rejection = MultipartRejection;
148
149 async fn from_request(request: Request, state: &S) -> Result<Self, Self::Rejection> {
150 let mut multipart = Multipart::from_request(request, state).await?;
151
152 let requests = async_stream::stream! {
153 let mut count = 0;
154 loop {
155 let field = match multipart.next_field().await {
156 Ok(Some(field)) => field,
157 Ok(None) => break,
158 Err(e) => {
159 yield Err(BatchError::from(e));
160 continue;
161 }
162 };
163 if count >= MAX_OPERATIONS {
164 yield Err(BatchError::LimitExceeded(format!(
165 "exceeded {MAX_OPERATIONS} operations per batch request"
166 )));
167 continue;
168 }
169 count += 1;
170 yield try_operation_from_field(field).await;
171 }
172 }
173 .boxed();
174
175 Ok(Self(requests))
176 }
177}
178
179#[cfg(test)]
180mod tests {
181 use std::time::Duration;
182
183 use super::*;
184 use axum::body::Body;
185 use axum::http::{Request, header::CONTENT_TYPE};
186 use futures::StreamExt;
187 use objectstore_service::streaming::Operation;
188 use objectstore_types::metadata::{ExpirationPolicy, HEADER_EXPIRATION, HEADER_ORIGIN};
189 use percent_encoding::NON_ALPHANUMERIC;
190
191 #[tokio::test]
192 async fn test_valid_request_works() {
193 let insert1_data = b"first blob data";
194 let insert2_data = b"second blob data";
195 let expiration = ExpirationPolicy::TimeToLive(Duration::from_hours(1));
196 let body = format!(
197 "--boundary\r\n\
198 {HEADER_BATCH_OPERATION_KEY}: {key0}\r\n\
199 {HEADER_BATCH_OPERATION_KIND}: get\r\n\
200 \r\n\
201 \r\n\
202 --boundary\r\n\
203 {HEADER_BATCH_OPERATION_KEY}: {key1}\r\n\
204 {HEADER_BATCH_OPERATION_KIND}: insert\r\n\
205 Content-Type: application/octet-stream\r\n\
206 \r\n\
207 {insert1}\r\n\
208 --boundary\r\n\
209 {HEADER_BATCH_OPERATION_KEY}: {key2}\r\n\
210 {HEADER_BATCH_OPERATION_KIND}: insert\r\n\
211 {HEADER_EXPIRATION}: {expiration}\r\n\
212 {HEADER_ORIGIN}: 203.0.113.42\r\n\
213 Content-Type: text/plain\r\n\
214 \r\n\
215 {insert2}\r\n\
216 --boundary\r\n\
217 {HEADER_BATCH_OPERATION_KEY}: {key3}\r\n\
218 {HEADER_BATCH_OPERATION_KIND}: delete\r\n\
219 \r\n\
220 \r\n\
221 --boundary--\r\n",
222 key0 = "test%2F0", key1 = "test1",
224 key2 = "test2",
225 key3 = "test3",
226 insert1 = String::from_utf8_lossy(insert1_data),
227 insert2 = String::from_utf8_lossy(insert2_data),
228 );
229
230 let request = Request::builder()
231 .header(CONTENT_TYPE, "multipart/form-data; boundary=boundary")
232 .body(Body::from(body))
233 .unwrap();
234
235 let batch_request = BatchOperationStream::from_request(request, &())
236 .await
237 .unwrap();
238
239 let operations: Vec<_> = batch_request.0.collect().await;
240 assert_eq!(operations.len(), 4);
241
242 let Operation::Get(get_op) = &operations[0].as_ref().unwrap() else {
243 panic!("expected get operation");
244 };
245 assert_eq!(get_op.key, "test/0");
246
247 let Operation::Insert(insert_op1) = &operations[1].as_ref().unwrap() else {
248 panic!("expected insert operation");
249 };
250 assert_eq!(insert_op1.key.as_deref(), Some("test1"));
251 assert_eq!(insert_op1.metadata.content_type, "application/octet-stream");
252 assert_eq!(insert_op1.metadata.origin, None);
253 assert_eq!(insert_op1.payload.as_ref(), insert1_data);
254
255 let Operation::Insert(insert_op2) = &operations[2].as_ref().unwrap() else {
256 panic!("expected insert operation");
257 };
258 assert_eq!(insert_op2.key.as_deref(), Some("test2"));
259 assert_eq!(insert_op2.metadata.content_type, "text/plain");
260 assert_eq!(insert_op2.metadata.expiration_policy, expiration);
261 assert_eq!(insert_op2.metadata.origin.as_deref(), Some("203.0.113.42"));
262 assert_eq!(insert_op2.payload.as_ref(), insert2_data);
263
264 let Operation::Delete(delete_op) = &operations[3].as_ref().unwrap() else {
265 panic!("expected delete operation");
266 };
267 assert_eq!(delete_op.key, "test3");
268 }
269
270 #[tokio::test]
271 async fn test_insert_without_key_header() {
272 let body = format!(
273 "--boundary\r\n\
274 {HEADER_BATCH_OPERATION_KIND}: insert\r\n\
275 Content-Type: application/octet-stream\r\n\
276 \r\n\
277 keyless payload\r\n\
278 --boundary--\r\n",
279 );
280
281 let request = Request::builder()
282 .header(CONTENT_TYPE, "multipart/form-data; boundary=boundary")
283 .body(Body::from(body))
284 .unwrap();
285
286 let batch_request = BatchOperationStream::from_request(request, &())
287 .await
288 .unwrap();
289
290 let operations: Vec<_> = batch_request.0.collect().await;
291 assert_eq!(operations.len(), 1);
292
293 let Operation::Insert(insert_op) = &operations[0].as_ref().unwrap() else {
294 panic!("expected insert operation");
295 };
296 assert!(insert_op.key.is_none());
297 assert_eq!(insert_op.payload.as_ref(), b"keyless payload");
298 }
299
300 #[tokio::test]
301 async fn test_individual_errors_with_isolation() {
302 let large_payload = "x".repeat(MAX_FIELD_SIZE + 1);
303 let valid_key = percent_encoding::percent_encode(b"valid", NON_ALPHANUMERIC);
304 let body = format!(
305 "--boundary\r\n\
306 {HEADER_BATCH_OPERATION_KIND}: get\r\n\
307 \r\n\
308 \r\n\
309 --boundary\r\n\
310 {HEADER_BATCH_OPERATION_KEY}: {valid_key}\r\n\
311 {HEADER_BATCH_OPERATION_KIND}: get\r\n\
312 \r\n\
313 \r\n\
314 --boundary\r\n\
315 {HEADER_BATCH_OPERATION_KIND}: delete\r\n\
316 \r\n\
317 \r\n\
318 --boundary\r\n\
319 {HEADER_BATCH_OPERATION_KEY}: {valid_key}\r\n\
320 {HEADER_BATCH_OPERATION_KIND}: insert\r\n\
321 Content-Type: application/octet-stream\r\n\
322 \r\n\
323 {large_payload}\r\n\
324 --boundary\r\n\
325 {HEADER_BATCH_OPERATION_KEY}: {valid_key}\r\n\
326 {HEADER_BATCH_OPERATION_KIND}: delete\r\n\
327 \r\n\
328 \r\n\
329 --boundary--\r\n",
330 );
331
332 let request = Request::builder()
333 .header(CONTENT_TYPE, "multipart/form-data; boundary=boundary")
334 .body(Body::from(body))
335 .unwrap();
336
337 let batch_request = BatchOperationStream::from_request(request, &())
338 .await
339 .unwrap();
340
341 let operations: Vec<_> = batch_request.0.collect().await;
342 assert_eq!(operations.len(), 5);
343
344 assert!(matches!(&operations[0], Err(BatchError::BadRequest(_))));
346 assert!(matches!(
348 &operations[1].as_ref().unwrap(),
349 Operation::Get(g) if g.key == "valid"
350 ));
351 assert!(matches!(&operations[2], Err(BatchError::BadRequest(_))));
353 assert!(matches!(&operations[3], Err(BatchError::LimitExceeded(_))));
355 assert!(matches!(
357 &operations[4].as_ref().unwrap(),
358 Operation::Delete(d) if d.key == "valid"
359 ));
360 }
361
362 #[tokio::test]
363 async fn test_max_operations_limit_enforced() {
364 let mut body = String::new();
365 for i in 0..(MAX_OPERATIONS + 1) {
366 let key =
367 percent_encoding::percent_encode(format!("test{i}").as_bytes(), NON_ALPHANUMERIC)
368 .to_string();
369 body.push_str(&format!(
370 "--boundary\r\n\
371 {HEADER_BATCH_OPERATION_KEY}: {key}\r\n\
372 {HEADER_BATCH_OPERATION_KIND}: get\r\n\
373 \r\n\
374 \r\n"
375 ));
376 }
377 body.push_str("--boundary--\r\n");
378
379 let request = Request::builder()
380 .header(CONTENT_TYPE, "multipart/form-data; boundary=boundary")
381 .body(Body::from(body))
382 .unwrap();
383
384 let batch_request = BatchOperationStream::from_request(request, &())
385 .await
386 .unwrap();
387 let operations: Vec<_> = batch_request.0.collect().await;
388
389 assert_eq!(operations.len(), MAX_OPERATIONS + 1);
390 assert!(matches!(
391 &operations[MAX_OPERATIONS],
392 Err(BatchError::LimitExceeded(_))
393 ));
394 }
395}