use std::{collections::HashMap, fmt, sync::Arc};
use base64ct::{Base64UrlUnpadded, Encoding};
use chrono::{DateTime, Duration, Utc};
use headers::{Authorization, HeaderMapExt};
use http::Request;
use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod};
#[cfg(feature = "keystore")]
use mas_jose::constraints::Constrainable;
use mas_jose::{
    claims::{self, ClaimError},
    jwa::SymmetricKey,
    jwt::{JsonWebSignatureHeader, Jwt},
};
#[cfg(feature = "keystore")]
use mas_keystore::Keystore;
use rand::Rng;
use serde::Serialize;
use serde_json::Value;
use serde_with::skip_serializing_none;
use tower::BoxError;
use url::Url;
use crate::error::CredentialsError;
pub const CLIENT_SUPPORTED_AUTH_METHODS: &[OAuthClientAuthenticationMethod] = &[
    OAuthClientAuthenticationMethod::None,
    OAuthClientAuthenticationMethod::ClientSecretBasic,
    OAuthClientAuthenticationMethod::ClientSecretPost,
    OAuthClientAuthenticationMethod::ClientSecretJwt,
    OAuthClientAuthenticationMethod::PrivateKeyJwt,
];
pub type JwtSigningFn =
    dyn Fn(HashMap<String, Value>, JsonWebSignatureAlg) -> Result<String, BoxError> + Send + Sync;
