use headers::{ContentLength, HeaderMapExt, Host, UserAgent};
use http::{header::USER_AGENT, HeaderValue, Request, Response};
use hyper_util::client::legacy::connect::HttpInfo;
use mas_tower::{
    DurationRecorderLayer, DurationRecorderService, EnrichSpan, InFlightCounterLayer,
    InFlightCounterService, MakeSpan, MetricsAttributes, TraceContextLayer, TraceContextService,
    TraceLayer, TraceService,
};
use opentelemetry::KeyValue;
use opentelemetry_semantic_conventions::{
    attribute::{HTTP_REQUEST_BODY_SIZE, HTTP_RESPONSE_BODY_SIZE},
    trace::{
        CLIENT_ADDRESS, CLIENT_PORT, HTTP_REQUEST_METHOD, HTTP_RESPONSE_STATUS_CODE,
        NETWORK_PROTOCOL_NAME, NETWORK_TRANSPORT, NETWORK_TYPE, SERVER_ADDRESS, SERVER_PORT,
        URL_FULL, USER_AGENT_ORIGINAL,
    },
};
use tower::{
    limit::{ConcurrencyLimit, GlobalConcurrencyLimitLayer},
    Layer,
};
use tower_http::{
    follow_redirect::{FollowRedirect, FollowRedirectLayer},
    set_header::{SetRequestHeader, SetRequestHeaderLayer},
};
use tracing::Span;
pub type ClientService<S> = SetRequestHeader<
    DurationRecorderService<
        InFlightCounterService<
            ConcurrencyLimit<
                FollowRedirect<
                    TraceService<
                        TraceContextService<S>,
                        MakeSpanForRequest,
                        EnrichSpanOnResponse,
                        EnrichSpanOnError,
                    >,
                >,
            >,
            OnRequestLabels,
        >,
        OnRequestLabels,
        OnResponseLabels,
        KeyValue,
    >,
    HeaderValue,
