relay_server/
http.rs

1//! Abstractions for dealing with HTTP clients.
2//!
3//! All of it is implemented as enums because if they were traits, they'd have to be boxed to be
4//! transferrable between actors. Trait objects in turn do not allow for consuming self, using
5//! generic methods or referencing the Self type in return values, all of which is very useful to
6//! do in builder types.
7//!
8//! Note: This literally does what the `http` crate is supposed to do. That crate has builder
9//! objects and common request objects, it's just that nobody bothers to implement the conversion
10//! logic.
11use std::io;
12
13use bytes::Bytes;
14use relay_config::HttpEncoding;
15use reqwest::header::{HeaderMap, HeaderValue};
16pub use reqwest::StatusCode;
17use serde::de::DeserializeOwned;
18
19#[derive(Debug, thiserror::Error)]
20pub enum HttpError {
21    #[error("payload too large")]
22    Overflow,
23    #[error("could not send request")]
24    Reqwest(#[from] reqwest::Error),
25    #[error("failed to stream payload")]
26    Io(#[from] io::Error),
27    #[error("failed to parse JSON response")]
28    Json(#[from] serde_json::Error),
29}
30
31impl HttpError {
32    /// Returns `true` if the error indicates a network downtime.
33    pub fn is_network_error(&self) -> bool {
34        match self {
35            Self::Io(_) => true,
36            // note: status codes are not handled here because we never call error_for_status. This
37            // logic is part of upstream service.
38            Self::Reqwest(error) => error.is_timeout(),
39            Self::Json(_) => false,
40            HttpError::Overflow => false,
41        }
42    }
43}
44
45pub struct Request(pub reqwest::Request);
46
47pub struct RequestBuilder {
48    builder: Option<reqwest::RequestBuilder>,
49}
50
51impl RequestBuilder {
52    pub fn reqwest(builder: reqwest::RequestBuilder) -> Self {
53        RequestBuilder {
54            builder: Some(builder),
55        }
56    }
57
58    pub fn finish(self) -> Result<Request, HttpError> {
59        // The builder is not optional, instead the option is used inside `build` so that it can be
60        // moved temporarily. Therefore, the `unwrap` here is infallible.
61        Ok(Request(self.builder.unwrap().build()?))
62    }
63
64    fn build<F>(&mut self, f: F) -> &mut Self
65    where
66        F: FnOnce(reqwest::RequestBuilder) -> reqwest::RequestBuilder,
67    {
68        self.builder = self.builder.take().map(f);
69        self
70    }
71
72    /// Add a new header, not replacing existing ones.
73    pub fn header(&mut self, key: impl AsRef<str>, value: impl AsRef<[u8]>) -> &mut Self {
74        self.build(|builder| builder.header(key.as_ref(), value.as_ref()))
75    }
76
77    /// Add an optional header, not replacing existing ones.
78    ///
79    /// If the value is `Some`, the header is added. If the value is `None`, headers are not
80    /// changed.
81    pub fn header_opt(
82        &mut self,
83        key: impl AsRef<str>,
84        value: Option<impl AsRef<[u8]>>,
85    ) -> &mut Self {
86        match value {
87            Some(value) => self.build(|builder| builder.header(key.as_ref(), value.as_ref())),
88            None => self,
89        }
90    }
91
92    pub fn content_encoding(&mut self, encoding: HttpEncoding) -> &mut Self {
93        self.header_opt("content-encoding", encoding.name())
94    }
95
96    pub fn body(&mut self, body: Bytes) -> &mut Self {
97        self.build(|builder| builder.body(body))
98    }
99}
100
101pub struct Response(pub reqwest::Response);
102
103impl Response {
104    pub fn status(&self) -> StatusCode {
105        self.0.status()
106    }
107
108    pub async fn consume(&mut self) -> Result<(), HttpError> {
109        // Consume the request payload such that the underlying connection returns to a "clean
110        // state" and can be reused by the client. This is explicitly required, see:
111        // https://github.com/seanmonstar/reqwest/issues/1272#issuecomment-839813308
112        while self.0.chunk().await?.is_some() {}
113        Ok(())
114    }
115
116    pub fn get_header(&self, key: impl AsRef<str>) -> Option<&[u8]> {
117        Some(self.0.headers().get(key.as_ref())?.as_bytes())
118    }
119
120    pub fn get_all_headers(&self, key: impl AsRef<str>) -> Vec<&[u8]> {
121        self.0
122            .headers()
123            .get_all(key.as_ref())
124            .into_iter()
125            .map(|value| value.as_bytes())
126            .collect()
127    }
128
129    pub fn headers(&self) -> &HeaderMap<HeaderValue> {
130        self.0.headers()
131    }
132
133    pub async fn bytes(self, limit: usize) -> Result<Vec<u8>, HttpError> {
134        let Self(mut request) = self;
135
136        let mut body = Vec::with_capacity(limit.min(8192));
137        while let Some(chunk) = request.chunk().await? {
138            if (body.len() + chunk.len()) > limit {
139                return Err(HttpError::Overflow);
140            }
141
142            body.extend_from_slice(&chunk);
143        }
144
145        Ok(body)
146    }
147
148    pub async fn json<T>(self, limit: usize) -> Result<T, HttpError>
149    where
150        T: 'static + DeserializeOwned,
151    {
152        let bytes = self.bytes(limit).await?;
153        serde_json::from_slice(&bytes).map_err(HttpError::Json)
154    }
155}