1use std::collections::{HashMap, HashSet};
2use std::fmt;
3use std::io;
4use std::pin::Pin;
5use std::sync::Arc;
6use std::task::{Context, Poll};
7
8use futures_util::{Stream, StreamExt as _};
9use multer::Field;
10use objectstore_types::metadata::{Compression, Metadata};
11use percent_encoding::NON_ALPHANUMERIC;
12use reqwest::header::{CONTENT_TYPE, HeaderMap, HeaderName, HeaderValue};
13use reqwest::multipart::Part;
14
15use crate::error::Error;
16use crate::put::PutBody;
17use crate::{
18 DeleteBuilder, DeleteResponse, GetBuilder, GetResponse, ObjectKey, PutBuilder, PutResponse,
19 Session, get, put,
20};
21
22const HEADER_BATCH_OPERATION_KEY: &str = "x-sn-batch-operation-key";
23const HEADER_BATCH_OPERATION_KIND: &str = "x-sn-batch-operation-kind";
24const HEADER_BATCH_OPERATION_INDEX: &str = "x-sn-batch-operation-index";
25const HEADER_BATCH_OPERATION_STATUS: &str = "x-sn-batch-operation-status";
26
27const MAX_BATCH_OPS: usize = 1000;
29
30const MAX_BATCH_PART_SIZE: u32 = 1024 * 1024; const DEFAULT_INDIVIDUAL_CONCURRENCY: usize = 5;
37
38const DEFAULT_BATCH_CONCURRENCY: usize = 3;
42
43const MAX_BATCH_BODY_SIZE: u64 = 100 * 1024 * 1024; #[derive(Debug)]
51pub struct ManyBuilder {
52 session: Session,
53 operations: Vec<BatchOperation>,
54 max_individual_concurrency: Option<usize>,
55 max_batch_concurrency: Option<usize>,
56}
57
58impl Session {
59 pub fn many(&self) -> ManyBuilder {
64 ManyBuilder {
65 session: self.clone(),
66 operations: vec![],
67 max_individual_concurrency: None,
68 max_batch_concurrency: None,
69 }
70 }
71}
72
73#[derive(Debug)]
74#[allow(clippy::large_enum_variant)]
75enum BatchOperation {
76 Get {
77 key: ObjectKey,
78 decompress: bool,
79 accept_encoding: Vec<Compression>,
80 },
81 Insert {
82 key: Option<ObjectKey>,
83 metadata: Metadata,
84 body: PutBody,
85 },
86 Delete {
87 key: ObjectKey,
88 },
89}
90
91impl From<GetBuilder> for BatchOperation {
92 fn from(value: GetBuilder) -> Self {
93 let GetBuilder {
94 key,
95 decompress,
96 accept_encoding,
97 session: _session,
98 } = value;
99 BatchOperation::Get {
100 key,
101 decompress,
102 accept_encoding,
103 }
104 }
105}
106
107impl From<PutBuilder> for BatchOperation {
108 fn from(value: PutBuilder) -> Self {
109 let PutBuilder {
110 key,
111 metadata,
112 body,
113 session: _session,
114 } = value;
115 BatchOperation::Insert {
116 key,
117 metadata,
118 body,
119 }
120 }
121}
122
123impl From<DeleteBuilder> for BatchOperation {
124 fn from(value: DeleteBuilder) -> Self {
125 let DeleteBuilder {
126 key,
127 session: _session,
128 } = value;
129 BatchOperation::Delete { key }
130 }
131}
132
133impl BatchOperation {
134 async fn into_part(self) -> crate::Result<Part> {
135 match self {
136 BatchOperation::Get { key, .. } => {
137 let headers = operation_headers("get", Some(&key));
138 Ok(Part::text("").headers(headers))
139 }
140 BatchOperation::Insert {
141 key,
142 metadata,
143 body,
144 } => {
145 let mut headers = operation_headers("insert", key.as_deref());
146 headers.extend(metadata.to_headers("")?);
147
148 let body = put::maybe_compress(body, metadata.compression).await?;
149 Ok(Part::stream(body).headers(headers))
150 }
151 BatchOperation::Delete { key } => {
152 let headers = operation_headers("delete", Some(&key));
153 Ok(Part::text("").headers(headers))
154 }
155 }
156 }
157}
158
159fn operation_headers(operation: &str, key: Option<&str>) -> HeaderMap {
160 let mut headers = HeaderMap::new();
161 headers.insert(
162 HeaderName::from_static(HEADER_BATCH_OPERATION_KIND),
163 HeaderValue::from_str(operation).expect("operation kind is always a valid header value"),
164 );
165 if let Some(key) = key {
166 let encoded =
167 percent_encoding::percent_encode(key.as_bytes(), NON_ALPHANUMERIC).to_string();
168 headers.insert(
169 HeaderName::from_static(HEADER_BATCH_OPERATION_KEY),
170 HeaderValue::try_from(encoded)
171 .expect("percent-encoded string is always a valid header value"),
172 );
173 }
174 headers
175}
176
177#[derive(Debug)]
179pub enum OperationResult {
180 Get(ObjectKey, Result<Option<GetResponse>, Error>),
184 Put(ObjectKey, Result<PutResponse, Error>),
186 Delete(ObjectKey, Result<DeleteResponse, Error>),
188 Error(Error),
194}
195
196enum OperationContext {
198 Get {
199 key: ObjectKey,
200 decompress: bool,
201 accept_encoding: Vec<Compression>,
202 },
203 Insert {
204 key: Option<ObjectKey>,
205 },
206 Delete {
207 key: ObjectKey,
208 },
209}
210
211impl From<&BatchOperation> for OperationContext {
212 fn from(op: &BatchOperation) -> Self {
213 match op {
214 BatchOperation::Get {
215 key,
216 decompress,
217 accept_encoding,
218 } => OperationContext::Get {
219 key: key.clone(),
220 decompress: *decompress,
221 accept_encoding: accept_encoding.clone(),
222 },
223 BatchOperation::Insert { key, .. } => OperationContext::Insert { key: key.clone() },
224 BatchOperation::Delete { key } => OperationContext::Delete { key: key.clone() },
225 }
226 }
227}
228
229impl OperationContext {
230 fn key(&self) -> Option<&str> {
231 match self {
232 OperationContext::Get { key, .. } | OperationContext::Delete { key } => Some(key),
233 OperationContext::Insert { key } => key.as_deref(),
234 }
235 }
236}
237
238enum Classified {
240 Batchable(BatchOperation, u64),
242 Individual(BatchOperation),
244 Failed(OperationResult),
246}
247
248fn error_result(ctx: OperationContext, error: Error) -> OperationResult {
250 let key = ctx.key().unwrap_or("<unknown>").to_owned();
251 match ctx {
252 OperationContext::Get { .. } => OperationResult::Get(key, Err(error)),
253 OperationContext::Insert { .. } => OperationResult::Put(key, Err(error)),
254 OperationContext::Delete { .. } => OperationResult::Delete(key, Err(error)),
255 }
256}
257
258impl OperationResult {
259 async fn from_field(
260 field: Field<'_>,
261 context_map: &HashMap<usize, OperationContext>,
262 ) -> (Option<usize>, Self) {
263 match Self::try_from_field(field, context_map).await {
264 Ok((index, result)) => (Some(index), result),
265 Err(e) => (None, OperationResult::Error(e)),
266 }
267 }
268
269 async fn try_from_field(
270 field: Field<'_>,
271 context_map: &HashMap<usize, OperationContext>,
272 ) -> Result<(usize, Self), Error> {
273 let mut headers = field.headers().clone();
274
275 let index: usize = headers
276 .remove(HEADER_BATCH_OPERATION_INDEX)
277 .and_then(|v| v.to_str().ok().and_then(|s| s.parse().ok()))
278 .ok_or_else(|| {
279 Error::MalformedResponse(format!(
280 "missing or invalid {HEADER_BATCH_OPERATION_INDEX} header"
281 ))
282 })?;
283
284 let status: u16 = headers
285 .remove(HEADER_BATCH_OPERATION_STATUS)
286 .and_then(|v| {
287 v.to_str().ok().and_then(|s| {
288 s.split_once(' ')
291 .map(|(code, _)| code)
292 .unwrap_or(s)
293 .parse()
294 .ok()
295 })
296 })
297 .ok_or_else(|| {
298 Error::MalformedResponse(format!(
299 "missing or invalid {HEADER_BATCH_OPERATION_STATUS} header"
300 ))
301 })?;
302
303 let ctx = context_map.get(&index).ok_or_else(|| {
304 Error::MalformedResponse(format!(
305 "response references unknown operation index {index}"
306 ))
307 })?;
308
309 let key = headers
311 .remove(HEADER_BATCH_OPERATION_KEY)
312 .and_then(|v| {
313 v.to_str()
314 .ok()
315 .and_then(|encoded| {
316 percent_encoding::percent_decode_str(encoded)
317 .decode_utf8()
318 .ok()
319 })
320 .map(|s| s.into_owned())
321 })
322 .or_else(|| ctx.key().map(str::to_owned));
323
324 let body = field.bytes().await?;
325
326 let is_error =
327 status >= 400 && !(matches!(ctx, OperationContext::Get { .. }) && status == 404);
328
329 let key = match key {
334 Some(key) => key,
335 None if is_error => "<unknown>".to_owned(),
336 None => {
337 return Err(Error::MalformedResponse(format!(
338 "missing or invalid {HEADER_BATCH_OPERATION_KEY} header"
339 )));
340 }
341 };
342 if is_error {
343 let message = String::from_utf8_lossy(&body).into_owned();
344 let error = Error::OperationFailure { status, message };
345
346 return Ok((
347 index,
348 match ctx {
349 OperationContext::Get { .. } => OperationResult::Get(key, Err(error)),
350 OperationContext::Insert { .. } => OperationResult::Put(key, Err(error)),
351 OperationContext::Delete { .. } => OperationResult::Delete(key, Err(error)),
352 },
353 ));
354 }
355
356 let result = match ctx {
357 OperationContext::Get {
358 decompress,
359 accept_encoding,
360 ..
361 } => {
362 if status == 404 {
363 OperationResult::Get(key, Ok(None))
364 } else {
365 let mut metadata = Metadata::from_headers(&headers, "")?;
366
367 let stream =
368 futures_util::stream::once(async move { Ok::<_, io::Error>(body) }).boxed();
369 let stream =
370 get::maybe_decompress(stream, &mut metadata, *decompress, accept_encoding);
371
372 OperationResult::Get(key, Ok(Some(GetResponse { metadata, stream })))
373 }
374 }
375 OperationContext::Insert { .. } => {
376 OperationResult::Put(key.clone(), Ok(PutResponse { key }))
377 }
378 OperationContext::Delete { .. } => OperationResult::Delete(key, Ok(())),
379 };
380 Ok((index, result))
381 }
382}
383
384pub struct OperationResults(Pin<Box<dyn Stream<Item = OperationResult> + Send>>);
386
387impl fmt::Debug for OperationResults {
388 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
389 f.write_str("OperationResults([Stream])")
390 }
391}
392
393impl Stream for OperationResults {
394 type Item = OperationResult;
395
396 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
397 self.0.as_mut().poll_next(cx)
398 }
399}
400
401impl OperationResults {
402 pub async fn error_for_failures(
407 mut self,
408 ) -> crate::Result<(), impl Iterator<Item = crate::Error>> {
409 let mut errs = Vec::new();
410 while let Some(res) = self.next().await {
411 match res {
412 OperationResult::Get(_, get) => {
413 if let Err(e) = get {
414 errs.push(e);
415 }
416 }
417 OperationResult::Put(_, put) => {
418 if let Err(e) = put {
419 errs.push(e);
420 }
421 }
422 OperationResult::Delete(_, delete) => {
423 if let Err(e) = delete {
424 errs.push(e);
425 }
426 }
427 OperationResult::Error(error) => errs.push(error),
428 }
429 }
430 if errs.is_empty() {
431 return Ok(());
432 }
433 Err(errs.into_iter())
434 }
435}
436
437async fn send_batch(
438 session: &Session,
439 operations: Vec<BatchOperation>,
440) -> crate::Result<Vec<OperationResult>> {
441 let mut context_map: HashMap<usize, OperationContext> = operations
442 .iter()
443 .enumerate()
444 .map(|(idx, op)| (idx, OperationContext::from(op)))
445 .collect();
446 let num_operations = operations.len();
447
448 let mut form = reqwest::multipart::Form::new();
449 for op in operations.into_iter() {
450 let part = op.into_part().await?;
451 form = form.part("part", part);
452 }
453
454 let request = session.batch_request()?.multipart(form);
455 let response = request.send().await?.error_for_status()?;
456
457 let boundary = response
458 .headers()
459 .get(CONTENT_TYPE)
460 .and_then(|v| v.to_str().ok())
461 .ok_or_else(|| Error::MalformedResponse("missing Content-Type header".to_owned()))
462 .map(multer::parse_boundary)??;
463
464 let byte_stream = response.bytes_stream().map(|r| r.map_err(io::Error::other));
465 let mut multipart = multer::Multipart::new(byte_stream, boundary);
466
467 let mut results = Vec::new();
468 let mut seen_indices = HashSet::new();
469 while let Some(field) = multipart.next_field().await? {
470 let (index, result) = OperationResult::from_field(field, &context_map).await;
471 if let Some(idx) = index {
472 seen_indices.insert(idx);
473 }
474 results.push(result);
475 }
476
477 for idx in 0..num_operations {
478 if !seen_indices.contains(&idx) {
479 let error = Error::MalformedResponse(format!(
480 "server did not return a response for operation at index {idx}"
481 ));
482 let result = match context_map.remove(&idx) {
483 Some(ctx) => error_result(ctx, error),
484 None => OperationResult::Error(error),
485 };
486 results.push(result);
487 }
488 }
489
490 Ok(results)
491}
492
493fn classify_fail(key: Option<ObjectKey>, error: Error) -> Classified {
494 Classified::Failed(OperationResult::Put(
495 key.unwrap_or_else(|| "<unknown>".to_owned()),
496 Err(error),
497 ))
498}
499
500async fn classify(op: BatchOperation) -> Classified {
505 match op {
506 BatchOperation::Insert {
507 key,
508 metadata,
509 body,
510 } => {
511 let size = match &body {
512 PutBody::Buffer(bytes) => Some(bytes.len() as u64),
513 PutBody::File(file) => match file.metadata().await {
514 Ok(meta) => Some(meta.len()),
515 Err(err) => return classify_fail(key, err.into()),
516 },
517 PutBody::Path(path) => match tokio::fs::metadata(path).await {
518 Ok(meta) => Some(meta.len()),
519 Err(err) => return classify_fail(key, err.into()),
520 },
521 PutBody::Stream(_) => None,
523 };
524
525 let op = BatchOperation::Insert {
526 key,
527 metadata,
528 body,
529 };
530
531 match size {
532 Some(s) if s <= MAX_BATCH_PART_SIZE as u64 => Classified::Batchable(op, s),
533 _ => Classified::Individual(op),
534 }
535 }
536 other => Classified::Batchable(other, 0),
537 }
538}
539
540async fn partition(
544 operations: Vec<BatchOperation>,
545) -> (
546 Vec<(BatchOperation, u64)>,
547 Vec<BatchOperation>,
548 Vec<OperationResult>,
549) {
550 let classified = futures_util::future::join_all(operations.into_iter().map(classify)).await;
551 let mut batchable = Vec::new();
552 let mut individual = Vec::new();
553 let mut failed = Vec::new();
554 for item in classified {
555 match item {
556 Classified::Batchable(op, size) => batchable.push((op, size)),
557 Classified::Individual(op) => individual.push(op),
558 Classified::Failed(result) => failed.push(result),
559 }
560 }
561 (batchable, individual, failed)
562}
563
564async fn execute_individual(op: BatchOperation, session: &Session) -> OperationResult {
566 match op {
567 BatchOperation::Get {
568 key,
569 decompress,
570 accept_encoding,
571 } => {
572 let get = GetBuilder {
573 session: session.clone(),
574 key: key.clone(),
575 decompress,
576 accept_encoding,
577 };
578 OperationResult::Get(key, get.send().await)
579 }
580 BatchOperation::Insert {
581 key,
582 metadata,
583 body,
584 } => {
585 let error_key = key.clone().unwrap_or_else(|| "<unknown>".to_owned());
586 let put = PutBuilder {
587 session: session.clone(),
588 metadata,
589 key,
590 body,
591 };
592 match put.send().await {
593 Ok(response) => OperationResult::Put(response.key.clone(), Ok(response)),
594 Err(err) => OperationResult::Put(error_key, Err(err)),
595 }
596 }
597 BatchOperation::Delete { key } => {
598 let delete = DeleteBuilder {
599 session: session.clone(),
600 key: key.clone(),
601 };
602 OperationResult::Delete(key, delete.send().await)
603 }
604 }
605}
606
607async fn execute_batch(operations: Vec<BatchOperation>, session: &Session) -> Vec<OperationResult> {
611 let contexts: Vec<_> = operations.iter().map(OperationContext::from).collect();
612 match send_batch(session, operations).await {
613 Ok(results) => results,
614 Err(e) => {
615 let shared = Arc::new(e);
616 contexts
617 .into_iter()
618 .map(|ctx| error_result(ctx, Error::Batch(shared.clone())))
619 .collect()
620 }
621 }
622}
623
624fn iter_batches(ops: Vec<(BatchOperation, u64)>) -> impl Iterator<Item = Vec<BatchOperation>> {
629 let mut remaining = ops.into_iter().peekable();
630
631 std::iter::from_fn(move || {
632 remaining.peek()?;
633 let mut batch_size = 0;
634 let mut batch = Vec::new();
635
636 while let Some((_, op_size)) = remaining.peek() {
637 if batch.len() >= MAX_BATCH_OPS
638 || (!batch.is_empty() && batch_size + op_size > MAX_BATCH_BODY_SIZE)
639 {
640 break;
641 }
642
643 let (op, op_size) = remaining.next().expect("peeked above");
644 batch_size += op_size;
645 batch.push(op);
646 }
647
648 Some(batch)
649 })
650}
651
652impl ManyBuilder {
653 pub async fn send(self) -> OperationResults {
657 let session = self.session;
658 let individual_concurrency = self
659 .max_individual_concurrency
660 .unwrap_or(DEFAULT_INDIVIDUAL_CONCURRENCY)
661 .max(1);
662 let batch_concurrency = self
663 .max_batch_concurrency
664 .unwrap_or(DEFAULT_BATCH_CONCURRENCY)
665 .max(1);
666
667 let (batchable, individual, failed) = partition(self.operations).await;
669
670 let individual_results = futures_util::stream::iter(individual)
672 .map({
673 let session = session.clone();
674 move |op| {
675 let session = session.clone();
676 async move { execute_individual(op, &session).await }
677 }
678 })
679 .buffer_unordered(individual_concurrency);
680
681 let batch_results = futures_util::stream::iter(iter_batches(batchable))
683 .map(move |chunk| {
684 let session = session.clone();
685 async move { execute_batch(chunk, &session).await }
686 })
687 .buffer_unordered(batch_concurrency)
688 .flat_map(futures_util::stream::iter);
689
690 let results = futures_util::stream::iter(failed)
691 .chain(individual_results)
692 .chain(batch_results);
693
694 OperationResults(results.boxed())
695 }
696
697 pub fn max_individual_concurrency(mut self, concurrency: usize) -> Self {
703 self.max_individual_concurrency = Some(concurrency);
704 self
705 }
706
707 pub fn max_batch_concurrency(mut self, concurrency: usize) -> Self {
713 self.max_batch_concurrency = Some(concurrency);
714 self
715 }
716
717 #[allow(private_bounds)]
727 pub fn push<B: Into<BatchOperation>>(mut self, builder: B) -> Self {
728 self.operations.push(builder.into());
729 self
730 }
731}
732
733#[cfg(test)]
734mod tests {
735 use super::*;
736
737 fn op(size: u64) -> (BatchOperation, u64) {
739 (
740 BatchOperation::Delete {
741 key: "k".to_owned(),
742 },
743 size,
744 )
745 }
746
747 fn batch_sizes(batches: &[Vec<BatchOperation>]) -> Vec<usize> {
748 batches.iter().map(Vec::len).collect()
749 }
750
751 fn batches(ops: Vec<(BatchOperation, u64)>) -> Vec<Vec<BatchOperation>> {
752 iter_batches(ops).collect()
753 }
754
755 #[test]
756 fn iter_batches_empty() {
757 assert!(batches(vec![]).is_empty());
758 }
759
760 #[test]
761 fn iter_batches_single_batch_count_limit() {
762 let ops: Vec<_> = (0..1000).map(|_| op(1)).collect();
764 assert_eq!(batch_sizes(&batches(ops)), vec![1000]);
765 }
766
767 #[test]
768 fn iter_batches_splits_on_count_limit() {
769 let ops: Vec<_> = (0..1001).map(|_| op(1)).collect();
771 assert_eq!(batch_sizes(&batches(ops)), vec![1000, 1]);
772 }
773
774 #[test]
775 fn iter_batches_exactly_at_size_limit() {
776 let ops: Vec<_> = (0..100).map(|_| op(1024 * 1024)).collect();
778 assert_eq!(batch_sizes(&batches(ops)), vec![100]);
779 }
780
781 #[test]
782 fn iter_batches_splits_on_size_limit() {
783 let ops: Vec<_> = (0..101).map(|_| op(1024 * 1024)).collect();
785 assert_eq!(batch_sizes(&batches(ops)), vec![100, 1]);
786 }
787
788 #[test]
789 fn iter_batches_size_limit_hits_before_count_limit() {
790 let op_size = 600 * 1024;
792 let ops: Vec<_> = (0..200).map(|_| op(op_size)).collect();
793 let result = batches(ops);
794 let per_batch = (MAX_BATCH_BODY_SIZE / op_size) as usize;
796 assert!(result.len() > 1, "expected multiple batches");
797 for batch in &result[..result.len() - 1] {
798 assert_eq!(batch.len(), per_batch);
799 }
800 }
801}