use std::borrow::Cow;
use std::error::Error;
use std::fmt;
use std::future::Future;
use std::pin::Pin;
use axum::extract::{DefaultBodyLimit, Request};
use axum::handler::Handler;
use axum::http::{header, HeaderMap, HeaderName, HeaderValue, StatusCode, Uri};
use axum::response::{IntoResponse, Response};
use bytes::Bytes;
use once_cell::sync::Lazy;
use relay_common::glob2::GlobMatcher;
use relay_config::Config;
use tokio::sync::oneshot;
use tokio::sync::oneshot::error::RecvError;
use crate::extractors::ForwardedFor;
use crate::http::{HttpError, RequestBuilder, Response as UpstreamResponse};
use crate::service::ServiceState;
use crate::services::upstream::{Method, SendRequest, UpstreamRequest, UpstreamRequestError};
static HOP_BY_HOP_HEADERS: &[HeaderName] = &[
header::CONNECTION,
header::PROXY_AUTHENTICATE,
header::PROXY_AUTHORIZATION,
header::TE,
header::TRAILER,
header::TRANSFER_ENCODING,
header::UPGRADE,
];
static IGNORED_REQUEST_HEADERS: &[HeaderName] = &[
header::HOST,
header::CONTENT_ENCODING,
header::CONTENT_LENGTH,
];
const API_PATH: &str = "/api/";
#[derive(Debug, thiserror::Error)]
#[error("error while forwarding request: {0}")]
struct ForwardError(#[from] UpstreamRequestError);
impl From<RecvError> for ForwardError {
fn from(_: RecvError) -> Self {
Self(UpstreamRequestError::ChannelClosed)
}
}
impl IntoResponse for ForwardError {
fn into_response(self) -> Response {
match &self.0 {
UpstreamRequestError::Http(e) => match e {
HttpError::Overflow => StatusCode::PAYLOAD_TOO_LARGE.into_response(),
HttpError::Reqwest(error) => {
relay_log::error!(error = error as &dyn Error);
error
.status()
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR)
.into_response()
}
HttpError::Io(_) => StatusCode::BAD_GATEWAY.into_response(),
HttpError::Json(_) => StatusCode::BAD_REQUEST.into_response(),
},
UpstreamRequestError::SendFailed(e) => {
if e.is_timeout() {
StatusCode::GATEWAY_TIMEOUT.into_response()
} else {
StatusCode::BAD_GATEWAY.into_response()
}
}
error => {
relay_log::error!(error = error as &dyn Error, "unreachable code");
StatusCode::INTERNAL_SERVER_ERROR.into_response()
}
}
}
}
type ForwardResponse = (StatusCode, HeaderMap<HeaderValue>, Vec<u8>);
struct ForwardRequest {
method: Method,
path: String,
headers: HeaderMap<HeaderValue>,
forwarded_for: ForwardedFor,
data: Bytes,
max_response_size: usize,
sender: oneshot::Sender<Result<ForwardResponse, UpstreamRequestError>>,
}
impl fmt::Debug for ForwardRequest {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ForwardRequest")
.field("method", &self.method)
.field("path", &self.path)
.finish()
}
}
impl UpstreamRequest for ForwardRequest {
fn method(&self) -> Method {
self.method.clone()
}
fn path(&self) -> Cow<'_, str> {
self.path.as_str().into()
}
fn retry(&self) -> bool {
false
}
fn intercept_status_errors(&self) -> bool {
false
}
fn set_relay_id(&self) -> bool {
false
}
fn route(&self) -> &'static str {
"forward"
}
fn build(&mut self, builder: &mut RequestBuilder) -> Result<(), HttpError> {
for (key, value) in &self.headers {
if !HOP_BY_HOP_HEADERS.contains(key) && !IGNORED_REQUEST_HEADERS.contains(key) {
builder.header(key, value);
}
}
builder
.header("X-Forwarded-For", self.forwarded_for.as_ref())
.body(self.data.clone());
Ok(())
}
fn respond(
self: Box<Self>,
result: Result<UpstreamResponse, UpstreamRequestError>,
) -> Pin<Box<dyn Future<Output = ()> + Send + Sync>> {
Box::pin(async move {
let result = match result {
Ok(response) => {
let status = response.status();
let headers = response
.headers()
.iter()
.filter(|(name, _)| !HOP_BY_HOP_HEADERS.contains(name))
.map(|(name, value)| (name.clone(), value.clone()))
.collect();
match response.bytes(self.max_response_size).await {
Ok(body) => Ok((status, headers, body)),
Err(error) => Err(UpstreamRequestError::Http(error)),
}
}
Err(error) => Err(error),
};
self.sender.send(result).ok();
})
}
}
async fn handle(
state: ServiceState,
forwarded_for: ForwardedFor,
method: Method,
uri: Uri,
headers: HeaderMap<HeaderValue>,
data: Bytes,
) -> Result<impl IntoResponse, ForwardError> {
if uri.path() == API_PATH || !uri.path().starts_with(API_PATH) {
return Ok(StatusCode::NOT_FOUND.into_response());
}
let (tx, rx) = oneshot::channel();
let request = ForwardRequest {
method,
path: uri.to_string(),
headers,
forwarded_for,
data,
max_response_size: state.config().max_api_payload_size(),
sender: tx,
};
state.upstream_relay().send(SendRequest(request));
let (status, headers, body) = rx.await??;
Ok(if headers.contains_key(header::CONTENT_TYPE) {
(status, headers, body).into_response()
} else {
(status, headers).into_response()
})
}
#[derive(Clone, Copy, Debug)]
enum SpecialRoute {
FileUpload,
ChunkUpload,
}
static SPECIAL_ROUTES: Lazy<GlobMatcher<SpecialRoute>> = Lazy::new(|| {
let mut m = GlobMatcher::new();
m.add(
"/api/0/projects/*/*/releases/*/files/",
SpecialRoute::FileUpload,
);
m.add(
"/api/0/projects/*/*/releases/*/dsyms/",
SpecialRoute::FileUpload,
);
m.add(
"/api/0/organizations/*/chunk-upload/",
SpecialRoute::ChunkUpload,
);
m
});
fn get_limit_for_path(path: &str, config: &Config) -> usize {
match SPECIAL_ROUTES.test(path) {
Some(SpecialRoute::FileUpload) => config.max_api_file_upload_size(),
Some(SpecialRoute::ChunkUpload) => config.max_api_chunk_upload_size(),
None => config.max_api_payload_size(),
}
}
pub fn forward(state: ServiceState, req: Request) -> impl Future<Output = Response> {
let limit = get_limit_for_path(req.uri().path(), state.config());
handle.layer(DefaultBodyLimit::max(limit)).call(req, state)
}