1use std::borrow::Cow;
2use std::collections::BTreeMap;
3
4use base64::Engine as _;
5use bytes::Bytes;
6use futures_util::StreamExt as _;
7use objectstore_types::metadata::Metadata;
8use objectstore_types::multipart::{
9 CompleteErrorDetail, CompleteRequest, CompleteSuccessResponse, InitiateResponse,
10 ListPartsResponse, UploadPartResponse,
11};
12use reqwest::Body;
13use serde::Deserialize;
14use tokio::io::AsyncRead;
15use tokio_util::io::ReaderStream;
16
17use crate::{ClientStream, ObjectKey, Session};
18
19pub use objectstore_types::multipart::CompletePart;
20pub use objectstore_types::multipart::ETag;
21pub use objectstore_types::multipart::PartInfo;
22pub use objectstore_types::multipart::PartNumber;
23pub use objectstore_types::multipart::UploadId;
24
25#[derive(Deserialize)]
26#[serde(untagged)]
27enum CompleteResponse {
28 Error { error: CompleteErrorDetail },
29 Success(CompleteSuccessResponse),
30}
31
32impl Session {
33 pub fn initiate_multipart_upload(&self) -> InitiateMultipartBuilder {
47 let metadata = Metadata {
48 expiration_policy: self.scope.usecase().expiration_policy(),
49 compression: self.scope.usecase().compression(),
50 ..Default::default()
51 };
52
53 InitiateMultipartBuilder {
54 session: self.clone(),
55 metadata,
56 key: None,
57 }
58 }
59
60 pub fn resume_multipart_upload(
66 &self,
67 key: impl Into<ObjectKey>,
68 upload_id: impl Into<String>,
69 ) -> crate::Result<MultipartUpload> {
70 Ok(MultipartUpload {
71 session: self.clone(),
72 key: key.into(),
73 upload_id: UploadId::new(upload_id.into())?,
74 })
75 }
76}
77
78#[derive(Debug)]
80pub struct InitiateMultipartBuilder {
81 session: Session,
82 metadata: Metadata,
83 key: Option<ObjectKey>,
84}
85
86impl InitiateMultipartBuilder {
87 pub fn key(mut self, key: impl Into<ObjectKey>) -> Self {
92 self.key = Some(key.into()).filter(|k| !k.is_empty());
93 self
94 }
95
96 pub fn compression(mut self, compression: impl Into<Option<crate::Compression>>) -> Self {
106 self.metadata.compression = compression.into();
107 self
108 }
109
110 pub fn expiration_policy(mut self, expiration_policy: crate::ExpirationPolicy) -> Self {
114 self.metadata.expiration_policy = expiration_policy;
115 self
116 }
117
118 pub fn content_type(mut self, content_type: impl Into<Cow<'static, str>>) -> Self {
123 self.metadata.content_type = content_type.into();
124 self
125 }
126
127 pub fn origin(mut self, origin: impl Into<String>) -> Self {
145 self.metadata.origin = Some(origin.into());
146 self
147 }
148
149 pub fn set_metadata(mut self, metadata: impl Into<BTreeMap<String, String>>) -> Self {
153 self.metadata.custom = metadata.into();
154 self
155 }
156
157 pub fn append_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
159 self.metadata.custom.insert(key.into(), value.into());
160 self
161 }
162
163 pub async fn send(self) -> crate::Result<MultipartUpload> {
165 let method = match self.key {
166 Some(_) => reqwest::Method::PUT,
167 None => reqwest::Method::POST,
168 };
169
170 let mut builder =
171 self.session
172 .multipart_request(method, None, self.key.as_deref(), None)?;
173
174 builder = builder.headers(self.metadata.to_headers("")?);
175
176 let response: InitiateResponse = builder.send().await?.error_for_status()?.json().await?;
177
178 Ok(MultipartUpload {
179 session: self.session,
180 key: response.key,
181 upload_id: response.upload_id,
182 })
183 }
184}
185
186#[derive(Debug)]
190pub struct MultipartUpload {
191 session: Session,
192 key: String,
193 upload_id: UploadId,
194}
195
196impl MultipartUpload {
197 pub fn upload_id(&self) -> &UploadId {
199 &self.upload_id
200 }
201
202 pub fn key(&self) -> &ObjectKey {
204 &self.key
205 }
206
207 pub async fn put(
214 &self,
215 body: impl Into<Bytes>,
216 part_number: u32,
217 content_md5: Option<&[u8; 16]>,
218 ) -> crate::Result<CompletePart> {
219 let bytes = body.into();
220 let content_length = bytes.len() as u64;
221 self.upload_part(bytes.into(), part_number, content_length, content_md5)
222 .await
223 }
224
225 pub async fn put_stream(
232 &self,
233 stream: ClientStream,
234 part_number: u32,
235 content_length: u64,
236 content_md5: Option<&[u8; 16]>,
237 ) -> crate::Result<CompletePart> {
238 self.upload_part(
239 Body::wrap_stream(stream),
240 part_number,
241 content_length,
242 content_md5,
243 )
244 .await
245 }
246
247 pub async fn put_read<R>(
254 &self,
255 reader: R,
256 part_number: u32,
257 content_length: u64,
258 content_md5: Option<&[u8; 16]>,
259 ) -> crate::Result<CompletePart>
260 where
261 R: AsyncRead + Send + Sync + 'static,
262 {
263 let stream = ReaderStream::new(reader).boxed();
264 self.put_stream(stream, part_number, content_length, content_md5)
265 .await
266 }
267
268 async fn upload_part(
269 &self,
270 body: Body,
271 part_number: u32,
272 content_length: u64,
273 content_md5: Option<&[u8; 16]>,
274 ) -> crate::Result<CompletePart> {
275 let part_number =
276 PartNumber::new(part_number).ok_or(crate::Error::InvalidPartNumber(part_number))?;
277
278 let mut builder = self
279 .session
280 .multipart_request(
281 reqwest::Method::PUT,
282 Some("parts"),
283 Some(&self.key),
284 Some(vec![
285 ("upload_id", self.upload_id.to_string()),
286 ("part_number", part_number.to_string()),
287 ]),
288 )?
289 .header(reqwest::header::CONTENT_LENGTH, content_length)
290 .body(body);
291
292 if let Some(md5) = content_md5 {
293 let encoded = base64::engine::general_purpose::STANDARD.encode(md5);
294 builder = builder.header("content-md5", encoded);
295 }
296
297 let response: UploadPartResponse = builder.send().await?.error_for_status()?.json().await?;
298 Ok(CompletePart {
299 part_number,
300 etag: response.etag,
301 })
302 }
303
304 pub async fn list_parts(&self) -> crate::Result<Vec<PartInfo>> {
306 let mut all_parts = Vec::new();
307 let mut marker = None;
308
309 loop {
310 let page = self.list_parts_page(None, marker).await?;
311 all_parts.extend(page.parts);
312
313 if !page.is_truncated {
314 return Ok(all_parts);
315 }
316 marker = page.next_part_number_marker;
317 if marker.is_none() {
318 return Err(crate::Error::MalformedResponse(
319 "server returned is_truncated=true but no next_part_number_marker. Please report a bug.".into(),
320 ));
321 }
322 }
323 }
324
325 async fn list_parts_page(
326 &self,
327 max_parts: Option<u32>,
328 part_number_marker: Option<PartNumber>,
329 ) -> crate::Result<ListPartsResponse> {
330 let mut params: Vec<(&str, String)> = vec![("upload_id", self.upload_id.to_string())];
331 if let Some(max) = max_parts {
332 params.push(("max_parts", max.to_string()));
333 }
334 if let Some(marker) = part_number_marker {
335 params.push(("part_number_marker", marker.to_string()));
336 }
337
338 let builder = self.session.multipart_request(
339 reqwest::Method::GET,
340 Some("parts"),
341 Some(&self.key),
342 Some(params),
343 )?;
344
345 let response: ListPartsResponse = builder.send().await?.error_for_status()?.json().await?;
346 Ok(response)
347 }
348
349 pub async fn abort(self) -> crate::Result<()> {
351 let builder = self.session.multipart_request(
352 reqwest::Method::DELETE,
353 None,
354 Some(&self.key),
355 Some(vec![("upload_id", self.upload_id.to_string())]),
356 )?;
357 builder.send().await?.error_for_status()?;
358 Ok(())
359 }
360
361 pub async fn complete(
363 self,
364 parts: impl IntoIterator<Item = CompletePart>,
365 ) -> crate::Result<ObjectKey> {
366 let mut parts: Vec<_> = parts.into_iter().collect();
367 parts.sort_by_key(|p| p.part_number);
368
369 let builder = self
370 .session
371 .multipart_request(
372 reqwest::Method::POST,
373 Some("complete"),
374 Some(&self.key),
375 Some(vec![("upload_id", self.upload_id.to_string())]),
376 )?
377 .json(&CompleteRequest { parts });
378
379 let response = builder.send().await?.error_for_status()?;
380 match response.json::<CompleteResponse>().await? {
381 CompleteResponse::Success(s) => Ok(s.key),
382 CompleteResponse::Error { error } => Err(crate::Error::MultipartComplete {
383 code: error.code,
384 message: error.message,
385 }),
386 }
387 }
388}