1use std::collections::{HashMap, HashSet};
2use std::fmt;
3use std::io;
4use std::pin::Pin;
5use std::sync::Arc;
6use std::task::{Context, Poll};
7
8use futures_util::{Stream, StreamExt as _};
9use multer::Field;
10use objectstore_types::metadata::{Compression, Metadata};
11use percent_encoding::NON_ALPHANUMERIC;
12use reqwest::header::{CONTENT_TYPE, HeaderMap, HeaderName, HeaderValue};
13use reqwest::multipart::Part;
14
15use crate::error::Error;
16use crate::put::PutBody;
17use crate::{
18 DeleteBuilder, DeleteResponse, GetBuilder, GetResponse, ObjectKey, PutBuilder, PutResponse,
19 Session, get, put,
20};
21
22const HEADER_BATCH_OPERATION_KEY: &str = "x-sn-batch-operation-key";
23const HEADER_BATCH_OPERATION_KIND: &str = "x-sn-batch-operation-kind";
24const HEADER_BATCH_OPERATION_INDEX: &str = "x-sn-batch-operation-index";
25const HEADER_BATCH_OPERATION_STATUS: &str = "x-sn-batch-operation-status";
26
27const MAX_BATCH_OPS: usize = 1000;
29
30const MAX_BATCH_PART_SIZE: u32 = 1024 * 1024; const MAX_INDIVIDUAL_CONCURRENCY: usize = 5;
37
38const MAX_BATCH_CONCURRENCY: usize = 3;
40
41#[derive(Debug)]
48pub struct ManyBuilder {
49 session: Session,
50 operations: Vec<BatchOperation>,
51}
52
53impl Session {
54 pub fn many(&self) -> ManyBuilder {
59 ManyBuilder {
60 session: self.clone(),
61 operations: vec![],
62 }
63 }
64}
65
66#[derive(Debug)]
67#[allow(clippy::large_enum_variant)]
68enum BatchOperation {
69 Get {
70 key: ObjectKey,
71 decompress: bool,
72 accept_encoding: Vec<Compression>,
73 },
74 Insert {
75 key: Option<ObjectKey>,
76 metadata: Metadata,
77 body: PutBody,
78 },
79 Delete {
80 key: ObjectKey,
81 },
82}
83
84impl From<GetBuilder> for BatchOperation {
85 fn from(value: GetBuilder) -> Self {
86 let GetBuilder {
87 key,
88 decompress,
89 accept_encoding,
90 session: _session,
91 } = value;
92 BatchOperation::Get {
93 key,
94 decompress,
95 accept_encoding,
96 }
97 }
98}
99
100impl From<PutBuilder> for BatchOperation {
101 fn from(value: PutBuilder) -> Self {
102 let PutBuilder {
103 key,
104 metadata,
105 body,
106 session: _session,
107 } = value;
108 BatchOperation::Insert {
109 key,
110 metadata,
111 body,
112 }
113 }
114}
115
116impl From<DeleteBuilder> for BatchOperation {
117 fn from(value: DeleteBuilder) -> Self {
118 let DeleteBuilder {
119 key,
120 session: _session,
121 } = value;
122 BatchOperation::Delete { key }
123 }
124}
125
126impl BatchOperation {
127 async fn into_part(self) -> crate::Result<Part> {
128 match self {
129 BatchOperation::Get { key, .. } => {
130 let headers = operation_headers("get", Some(&key));
131 Ok(Part::text("").headers(headers))
132 }
133 BatchOperation::Insert {
134 key,
135 metadata,
136 body,
137 } => {
138 let mut headers = operation_headers("insert", key.as_deref());
139 headers.extend(metadata.to_headers("")?);
140
141 let body = put::maybe_compress(body, metadata.compression);
142 Ok(Part::stream(body).headers(headers))
143 }
144 BatchOperation::Delete { key } => {
145 let headers = operation_headers("delete", Some(&key));
146 Ok(Part::text("").headers(headers))
147 }
148 }
149 }
150}
151
152fn operation_headers(operation: &str, key: Option<&str>) -> HeaderMap {
153 let mut headers = HeaderMap::new();
154 headers.insert(
155 HeaderName::from_static(HEADER_BATCH_OPERATION_KIND),
156 HeaderValue::from_str(operation).expect("operation kind is always a valid header value"),
157 );
158 if let Some(key) = key {
159 let encoded =
160 percent_encoding::percent_encode(key.as_bytes(), NON_ALPHANUMERIC).to_string();
161 headers.insert(
162 HeaderName::from_static(HEADER_BATCH_OPERATION_KEY),
163 HeaderValue::try_from(encoded)
164 .expect("percent-encoded string is always a valid header value"),
165 );
166 }
167 headers
168}
169
170#[derive(Debug)]
172pub enum OperationResult {
173 Get(ObjectKey, Result<Option<GetResponse>, Error>),
177 Put(ObjectKey, Result<PutResponse, Error>),
179 Delete(ObjectKey, Result<DeleteResponse, Error>),
181 Error(Error),
187}
188
189enum OperationContext {
191 Get {
192 key: ObjectKey,
193 decompress: bool,
194 accept_encoding: Vec<Compression>,
195 },
196 Insert {
197 key: Option<ObjectKey>,
198 },
199 Delete {
200 key: ObjectKey,
201 },
202}
203
204impl From<&BatchOperation> for OperationContext {
205 fn from(op: &BatchOperation) -> Self {
206 match op {
207 BatchOperation::Get {
208 key,
209 decompress,
210 accept_encoding,
211 } => OperationContext::Get {
212 key: key.clone(),
213 decompress: *decompress,
214 accept_encoding: accept_encoding.clone(),
215 },
216 BatchOperation::Insert { key, .. } => OperationContext::Insert { key: key.clone() },
217 BatchOperation::Delete { key } => OperationContext::Delete { key: key.clone() },
218 }
219 }
220}
221
222impl OperationContext {
223 fn key(&self) -> Option<&str> {
224 match self {
225 OperationContext::Get { key, .. } | OperationContext::Delete { key } => Some(key),
226 OperationContext::Insert { key } => key.as_deref(),
227 }
228 }
229}
230
231enum Classified {
233 Batchable(BatchOperation),
235 Individual(BatchOperation),
237 Failed(OperationResult),
239}
240
241fn error_result(ctx: OperationContext, error: Error) -> OperationResult {
243 let key = ctx.key().unwrap_or("<unknown>").to_owned();
244 match ctx {
245 OperationContext::Get { .. } => OperationResult::Get(key, Err(error)),
246 OperationContext::Insert { .. } => OperationResult::Put(key, Err(error)),
247 OperationContext::Delete { .. } => OperationResult::Delete(key, Err(error)),
248 }
249}
250
251impl OperationResult {
252 async fn from_field(
253 field: Field<'_>,
254 context_map: &HashMap<usize, OperationContext>,
255 ) -> (Option<usize>, Self) {
256 match Self::try_from_field(field, context_map).await {
257 Ok((index, result)) => (Some(index), result),
258 Err(e) => (None, OperationResult::Error(e)),
259 }
260 }
261
262 async fn try_from_field(
263 field: Field<'_>,
264 context_map: &HashMap<usize, OperationContext>,
265 ) -> Result<(usize, Self), Error> {
266 let mut headers = field.headers().clone();
267
268 let index: usize = headers
269 .remove(HEADER_BATCH_OPERATION_INDEX)
270 .and_then(|v| v.to_str().ok().and_then(|s| s.parse().ok()))
271 .ok_or_else(|| {
272 Error::MalformedResponse(format!(
273 "missing or invalid {HEADER_BATCH_OPERATION_INDEX} header"
274 ))
275 })?;
276
277 let status: u16 = headers
278 .remove(HEADER_BATCH_OPERATION_STATUS)
279 .and_then(|v| {
280 v.to_str().ok().and_then(|s| {
281 s.split_once(' ')
284 .map(|(code, _)| code)
285 .unwrap_or(s)
286 .parse()
287 .ok()
288 })
289 })
290 .ok_or_else(|| {
291 Error::MalformedResponse(format!(
292 "missing or invalid {HEADER_BATCH_OPERATION_STATUS} header"
293 ))
294 })?;
295
296 let ctx = context_map.get(&index).ok_or_else(|| {
297 Error::MalformedResponse(format!(
298 "response references unknown operation index {index}"
299 ))
300 })?;
301
302 let key = headers
304 .remove(HEADER_BATCH_OPERATION_KEY)
305 .and_then(|v| {
306 v.to_str()
307 .ok()
308 .and_then(|encoded| {
309 percent_encoding::percent_decode_str(encoded)
310 .decode_utf8()
311 .ok()
312 })
313 .map(|s| s.into_owned())
314 })
315 .or_else(|| ctx.key().map(str::to_owned));
316
317 let body = field.bytes().await?;
318
319 let is_error =
320 status >= 400 && !(matches!(ctx, OperationContext::Get { .. }) && status == 404);
321
322 let key = match key {
327 Some(key) => key,
328 None if is_error => "<unknown>".to_owned(),
329 None => {
330 return Err(Error::MalformedResponse(format!(
331 "missing or invalid {HEADER_BATCH_OPERATION_KEY} header"
332 )));
333 }
334 };
335 if is_error {
336 let message = String::from_utf8_lossy(&body).into_owned();
337 let error = Error::OperationFailure { status, message };
338
339 return Ok((
340 index,
341 match ctx {
342 OperationContext::Get { .. } => OperationResult::Get(key, Err(error)),
343 OperationContext::Insert { .. } => OperationResult::Put(key, Err(error)),
344 OperationContext::Delete { .. } => OperationResult::Delete(key, Err(error)),
345 },
346 ));
347 }
348
349 let result = match ctx {
350 OperationContext::Get {
351 decompress,
352 accept_encoding,
353 ..
354 } => {
355 if status == 404 {
356 OperationResult::Get(key, Ok(None))
357 } else {
358 let mut metadata = Metadata::from_headers(&headers, "")?;
359
360 let stream =
361 futures_util::stream::once(async move { Ok::<_, io::Error>(body) }).boxed();
362 let stream =
363 get::maybe_decompress(stream, &mut metadata, *decompress, accept_encoding);
364
365 OperationResult::Get(key, Ok(Some(GetResponse { metadata, stream })))
366 }
367 }
368 OperationContext::Insert { .. } => {
369 OperationResult::Put(key.clone(), Ok(PutResponse { key }))
370 }
371 OperationContext::Delete { .. } => OperationResult::Delete(key, Ok(())),
372 };
373 Ok((index, result))
374 }
375}
376
377pub struct OperationResults(Pin<Box<dyn Stream<Item = OperationResult> + Send>>);
379
380impl fmt::Debug for OperationResults {
381 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
382 f.write_str("OperationResults([Stream])")
383 }
384}
385
386impl Stream for OperationResults {
387 type Item = OperationResult;
388
389 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
390 self.0.as_mut().poll_next(cx)
391 }
392}
393
394impl OperationResults {
395 pub async fn error_for_failures(
400 mut self,
401 ) -> crate::Result<(), impl Iterator<Item = crate::Error>> {
402 let mut errs = Vec::new();
403 while let Some(res) = self.next().await {
404 match res {
405 OperationResult::Get(_, get) => {
406 if let Err(e) = get {
407 errs.push(e);
408 }
409 }
410 OperationResult::Put(_, put) => {
411 if let Err(e) = put {
412 errs.push(e);
413 }
414 }
415 OperationResult::Delete(_, delete) => {
416 if let Err(e) = delete {
417 errs.push(e);
418 }
419 }
420 OperationResult::Error(error) => errs.push(error),
421 }
422 }
423 if errs.is_empty() {
424 return Ok(());
425 }
426 Err(errs.into_iter())
427 }
428}
429
430async fn send_batch(
431 session: &Session,
432 operations: Vec<BatchOperation>,
433) -> crate::Result<Vec<OperationResult>> {
434 let mut context_map: HashMap<usize, OperationContext> = operations
435 .iter()
436 .enumerate()
437 .map(|(idx, op)| (idx, OperationContext::from(op)))
438 .collect();
439 let num_operations = operations.len();
440
441 let mut form = reqwest::multipart::Form::new();
442 for op in operations.into_iter() {
443 let part = op.into_part().await?;
444 form = form.part("part", part);
445 }
446
447 let request = session.batch_request()?.multipart(form);
448 let response = request.send().await?.error_for_status()?;
449
450 let boundary = response
451 .headers()
452 .get(CONTENT_TYPE)
453 .and_then(|v| v.to_str().ok())
454 .ok_or_else(|| Error::MalformedResponse("missing Content-Type header".to_owned()))
455 .map(multer::parse_boundary)??;
456
457 let byte_stream = response.bytes_stream().map(|r| r.map_err(io::Error::other));
458 let mut multipart = multer::Multipart::new(byte_stream, boundary);
459
460 let mut results = Vec::new();
461 let mut seen_indices = HashSet::new();
462 while let Some(field) = multipart.next_field().await? {
463 let (index, result) = OperationResult::from_field(field, &context_map).await;
464 if let Some(idx) = index {
465 seen_indices.insert(idx);
466 }
467 results.push(result);
468 }
469
470 for idx in 0..num_operations {
471 if !seen_indices.contains(&idx) {
472 let error = Error::MalformedResponse(format!(
473 "server did not return a response for operation at index {idx}"
474 ));
475 let result = match context_map.remove(&idx) {
476 Some(ctx) => error_result(ctx, error),
477 None => OperationResult::Error(error),
478 };
479 results.push(result);
480 }
481 }
482
483 Ok(results)
484}
485
486async fn classify(op: BatchOperation) -> Classified {
491 match op {
492 BatchOperation::Insert {
493 key,
494 metadata,
495 body: PutBody::File(file),
496 } => {
497 let meta = match file.metadata().await {
498 Ok(meta) => meta,
499 Err(err) => {
500 let key = key.unwrap_or_else(|| "<unknown>".to_owned());
501 return Classified::Failed(OperationResult::Put(key, Err(err.into())));
502 }
503 };
504
505 let op = BatchOperation::Insert {
506 key,
507 metadata,
508 body: PutBody::File(file),
509 };
510 if meta.len() <= MAX_BATCH_PART_SIZE as u64 {
511 Classified::Batchable(op)
512 } else {
513 Classified::Individual(op)
514 }
515 }
516 other => Classified::Batchable(other),
518 }
519}
520
521async fn partition(
525 operations: Vec<BatchOperation>,
526) -> (
527 Vec<BatchOperation>,
528 Vec<BatchOperation>,
529 Vec<OperationResult>,
530) {
531 let classified = futures_util::future::join_all(operations.into_iter().map(classify)).await;
532 let mut batchable = Vec::new();
533 let mut individual = Vec::new();
534 let mut failed = Vec::new();
535 for item in classified {
536 match item {
537 Classified::Batchable(op) => batchable.push(op),
538 Classified::Individual(op) => individual.push(op),
539 Classified::Failed(result) => failed.push(result),
540 }
541 }
542 (batchable, individual, failed)
543}
544
545async fn execute_individual(op: BatchOperation, session: &Session) -> OperationResult {
547 match op {
548 BatchOperation::Get {
549 key,
550 decompress,
551 accept_encoding,
552 } => {
553 let get = GetBuilder {
554 session: session.clone(),
555 key: key.clone(),
556 decompress,
557 accept_encoding,
558 };
559 OperationResult::Get(key, get.send().await)
560 }
561 BatchOperation::Insert {
562 key,
563 metadata,
564 body,
565 } => {
566 let error_key = key.clone().unwrap_or_else(|| "<unknown>".to_owned());
567 let put = PutBuilder {
568 session: session.clone(),
569 metadata,
570 key,
571 body,
572 };
573 match put.send().await {
574 Ok(response) => OperationResult::Put(response.key.clone(), Ok(response)),
575 Err(err) => OperationResult::Put(error_key, Err(err)),
576 }
577 }
578 BatchOperation::Delete { key } => {
579 let delete = DeleteBuilder {
580 session: session.clone(),
581 key: key.clone(),
582 };
583 OperationResult::Delete(key, delete.send().await)
584 }
585 }
586}
587
588async fn execute_batch(operations: Vec<BatchOperation>, session: &Session) -> Vec<OperationResult> {
592 let contexts: Vec<_> = operations.iter().map(OperationContext::from).collect();
593 match send_batch(session, operations).await {
594 Ok(results) => results,
595 Err(e) => {
596 let shared = Arc::new(e);
597 contexts
598 .into_iter()
599 .map(|ctx| error_result(ctx, Error::Batch(shared.clone())))
600 .collect()
601 }
602 }
603}
604
605impl ManyBuilder {
606 pub async fn send(self) -> OperationResults {
610 let session = self.session;
611
612 let (batchable, individual, failed) = partition(self.operations).await;
614
615 let individual_results = futures_util::stream::iter(individual)
617 .map({
618 let session = session.clone();
619 move |op| {
620 let session = session.clone();
621 async move { execute_individual(op, &session).await }
622 }
623 })
624 .buffer_unordered(MAX_INDIVIDUAL_CONCURRENCY);
625
626 let batch_results = futures_util::stream::unfold(batchable, |mut remaining| async {
628 if remaining.is_empty() {
629 return None;
630 }
631 let at = remaining.len().min(MAX_BATCH_OPS);
632 let chunk: Vec<_> = remaining.drain(..at).collect();
633 Some((chunk, remaining))
634 })
635 .map(move |chunk| {
636 let session = session.clone();
637 async move { execute_batch(chunk, &session).await }
638 })
639 .buffer_unordered(MAX_BATCH_CONCURRENCY)
640 .flat_map(futures_util::stream::iter);
641
642 let results = futures_util::stream::iter(failed)
643 .chain(individual_results)
644 .chain(batch_results);
645
646 OperationResults(results.boxed())
647 }
648
649 #[allow(private_bounds)]
659 pub fn push<B: Into<BatchOperation>>(mut self, builder: B) -> Self {
660 self.operations.push(builder.into());
661 self
662 }
663}