>;
#[derive(Debug, Clone, Default)]
pub struct MakeSpanForRequest {
    category: Option<&'static str>,
}
impl<B> MakeSpan<Request<B>> for MakeSpanForRequest {
    fn make_span(&self, request: &Request<B>) -> Span {
        let headers = request.headers();
        let host = headers.typed_get::<Host>().map(tracing::field::display);
        let user_agent = headers
            .typed_get::<UserAgent>()
            .map(tracing::field::display);
        let content_length = headers.typed_get().map(|ContentLength(len)| len);
        let category = self.category.unwrap_or("UNSET");
        tracing::info_span!(
            "http.client.request",
            "otel.kind" = "client",
            "otel.status_code" = tracing::field::Empty,
            { HTTP_REQUEST_METHOD } = %request.method(),
            { URL_FULL } = %request.uri(),
            { HTTP_RESPONSE_STATUS_CODE } = tracing::field::Empty,
            { SERVER_ADDRESS } = host,
            { HTTP_REQUEST_BODY_SIZE } = content_length,
            { HTTP_RESPONSE_BODY_SIZE } = tracing::field::Empty,
            { NETWORK_TRANSPORT } = "tcp",
            { NETWORK_TYPE } = tracing::field::Empty,
            { SERVER_ADDRESS } = tracing::field::Empty,
            { SERVER_PORT } = tracing::field::Empty,
            { CLIENT_ADDRESS } = tracing::field::Empty,
            { CLIENT_PORT } = tracing::field::Empty,
            { USER_AGENT_ORIGINAL } = user_agent,
            "rust.error" = tracing::field::Empty,
            "mas.category" = category,
        )
    }
}
#[derive(Debug, Clone)]
pub struct EnrichSpanOnResponse;
impl<B> EnrichSpan<Response<B>> for EnrichSpanOnResponse {
    fn enrich_span(&self, span: &Span, response: &Response<B>) {
        span.record("otel.status_code", "OK");
        span.record(HTTP_RESPONSE_STATUS_CODE, response.status().as_u16());
        if let Some(ContentLength(content_length)) = response.headers().typed_get() {
            span.record(HTTP_RESPONSE_BODY_SIZE, content_length);
        }
        if let Some(http_info) = response.extensions().get::<HttpInfo>() {
            let local = http_info.local_addr();
            let remote = http_info.remote_addr();
            let family = if local.is_ipv4() { "ipv4" } else { "ipv6" };
            span.record(NETWORK_TYPE, family);
            span.record(CLIENT_ADDRESS, remote.ip().to_string());
            span.record(CLIENT_PORT, remote.port());
            span.record(SERVER_ADDRESS, local.ip().to_string());
            span.record(SERVER_PORT, local.port());
        } else {
            tracing::warn!("No HttpInfo injected in response extensions");
        }
    }
}
#[derive(Debug, Clone)]
pub struct EnrichSpanOnError;
impl<E> EnrichSpan<E> for EnrichSpanOnError
where
    E: std::error::Error + 'static,
{
    fn enrich_span(&self, span: &Span, error: &E) {
        span.record("otel.status_code", "ERROR");
        span.record("rust.error", error as &dyn std::error::Error);
    }
}
#[derive(Debug, Clone, Default)]
pub struct OnRequestLabels {
    category: Option<&'static str>,
}
impl<B> MetricsAttributes<Request<B>> for OnRequestLabels
where
    B: 'static,
{
    type Iter<'a> = std::array::IntoIter<KeyValue, 3>;
    fn attributes<'a>(&'a self, t: &'a Request<B>) -> Self::Iter<'a> {
        [
            KeyValue::new(HTTP_REQUEST_METHOD, t.method().as_str().to_owned()),
            KeyValue::new(NETWORK_PROTOCOL_NAME, "http"),
            KeyValue::new("mas.category", self.category.unwrap_or("UNSET")),
        ]
        .into_iter()
    }
}
#[derive(Debug, Clone, Default)]
pub struct OnResponseLabels;
impl<B> MetricsAttributes<Response<B>> for OnResponseLabels
where
    B: 'static,
{
    type Iter<'a> = std::iter::Once<KeyValue>;
    fn attributes<'a>(&'a self, t: &'a Response<B>) -> Self::Iter<'a> {
        std::iter::once(KeyValue::new(
            HTTP_RESPONSE_STATUS_CODE,
            i64::from(t.status().as_u16()),
        ))
    }
}
#[derive(Debug, Clone)]
pub struct ClientLayer {
    user_agent_layer: SetRequestHeaderLayer<HeaderValue>,
    concurrency_limit_layer: GlobalConcurrencyLimitLayer,
    follow_redirect_layer: FollowRedirectLayer,
    trace_layer: TraceLayer<MakeSpanForRequest, EnrichSpanOnResponse, EnrichSpanOnError>,
    trace_context_layer: TraceContextLayer,
    duration_recorder_layer: DurationRecorderLayer<OnRequestLabels, OnResponseLabels, KeyValue>,
    in_flight_counter_layer: InFlightCounterLayer<OnRequestLabels>,
}
impl Default for ClientLayer {
    fn default() -> Self {
        Self::new()
    }
}
impl ClientLayer {
    #[must_use]
    pub fn new() -> Self {
        Self {
            user_agent_layer: SetRequestHeaderLayer::overriding(
                USER_AGENT,
                HeaderValue::from_static("matrix-authentication-service/0.0.1"),
            ),
            concurrency_limit_layer: GlobalConcurrencyLimitLayer::new(10),
            follow_redirect_layer: FollowRedirectLayer::new(),
            trace_layer: TraceLayer::new(MakeSpanForRequest::default())
                .on_response(EnrichSpanOnResponse)
                .on_error(EnrichSpanOnError),
            trace_context_layer: TraceContextLayer::new(),
            duration_recorder_layer: DurationRecorderLayer::new("http.client.duration")
                .on_request(OnRequestLabels::default())
                .on_response(OnResponseLabels)
                .on_error(KeyValue::new("http.error", true)),
            in_flight_counter_layer: InFlightCounterLayer::new("http.client.active_requests")
                .on_request(OnRequestLabels::default()),
        }
    }
    #[must_use]
    pub fn with_category(mut self, category: &'static str) -> Self {
        self.trace_layer = TraceLayer::new(MakeSpanForRequest {
            category: Some(category),
        })
        .on_response(EnrichSpanOnResponse)
        .on_error(EnrichSpanOnError);
        self.duration_recorder_layer = self.duration_recorder_layer.on_request(OnRequestLabels {
            category: Some(category),
        });
        self.in_flight_counter_layer = self.in_flight_counter_layer.on_request(OnRequestLabels {
            category: Some(category),
        });
        self
    }
}
impl<S> Layer<S> for ClientLayer
where
    S: Clone,
{
    type Service = ClientService<S>;
    fn layer(&self, inner: S) -> Self::Service {
        (
            &self.user_agent_layer,
            &self.duration_recorder_layer,
            &self.in_flight_counter_layer,
            &self.concurrency_limit_layer,
            &self.follow_redirect_layer,
            &self.trace_layer,
            &self.trace_context_layer,
        )
            .layer(inner)
    }
}