Skip to main content

relay_server/
http.rs

1//! Abstractions for dealing with HTTP clients.
2//!
3//! All of it is implemented as enums because if they were traits, they'd have to be boxed to be
4//! transferrable between actors. Trait objects in turn do not allow for consuming self, using
5//! generic methods or referencing the Self type in return values, all of which is very useful to
6//! do in builder types.
7//!
8//! Note: This literally does what the `http` crate is supposed to do. That crate has builder
9//! objects and common request objects, it's just that nobody bothers to implement the conversion
10//! logic.
11use std::io;
12use std::time::Duration;
13
14use http::header::InvalidHeaderValue;
15use relay_config::HttpEncoding;
16pub use reqwest::StatusCode;
17use reqwest::header::{HeaderMap, HeaderValue};
18use serde::de::DeserializeOwned;
19
20#[derive(Debug, thiserror::Error)]
21pub enum HttpError {
22    #[error("payload too large")]
23    Overflow,
24    #[error("could not send request")]
25    Reqwest(#[from] reqwest::Error),
26    #[error("failed to stream payload")]
27    Io(#[from] io::Error),
28    #[error("failed to parse JSON response")]
29    Json(#[from] serde_json::Error),
30    #[error("request was retried or not initialized")]
31    Misconfigured,
32    #[error("failed to construct header")]
33    Header(#[from] InvalidHeaderValue),
34}
35
36impl HttpError {
37    /// Returns `true` if the error indicates a network downtime.
38    pub fn is_network_error(&self) -> bool {
39        match self {
40            Self::Io(_) => true,
41            // note: status codes are not handled here because we never call error_for_status. This
42            // logic is part of upstream service.
43            Self::Reqwest(error) => error.is_timeout(),
44            Self::Json(_) => false,
45            Self::Overflow => false,
46            Self::Misconfigured => false,
47            Self::Header(_) => false,
48        }
49    }
50}
51
52pub struct Request(pub reqwest::Request);
53
54pub struct RequestBuilder {
55    builder: Option<reqwest::RequestBuilder>,
56}
57
58impl RequestBuilder {
59    pub fn reqwest(builder: reqwest::RequestBuilder) -> Self {
60        RequestBuilder {
61            builder: Some(builder),
62        }
63    }
64
65    pub fn finish(self) -> Result<Request, HttpError> {
66        // The builder is not optional, instead the option is used inside `build` so that it can be
67        // moved temporarily. Therefore, the `unwrap` here is infallible.
68        Ok(Request(self.builder.unwrap().build()?))
69    }
70
71    fn build<F>(&mut self, f: F) -> &mut Self
72    where
73        F: FnOnce(reqwest::RequestBuilder) -> reqwest::RequestBuilder,
74    {
75        self.builder = self.builder.take().map(f);
76        self
77    }
78
79    /// Add a new header, not replacing existing ones.
80    pub fn header(&mut self, key: impl AsRef<str>, value: impl AsRef<[u8]>) -> &mut Self {
81        self.build(|builder| builder.header(key.as_ref(), value.as_ref()))
82    }
83
84    /// Add an optional header, not replacing existing ones.
85    ///
86    /// If the value is `Some`, the header is added. If the value is `None`, headers are not
87    /// changed.
88    pub fn header_opt(
89        &mut self,
90        key: impl AsRef<str>,
91        value: Option<impl AsRef<[u8]>>,
92    ) -> &mut Self {
93        match value {
94            Some(value) => self.build(|builder| builder.header(key.as_ref(), value.as_ref())),
95            None => self,
96        }
97    }
98
99    pub fn content_encoding(&mut self, encoding: HttpEncoding) -> &mut Self {
100        self.header_opt("content-encoding", encoding.name())
101    }
102
103    /// Enables a total request timeout.
104    ///
105    /// See [`reqwest::RequestBuilder::timeout`].
106    pub fn timeout(&mut self, timeout: Duration) -> &mut Self {
107        self.build(|builder| builder.timeout(timeout))
108    }
109
110    pub fn body(&mut self, body: impl Into<reqwest::Body>) -> &mut Self {
111        self.build(|builder| builder.body(body))
112    }
113}
114
115pub struct Response(pub reqwest::Response);
116
117impl Response {
118    pub fn status(&self) -> StatusCode {
119        self.0.status()
120    }
121
122    pub async fn consume(&mut self) -> Result<(), HttpError> {
123        // Consume the request payload such that the underlying connection returns to a "clean
124        // state" and can be reused by the client. This is explicitly required, see:
125        // https://github.com/seanmonstar/reqwest/issues/1272#issuecomment-839813308
126        while self.0.chunk().await?.is_some() {}
127        Ok(())
128    }
129
130    pub fn get_header(&self, key: impl AsRef<str>) -> Option<&[u8]> {
131        Some(self.0.headers().get(key.as_ref())?.as_bytes())
132    }
133
134    pub fn get_all_headers(&self, key: impl AsRef<str>) -> Vec<&[u8]> {
135        self.0
136            .headers()
137            .get_all(key.as_ref())
138            .into_iter()
139            .map(|value| value.as_bytes())
140            .collect()
141    }
142
143    pub fn headers(&self) -> &HeaderMap<HeaderValue> {
144        self.0.headers()
145    }
146
147    pub async fn bytes(self, limit: usize) -> Result<Vec<u8>, HttpError> {
148        let Self(mut request) = self;
149
150        let mut body = Vec::with_capacity(limit.min(8192));
151        while let Some(chunk) = request.chunk().await? {
152            if (body.len() + chunk.len()) > limit {
153                return Err(HttpError::Overflow);
154            }
155
156            body.extend_from_slice(&chunk);
157        }
158
159        Ok(body)
160    }
161
162    pub async fn json<T>(self, limit: usize) -> Result<T, HttpError>
163    where
164        T: 'static + DeserializeOwned,
165    {
166        let bytes = self.bytes(limit).await?;
167        serde_json::from_slice(&bytes).map_err(HttpError::Json)
168    }
169}