#[derive(Clone)]
pub enum JwtSigningMethod {
    #[cfg(feature = "keystore")]
    Keystore(Keystore),
    Custom(Arc<JwtSigningFn>),
}
impl JwtSigningMethod {
    #[cfg(feature = "keystore")]
    #[must_use]
    pub fn with_keystore(keystore: Keystore) -> Self {
        Self::Keystore(keystore)
    }
    #[must_use]
    pub fn with_custom_signing_method<F>(signing_fn: F) -> Self
    where
        F: Fn(HashMap<String, Value>, JsonWebSignatureAlg) -> Result<String, BoxError>
            + Send
            + Sync
            + 'static,
    {
        Self::Custom(Arc::new(signing_fn))
    }
    #[cfg(feature = "keystore")]
    #[must_use]
    pub fn keystore(&self) -> Option<&Keystore> {
        match self {
            JwtSigningMethod::Keystore(k) => Some(k),
            JwtSigningMethod::Custom(_) => None,
        }
    }
    #[must_use]
    pub fn jwt_custom(&self) -> Option<&JwtSigningFn> {
        match self {
            JwtSigningMethod::Custom(s) => Some(s.as_ref()),
            #[cfg(feature = "keystore")]
            JwtSigningMethod::Keystore(_) => None,
        }
    }
}
#[derive(Clone)]
pub enum ClientCredentials {
    None {
        client_id: String,
    },
    ClientSecretBasic {
        client_id: String,
        client_secret: String,
    },
    ClientSecretPost {
        client_id: String,
        client_secret: String,
    },
    ClientSecretJwt {
        client_id: String,
        client_secret: String,
        signing_algorithm: JsonWebSignatureAlg,
        token_endpoint: Url,
    },
    PrivateKeyJwt {
        client_id: String,
        jwt_signing_method: JwtSigningMethod,
        signing_algorithm: JsonWebSignatureAlg,
        token_endpoint: Url,
    },
}
impl ClientCredentials {
    #[must_use]
    pub fn client_id(&self) -> &str {
        match self {
            ClientCredentials::None { client_id }
            | ClientCredentials::ClientSecretBasic { client_id, .. }
            | ClientCredentials::ClientSecretPost { client_id, .. }
            | ClientCredentials::ClientSecretJwt { client_id, .. }
            | ClientCredentials::PrivateKeyJwt { client_id, .. } => client_id,
        }
    }
    pub(crate) fn apply_to_request<T: Serialize>(
        self,
        request: Request<T>,
        now: DateTime<Utc>,
        rng: &mut impl Rng,
    ) -> Result<Request<RequestWithClientCredentials<T>>, CredentialsError> {
        let credentials = RequestClientCredentials::try_from_credentials(self, now, rng)?;
        let (parts, body) = request.into_parts();
        let mut body = RequestWithClientCredentials {
            body,
            credentials: None,
        };
        let request = match credentials {
            RequestClientCredentials::Body(credentials) => {
                body.credentials = Some(credentials);
                Request::from_parts(parts, body)
            }
            RequestClientCredentials::Header(credentials) => {
                let HeaderClientCredentials {
                    client_id,
                    client_secret,
                } = credentials;
                let mut request = Request::from_parts(parts, body);
                let client_id =
                    form_urlencoded::byte_serialize(client_id.as_bytes()).collect::<String>();
                let client_secret =
                    form_urlencoded::byte_serialize(client_secret.as_bytes()).collect::<String>();
                let auth = Authorization::basic(&client_id, &client_secret);
                request.headers_mut().typed_insert(auth);
                request
            }
        };
        Ok(request)
    }
}
impl fmt::Debug for ClientCredentials {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        match self {
            Self::None { client_id } => f
                .debug_struct("None")
                .field("client_id", client_id)
                .finish(),
            Self::ClientSecretBasic { client_id, .. } => f
                .debug_struct("ClientSecretBasic")
                .field("client_id", client_id)
                .finish_non_exhaustive(),
            Self::ClientSecretPost { client_id, .. } => f
                .debug_struct("ClientSecretPost")
                .field("client_id", client_id)
                .finish_non_exhaustive(),
            Self::ClientSecretJwt {
                client_id,
                signing_algorithm,
                token_endpoint,
                ..
            } => f
                .debug_struct("ClientSecretJwt")
                .field("client_id", client_id)
                .field("signing_algorithm", signing_algorithm)
                .field("token_endpoint", token_endpoint)
                .finish_non_exhaustive(),
            Self::PrivateKeyJwt {
                client_id,
                signing_algorithm,
                token_endpoint,
                ..
            } => f
                .debug_struct("PrivateKeyJwt")
                .field("client_id", client_id)
                .field("signing_algorithm", signing_algorithm)
                .field("token_endpoint", token_endpoint)
                .finish_non_exhaustive(),
        }
    }
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)]
#[serde(rename = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer")]
pub(crate) struct JwtBearerClientAssertionType;
enum RequestClientCredentials {
    Body(BodyClientCredentials),
    Header(HeaderClientCredentials),
}
impl RequestClientCredentials {
    fn try_from_credentials(
        credentials: ClientCredentials,
        now: DateTime<Utc>,
        rng: &mut impl Rng,
    ) -> Result<Self, CredentialsError> {
        let res = match credentials {
            ClientCredentials::None { client_id } => Self::Body(BodyClientCredentials {
                client_id,
                client_secret: None,
                client_assertion: None,
                client_assertion_type: None,
            }),
            ClientCredentials::ClientSecretBasic {
                client_id,
                client_secret,
            } => Self::Header(HeaderClientCredentials {
                client_id,
                client_secret,
            }),
            ClientCredentials::ClientSecretPost {
                client_id,
                client_secret,
            } => Self::Body(BodyClientCredentials {
                client_id,
                client_secret: Some(client_secret),
                client_assertion: None,
                client_assertion_type: None,
            }),
            ClientCredentials::ClientSecretJwt {
                client_id,
                client_secret,
                signing_algorithm,
                token_endpoint,
            } => {
                let claims =
                    prepare_claims(client_id.clone(), token_endpoint.to_string(), now, rng)?;
                let key = SymmetricKey::new_for_alg(client_secret.into(), &signing_algorithm)?;
                let header = JsonWebSignatureHeader::new(signing_algorithm);
                let jwt = Jwt::sign(header, claims, &key)?;
                Self::Body(BodyClientCredentials {
                    client_id,
                    client_secret: None,
                    client_assertion: Some(jwt.to_string()),
                    client_assertion_type: Some(JwtBearerClientAssertionType),
                })
            }
            ClientCredentials::PrivateKeyJwt {
                client_id,
                jwt_signing_method,
                signing_algorithm,
                token_endpoint,
            } => {
                let claims =
                    prepare_claims(client_id.clone(), token_endpoint.to_string(), now, rng)?;
                let client_assertion = match jwt_signing_method {
                    #[cfg(feature = "keystore")]
                    JwtSigningMethod::Keystore(keystore) => {
                        let key = keystore
                            .signing_key_for_algorithm(&signing_algorithm)
                            .ok_or(CredentialsError::NoPrivateKeyFound)?;
                        let signer = key
                            .params()
                            .signing_key_for_alg(&signing_algorithm)
                            .map_err(|_| CredentialsError::JwtWrongAlgorithm)?;
                        let mut header = JsonWebSignatureHeader::new(signing_algorithm);
                        if let Some(kid) = key.kid() {
                            header = header.with_kid(kid);
                        }
                        Jwt::sign(header, claims, &signer)?.to_string()
                    }
                    JwtSigningMethod::Custom(jwt_signing_fn) => {
                        jwt_signing_fn(claims, signing_algorithm)
                            .map_err(CredentialsError::Custom)?
                    }
                };
                Self::Body(BodyClientCredentials {
                    client_id,
                    client_secret: None,
                    client_assertion: Some(client_assertion),
                    client_assertion_type: Some(JwtBearerClientAssertionType),
                })
            }
        };
        Ok(res)
    }
}
#[allow(clippy::struct_field_names)] #[skip_serializing_none]
#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
pub(crate) struct BodyClientCredentials {
    client_id: String,
    client_secret: Option<String>,
    client_assertion: Option<String>,
    client_assertion_type: Option<JwtBearerClientAssertionType>,
}
#[derive(Debug, Clone)]
struct HeaderClientCredentials {
    client_id: String,
    client_secret: String,
}
fn prepare_claims(
    iss: String,
    aud: String,
    now: DateTime<Utc>,
    rng: &mut impl Rng,
) -> Result<HashMap<String, Value>, ClaimError> {
    let mut claims = HashMap::new();
    claims::ISS.insert(&mut claims, iss.clone())?;
    claims::SUB.insert(&mut claims, iss)?;
    claims::AUD.insert(&mut claims, aud)?;
    claims::IAT.insert(&mut claims, now)?;
    claims::EXP.insert(
        &mut claims,
        now + Duration::microseconds(5 * 60 * 1000 * 1000),
    )?;
    let mut jti = [0u8; 16];
    rng.fill(&mut jti);
    let jti = Base64UrlUnpadded::encode_string(&jti);
    claims::JTI.insert(&mut claims, jti)?;
    Ok(claims)
}
#[derive(Clone, Serialize)]
#[skip_serializing_none]
pub struct RequestWithClientCredentials<T> {
    #[serde(flatten)]
    pub(crate) body: T,
    #[serde(flatten)]
    pub(crate) credentials: Option<BodyClientCredentials>,
}
#[cfg(test)]
mod test {
    use assert_matches::assert_matches;
    use headers::authorization::Basic;
    #[cfg(feature = "keystore")]
    use mas_keystore::{JsonWebKey, JsonWebKeySet, Keystore, PrivateKey};
    use rand::SeedableRng;
    use rand_chacha::ChaCha8Rng;
    use super::*;
    const CLIENT_ID: &str = "abcd$++";
    const CLIENT_SECRET: &str = "xyz!;?";
    const REQUEST_BODY: &str = "some_body";
    #[derive(Serialize)]
    struct Body {
        body: &'static str,
    }
    fn now() -> DateTime<Utc> {
        #[allow(clippy::disallowed_methods)]
        Utc::now()
    }
    #[test]
    fn serialize_credentials() {
        assert_eq!(
            serde_urlencoded::to_string(BodyClientCredentials {
                client_id: CLIENT_ID.to_owned(),
                client_secret: None,
                client_assertion: None,
                client_assertion_type: None,
            })
            .unwrap(),
            "client_id=abcd%24%2B%2B"
        );
        assert_eq!(
            serde_urlencoded::to_string(BodyClientCredentials {
                client_id: CLIENT_ID.to_owned(),
                client_secret: Some(CLIENT_SECRET.to_owned()),
                client_assertion: None,
                client_assertion_type: None,
            })
            .unwrap(),
            "client_id=abcd%24%2B%2B&client_secret=xyz%21%3B%3F"
        );
        assert_eq!(
            serde_urlencoded::to_string(BodyClientCredentials {
                client_id: CLIENT_ID.to_owned(),
                client_secret: None,
                client_assertion: Some(CLIENT_SECRET.to_owned()),
                client_assertion_type: Some(JwtBearerClientAssertionType)
            })
            .unwrap(),
            "client_id=abcd%24%2B%2B&client_assertion=xyz%21%3B%3F&client_assertion_type=urn%3Aietf%3Aparams%3Aoauth%3Aclient-assertion-type%3Ajwt-bearer"
        );
    }
    #[test]
    fn serialize_request_with_credentials() {
        let req = RequestWithClientCredentials {
            body: Body { body: REQUEST_BODY },
            credentials: None,
        };
        assert_eq!(serde_urlencoded::to_string(req).unwrap(), "body=some_body");
        let req = RequestWithClientCredentials {
            body: Body { body: REQUEST_BODY },
            credentials: Some(BodyClientCredentials {
                client_id: CLIENT_ID.to_owned(),
                client_secret: None,
                client_assertion: None,
                client_assertion_type: None,
            }),
        };
        assert_eq!(
            serde_urlencoded::to_string(req).unwrap(),
            "body=some_body&client_id=abcd%24%2B%2B"
        );
        let req = RequestWithClientCredentials {
            body: Body { body: REQUEST_BODY },
            credentials: Some(BodyClientCredentials {
                client_id: CLIENT_ID.to_owned(),
                client_secret: Some(CLIENT_SECRET.to_owned()),
                client_assertion: None,
                client_assertion_type: None,
            }),
        };
        assert_eq!(
            serde_urlencoded::to_string(req).unwrap(),
            "body=some_body&client_id=abcd%24%2B%2B&client_secret=xyz%21%3B%3F"
        );
        let req = RequestWithClientCredentials {
            body: Body { body: REQUEST_BODY },
            credentials: Some(BodyClientCredentials {
                client_id: CLIENT_ID.to_owned(),
                client_secret: None,
                client_assertion: Some(CLIENT_SECRET.to_owned()),
                client_assertion_type: Some(JwtBearerClientAssertionType),
            }),
        };
        assert_eq!(
            serde_urlencoded::to_string(req).unwrap(),
            "body=some_body&client_id=abcd%24%2B%2B&client_assertion=xyz%21%3B%3F&client_assertion_type=urn%3Aietf%3Aparams%3Aoauth%3Aclient-assertion-type%3Ajwt-bearer"
        );
    }
    #[tokio::test]
    async fn build_request_none() {
        let credentials = ClientCredentials::None {
            client_id: CLIENT_ID.to_owned(),
        };
        let request = Request::new(Body { body: REQUEST_BODY });
        let now = now();
        let mut rng = ChaCha8Rng::seed_from_u64(42);
        let request = credentials
            .apply_to_request(request, now, &mut rng)
            .unwrap();
        assert_eq!(request.headers().typed_get::<Authorization<Basic>>(), None);
        let body = request.into_body();
        assert_eq!(body.body.body, REQUEST_BODY);
        let credentials = body.credentials.unwrap();
        assert_eq!(credentials.client_id, CLIENT_ID);
        assert_eq!(credentials.client_secret, None);
        assert_eq!(credentials.client_assertion, None);
        assert_eq!(credentials.client_assertion_type, None);
    }
    #[tokio::test]
    async fn build_request_client_secret_basic() {
        let credentials = ClientCredentials::ClientSecretBasic {
            client_id: CLIENT_ID.to_owned(),
            client_secret: CLIENT_SECRET.to_owned(),
        };
        let now = now();
        let mut rng = ChaCha8Rng::seed_from_u64(42);
        let request = Request::new(Body { body: REQUEST_BODY });
        let request = credentials
            .apply_to_request(request, now, &mut rng)
            .unwrap();
        let auth = assert_matches!(
            request.headers().typed_get::<Authorization<Basic>>(),
            Some(auth) => auth
        );
        assert_eq!(
            form_urlencoded::parse(auth.username().as_bytes())
                .next()
                .unwrap()
                .0,
            CLIENT_ID
        );
        assert_eq!(
            form_urlencoded::parse(auth.password().as_bytes())
                .next()
                .unwrap()
                .0,
            CLIENT_SECRET
        );
        let body = request.into_body();
        assert_eq!(body.body.body, REQUEST_BODY);
        assert_eq!(body.credentials, None);
    }
    #[tokio::test]
    async fn build_request_client_secret_post() {
        let credentials = ClientCredentials::ClientSecretPost {
            client_id: CLIENT_ID.to_owned(),
            client_secret: CLIENT_SECRET.to_owned(),
        };
        let now = now();
        let mut rng = ChaCha8Rng::seed_from_u64(42);
        let request = Request::new(Body { body: REQUEST_BODY });
        let request = credentials
            .apply_to_request(request, now, &mut rng)
            .unwrap();
        assert_eq!(request.headers().typed_get::<Authorization<Basic>>(), None);
        let body = request.into_body();
        assert_eq!(body.body.body, REQUEST_BODY);
        let credentials = body.credentials.unwrap();
        assert_eq!(credentials.client_id, CLIENT_ID);
        assert_eq!(credentials.client_secret.unwrap(), CLIENT_SECRET);
        assert_eq!(credentials.client_assertion, None);
        assert_eq!(credentials.client_assertion_type, None);
    }
    #[tokio::test]
    async fn build_request_client_secret_jwt() {
        let credentials = ClientCredentials::ClientSecretJwt {
            client_id: CLIENT_ID.to_owned(),
            client_secret: CLIENT_SECRET.to_owned(),
            signing_algorithm: JsonWebSignatureAlg::Hs256,
            token_endpoint: Url::parse("http://localhost").unwrap(),
        };
        let now = now();
        let mut rng = ChaCha8Rng::seed_from_u64(42);
        let request = Request::new(Body { body: REQUEST_BODY });
        let request = credentials
            .apply_to_request(request, now, &mut rng)
            .unwrap();
        assert_eq!(request.headers().typed_get::<Authorization<Basic>>(), None);
        let body = request.into_body();
        assert_eq!(body.body.body, REQUEST_BODY);
        let credentials = body.credentials.unwrap();
        assert_eq!(credentials.client_id, CLIENT_ID);
        assert_eq!(credentials.client_secret, None);
        credentials.client_assertion.unwrap();
        credentials.client_assertion_type.unwrap();
    }
    #[tokio::test]
    #[cfg(feature = "keystore")]
    async fn build_request_private_key_jwt() {
        let rng = rand_chacha::ChaCha8Rng::seed_from_u64(42);
        let key = PrivateKey::generate_rsa(rng).unwrap();
        let keystore = Keystore::new(JsonWebKeySet::<PrivateKey>::new(vec![JsonWebKey::new(key)]));
        let jwt_signing_method = JwtSigningMethod::with_keystore(keystore);
        let now = now();
        let mut rng = ChaCha8Rng::seed_from_u64(42);
        let credentials = ClientCredentials::PrivateKeyJwt {
            client_id: CLIENT_ID.to_owned(),
            jwt_signing_method,
            signing_algorithm: JsonWebSignatureAlg::Rs256,
            token_endpoint: Url::parse("http://localhost").unwrap(),
        };
        let request = Request::new(Body { body: REQUEST_BODY });
        let request = credentials
            .apply_to_request(request, now, &mut rng)
            .unwrap();
        assert_eq!(request.headers().typed_get::<Authorization<Basic>>(), None);
        let body = request.into_body();
        assert_eq!(body.body.body, REQUEST_BODY);
        let credentials = body.credentials.unwrap();
        assert_eq!(credentials.client_id, CLIENT_ID);
        assert_eq!(credentials.client_secret, None);
        credentials.client_assertion.unwrap();
        credentials.client_assertion_type.unwrap();
    }
}