use actix_http::header::HeaderName; use std::future::{ready, Ready}; use tracing::{span, Instrument, Level}; use actix_service::{forward_ready, Service, Transform}; use actix_web::dev::{ServiceRequest, ServiceResponse}; use futures_util::future::LocalBoxFuture; use reqwest::header::HeaderValue; const X_REQUEST_ID: &str = "x-request-id"; pub struct RequestIdMiddleware; impl Transform for RequestIdMiddleware where S: Service, Error = actix_web::Error>, S::Future: 'static, B: 'static, { type Response = ServiceResponse; type Error = actix_web::Error; type Transform = RequestIdMiddlewareService; type InitError = (); type Future = Ready>; fn new_transform(&self, service: S) -> Self::Future { ready(Ok(RequestIdMiddlewareService { service })) } } pub struct RequestIdMiddlewareService { service: S, } impl Service for RequestIdMiddlewareService where S: Service, Error = actix_web::Error>, S::Future: 'static, B: 'static, { type Response = ServiceResponse; type Error = actix_web::Error; type Future = LocalBoxFuture<'static, Result>; forward_ready!(service); fn call(&self, mut req: ServiceRequest) -> Self::Future { // Skip generate request id for metrics requests if req.path() == "/metrics" { let fut = self.service.call(req); Box::pin(fut) } else { let request_id = get_request_id(&req).unwrap_or_else(|| { let request_id = uuid::Uuid::new_v4().to_string(); if let Ok(header_value) = HeaderValue::from_str(&request_id) { req .headers_mut() .insert(HeaderName::from_static(X_REQUEST_ID), header_value); } request_id }); let client_info = get_client_info(&req); let span = span!(Level::INFO, "request", request_id = %request_id, client_version = client_info.client_version, payload_size = client_info.payload_size ); let fut = self.service.call(req); Box::pin(async move { let mut res = fut.instrument(span).await?; // Insert the request id to the response header if let Ok(header_value) = HeaderValue::from_str(&request_id) { res .headers_mut() .insert(HeaderName::from_static(X_REQUEST_ID), header_value); } Ok(res) }) } } } pub fn get_request_id(req: &ServiceRequest) -> Option { match req.headers().get(HeaderName::from_static(X_REQUEST_ID)) { Some(h) => match h.to_str() { Ok(s) => Some(s.to_owned()), Err(e) => { tracing::error!("Failed to get request id from header: {}", e); None }, }, None => None, } } #[inline] fn get_client_info(req: &ServiceRequest) -> ClientInfo { let payload_size = req .headers() .get("content-length") .and_then(|val| val.to_str().ok()) .and_then(|val| val.parse::().ok()) .unwrap_or_default(); let client_version = req .headers() .get("client-version") .and_then(|val| val.to_str().ok()); ClientInfo { payload_size, client_version, } } struct ClientInfo<'a> { payload_size: usize, client_version: Option<&'a str>, }