objectstore_server/extractors/
id.rs

1use std::borrow::Cow;
2
3use axum::extract::rejection::PathRejection;
4use axum::extract::{FromRequestParts, Path};
5use axum::http::request::Parts;
6use axum::response::{IntoResponse, Response};
7use objectstore_service::id::{ObjectContext, ObjectId};
8use objectstore_types::scope::{EMPTY_SCOPES, Scope, Scopes};
9use serde::{Deserialize, de};
10
11use crate::extractors::Xt;
12use crate::extractors::downstream_service::DownstreamService;
13use crate::state::ServiceState;
14
15#[derive(Debug)]
16pub enum ObjectRejection {
17    Path(PathRejection),
18    Killswitched,
19    RateLimited,
20}
21
22impl IntoResponse for ObjectRejection {
23    fn into_response(self) -> Response {
24        match self {
25            ObjectRejection::Path(rejection) => rejection.into_response(),
26            ObjectRejection::Killswitched => (
27                axum::http::StatusCode::FORBIDDEN,
28                "Object access is disabled for this scope through killswitches",
29            )
30                .into_response(),
31            ObjectRejection::RateLimited => (
32                axum::http::StatusCode::TOO_MANY_REQUESTS,
33                "Object access is rate limited",
34            )
35                .into_response(),
36        }
37    }
38}
39
40impl From<PathRejection> for ObjectRejection {
41    fn from(rejection: PathRejection) -> Self {
42        ObjectRejection::Path(rejection)
43    }
44}
45
46impl FromRequestParts<ServiceState> for Xt<ObjectId> {
47    type Rejection = ObjectRejection;
48
49    async fn from_request_parts(
50        parts: &mut Parts,
51        state: &ServiceState,
52    ) -> Result<Self, Self::Rejection> {
53        let Path(params) = Path::<ObjectParams>::from_request_parts(parts, state).await?;
54        let id = ObjectId::from_parts(params.usecase, params.scopes, params.key);
55
56        populate_sentry_context(id.context());
57        sentry::configure_scope(|s| s.set_extra("key", id.key().into()));
58
59        let service = DownstreamService::from_request_parts(parts, state)
60            .await
61            .unwrap();
62
63        if state
64            .config
65            .killswitches
66            .matches(id.context(), service.as_str())
67        {
68            return Err(ObjectRejection::Killswitched);
69        }
70
71        if !state.rate_limiter.check(id.context()) {
72            return Err(ObjectRejection::RateLimited);
73        }
74
75        Ok(Xt(id))
76    }
77}
78
79/// Path parameters used for object-level endpoints.
80///
81/// This is meant to be used with the axum `Path` extractor.
82#[derive(Clone, Debug, Deserialize)]
83struct ObjectParams {
84    usecase: String,
85    #[serde(deserialize_with = "deserialize_scopes")]
86    scopes: Scopes,
87    key: String,
88}
89
90/// Deserializes a `Scopes` instance from a string representation.
91///
92/// The string representation is a semicolon-separated list of `key=value` pairs, following the
93/// Matrix URIs proposal. An empty scopes string (`"_"`) represents no scopes.
94fn deserialize_scopes<'de, D>(deserializer: D) -> Result<Scopes, D::Error>
95where
96    D: de::Deserializer<'de>,
97{
98    let s = Cow::<str>::deserialize(deserializer)?;
99    if s == EMPTY_SCOPES {
100        return Ok(Scopes::empty());
101    }
102
103    let scopes = s
104        .split(';')
105        .map(|s| {
106            let (key, value) = s
107                .split_once("=")
108                .ok_or_else(|| de::Error::custom("scope must be 'key=value'"))?;
109
110            Scope::create(key, value).map_err(de::Error::custom)
111        })
112        .collect::<Result<_, _>>()?;
113
114    Ok(scopes)
115}
116
117impl FromRequestParts<ServiceState> for Xt<ObjectContext> {
118    type Rejection = ObjectRejection;
119
120    async fn from_request_parts(
121        parts: &mut Parts,
122        state: &ServiceState,
123    ) -> Result<Self, Self::Rejection> {
124        let Path(params) = Path::<ContextParams>::from_request_parts(parts, state).await?;
125        let context = ObjectContext {
126            usecase: params.usecase,
127            scopes: params.scopes,
128        };
129
130        populate_sentry_context(&context);
131
132        let service = DownstreamService::from_request_parts(parts, state)
133            .await
134            .unwrap();
135
136        if state
137            .config
138            .killswitches
139            .matches(&context, service.as_str())
140        {
141            return Err(ObjectRejection::Killswitched);
142        }
143
144        if !state.rate_limiter.check(&context) {
145            return Err(ObjectRejection::RateLimited);
146        }
147
148        Ok(Xt(context))
149    }
150}
151
152/// Path parameters for extracting an [`ObjectContext`] from a request path.
153///
154/// Works on both collection-level (`/objects/{usecase}/{scopes}`) and object-level
155/// (`/objects/{usecase}/{scopes}/{*key}`) routes — the extra `key` parameter is ignored.
156#[derive(Clone, Debug, Deserialize)]
157pub(super) struct ContextParams {
158    pub usecase: String,
159    #[serde(deserialize_with = "deserialize_scopes")]
160    pub scopes: Scopes,
161}
162
163fn populate_sentry_context(context: &ObjectContext) {
164    sentry::configure_scope(|s| {
165        s.set_tag("usecase", &context.usecase);
166        for scope in &context.scopes {
167            s.set_tag(&format!("scope.{}", scope.name()), scope.value());
168        }
169    });
170}
171
172#[cfg(test)]
173mod tests {
174    use super::*;
175    use serde::de::IntoDeserializer;
176    use serde::de::value::{CowStrDeserializer, Error as DeError};
177    use std::borrow::Cow;
178
179    fn deser_scopes(input: &str) -> Result<Scopes, DeError> {
180        let deserializer: CowStrDeserializer<DeError> = Cow::Borrowed(input).into_deserializer();
181        deserialize_scopes(deserializer)
182    }
183
184    #[test]
185    fn parse_single_scope() {
186        let scopes = deser_scopes("org=123").unwrap();
187        assert_eq!(scopes.get_value("org"), Some("123"));
188    }
189
190    #[test]
191    fn parse_multiple_scopes() {
192        let scopes = deser_scopes("org=123;project=456").unwrap();
193        assert_eq!(scopes.get_value("org"), Some("123"));
194        assert_eq!(scopes.get_value("project"), Some("456"));
195    }
196
197    #[test]
198    fn parse_empty_scopes() {
199        let scopes = deser_scopes("_").unwrap();
200        assert!(scopes.is_empty());
201    }
202
203    #[test]
204    fn parse_missing_equals() {
205        let result = deser_scopes("org123");
206        assert!(result.is_err());
207    }
208
209    #[test]
210    fn parse_invalid_scope_chars() {
211        let result = deser_scopes("org=hello world");
212        assert!(result.is_err());
213    }
214
215    #[test]
216    fn parse_empty_key_or_value() {
217        assert!(deser_scopes("=value").is_err());
218        assert!(deser_scopes("key=").is_err());
219    }
220
221    // --- Extractor integration tests ---
222
223    use std::collections::BTreeMap;
224    use std::sync::Arc;
225
226    use axum::Router;
227    use axum::body::Body;
228    use axum::http::{Request, StatusCode};
229    use axum::routing::{get, post};
230    use objectstore_service::StorageService;
231    use objectstore_service::backend::in_memory::InMemoryBackend;
232    use tower::ServiceExt;
233
234    use crate::auth::PublicKeyDirectory;
235    use crate::config::Config;
236    use crate::killswitches::{Killswitch, Killswitches};
237    use crate::rate_limits::{RateLimiter, RateLimits, ThroughputLimits};
238    use crate::state::{ServiceState, Services};
239    use crate::web::RequestCounter;
240
241    async fn test_state(config: Config) -> ServiceState {
242        let service = StorageService::new(Box::new(InMemoryBackend::new("in-memory")));
243        let key_directory = PublicKeyDirectory::try_from(&config.auth).unwrap();
244        let rate_limiter = RateLimiter::new(config.rate_limits.clone());
245
246        Arc::new(Services {
247            config,
248            service,
249            key_directory,
250            rate_limiter,
251            request_counter: RequestCounter::new(0),
252        })
253    }
254
255    async fn handle_object_id(Xt(id): Xt<ObjectId>) -> String {
256        format!(
257            "usecase={} key={} scopes_empty={}",
258            id.context().usecase,
259            id.key(),
260            id.context().scopes.is_empty(),
261        )
262    }
263
264    async fn handle_object_context(Xt(ctx): Xt<ObjectContext>) -> String {
265        format!(
266            "usecase={} scopes_empty={}",
267            ctx.usecase,
268            ctx.scopes.is_empty(),
269        )
270    }
271
272    fn test_router(state: ServiceState) -> Router {
273        Router::new()
274            .route("/objects/{usecase}/{scopes}/{*key}", get(handle_object_id))
275            .route("/objects/{usecase}/{scopes}/", post(handle_object_context))
276            .with_state(state)
277    }
278
279    async fn response_body(response: axum::http::Response<Body>) -> String {
280        let bytes = axum::body::to_bytes(response.into_body(), usize::MAX)
281            .await
282            .unwrap();
283        String::from_utf8(bytes.to_vec()).unwrap()
284    }
285
286    // Extraction tests
287
288    #[tokio::test]
289    async fn extract_object_id_parses_path() {
290        let state = test_state(Config::default()).await;
291        let app = test_router(state);
292
293        let request = Request::builder()
294            .uri("/objects/myusecase/org=123;project=456/my-key")
295            .body(Body::empty())
296            .unwrap();
297        let response = app.oneshot(request).await.unwrap();
298
299        assert_eq!(response.status(), StatusCode::OK);
300        let body = response_body(response).await;
301        assert!(body.contains("usecase=myusecase"));
302        assert!(body.contains("key=my-key"));
303        assert!(body.contains("scopes_empty=false"));
304    }
305
306    #[tokio::test]
307    async fn extract_object_id_with_empty_scopes() {
308        let state = test_state(Config::default()).await;
309        let app = test_router(state);
310
311        let request = Request::builder()
312            .uri("/objects/myusecase/_/my-key")
313            .body(Body::empty())
314            .unwrap();
315        let response = app.oneshot(request).await.unwrap();
316
317        assert_eq!(response.status(), StatusCode::OK);
318        let body = response_body(response).await;
319        assert!(body.contains("scopes_empty=true"));
320    }
321
322    #[tokio::test]
323    async fn extract_object_context_parses_path() {
324        let state = test_state(Config::default()).await;
325        let app = test_router(state);
326
327        let request = Request::builder()
328            .method("POST")
329            .uri("/objects/myusecase/org=123;project=456/")
330            .body(Body::empty())
331            .unwrap();
332        let response = app.oneshot(request).await.unwrap();
333
334        assert_eq!(response.status(), StatusCode::OK);
335        let body = response_body(response).await;
336        assert!(body.contains("usecase=myusecase"));
337        assert!(body.contains("scopes_empty=false"));
338    }
339
340    #[tokio::test]
341    async fn extract_object_context_with_empty_scopes() {
342        let state = test_state(Config::default()).await;
343        let app = test_router(state);
344
345        let request = Request::builder()
346            .method("POST")
347            .uri("/objects/myusecase/_/")
348            .body(Body::empty())
349            .unwrap();
350        let response = app.oneshot(request).await.unwrap();
351
352        assert_eq!(response.status(), StatusCode::OK);
353        let body = response_body(response).await;
354        assert!(body.contains("scopes_empty=true"));
355    }
356
357    #[tokio::test]
358    async fn extract_object_id_invalid_scopes() {
359        let state = test_state(Config::default()).await;
360        let app = test_router(state);
361
362        let request = Request::builder()
363            .uri("/objects/myusecase/invalid-no-equals/key")
364            .body(Body::empty())
365            .unwrap();
366        let response = app.oneshot(request).await.unwrap();
367
368        assert_eq!(response.status(), StatusCode::BAD_REQUEST);
369    }
370
371    // Killswitch tests
372
373    #[tokio::test]
374    async fn extract_object_id_killswitched() {
375        let config = Config {
376            killswitches: Killswitches(vec![Killswitch {
377                usecase: Some("blocked".into()),
378                scopes: BTreeMap::new(),
379                service: None,
380                service_matcher: Default::default(),
381            }]),
382            ..Config::default()
383        };
384        let state = test_state(config).await;
385        let app = test_router(state);
386
387        let request = Request::builder()
388            .uri("/objects/blocked/org=1/key")
389            .body(Body::empty())
390            .unwrap();
391        let response = app.clone().oneshot(request).await.unwrap();
392        assert_eq!(response.status(), StatusCode::FORBIDDEN);
393
394        let request = Request::builder()
395            .uri("/objects/allowed/org=1/key")
396            .body(Body::empty())
397            .unwrap();
398        let response = app.oneshot(request).await.unwrap();
399        assert_eq!(response.status(), StatusCode::OK);
400    }
401
402    #[tokio::test]
403    async fn extract_object_context_killswitched() {
404        let config = Config {
405            killswitches: Killswitches(vec![Killswitch {
406                usecase: Some("blocked".into()),
407                scopes: BTreeMap::new(),
408                service: None,
409                service_matcher: Default::default(),
410            }]),
411            ..Config::default()
412        };
413        let state = test_state(config).await;
414        let app = test_router(state);
415
416        let request = Request::builder()
417            .method("POST")
418            .uri("/objects/blocked/org=1/")
419            .body(Body::empty())
420            .unwrap();
421        let response = app.clone().oneshot(request).await.unwrap();
422        assert_eq!(response.status(), StatusCode::FORBIDDEN);
423
424        let request = Request::builder()
425            .method("POST")
426            .uri("/objects/allowed/org=1/")
427            .body(Body::empty())
428            .unwrap();
429        let response = app.oneshot(request).await.unwrap();
430        assert_eq!(response.status(), StatusCode::OK);
431    }
432
433    #[tokio::test]
434    async fn extract_object_id_killswitched_with_service() {
435        let config = Config {
436            killswitches: Killswitches(vec![Killswitch {
437                usecase: None,
438                scopes: BTreeMap::new(),
439                service: Some("test-*".into()),
440                service_matcher: Default::default(),
441            }]),
442            ..Config::default()
443        };
444        let state = test_state(config).await;
445        let app = test_router(state);
446
447        // Matching service header → 403
448        let request = Request::builder()
449            .uri("/objects/any/org=1/key")
450            .header("x-downstream-service", "test-service")
451            .body(Body::empty())
452            .unwrap();
453        let response = app.clone().oneshot(request).await.unwrap();
454        assert_eq!(response.status(), StatusCode::FORBIDDEN);
455
456        // Non-matching service header → 200
457        let request = Request::builder()
458            .uri("/objects/any/org=1/key")
459            .header("x-downstream-service", "other-service")
460            .body(Body::empty())
461            .unwrap();
462        let response = app.clone().oneshot(request).await.unwrap();
463        assert_eq!(response.status(), StatusCode::OK);
464
465        // No service header → 200
466        let request = Request::builder()
467            .uri("/objects/any/org=1/key")
468            .body(Body::empty())
469            .unwrap();
470        let response = app.oneshot(request).await.unwrap();
471        assert_eq!(response.status(), StatusCode::OK);
472    }
473
474    // Rate limiter tests
475
476    #[tokio::test]
477    async fn extract_object_id_rate_limited() {
478        let config = Config {
479            rate_limits: RateLimits {
480                throughput: ThroughputLimits {
481                    global_rps: Some(1),
482                    burst: 0,
483                    ..ThroughputLimits::default()
484                },
485                ..RateLimits::default()
486            },
487            ..Config::default()
488        };
489        let state = test_state(config).await;
490        let app = test_router(state);
491
492        let request = Request::builder()
493            .uri("/objects/test/org=1/key")
494            .body(Body::empty())
495            .unwrap();
496        let response = app.clone().oneshot(request).await.unwrap();
497        assert_eq!(response.status(), StatusCode::OK);
498
499        let request = Request::builder()
500            .uri("/objects/test/org=1/key")
501            .body(Body::empty())
502            .unwrap();
503        let response = app.oneshot(request).await.unwrap();
504        assert_eq!(response.status(), StatusCode::TOO_MANY_REQUESTS);
505    }
506
507    #[tokio::test]
508    async fn extract_object_context_rate_limited() {
509        let config = Config {
510            rate_limits: RateLimits {
511                throughput: ThroughputLimits {
512                    global_rps: Some(1),
513                    burst: 0,
514                    ..ThroughputLimits::default()
515                },
516                ..RateLimits::default()
517            },
518            ..Config::default()
519        };
520        let state = test_state(config).await;
521        let app = test_router(state);
522
523        let request = Request::builder()
524            .method("POST")
525            .uri("/objects/test/org=1/")
526            .body(Body::empty())
527            .unwrap();
528        let response = app.clone().oneshot(request).await.unwrap();
529        assert_eq!(response.status(), StatusCode::OK);
530
531        let request = Request::builder()
532            .method("POST")
533            .uri("/objects/test/org=1/")
534            .body(Body::empty())
535            .unwrap();
536        let response = app.oneshot(request).await.unwrap();
537        assert_eq!(response.status(), StatusCode::TOO_MANY_REQUESTS);
538    }
539}