1use std::collections::{HashMap, HashSet};
2use std::fmt;
3use std::io;
4use std::pin::Pin;
5use std::task::{Context, Poll};
6
7use async_stream::stream;
8use futures_util::{Stream, StreamExt as _};
9use multer::Field;
10use objectstore_types::metadata::{Compression, Metadata};
11use percent_encoding::NON_ALPHANUMERIC;
12use reqwest::header::{CONTENT_TYPE, HeaderMap, HeaderName, HeaderValue};
13use reqwest::multipart::Part;
14
15use crate::error::Error;
16use crate::put::PutBody;
17use crate::{
18 DeleteBuilder, DeleteResponse, GetBuilder, GetResponse, ObjectKey, PutBuilder, PutResponse,
19 Session, get, put,
20};
21
22const HEADER_BATCH_OPERATION_KEY: &str = "x-sn-batch-operation-key";
23const HEADER_BATCH_OPERATION_KIND: &str = "x-sn-batch-operation-kind";
24const HEADER_BATCH_OPERATION_INDEX: &str = "x-sn-batch-operation-index";
25const HEADER_BATCH_OPERATION_STATUS: &str = "x-sn-batch-operation-status";
26
27const MAX_BATCH_OPS: usize = 1000;
29
30const MAX_BATCH_PART_SIZE: u32 = 1024 * 1024; #[derive(Debug)]
40pub struct ManyBuilder {
41 session: Session,
42 operations: Vec<BatchOperation>,
43}
44
45impl Session {
46 pub fn many(&self) -> ManyBuilder {
51 ManyBuilder {
52 session: self.clone(),
53 operations: vec![],
54 }
55 }
56}
57
58#[derive(Debug)]
59#[allow(clippy::large_enum_variant)]
60enum BatchOperation {
61 Get {
62 key: ObjectKey,
63 decompress: bool,
64 accept_encoding: Vec<Compression>,
65 },
66 Insert {
67 key: Option<ObjectKey>,
68 metadata: Metadata,
69 body: PutBody,
70 },
71 Delete {
72 key: ObjectKey,
73 },
74}
75
76impl From<GetBuilder> for BatchOperation {
77 fn from(value: GetBuilder) -> Self {
78 let GetBuilder {
79 key,
80 decompress,
81 accept_encoding,
82 session: _session,
83 } = value;
84 BatchOperation::Get {
85 key,
86 decompress,
87 accept_encoding,
88 }
89 }
90}
91
92impl From<PutBuilder> for BatchOperation {
93 fn from(value: PutBuilder) -> Self {
94 let PutBuilder {
95 key,
96 metadata,
97 body,
98 session: _session,
99 } = value;
100 BatchOperation::Insert {
101 key,
102 metadata,
103 body,
104 }
105 }
106}
107
108impl From<DeleteBuilder> for BatchOperation {
109 fn from(value: DeleteBuilder) -> Self {
110 let DeleteBuilder {
111 key,
112 session: _session,
113 } = value;
114 BatchOperation::Delete { key }
115 }
116}
117
118impl BatchOperation {
119 async fn into_part(self) -> crate::Result<Part> {
120 match self {
121 BatchOperation::Get { key, .. } => {
122 let headers = operation_headers("get", Some(&key));
123 Ok(Part::text("").headers(headers))
124 }
125 BatchOperation::Insert {
126 key,
127 metadata,
128 body,
129 } => {
130 let mut headers = operation_headers("insert", key.as_deref());
131 headers.extend(metadata.to_headers("")?);
132
133 let body = put::maybe_compress(body, metadata.compression);
134 Ok(Part::stream(body).headers(headers))
135 }
136 BatchOperation::Delete { key } => {
137 let headers = operation_headers("delete", Some(&key));
138 Ok(Part::text("").headers(headers))
139 }
140 }
141 }
142}
143
144fn operation_headers(operation: &str, key: Option<&str>) -> HeaderMap {
145 let mut headers = HeaderMap::new();
146 headers.insert(
147 HeaderName::from_static(HEADER_BATCH_OPERATION_KIND),
148 HeaderValue::from_str(operation).expect("operation kind is always a valid header value"),
149 );
150 if let Some(key) = key {
151 let encoded =
152 percent_encoding::percent_encode(key.as_bytes(), NON_ALPHANUMERIC).to_string();
153 headers.insert(
154 HeaderName::from_static(HEADER_BATCH_OPERATION_KEY),
155 HeaderValue::try_from(encoded)
156 .expect("percent-encoded string is always a valid header value"),
157 );
158 }
159 headers
160}
161
162#[derive(Debug)]
164pub enum OperationResult {
165 Get(ObjectKey, Result<Option<GetResponse>, Error>),
169 Put(ObjectKey, Result<PutResponse, Error>),
171 Delete(ObjectKey, Result<DeleteResponse, Error>),
173 Error(Error),
179}
180
181enum OperationContext {
183 Get {
184 key: ObjectKey,
185 decompress: bool,
186 accept_encoding: Vec<Compression>,
187 },
188 Insert {
189 key: Option<ObjectKey>,
190 },
191 Delete {
192 key: ObjectKey,
193 },
194}
195
196impl From<&BatchOperation> for OperationContext {
197 fn from(op: &BatchOperation) -> Self {
198 match op {
199 BatchOperation::Get {
200 key,
201 decompress,
202 accept_encoding,
203 } => OperationContext::Get {
204 key: key.clone(),
205 decompress: *decompress,
206 accept_encoding: accept_encoding.clone(),
207 },
208 BatchOperation::Insert { key, .. } => OperationContext::Insert { key: key.clone() },
209 BatchOperation::Delete { key } => OperationContext::Delete { key: key.clone() },
210 }
211 }
212}
213
214impl OperationContext {
215 fn key(&self) -> Option<&str> {
216 match self {
217 OperationContext::Get { key, .. } | OperationContext::Delete { key } => Some(key),
218 OperationContext::Insert { key } => key.as_deref(),
219 }
220 }
221}
222
223impl OperationResult {
224 async fn from_field(
225 field: Field<'_>,
226 context_map: &HashMap<usize, OperationContext>,
227 ) -> (Option<usize>, Self) {
228 match Self::try_from_field(field, context_map).await {
229 Ok((index, result)) => (Some(index), result),
230 Err(e) => (None, OperationResult::Error(e)),
231 }
232 }
233
234 async fn try_from_field(
235 field: Field<'_>,
236 context_map: &HashMap<usize, OperationContext>,
237 ) -> Result<(usize, Self), Error> {
238 let mut headers = field.headers().clone();
239
240 let index: usize = headers
241 .remove(HEADER_BATCH_OPERATION_INDEX)
242 .and_then(|v| v.to_str().ok().and_then(|s| s.parse().ok()))
243 .ok_or_else(|| {
244 Error::MalformedResponse(format!(
245 "missing or invalid {HEADER_BATCH_OPERATION_INDEX} header"
246 ))
247 })?;
248
249 let status: u16 = headers
250 .remove(HEADER_BATCH_OPERATION_STATUS)
251 .and_then(|v| {
252 v.to_str().ok().and_then(|s| {
253 s.split_once(' ')
256 .map(|(code, _)| code)
257 .unwrap_or(s)
258 .parse()
259 .ok()
260 })
261 })
262 .ok_or_else(|| {
263 Error::MalformedResponse(format!(
264 "missing or invalid {HEADER_BATCH_OPERATION_STATUS} header"
265 ))
266 })?;
267
268 let ctx = context_map.get(&index).ok_or_else(|| {
269 Error::MalformedResponse(format!(
270 "response references unknown operation index {index}"
271 ))
272 })?;
273
274 let key = headers
276 .remove(HEADER_BATCH_OPERATION_KEY)
277 .and_then(|v| {
278 v.to_str()
279 .ok()
280 .and_then(|encoded| {
281 percent_encoding::percent_decode_str(encoded)
282 .decode_utf8()
283 .ok()
284 })
285 .map(|s| s.into_owned())
286 })
287 .or_else(|| ctx.key().map(str::to_owned));
288
289 let body = field.bytes().await?;
290
291 let is_error =
292 status >= 400 && !(matches!(ctx, OperationContext::Get { .. }) && status == 404);
293
294 let key = match key {
299 Some(key) => key,
300 None if is_error => "<unknown>".to_owned(),
301 None => {
302 return Err(Error::MalformedResponse(format!(
303 "missing or invalid {HEADER_BATCH_OPERATION_KEY} header"
304 )));
305 }
306 };
307 if is_error {
308 let message = String::from_utf8_lossy(&body).into_owned();
309 let error = Error::OperationFailure { status, message };
310
311 return Ok((
312 index,
313 match ctx {
314 OperationContext::Get { .. } => OperationResult::Get(key, Err(error)),
315 OperationContext::Insert { .. } => OperationResult::Put(key, Err(error)),
316 OperationContext::Delete { .. } => OperationResult::Delete(key, Err(error)),
317 },
318 ));
319 }
320
321 let result = match ctx {
322 OperationContext::Get {
323 decompress,
324 accept_encoding,
325 ..
326 } => {
327 if status == 404 {
328 OperationResult::Get(key, Ok(None))
329 } else {
330 let mut metadata = Metadata::from_headers(&headers, "")?;
331
332 let stream =
333 futures_util::stream::once(async move { Ok::<_, io::Error>(body) }).boxed();
334 let stream =
335 get::maybe_decompress(stream, &mut metadata, *decompress, accept_encoding);
336
337 OperationResult::Get(key, Ok(Some(GetResponse { metadata, stream })))
338 }
339 }
340 OperationContext::Insert { .. } => {
341 OperationResult::Put(key.clone(), Ok(PutResponse { key }))
342 }
343 OperationContext::Delete { .. } => OperationResult::Delete(key, Ok(())),
344 };
345 Ok((index, result))
346 }
347}
348
349pub struct OperationResults(Pin<Box<dyn Stream<Item = OperationResult> + Send>>);
351
352impl fmt::Debug for OperationResults {
353 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
354 f.write_str("OperationResults([Stream])")
355 }
356}
357
358impl Stream for OperationResults {
359 type Item = OperationResult;
360
361 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
362 self.0.as_mut().poll_next(cx)
363 }
364}
365
366impl OperationResults {
367 pub async fn error_for_failures(
372 mut self,
373 ) -> crate::Result<(), impl Iterator<Item = crate::Error>> {
374 let mut errs = Vec::new();
375 while let Some(res) = self.next().await {
376 match res {
377 OperationResult::Get(_, get) => {
378 if let Err(e) = get {
379 errs.push(e);
380 }
381 }
382 OperationResult::Put(_, put) => {
383 if let Err(e) = put {
384 errs.push(e);
385 }
386 }
387 OperationResult::Delete(_, delete) => {
388 if let Err(e) = delete {
389 errs.push(e);
390 }
391 }
392 OperationResult::Error(error) => errs.push(error),
393 }
394 }
395 if errs.is_empty() {
396 return Ok(());
397 }
398 Err(errs.into_iter())
399 }
400}
401
402async fn send_batch(
403 session: &Session,
404 operations: Vec<BatchOperation>,
405) -> crate::Result<Vec<OperationResult>> {
406 let context_map: HashMap<usize, OperationContext> = operations
407 .iter()
408 .enumerate()
409 .map(|(idx, op)| (idx, OperationContext::from(op)))
410 .collect();
411 let num_operations = operations.len();
412
413 let mut form = reqwest::multipart::Form::new();
414 for op in operations.into_iter() {
415 let part = op.into_part().await?;
416 form = form.part("part", part);
417 }
418
419 let request = session.batch_request()?.multipart(form);
420 let response = request.send().await?.error_for_status()?;
421
422 let boundary = response
423 .headers()
424 .get(CONTENT_TYPE)
425 .and_then(|v| v.to_str().ok())
426 .ok_or_else(|| Error::MalformedResponse("missing Content-Type header".to_owned()))
427 .map(multer::parse_boundary)??;
428
429 let byte_stream = response.bytes_stream().map(|r| r.map_err(io::Error::other));
430 let mut multipart = multer::Multipart::new(byte_stream, boundary);
431
432 let mut results = Vec::new();
433 let mut seen_indices = HashSet::new();
434 while let Some(field) = multipart.next_field().await? {
435 let (index, result) = OperationResult::from_field(field, &context_map).await;
436 if let Some(idx) = index {
437 seen_indices.insert(idx);
438 }
439 results.push(result);
440 }
441
442 for idx in 0..num_operations {
443 if !seen_indices.contains(&idx) {
444 let error = Error::MalformedResponse(format!(
445 "server did not return a response for operation at index {idx}"
446 ));
447 let result = match context_map.get(&idx) {
448 Some(ctx) => {
449 let key = ctx.key().unwrap_or("<unknown>").to_owned();
450 match ctx {
451 OperationContext::Get { .. } => OperationResult::Get(key, Err(error)),
452 OperationContext::Insert { .. } => OperationResult::Put(key, Err(error)),
453 OperationContext::Delete { .. } => OperationResult::Delete(key, Err(error)),
454 }
455 }
456 None => OperationResult::Error(error),
457 };
458 results.push(result);
459 }
460 }
461
462 Ok(results)
463}
464
465impl ManyBuilder {
466 pub fn send(self) -> OperationResults {
470 let session = self.session;
471 let mut operations = self.operations;
472
473 let inner = stream! {
474 while !operations.is_empty() {
475 let mut batch: Vec<BatchOperation> = vec![];
476
477 while !operations.is_empty() && batch.len() < MAX_BATCH_OPS {
480 let operation = operations.pop().unwrap();
481 match operation {
482 BatchOperation::Insert {
483 key,
484 metadata,
485 body: PutBody::File(file),
486 } => {
487 let meta = match file.metadata().await {
488 Ok(meta) => meta,
489 Err(err) => {
490 let key = key.unwrap_or_else(|| "<unknown>".to_owned());
491 yield OperationResult::Put(key, Err(err.into()));
492 continue;
493 }
494 };
495
496 let size = meta.len();
497 if size <= MAX_BATCH_PART_SIZE as u64 {
498 batch.push(BatchOperation::Insert {
499 key,
500 metadata,
501 body: PutBody::File(file),
502 });
503 continue;
504 }
505 let error_key = key.clone().unwrap_or_else(|| "<unknown>".to_owned());
506 let put = PutBuilder {
507 session: session.clone(),
508 metadata,
509 key,
510 body: PutBody::File(file),
511 };
512 let res = put.send().await;
513 let res = match res {
514 Ok(ref inner) => OperationResult::Put(inner.key.clone(), res),
515 Err(err) => OperationResult::Put(error_key, Err(err)),
516 };
517 yield res;
518 }
519 _ => batch.push(operation),
521 }
522 }
523
524 if batch.is_empty() {
525 continue;
526 }
527
528 let contexts: Vec<_> =
531 batch.iter().map(OperationContext::from).collect();
532
533 match send_batch(&session, batch).await {
534 Ok(results) => {
535 for result in results {
536 yield result;
537 }
538 }
539 Err(e) => {
540 let shared = std::sync::Arc::new(e);
541 for ctx in contexts {
542 let error = Error::Batch(shared.clone());
543 let key = ctx.key().unwrap_or("<unknown>").to_owned();
544 yield match ctx {
545 OperationContext::Get { .. } => OperationResult::Get(key, Err(error)),
546 OperationContext::Insert { .. } => {
547 OperationResult::Put(key, Err(error))
548 }
549 OperationContext::Delete { .. } => {
550 OperationResult::Delete(key, Err(error))
551 }
552 };
553 }
554 }
555 }
556 }
557 };
558
559 OperationResults(Box::pin(inner))
560 }
561
562 #[allow(private_bounds)]
572 pub fn push<B: Into<BatchOperation>>(mut self, builder: B) -> Self {
573 self.operations.push(builder.into());
574 self
575 }
576}