objectstore_server/extractors/
batch.rs1use std::fmt::Debug;
2
3use axum::extract::{
4 FromRequest, Multipart, Request,
5 multipart::{Field, MultipartError, MultipartRejection},
6};
7use base64::Engine;
8use base64::prelude::BASE64_STANDARD;
9use bytes::Bytes;
10use futures::{StreamExt, stream::BoxStream};
11use objectstore_service::id::ObjectKey;
12use objectstore_types::Metadata;
13use thiserror::Error;
14
15use crate::batch::{HEADER_BATCH_OPERATION_KEY, HEADER_BATCH_OPERATION_KIND};
16
17#[derive(Debug, Error)]
19pub enum BatchError {
20 #[error("bad request: {0}")]
22 BadRequest(String),
23
24 #[error("multipart error: {0}")]
26 Multipart(#[from] MultipartError),
27
28 #[error("metadata error: {0}")]
30 Metadata(#[from] objectstore_types::Error),
31
32 #[error("batch limit exceeded: {0}")]
34 LimitExceeded(String),
35
36 #[error("rate limited")]
38 RateLimited,
39
40 #[error("response part serialization error: {context}")]
42 ResponseSerialization {
43 context: String,
45 #[source]
47 cause: Box<dyn std::error::Error + Send + Sync>,
48 },
49}
50
51#[derive(Debug)]
52pub struct GetOperation {
53 pub key: ObjectKey,
54}
55
56#[derive(Debug)]
57pub struct InsertOperation {
58 pub key: ObjectKey,
59 pub metadata: Metadata,
60 pub payload: Bytes,
61}
62
63#[derive(Debug)]
64pub struct DeleteOperation {
65 pub key: ObjectKey,
66}
67
68#[derive(Debug)]
69pub enum Operation {
70 Get(GetOperation),
71 Insert(InsertOperation),
72 Delete(DeleteOperation),
73}
74
75impl Operation {
76 async fn try_from_field(field: Field<'_>) -> Result<Self, BatchError> {
77 let kind = field
78 .headers()
79 .get(HEADER_BATCH_OPERATION_KIND)
80 .ok_or_else(|| {
81 BatchError::BadRequest(format!("missing {HEADER_BATCH_OPERATION_KIND} header"))
82 })?;
83 let kind = kind
84 .to_str()
85 .map_err(|_| {
86 BatchError::BadRequest(format!(
87 "unable to convert {HEADER_BATCH_OPERATION_KIND} header value to string"
88 ))
89 })?
90 .to_lowercase();
91
92 let key_header = field
93 .headers()
94 .get(HEADER_BATCH_OPERATION_KEY)
95 .ok_or_else(|| {
96 BatchError::BadRequest(format!("missing {HEADER_BATCH_OPERATION_KEY} header"))
97 })?
98 .to_str()
99 .map_err(|_| {
100 BatchError::BadRequest(format!(
101 "unable to convert {HEADER_BATCH_OPERATION_KEY} header value to string"
102 ))
103 })?;
104 let key_bytes = BASE64_STANDARD.decode(key_header).map_err(|_| {
105 BatchError::BadRequest(format!(
106 "unable to base64 decode {HEADER_BATCH_OPERATION_KEY} header value"
107 ))
108 })?;
109 let key = String::from_utf8(key_bytes).map_err(|_| {
110 BatchError::BadRequest(format!(
111 "{HEADER_BATCH_OPERATION_KEY} header value is not valid UTF-8"
112 ))
113 })?;
114
115 let operation = match kind.as_str() {
116 "get" => Operation::Get(GetOperation { key }),
117 "insert" => {
118 let metadata = Metadata::from_headers(field.headers(), "")?;
119 let payload = field.bytes().await?;
120 if payload.len() > MAX_FIELD_SIZE {
121 return Err(BatchError::LimitExceeded(format!(
122 "individual request in batch exceeds body size limit of {MAX_FIELD_SIZE} bytes"
123 )));
124 }
125 Operation::Insert(InsertOperation {
126 key,
127 metadata,
128 payload,
129 })
130 }
131 "delete" => Operation::Delete(DeleteOperation { key }),
132 _ => {
133 return Err(BatchError::BadRequest(format!(
134 "invalid operation kind: {kind}"
135 )));
136 }
137 };
138 Ok(operation)
139 }
140
141 pub fn key(&self) -> &ObjectKey {
142 match self {
143 Operation::Get(op) => &op.key,
144 Operation::Insert(op) => &op.key,
145 Operation::Delete(op) => &op.key,
146 }
147 }
148}
149
150pub struct BatchOperationStream(pub BoxStream<'static, Result<Operation, BatchError>>);
151
152impl Debug for BatchOperationStream {
153 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
154 f.debug_struct("BatchOperationStream").finish()
155 }
156}
157
158const MAX_FIELD_SIZE: usize = 1024 * 1024; const MAX_OPERATIONS: usize = 1000;
160
161impl<S> FromRequest<S> for BatchOperationStream
162where
163 S: Send + Sync,
164{
165 type Rejection = MultipartRejection;
166
167 async fn from_request(request: Request, state: &S) -> Result<Self, Self::Rejection> {
168 let mut multipart = Multipart::from_request(request, state).await?;
169
170 let requests = async_stream::try_stream! {
171 let mut count = 0;
172 while let Some(field) = multipart.next_field().await? {
173 if count >= MAX_OPERATIONS {
174 Err(BatchError::LimitExceeded(format!(
175 "exceeded {MAX_OPERATIONS} operations per batch request"
176 )))?;
177 }
178 count += 1;
179 yield Operation::try_from_field(field).await?;
180 }
181 }
182 .boxed();
183
184 Ok(Self(requests))
185 }
186}
187
188#[cfg(test)]
189mod tests {
190 use std::time::Duration;
191
192 use super::*;
193 use axum::body::Body;
194 use axum::http::{Request, header::CONTENT_TYPE};
195 use futures::StreamExt;
196 use objectstore_types::{ExpirationPolicy, HEADER_EXPIRATION};
197
198 #[tokio::test]
199 async fn test_valid_request_works() {
200 let insert1_data = b"first blob data";
201 let insert2_data = b"second blob data";
202 let expiration = ExpirationPolicy::TimeToLive(Duration::from_hours(1));
203 let body = format!(
204 "--boundary\r\n\
205 {HEADER_BATCH_OPERATION_KEY}: {key0}\r\n\
206 {HEADER_BATCH_OPERATION_KIND}: get\r\n\
207 \r\n\
208 \r\n\
209 --boundary\r\n\
210 {HEADER_BATCH_OPERATION_KEY}: {key1}\r\n\
211 {HEADER_BATCH_OPERATION_KIND}: insert\r\n\
212 Content-Type: application/octet-stream\r\n\
213 \r\n\
214 {insert1}\r\n\
215 --boundary\r\n\
216 {HEADER_BATCH_OPERATION_KEY}: {key2}\r\n\
217 {HEADER_BATCH_OPERATION_KIND}: insert\r\n\
218 {HEADER_EXPIRATION}: {expiration}\r\n\
219 Content-Type: text/plain\r\n\
220 \r\n\
221 {insert2}\r\n\
222 --boundary\r\n\
223 {HEADER_BATCH_OPERATION_KEY}: {key3}\r\n\
224 {HEADER_BATCH_OPERATION_KIND}: delete\r\n\
225 \r\n\
226 \r\n\
227 --boundary--\r\n",
228 key0 = BASE64_STANDARD.encode("test0"),
229 key1 = BASE64_STANDARD.encode("test1"),
230 key2 = BASE64_STANDARD.encode("test2"),
231 key3 = BASE64_STANDARD.encode("test3"),
232 insert1 = String::from_utf8_lossy(insert1_data),
233 insert2 = String::from_utf8_lossy(insert2_data),
234 );
235
236 let request = Request::builder()
237 .header(CONTENT_TYPE, "multipart/form-data; boundary=boundary")
238 .body(Body::from(body))
239 .unwrap();
240
241 let batch_request = BatchOperationStream::from_request(request, &())
242 .await
243 .unwrap();
244
245 let operations: Vec<_> = batch_request.0.collect().await;
246 assert_eq!(operations.len(), 4);
247
248 let Operation::Get(get_op) = &operations[0].as_ref().unwrap() else {
249 panic!("expected get operation");
250 };
251 assert_eq!(get_op.key, "test0");
252
253 let Operation::Insert(insert_op1) = &operations[1].as_ref().unwrap() else {
254 panic!("expected insert operation");
255 };
256 assert_eq!(insert_op1.key, "test1");
257 assert_eq!(insert_op1.metadata.content_type, "application/octet-stream");
258 assert_eq!(insert_op1.payload.as_ref(), insert1_data);
259
260 let Operation::Insert(insert_op2) = &operations[2].as_ref().unwrap() else {
261 panic!("expected insert operation");
262 };
263 assert_eq!(insert_op2.key, "test2");
264 assert_eq!(insert_op2.metadata.content_type, "text/plain");
265 assert_eq!(insert_op2.metadata.expiration_policy, expiration);
266 assert_eq!(insert_op2.payload.as_ref(), insert2_data);
267
268 let Operation::Delete(delete_op) = &operations[3].as_ref().unwrap() else {
269 panic!("expected delete operation");
270 };
271 assert_eq!(delete_op.key, "test3");
272 }
273
274 #[tokio::test]
275 async fn test_max_operations_limit_enforced() {
276 let mut body = String::new();
277 for i in 0..(MAX_OPERATIONS + 1) {
278 let key = BASE64_STANDARD.encode(format!("test{i}"));
279 body.push_str(&format!(
280 "--boundary\r\n\
281 {HEADER_BATCH_OPERATION_KEY}: {key}\r\n\
282 {HEADER_BATCH_OPERATION_KIND}: get\r\n\
283 \r\n\
284 \r\n"
285 ));
286 }
287 body.push_str("--boundary--\r\n");
288
289 let request = Request::builder()
290 .header(CONTENT_TYPE, "multipart/form-data; boundary=boundary")
291 .body(Body::from(body))
292 .unwrap();
293
294 let batch_request = BatchOperationStream::from_request(request, &())
295 .await
296 .unwrap();
297 let operations: Vec<_> = batch_request.0.collect().await;
298
299 assert_eq!(operations.len(), MAX_OPERATIONS + 1);
300 matches!(
301 &operations[MAX_OPERATIONS],
302 Err(BatchError::LimitExceeded(_))
303 );
304 }
305
306 #[tokio::test]
307 async fn test_operation_body_size_limit_enforced() {
308 let large_payload = "x".repeat(MAX_FIELD_SIZE + 1);
309 let key = BASE64_STANDARD.encode("test");
310 let body = format!(
311 "--boundary\r\n\
312 {HEADER_BATCH_OPERATION_KEY}: {key}\r\n\
313 {HEADER_BATCH_OPERATION_KIND}: insert\r\n\
314 Content-Type: application/octet-stream\r\n\
315 \r\n\
316 {large_payload}\r\n\
317 --boundary--\r\n",
318 );
319
320 let request = Request::builder()
321 .header(CONTENT_TYPE, "multipart/form-data; boundary=boundary")
322 .body(Body::from(body))
323 .unwrap();
324
325 let batch_request = BatchOperationStream::from_request(request, &())
326 .await
327 .unwrap();
328 let operations: Vec<_> = batch_request.0.collect().await;
329
330 assert_eq!(operations.len(), 1);
331 assert!(matches!(&operations[0], Err(BatchError::LimitExceeded(_))));
332 }
333}