objectstore_server/extractors/
body.rs

1//! Axum extractor for bandwidth-metered request bodies.
2
3use std::convert::Infallible;
4
5use axum::extract::{FromRequest, FromRequestParts, Path, Request};
6use futures_util::{StreamExt, TryStreamExt};
7use objectstore_service::id::ObjectContext;
8use objectstore_service::stream::{ClientError, ClientStream};
9
10use super::id::ContextParams;
11use crate::state::ServiceState;
12
13/// An extractor that converts the request body into a metered [`ClientStream`].
14///
15/// Extracts the [`ObjectContext`] from the request path to attribute bandwidth to the correct
16/// per-usecase and per-scope accumulators in addition to the global accumulator.
17pub struct MeteredBody(pub ClientStream);
18
19impl std::fmt::Debug for MeteredBody {
20    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
21        f.debug_struct("MeteredBody").finish()
22    }
23}
24
25impl FromRequest<ServiceState> for MeteredBody {
26    type Rejection = Infallible;
27
28    async fn from_request(request: Request, state: &ServiceState) -> Result<Self, Self::Rejection> {
29        let (mut parts, body) = request.into_parts();
30        let Path(params) =
31            <Path<ContextParams> as FromRequestParts<ServiceState>>::from_request_parts(
32                &mut parts, state,
33            )
34            .await
35            .expect("MeteredBody must be used on routes with {usecase} and {scopes} path params");
36        let context = ObjectContext {
37            usecase: params.usecase,
38            scopes: params.scopes,
39        };
40        let stream = body.into_data_stream().map_err(ClientError::new).boxed();
41        let stream = state.meter_stream(stream, &context).boxed();
42        Ok(Self(stream))
43    }
44}