1use std::borrow::Cow;
7use std::error::Error;
8use std::fmt;
9use std::future::Future;
10use std::pin::Pin;
11
12use axum::extract::{DefaultBodyLimit, Request};
13use axum::handler::Handler;
14use axum::http::{HeaderMap, HeaderName, HeaderValue, StatusCode, Uri, header};
15use axum::response::{IntoResponse, Response};
16use bytes::Bytes;
17use once_cell::sync::Lazy;
18use relay_common::glob2::GlobMatcher;
19use relay_config::Config;
20use tokio::sync::oneshot;
21use tokio::sync::oneshot::error::RecvError;
22
23use crate::extractors::ForwardedFor;
24use crate::http::{HttpError, RequestBuilder, Response as UpstreamResponse};
25use crate::service::ServiceState;
26use crate::services::upstream::{Method, SendRequest, UpstreamRequest, UpstreamRequestError};
27
28static HOP_BY_HOP_HEADERS: &[HeaderName] = &[
30 header::CONNECTION,
31 header::PROXY_AUTHENTICATE,
32 header::PROXY_AUTHORIZATION,
33 header::TE,
34 header::TRAILER,
35 header::TRANSFER_ENCODING,
36 header::UPGRADE,
37];
38
39static IGNORED_REQUEST_HEADERS: &[HeaderName] = &[
41 header::HOST,
42 header::CONTENT_ENCODING,
43 header::CONTENT_LENGTH,
44];
45
46const API_PATH: &str = "/api/";
48
49#[derive(Debug, thiserror::Error)]
52#[error("error while forwarding request: {0}")]
53struct ForwardError(#[from] UpstreamRequestError);
54
55impl From<RecvError> for ForwardError {
56 fn from(_: RecvError) -> Self {
57 Self(UpstreamRequestError::ChannelClosed)
58 }
59}
60
61impl IntoResponse for ForwardError {
62 fn into_response(self) -> Response {
63 match &self.0 {
64 UpstreamRequestError::Http(e) => match e {
65 HttpError::Overflow => StatusCode::PAYLOAD_TOO_LARGE.into_response(),
66 HttpError::Reqwest(error) => {
67 relay_log::error!(error = error as &dyn Error);
68 error
69 .status()
70 .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR)
71 .into_response()
72 }
73 HttpError::Io(_) => StatusCode::BAD_GATEWAY.into_response(),
74 HttpError::Json(_) => StatusCode::BAD_REQUEST.into_response(),
75 },
76 UpstreamRequestError::SendFailed(e) => {
77 if e.is_timeout() {
78 StatusCode::GATEWAY_TIMEOUT.into_response()
79 } else {
80 StatusCode::BAD_GATEWAY.into_response()
81 }
82 }
83 error => {
84 relay_log::error!(error = error as &dyn Error, "unreachable code");
86 StatusCode::INTERNAL_SERVER_ERROR.into_response()
87 }
88 }
89 }
90}
91
92type ForwardResponse = (StatusCode, HeaderMap<HeaderValue>, Vec<u8>);
93
94struct ForwardRequest {
95 method: Method,
96 path: String,
97 headers: HeaderMap<HeaderValue>,
98 forwarded_for: ForwardedFor,
99 data: Bytes,
100 max_response_size: usize,
101 sender: oneshot::Sender<Result<ForwardResponse, UpstreamRequestError>>,
102}
103
104impl fmt::Debug for ForwardRequest {
105 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
106 f.debug_struct("ForwardRequest")
107 .field("method", &self.method)
108 .field("path", &self.path)
109 .finish()
110 }
111}
112
113impl UpstreamRequest for ForwardRequest {
114 fn method(&self) -> Method {
115 self.method.clone()
116 }
117
118 fn path(&self) -> Cow<'_, str> {
119 self.path.as_str().into()
120 }
121
122 fn retry(&self) -> bool {
123 false
124 }
125
126 fn intercept_status_errors(&self) -> bool {
127 false
128 }
129
130 fn set_relay_id(&self) -> bool {
131 false
132 }
133
134 fn route(&self) -> &'static str {
135 "forward"
136 }
137
138 fn build(&mut self, builder: &mut RequestBuilder) -> Result<(), HttpError> {
139 for (key, value) in &self.headers {
140 if !HOP_BY_HOP_HEADERS.contains(key) && !IGNORED_REQUEST_HEADERS.contains(key) {
144 builder.header(key, value);
145 }
146 }
147
148 builder
149 .header("X-Forwarded-For", self.forwarded_for.as_ref())
150 .body(self.data.clone());
151
152 Ok(())
153 }
154
155 fn respond(
156 self: Box<Self>,
157 result: Result<UpstreamResponse, UpstreamRequestError>,
158 ) -> Pin<Box<dyn Future<Output = ()> + Send + Sync>> {
159 Box::pin(async move {
160 let result = match result {
161 Ok(response) => {
162 let status = response.status();
163 let headers = response
164 .headers()
165 .iter()
166 .filter(|(name, _)| !HOP_BY_HOP_HEADERS.contains(name))
167 .map(|(name, value)| (name.clone(), value.clone()))
168 .collect();
169
170 match response.bytes(self.max_response_size).await {
171 Ok(body) => Ok((status, headers, body)),
172 Err(error) => Err(UpstreamRequestError::Http(error)),
173 }
174 }
175 Err(error) => Err(error),
176 };
177
178 self.sender.send(result).ok();
179 })
180 }
181}
182
183async fn handle(
185 state: ServiceState,
186 forwarded_for: ForwardedFor,
187 method: Method,
188 uri: Uri,
189 headers: HeaderMap<HeaderValue>,
190 data: Bytes,
191) -> Result<impl IntoResponse, ForwardError> {
192 if uri.path() == API_PATH || !uri.path().starts_with(API_PATH) {
195 return Ok(StatusCode::NOT_FOUND.into_response());
196 }
197
198 let (tx, rx) = oneshot::channel();
199
200 let request = ForwardRequest {
201 method,
202 path: uri.to_string(),
203 headers,
204 forwarded_for,
205 data,
206 max_response_size: state.config().max_api_payload_size(),
207 sender: tx,
208 };
209
210 state.upstream_relay().send(SendRequest(request));
211 let (status, headers, body) = rx.await??;
212
213 Ok(if headers.contains_key(header::CONTENT_TYPE) {
214 (status, headers, body).into_response()
215 } else {
216 (status, headers).into_response()
217 })
218}
219
220#[derive(Clone, Copy, Debug)]
222enum SpecialRoute {
223 FileUpload,
224 ChunkUpload,
225}
226
227static SPECIAL_ROUTES: Lazy<GlobMatcher<SpecialRoute>> = Lazy::new(|| {
229 let mut m = GlobMatcher::new();
230 m.add(
232 "/api/0/projects/*/*/releases/*/files/",
233 SpecialRoute::FileUpload,
234 );
235 m.add(
236 "/api/0/projects/*/*/releases/*/dsyms/",
237 SpecialRoute::FileUpload,
238 );
239 m.add(
241 "/api/0/organizations/*/chunk-upload/",
242 SpecialRoute::ChunkUpload,
243 );
244 m
245});
246
247fn get_limit_for_path(path: &str, config: &Config) -> usize {
249 match SPECIAL_ROUTES.test(path) {
250 Some(SpecialRoute::FileUpload) => config.max_api_file_upload_size(),
251 Some(SpecialRoute::ChunkUpload) => config.max_api_chunk_upload_size(),
252 None => config.max_api_payload_size(),
253 }
254}
255
256pub fn forward(state: ServiceState, req: Request) -> impl Future<Output = Response> {
269 let limit = get_limit_for_path(req.uri().path(), state.config());
270 handle.layer(DefaultBodyLimit::max(limit)).call(req, state)
271}