objectstore_server/
http.rs

1use std::any::Any;
2use std::net::SocketAddr;
3use std::time::Duration;
4
5use anyhow::{Context, Result};
6use axum::extract::{ConnectInfo, MatchedPath, Request};
7use axum::http::{HeaderValue, StatusCode, header};
8use axum::middleware::Next;
9use axum::response::{IntoResponse, Response};
10use axum::{RequestExt, ServiceExt};
11use sentry::integrations::tower::{NewSentryLayer, SentryHttpLayer};
12use tokio::net::{TcpListener, TcpSocket};
13use tokio::signal::unix::SignalKind;
14use tokio::time::Instant;
15use tower::ServiceBuilder;
16use tower_http::catch_panic::CatchPanicLayer;
17use tower_http::metrics::InFlightRequestsLayer;
18use tower_http::metrics::in_flight_requests::InFlightRequestsCounter;
19use tower_http::set_header::SetResponseHeaderLayer;
20use tower_http::trace::{DefaultOnFailure, TraceLayer};
21use tracing::Level;
22
23use crate::config::Config;
24use crate::endpoints;
25use crate::state::{ServiceState, State};
26
27/// The maximum backlog for TCP listen sockets before refusing connections.
28const TCP_LISTEN_BACKLOG: u32 = 1024;
29
30/// Interval for emitting the in-flight requests gauge metric.
31const IN_FLIGHT_INTERVAL: Duration = Duration::from_secs(1);
32
33/// The value for the `Server` HTTP header.
34const SERVER: &str = concat!("objectstore/", env!("CARGO_PKG_VERSION"));
35
36/// Runs the objectstore HTTP server.
37///
38/// This function initializes the server, binds to the configured address, and runs until
39/// termination is requested.
40pub async fn server(config: Config) -> Result<()> {
41    tracing::info!("Starting server");
42    merni::counter!("server.start": 1);
43
44    let listener = listen(&config).context("failed to start TCP listener")?;
45    let state = State::new(config).await?;
46
47    let server_handle = tokio::spawn(async move {
48        App::new(state)
49            .graceful_shutdown(true)
50            .serve(listener)
51            .await
52    });
53
54    tokio::spawn(async move {
55        elegant_departure::get_shutdown_guard().wait().await;
56        tracing::info!("Shutting down ...");
57    });
58
59    elegant_departure::tokio::depart()
60        .on_termination()
61        .on_sigint()
62        .on_signal(SignalKind::hangup())
63        .on_signal(SignalKind::quit())
64        .await;
65
66    let server_result = server_handle.await.map_err(From::from).flatten();
67    tracing::info!("Shutdown complete");
68    server_result
69}
70
71/// The objectstore web server application.
72#[derive(Debug)]
73pub struct App {
74    router: axum::Router,
75    in_flight_requests: InFlightRequestsCounter,
76    graceful_shutdown: bool,
77}
78
79impl App {
80    /// Creates a new application router for the given service state.
81    ///
82    /// The applications sets up middlewares and routes for the objectstore web API. Use
83    /// [`serve`](Self::serve) to run the server future.
84    pub fn new(state: ServiceState) -> Self {
85        let (in_flight_layer, in_flight_requests) = InFlightRequestsLayer::pair();
86
87        // Build the router middleware into a single service which runs _after_ routing. Service
88        // builder order defines layers added first will be called first. This means:
89        //  - Requests go from top to bottom
90        //  - Responses go from bottom to top
91        let middleware = ServiceBuilder::new()
92            .layer(axum::middleware::from_fn(emit_request_metrics))
93            .layer(in_flight_layer)
94            .layer(CatchPanicLayer::custom(handle_panic))
95            .layer(SetResponseHeaderLayer::overriding(
96                header::SERVER,
97                HeaderValue::from_static(SERVER),
98            ))
99            .layer(NewSentryLayer::new_from_top())
100            .layer(SentryHttpLayer::new().enable_transaction())
101            .layer(
102                TraceLayer::new_for_http()
103                    .make_span_with(make_http_span)
104                    .on_failure(DefaultOnFailure::new().level(Level::DEBUG)),
105            );
106
107        let router = endpoints::routes().layer(middleware).with_state(state);
108
109        App {
110            router,
111            in_flight_requests,
112            graceful_shutdown: false,
113        }
114    }
115
116    /// Enables or disables graceful shutdown for the server.
117    ///
118    /// By default, graceful shutdown is disabled.
119    pub fn graceful_shutdown(mut self, enable: bool) -> Self {
120        self.graceful_shutdown = enable;
121        self
122    }
123
124    /// Runs the web server until graceful shutdown is triggered.
125    ///
126    /// This function creates a future that runs the server. The future must be spawned or awaited for
127    /// the server to continue running.
128    pub async fn serve(self, listener: TcpListener) -> Result<()> {
129        let Self {
130            router,
131            in_flight_requests,
132            graceful_shutdown,
133        } = self;
134
135        let service =
136            ServiceExt::<Request>::into_make_service_with_connect_info::<SocketAddr>(router);
137
138        let guard = if graceful_shutdown {
139            Some(elegant_departure::get_shutdown_guard())
140        } else {
141            None
142        };
143
144        let server = async {
145            if let Some(ref guard) = guard {
146                axum::serve(listener, service)
147                    .with_graceful_shutdown(guard.wait_owned())
148                    .await
149            } else {
150                axum::serve(listener, service).await
151            }
152        };
153
154        let emitter = in_flight_requests.run_emitter(IN_FLIGHT_INTERVAL, |count| async move {
155            merni::gauge!("server.requests.in_flight": count);
156        });
157
158        let (serve_result, _) = tokio::join!(server, emitter);
159        serve_result?;
160
161        Ok(())
162    }
163}
164
165/// Create a tracing span for an HTTP request.
166///
167/// As opposed to `DefaultMakeSpan`, this also records the client IP address if available.
168fn make_http_span(request: &Request) -> tracing::Span {
169    let span = tracing::debug_span!(
170        "request",
171        method = %request.method(),
172        uri = %request.uri(),
173        version = ?request.version(),
174        client_addr = tracing::field::Empty,
175    );
176
177    if let Some(ConnectInfo(addr)) = request
178        .extensions()
179        .get::<axum::extract::ConnectInfo<SocketAddr>>()
180    {
181        span.record("client_addr", tracing::field::display(addr.ip()));
182    }
183
184    span
185}
186
187/// A panic handler that logs the panic and turns it into a 500 response.
188///
189/// Use with the [`CatchPanicLayer`] middleware.
190fn handle_panic(err: Box<dyn Any + Send + 'static>) -> Response {
191    let detail = if let Some(s) = err.downcast_ref::<String>() {
192        s.clone()
193    } else if let Some(s) = err.downcast_ref::<&str>() {
194        s.to_string()
195    } else {
196        "no error details".to_owned()
197    };
198
199    tracing::error!("panic in web handler: {detail}");
200
201    let response = (StatusCode::INTERNAL_SERVER_ERROR, detail);
202    response.into_response()
203}
204
205/// A middleware that logs web request timings as metrics.
206///
207/// Use this with [`from_fn`](axum::middleware::from_fn).
208async fn emit_request_metrics(mut request: Request, next: Next) -> Response {
209    let request_start = Instant::now();
210
211    let matched_path = request.extract_parts::<MatchedPath>().await;
212    let route = matched_path.as_ref().map_or("unknown", |m| m.as_str());
213    let method = request.method().clone();
214    merni::counter!("server.requests": 1, "route" => route, "method" => method.as_str());
215
216    let response = next.run(request).await;
217
218    merni::distribution!(
219        "server.requests.duration"@s: request_start.elapsed(),
220        "route" => route,
221        "method" => method.as_str(),
222        "status" => response.status().as_str()
223    );
224
225    response
226}
227
228fn listen(config: &Config) -> Result<TcpListener> {
229    let addr = config.http_addr;
230    let socket = match addr {
231        SocketAddr::V4(_) => TcpSocket::new_v4(),
232        SocketAddr::V6(_) => TcpSocket::new_v6(),
233    }?;
234
235    #[cfg(all(unix, not(target_os = "solaris"), not(target_os = "illumos")))]
236    socket.set_reuseport(true)?;
237    socket.bind(addr)?;
238
239    let listener = socket.listen(TCP_LISTEN_BACKLOG)?;
240    tracing::info!("HTTP server listening on {addr}");
241
242    Ok(listener)
243}