1use std::collections::{HashMap, HashSet};
2use std::fmt;
3use std::io;
4use std::pin::Pin;
5use std::sync::Arc;
6use std::task::{Context, Poll};
7
8use futures_util::{Stream, StreamExt as _};
9use multer::Field;
10use objectstore_types::metadata::{Compression, Metadata};
11use percent_encoding::NON_ALPHANUMERIC;
12use reqwest::header::{CONTENT_TYPE, HeaderMap, HeaderName, HeaderValue};
13use reqwest::multipart::Part;
14
15use crate::error::Error;
16use crate::put::PutBody;
17use crate::{
18 DeleteBuilder, DeleteResponse, GetBuilder, GetResponse, HeadBuilder, HeadResponse, ObjectKey,
19 PutBuilder, PutResponse, Session, get, put,
20};
21
22const HEADER_BATCH_OPERATION_KEY: &str = "x-sn-batch-operation-key";
23const HEADER_BATCH_OPERATION_KIND: &str = "x-sn-batch-operation-kind";
24const HEADER_BATCH_OPERATION_INDEX: &str = "x-sn-batch-operation-index";
25const HEADER_BATCH_OPERATION_STATUS: &str = "x-sn-batch-operation-status";
26
27const MAX_BATCH_OPS: usize = 1000;
29
30const MAX_BATCH_PART_SIZE: u32 = 1024 * 1024; const DEFAULT_INDIVIDUAL_CONCURRENCY: usize = 5;
37
38const DEFAULT_BATCH_CONCURRENCY: usize = 3;
42
43const MAX_BATCH_BODY_SIZE: u64 = 100 * 1024 * 1024; #[derive(Debug)]
51pub struct ManyBuilder {
52 session: Session,
53 operations: Vec<BatchOperation>,
54 max_individual_concurrency: Option<usize>,
55 max_batch_concurrency: Option<usize>,
56}
57
58impl Session {
59 pub fn many(&self) -> ManyBuilder {
64 ManyBuilder {
65 session: self.clone(),
66 operations: vec![],
67 max_individual_concurrency: None,
68 max_batch_concurrency: None,
69 }
70 }
71}
72
73#[derive(Debug)]
74#[allow(clippy::large_enum_variant)]
75enum BatchOperation {
76 Get {
77 key: ObjectKey,
78 decompress: bool,
79 accept_encoding: Vec<Compression>,
80 },
81 Insert {
82 key: Option<ObjectKey>,
83 metadata: Metadata,
84 body: PutBody,
85 },
86 Delete {
87 key: ObjectKey,
88 },
89 Head {
90 key: ObjectKey,
91 },
92}
93
94impl From<GetBuilder> for BatchOperation {
95 fn from(value: GetBuilder) -> Self {
96 let GetBuilder {
97 key,
98 decompress,
99 accept_encoding,
100 session: _session,
101 } = value;
102 BatchOperation::Get {
103 key,
104 decompress,
105 accept_encoding,
106 }
107 }
108}
109
110impl From<PutBuilder> for BatchOperation {
111 fn from(value: PutBuilder) -> Self {
112 let PutBuilder {
113 key,
114 metadata,
115 body,
116 session: _session,
117 } = value;
118 BatchOperation::Insert {
119 key,
120 metadata,
121 body,
122 }
123 }
124}
125
126impl From<DeleteBuilder> for BatchOperation {
127 fn from(value: DeleteBuilder) -> Self {
128 let DeleteBuilder {
129 key,
130 session: _session,
131 } = value;
132 BatchOperation::Delete { key }
133 }
134}
135
136impl From<HeadBuilder> for BatchOperation {
137 fn from(value: HeadBuilder) -> Self {
138 let HeadBuilder {
139 key,
140 session: _session,
141 } = value;
142 BatchOperation::Head { key }
143 }
144}
145
146impl BatchOperation {
147 async fn into_part(self) -> crate::Result<Part> {
148 match self {
149 BatchOperation::Get { key, .. } => {
150 let headers = operation_headers("get", Some(&key));
151 Ok(Part::text("").headers(headers))
152 }
153 BatchOperation::Insert {
154 key,
155 metadata,
156 body,
157 } => {
158 let mut headers = operation_headers("insert", key.as_deref());
159 headers.extend(metadata.to_headers("")?);
160
161 let body = put::maybe_compress(body, metadata.compression).await?;
162 Ok(Part::stream(body).headers(headers))
163 }
164 BatchOperation::Delete { key } => {
165 let headers = operation_headers("delete", Some(&key));
166 Ok(Part::text("").headers(headers))
167 }
168 BatchOperation::Head { key } => {
169 let headers = operation_headers("head", Some(&key));
170 Ok(Part::text("").headers(headers))
171 }
172 }
173 }
174}
175
176fn operation_headers(operation: &str, key: Option<&str>) -> HeaderMap {
177 let mut headers = HeaderMap::new();
178 headers.insert(
179 HeaderName::from_static(HEADER_BATCH_OPERATION_KIND),
180 HeaderValue::from_str(operation).expect("operation kind is always a valid header value"),
181 );
182 if let Some(key) = key {
183 let encoded =
184 percent_encoding::percent_encode(key.as_bytes(), NON_ALPHANUMERIC).to_string();
185 headers.insert(
186 HeaderName::from_static(HEADER_BATCH_OPERATION_KEY),
187 HeaderValue::try_from(encoded)
188 .expect("percent-encoded string is always a valid header value"),
189 );
190 }
191 headers
192}
193
194#[derive(Debug)]
196pub enum OperationResult {
197 Get(ObjectKey, Result<Option<GetResponse>, Error>),
201 Put(ObjectKey, Result<PutResponse, Error>),
203 Delete(ObjectKey, Result<DeleteResponse, Error>),
205 Head(ObjectKey, Result<HeadResponse, Error>),
209 Error(Error),
215}
216
217enum OperationContext {
219 Get {
220 key: ObjectKey,
221 decompress: bool,
222 accept_encoding: Vec<Compression>,
223 },
224 Insert {
225 key: Option<ObjectKey>,
226 },
227 Delete {
228 key: ObjectKey,
229 },
230 Head {
231 key: ObjectKey,
232 },
233}
234
235impl From<&BatchOperation> for OperationContext {
236 fn from(op: &BatchOperation) -> Self {
237 match op {
238 BatchOperation::Get {
239 key,
240 decompress,
241 accept_encoding,
242 } => OperationContext::Get {
243 key: key.clone(),
244 decompress: *decompress,
245 accept_encoding: accept_encoding.clone(),
246 },
247 BatchOperation::Insert { key, .. } => OperationContext::Insert { key: key.clone() },
248 BatchOperation::Delete { key } => OperationContext::Delete { key: key.clone() },
249 BatchOperation::Head { key } => OperationContext::Head { key: key.clone() },
250 }
251 }
252}
253
254impl OperationContext {
255 fn key(&self) -> Option<&str> {
256 match self {
257 OperationContext::Get { key, .. }
258 | OperationContext::Delete { key }
259 | OperationContext::Head { key } => Some(key),
260 OperationContext::Insert { key } => key.as_deref(),
261 }
262 }
263}
264
265enum Classified {
267 Batchable(BatchOperation, u64),
269 Individual(BatchOperation),
271 Failed(OperationResult),
273}
274
275fn error_result(ctx: OperationContext, error: Error) -> OperationResult {
277 let key = ctx.key().unwrap_or("<unknown>").to_owned();
278 match ctx {
279 OperationContext::Get { .. } => OperationResult::Get(key, Err(error)),
280 OperationContext::Insert { .. } => OperationResult::Put(key, Err(error)),
281 OperationContext::Delete { .. } => OperationResult::Delete(key, Err(error)),
282 OperationContext::Head { .. } => OperationResult::Head(key, Err(error)),
283 }
284}
285
286impl OperationResult {
287 async fn from_field(
288 field: Field<'_>,
289 context_map: &HashMap<usize, OperationContext>,
290 ) -> (Option<usize>, Self) {
291 match Self::try_from_field(field, context_map).await {
292 Ok((index, result)) => (Some(index), result),
293 Err(e) => (None, OperationResult::Error(e)),
294 }
295 }
296
297 async fn try_from_field(
298 field: Field<'_>,
299 context_map: &HashMap<usize, OperationContext>,
300 ) -> Result<(usize, Self), Error> {
301 let mut headers = field.headers().clone();
302
303 let index: usize = headers
304 .remove(HEADER_BATCH_OPERATION_INDEX)
305 .and_then(|v| v.to_str().ok().and_then(|s| s.parse().ok()))
306 .ok_or_else(|| {
307 Error::MalformedResponse(format!(
308 "missing or invalid {HEADER_BATCH_OPERATION_INDEX} header"
309 ))
310 })?;
311
312 let status: u16 = headers
313 .remove(HEADER_BATCH_OPERATION_STATUS)
314 .and_then(|v| {
315 v.to_str().ok().and_then(|s| {
316 s.split_once(' ')
319 .map(|(code, _)| code)
320 .unwrap_or(s)
321 .parse()
322 .ok()
323 })
324 })
325 .ok_or_else(|| {
326 Error::MalformedResponse(format!(
327 "missing or invalid {HEADER_BATCH_OPERATION_STATUS} header"
328 ))
329 })?;
330
331 let ctx = context_map.get(&index).ok_or_else(|| {
332 Error::MalformedResponse(format!(
333 "response references unknown operation index {index}"
334 ))
335 })?;
336
337 let key = headers
339 .remove(HEADER_BATCH_OPERATION_KEY)
340 .and_then(|v| {
341 v.to_str()
342 .ok()
343 .and_then(|encoded| {
344 percent_encoding::percent_decode_str(encoded)
345 .decode_utf8()
346 .ok()
347 })
348 .map(|s| s.into_owned())
349 })
350 .or_else(|| ctx.key().map(str::to_owned));
351
352 let body = field.bytes().await?;
353
354 let is_error = status >= 400
355 && !(matches!(
356 ctx,
357 OperationContext::Get { .. } | OperationContext::Head { .. }
358 ) && status == 404);
359
360 let key = match key {
365 Some(key) => key,
366 None if is_error => "<unknown>".to_owned(),
367 None => {
368 return Err(Error::MalformedResponse(format!(
369 "missing or invalid {HEADER_BATCH_OPERATION_KEY} header"
370 )));
371 }
372 };
373 if is_error {
374 let message = String::from_utf8_lossy(&body).into_owned();
375 let error = Error::OperationFailure { status, message };
376
377 return Ok((
378 index,
379 match ctx {
380 OperationContext::Get { .. } => OperationResult::Get(key, Err(error)),
381 OperationContext::Insert { .. } => OperationResult::Put(key, Err(error)),
382 OperationContext::Delete { .. } => OperationResult::Delete(key, Err(error)),
383 OperationContext::Head { .. } => OperationResult::Head(key, Err(error)),
384 },
385 ));
386 }
387
388 let result = match ctx {
389 OperationContext::Get {
390 decompress,
391 accept_encoding,
392 ..
393 } => {
394 if status == 404 {
395 OperationResult::Get(key, Ok(None))
396 } else {
397 let mut metadata = Metadata::from_headers(&headers, "")?;
398
399 let stream =
400 futures_util::stream::once(async move { Ok::<_, io::Error>(body) }).boxed();
401 let stream =
402 get::maybe_decompress(stream, &mut metadata, *decompress, accept_encoding);
403
404 OperationResult::Get(key, Ok(Some(GetResponse { metadata, stream })))
405 }
406 }
407 OperationContext::Insert { .. } => {
408 OperationResult::Put(key.clone(), Ok(PutResponse { key }))
409 }
410 OperationContext::Delete { .. } => OperationResult::Delete(key, Ok(())),
411 OperationContext::Head { .. } => {
412 if status == 404 {
413 OperationResult::Head(key, Ok(None))
414 } else {
415 let metadata = Metadata::from_headers(&headers, "")?;
416 OperationResult::Head(key, Ok(Some(metadata)))
417 }
418 }
419 };
420 Ok((index, result))
421 }
422}
423
424pub struct OperationResults(Pin<Box<dyn Stream<Item = OperationResult> + Send>>);
426
427impl fmt::Debug for OperationResults {
428 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
429 f.write_str("OperationResults([Stream])")
430 }
431}
432
433impl Stream for OperationResults {
434 type Item = OperationResult;
435
436 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
437 self.0.as_mut().poll_next(cx)
438 }
439}
440
441impl OperationResults {
442 pub async fn error_for_failures(
447 mut self,
448 ) -> crate::Result<(), impl Iterator<Item = crate::Error>> {
449 let mut errs = Vec::new();
450 while let Some(res) = self.next().await {
451 match res {
452 OperationResult::Get(_, get) => {
453 if let Err(e) = get {
454 errs.push(e);
455 }
456 }
457 OperationResult::Put(_, put) => {
458 if let Err(e) = put {
459 errs.push(e);
460 }
461 }
462 OperationResult::Delete(_, delete) => {
463 if let Err(e) = delete {
464 errs.push(e);
465 }
466 }
467 OperationResult::Head(_, head) => {
468 if let Err(e) = head {
469 errs.push(e);
470 }
471 }
472 OperationResult::Error(error) => errs.push(error),
473 }
474 }
475 if errs.is_empty() {
476 return Ok(());
477 }
478 Err(errs.into_iter())
479 }
480}
481
482async fn send_batch(
483 session: &Session,
484 operations: Vec<BatchOperation>,
485) -> crate::Result<Vec<OperationResult>> {
486 let mut context_map: HashMap<usize, OperationContext> = operations
487 .iter()
488 .enumerate()
489 .map(|(idx, op)| (idx, OperationContext::from(op)))
490 .collect();
491 let num_operations = operations.len();
492
493 let mut form = reqwest::multipart::Form::new();
494 for op in operations.into_iter() {
495 let part = op.into_part().await?;
496 form = form.part("part", part);
497 }
498
499 let request = session.batch_request()?.multipart(form);
500 let response = request.send().await?.error_for_status()?;
501
502 let boundary = response
503 .headers()
504 .get(CONTENT_TYPE)
505 .and_then(|v| v.to_str().ok())
506 .ok_or_else(|| Error::MalformedResponse("missing Content-Type header".to_owned()))
507 .map(multer::parse_boundary)??;
508
509 let byte_stream = response.bytes_stream().map(|r| r.map_err(io::Error::other));
510 let mut multipart = multer::Multipart::new(byte_stream, boundary);
511
512 let mut results = Vec::new();
513 let mut seen_indices = HashSet::new();
514 while let Some(field) = multipart.next_field().await? {
515 let (index, result) = OperationResult::from_field(field, &context_map).await;
516 if let Some(idx) = index {
517 seen_indices.insert(idx);
518 }
519 results.push(result);
520 }
521
522 for idx in 0..num_operations {
523 if !seen_indices.contains(&idx) {
524 let error = Error::MalformedResponse(format!(
525 "server did not return a response for operation at index {idx}"
526 ));
527 let result = match context_map.remove(&idx) {
528 Some(ctx) => error_result(ctx, error),
529 None => OperationResult::Error(error),
530 };
531 results.push(result);
532 }
533 }
534
535 Ok(results)
536}
537
538fn classify_fail(key: Option<ObjectKey>, error: Error) -> Classified {
539 Classified::Failed(OperationResult::Put(
540 key.unwrap_or_else(|| "<unknown>".to_owned()),
541 Err(error),
542 ))
543}
544
545async fn classify(op: BatchOperation) -> Classified {
550 match op {
551 BatchOperation::Insert {
552 key,
553 metadata,
554 body,
555 } => {
556 let size = match &body {
557 PutBody::Buffer(bytes) => Some(bytes.len() as u64),
558 PutBody::File(file) => match file.metadata().await {
559 Ok(meta) => Some(meta.len()),
560 Err(err) => return classify_fail(key, err.into()),
561 },
562 PutBody::Path(path) => match tokio::fs::metadata(path).await {
563 Ok(meta) => Some(meta.len()),
564 Err(err) => return classify_fail(key, err.into()),
565 },
566 PutBody::Stream(_) => None,
568 };
569
570 let op = BatchOperation::Insert {
571 key,
572 metadata,
573 body,
574 };
575
576 match size {
577 Some(s) if s <= MAX_BATCH_PART_SIZE as u64 => Classified::Batchable(op, s),
578 _ => Classified::Individual(op),
579 }
580 }
581 other => Classified::Batchable(other, 0),
582 }
583}
584
585async fn partition(
589 operations: Vec<BatchOperation>,
590) -> (
591 Vec<(BatchOperation, u64)>,
592 Vec<BatchOperation>,
593 Vec<OperationResult>,
594) {
595 let classified = futures_util::future::join_all(operations.into_iter().map(classify)).await;
596 let mut batchable = Vec::new();
597 let mut individual = Vec::new();
598 let mut failed = Vec::new();
599 for item in classified {
600 match item {
601 Classified::Batchable(op, size) => batchable.push((op, size)),
602 Classified::Individual(op) => individual.push(op),
603 Classified::Failed(result) => failed.push(result),
604 }
605 }
606 (batchable, individual, failed)
607}
608
609async fn execute_individual(op: BatchOperation, session: &Session) -> OperationResult {
611 match op {
612 BatchOperation::Get {
613 key,
614 decompress,
615 accept_encoding,
616 } => {
617 let get = GetBuilder {
618 session: session.clone(),
619 key: key.clone(),
620 decompress,
621 accept_encoding,
622 };
623 OperationResult::Get(key, get.send().await)
624 }
625 BatchOperation::Insert {
626 key,
627 metadata,
628 body,
629 } => {
630 let error_key = key.clone().unwrap_or_else(|| "<unknown>".to_owned());
631 let put = PutBuilder {
632 session: session.clone(),
633 metadata,
634 key,
635 body,
636 };
637 match put.send().await {
638 Ok(response) => OperationResult::Put(response.key.clone(), Ok(response)),
639 Err(err) => OperationResult::Put(error_key, Err(err)),
640 }
641 }
642 BatchOperation::Delete { key } => {
643 let delete = DeleteBuilder {
644 session: session.clone(),
645 key: key.clone(),
646 };
647 OperationResult::Delete(key, delete.send().await)
648 }
649 BatchOperation::Head { key } => {
650 let head = HeadBuilder {
651 session: session.clone(),
652 key: key.clone(),
653 };
654 OperationResult::Head(key, head.send().await)
655 }
656 }
657}
658
659async fn execute_batch(operations: Vec<BatchOperation>, session: &Session) -> Vec<OperationResult> {
663 let contexts: Vec<_> = operations.iter().map(OperationContext::from).collect();
664 match send_batch(session, operations).await {
665 Ok(results) => results,
666 Err(e) => {
667 let shared = Arc::new(e);
668 contexts
669 .into_iter()
670 .map(|ctx| error_result(ctx, Error::Batch(shared.clone())))
671 .collect()
672 }
673 }
674}
675
676fn iter_batches(ops: Vec<(BatchOperation, u64)>) -> impl Iterator<Item = Vec<BatchOperation>> {
681 let mut remaining = ops.into_iter().peekable();
682
683 std::iter::from_fn(move || {
684 remaining.peek()?;
685 let mut batch_size = 0;
686 let mut batch = Vec::new();
687
688 while let Some((_, op_size)) = remaining.peek() {
689 if batch.len() >= MAX_BATCH_OPS
690 || (!batch.is_empty() && batch_size + op_size > MAX_BATCH_BODY_SIZE)
691 {
692 break;
693 }
694
695 let (op, op_size) = remaining.next().expect("peeked above");
696 batch_size += op_size;
697 batch.push(op);
698 }
699
700 Some(batch)
701 })
702}
703
704impl ManyBuilder {
705 pub async fn send(self) -> OperationResults {
709 let session = self.session;
710 let individual_concurrency = self
711 .max_individual_concurrency
712 .unwrap_or(DEFAULT_INDIVIDUAL_CONCURRENCY)
713 .max(1);
714 let batch_concurrency = self
715 .max_batch_concurrency
716 .unwrap_or(DEFAULT_BATCH_CONCURRENCY)
717 .max(1);
718
719 let (batchable, individual, failed) = partition(self.operations).await;
721
722 let individual_results = futures_util::stream::iter(individual)
724 .map({
725 let session = session.clone();
726 move |op| {
727 let session = session.clone();
728 async move { execute_individual(op, &session).await }
729 }
730 })
731 .buffer_unordered(individual_concurrency);
732
733 let batch_results = futures_util::stream::iter(iter_batches(batchable))
735 .map(move |chunk| {
736 let session = session.clone();
737 async move { execute_batch(chunk, &session).await }
738 })
739 .buffer_unordered(batch_concurrency)
740 .flat_map(futures_util::stream::iter);
741
742 let results = futures_util::stream::iter(failed)
743 .chain(individual_results)
744 .chain(batch_results);
745
746 OperationResults(results.boxed())
747 }
748
749 pub fn max_individual_concurrency(mut self, concurrency: usize) -> Self {
755 self.max_individual_concurrency = Some(concurrency);
756 self
757 }
758
759 pub fn max_batch_concurrency(mut self, concurrency: usize) -> Self {
765 self.max_batch_concurrency = Some(concurrency);
766 self
767 }
768
769 #[allow(private_bounds)]
779 pub fn push<B: Into<BatchOperation>>(mut self, builder: B) -> Self {
780 self.operations.push(builder.into());
781 self
782 }
783}
784
785#[cfg(test)]
786mod tests {
787 use super::*;
788
789 fn op(size: u64) -> (BatchOperation, u64) {
791 (
792 BatchOperation::Delete {
793 key: "k".to_owned(),
794 },
795 size,
796 )
797 }
798
799 fn batch_sizes(batches: &[Vec<BatchOperation>]) -> Vec<usize> {
800 batches.iter().map(Vec::len).collect()
801 }
802
803 fn batches(ops: Vec<(BatchOperation, u64)>) -> Vec<Vec<BatchOperation>> {
804 iter_batches(ops).collect()
805 }
806
807 #[test]
808 fn iter_batches_empty() {
809 assert!(batches(vec![]).is_empty());
810 }
811
812 #[test]
813 fn iter_batches_single_batch_count_limit() {
814 let ops: Vec<_> = (0..1000).map(|_| op(1)).collect();
816 assert_eq!(batch_sizes(&batches(ops)), vec![1000]);
817 }
818
819 #[test]
820 fn iter_batches_splits_on_count_limit() {
821 let ops: Vec<_> = (0..1001).map(|_| op(1)).collect();
823 assert_eq!(batch_sizes(&batches(ops)), vec![1000, 1]);
824 }
825
826 #[test]
827 fn iter_batches_exactly_at_size_limit() {
828 let ops: Vec<_> = (0..100).map(|_| op(1024 * 1024)).collect();
830 assert_eq!(batch_sizes(&batches(ops)), vec![100]);
831 }
832
833 #[test]
834 fn iter_batches_splits_on_size_limit() {
835 let ops: Vec<_> = (0..101).map(|_| op(1024 * 1024)).collect();
837 assert_eq!(batch_sizes(&batches(ops)), vec![100, 1]);
838 }
839
840 #[test]
841 fn iter_batches_size_limit_hits_before_count_limit() {
842 let op_size = 600 * 1024;
844 let ops: Vec<_> = (0..200).map(|_| op(op_size)).collect();
845 let result = batches(ops);
846 let per_batch = (MAX_BATCH_BODY_SIZE / op_size) as usize;
848 assert!(result.len() > 1, "expected multiple batches");
849 for batch in &result[..result.len() - 1] {
850 assert_eq!(batch.len(), per_batch);
851 }
852 }
853}