relay_server/
http.rs

1//! Abstractions for dealing with HTTP clients.
2//!
3//! All of it is implemented as enums because if they were traits, they'd have to be boxed to be
4//! transferrable between actors. Trait objects in turn do not allow for consuming self, using
5//! generic methods or referencing the Self type in return values, all of which is very useful to
6//! do in builder types.
7//!
8//! Note: This literally does what the `http` crate is supposed to do. That crate has builder
9//! objects and common request objects, it's just that nobody bothers to implement the conversion
10//! logic.
11use std::io;
12use std::time::Duration;
13
14use relay_config::HttpEncoding;
15pub use reqwest::StatusCode;
16use reqwest::header::{HeaderMap, HeaderValue};
17use serde::de::DeserializeOwned;
18
19#[derive(Debug, thiserror::Error)]
20pub enum HttpError {
21    #[error("payload too large")]
22    Overflow,
23    #[error("could not send request")]
24    Reqwest(#[from] reqwest::Error),
25    #[error("failed to stream payload")]
26    Io(#[from] io::Error),
27    #[error("failed to parse JSON response")]
28    Json(#[from] serde_json::Error),
29    #[error("request was retried or not initialized")]
30    Misconfigured,
31}
32
33impl HttpError {
34    /// Returns `true` if the error indicates a network downtime.
35    pub fn is_network_error(&self) -> bool {
36        match self {
37            Self::Io(_) => true,
38            // note: status codes are not handled here because we never call error_for_status. This
39            // logic is part of upstream service.
40            Self::Reqwest(error) => error.is_timeout(),
41            Self::Json(_) => false,
42            Self::Overflow => false,
43            Self::Misconfigured => false,
44        }
45    }
46}
47
48pub struct Request(pub reqwest::Request);
49
50pub struct RequestBuilder {
51    builder: Option<reqwest::RequestBuilder>,
52}
53
54impl RequestBuilder {
55    pub fn reqwest(builder: reqwest::RequestBuilder) -> Self {
56        RequestBuilder {
57            builder: Some(builder),
58        }
59    }
60
61    pub fn finish(self) -> Result<Request, HttpError> {
62        // The builder is not optional, instead the option is used inside `build` so that it can be
63        // moved temporarily. Therefore, the `unwrap` here is infallible.
64        Ok(Request(self.builder.unwrap().build()?))
65    }
66
67    fn build<F>(&mut self, f: F) -> &mut Self
68    where
69        F: FnOnce(reqwest::RequestBuilder) -> reqwest::RequestBuilder,
70    {
71        self.builder = self.builder.take().map(f);
72        self
73    }
74
75    /// Add a new header, not replacing existing ones.
76    pub fn header(&mut self, key: impl AsRef<str>, value: impl AsRef<[u8]>) -> &mut Self {
77        self.build(|builder| builder.header(key.as_ref(), value.as_ref()))
78    }
79
80    /// Add an optional header, not replacing existing ones.
81    ///
82    /// If the value is `Some`, the header is added. If the value is `None`, headers are not
83    /// changed.
84    pub fn header_opt(
85        &mut self,
86        key: impl AsRef<str>,
87        value: Option<impl AsRef<[u8]>>,
88    ) -> &mut Self {
89        match value {
90            Some(value) => self.build(|builder| builder.header(key.as_ref(), value.as_ref())),
91            None => self,
92        }
93    }
94
95    pub fn content_encoding(&mut self, encoding: HttpEncoding) -> &mut Self {
96        self.header_opt("content-encoding", encoding.name())
97    }
98
99    /// Enables a total request timeout.
100    ///
101    /// See [`reqwest::RequestBuilder::timeout`].
102    pub fn timeout(&mut self, timeout: Duration) -> &mut Self {
103        self.build(|builder| builder.timeout(timeout))
104    }
105
106    pub fn body(&mut self, body: impl Into<reqwest::Body>) -> &mut Self {
107        self.build(|builder| builder.body(body))
108    }
109}
110
111pub struct Response(pub reqwest::Response);
112
113impl Response {
114    pub fn status(&self) -> StatusCode {
115        self.0.status()
116    }
117
118    pub async fn consume(&mut self) -> Result<(), HttpError> {
119        // Consume the request payload such that the underlying connection returns to a "clean
120        // state" and can be reused by the client. This is explicitly required, see:
121        // https://github.com/seanmonstar/reqwest/issues/1272#issuecomment-839813308
122        while self.0.chunk().await?.is_some() {}
123        Ok(())
124    }
125
126    pub fn get_header(&self, key: impl AsRef<str>) -> Option<&[u8]> {
127        Some(self.0.headers().get(key.as_ref())?.as_bytes())
128    }
129
130    pub fn get_all_headers(&self, key: impl AsRef<str>) -> Vec<&[u8]> {
131        self.0
132            .headers()
133            .get_all(key.as_ref())
134            .into_iter()
135            .map(|value| value.as_bytes())
136            .collect()
137    }
138
139    pub fn headers(&self) -> &HeaderMap<HeaderValue> {
140        self.0.headers()
141    }
142
143    pub async fn bytes(self, limit: usize) -> Result<Vec<u8>, HttpError> {
144        let Self(mut request) = self;
145
146        let mut body = Vec::with_capacity(limit.min(8192));
147        while let Some(chunk) = request.chunk().await? {
148            if (body.len() + chunk.len()) > limit {
149                return Err(HttpError::Overflow);
150            }
151
152            body.extend_from_slice(&chunk);
153        }
154
155        Ok(body)
156    }
157
158    pub async fn json<T>(self, limit: usize) -> Result<T, HttpError>
159    where
160        T: 'static + DeserializeOwned,
161    {
162        let bytes = self.bytes(limit).await?;
163        serde_json::from_slice(&bytes).map_err(HttpError::Json)
164    }
165}