From bedcf4474142c82d2b5e13c0d230e0e8c7daa9f7 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Tue, 22 Nov 2022 18:28:16 +0100 Subject: [PATCH] WIP: upstream OIDC provider support --- Cargo.lock | 5 +- crates/cli/Cargo.toml | 2 + crates/cli/src/commands/manage.rs | 199 +++++++++++- crates/config/src/sections/secrets.rs | 4 +- crates/data-model/Cargo.toml | 1 + crates/data-model/src/lib.rs | 4 + crates/data-model/src/upstream_oauth2/mod.rs | 48 +++ crates/handlers/Cargo.toml | 3 +- crates/handlers/src/lib.rs | 47 +-- .../handlers/src/upstream_oauth2/authorize.rs | 149 +++++++++ .../handlers/src/upstream_oauth2/callback.rs | 290 ++++++++++++++++++ crates/handlers/src/upstream_oauth2/mod.rs | 35 +++ crates/http/Cargo.toml | 4 +- crates/http/src/client.rs | 26 +- .../http/src/layers/body_to_bytes_response.rs | 8 + crates/http/src/layers/client.rs | 21 +- crates/oauth2-types/Cargo.toml | 1 - crates/oauth2-types/src/response_type.rs | 20 +- crates/oauth2-types/src/scope.rs | 22 +- .../src/types/client_credentials.rs | 8 +- crates/router/src/endpoints.rs | 46 +++ crates/router/src/url_builder.rs | 13 + .../20221121151402_upstream_oauth.sql | 84 +++++ crates/storage/sqlx-data.json | 196 ++++++++++++ crates/storage/src/lib.rs | 1 + crates/storage/src/upstream_oauth2/mod.rs | 21 ++ .../storage/src/upstream_oauth2/provider.rs | 159 ++++++++++ crates/storage/src/upstream_oauth2/session.rs | 184 +++++++++++ 28 files changed, 1505 insertions(+), 96 deletions(-) create mode 100644 crates/data-model/src/upstream_oauth2/mod.rs create mode 100644 crates/handlers/src/upstream_oauth2/authorize.rs create mode 100644 crates/handlers/src/upstream_oauth2/callback.rs create mode 100644 crates/handlers/src/upstream_oauth2/mod.rs create mode 100644 crates/storage/migrations/20221121151402_upstream_oauth.sql create mode 100644 crates/storage/src/upstream_oauth2/mod.rs create mode 100644 crates/storage/src/upstream_oauth2/provider.rs create mode 100644 crates/storage/src/upstream_oauth2/session.rs diff --git a/Cargo.lock b/Cargo.lock index cacf96d3..eaf14d5d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2687,6 +2687,7 @@ dependencies = [ "mas-email", "mas-handlers", "mas-http", + "mas-iana", "mas-listener", "mas-policy", "mas-router", @@ -2694,6 +2695,7 @@ dependencies = [ "mas-storage", "mas-tasks", "mas-templates", + "oauth2-types", "opentelemetry", "opentelemetry-http", "opentelemetry-jaeger", @@ -2761,6 +2763,7 @@ dependencies = [ "rand 0.8.5", "serde", "thiserror", + "ulid", "url", ] @@ -2825,6 +2828,7 @@ dependencies = [ "mas-iana", "mas-jose", "mas-keystore", + "mas-oidc-client", "mas-policy", "mas-router", "mas-storage", @@ -3345,7 +3349,6 @@ dependencies = [ "data-encoding", "http", "indoc", - "itertools", "language-tags", "mas-iana", "mas-jose", diff --git a/crates/cli/Cargo.toml b/crates/cli/Cargo.toml index d37a715c..c9529d1c 100644 --- a/crates/cli/Cargo.toml +++ b/crates/cli/Cargo.toml @@ -45,6 +45,7 @@ mas-config = { path = "../config" } mas-email = { path = "../email" } mas-handlers = { path = "../handlers", default-features = false } mas-http = { path = "../http", default-features = false, features = ["axum", "client"] } +mas-iana = { path = "../iana" } mas-listener = { path = "../listener" } mas-policy = { path = "../policy" } mas-router = { path = "../router" } @@ -52,6 +53,7 @@ mas-spa = { path = "../spa" } mas-storage = { path = "../storage" } mas-tasks = { path = "../tasks" } mas-templates = { path = "../templates" } +oauth2-types = { path = "../oauth2-types" } [dev-dependencies] indoc = "1.0.7" diff --git a/crates/cli/src/commands/manage.rs b/crates/cli/src/commands/manage.rs index aeffa5f6..d9b2be14 100644 --- a/crates/cli/src/commands/manage.rs +++ b/crates/cli/src/commands/manage.rs @@ -13,8 +13,10 @@ // limitations under the License. use argon2::Argon2; -use clap::Parser; +use clap::{Parser, ValueEnum}; use mas_config::{DatabaseConfig, RootConfig}; +use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod}; +use mas_router::UrlBuilder; use mas_storage::{ oauth2::client::{insert_client_from_config, lookup_client, truncate_clients}, user::{ @@ -22,6 +24,7 @@ use mas_storage::{ }, Clock, LookupError, }; +use oauth2_types::scope::Scope; use rand::SeedableRng; use tracing::{info, warn}; @@ -31,6 +34,110 @@ pub(super) struct Options { subcommand: Subcommand, } +#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, ValueEnum)] +enum AuthenticationMethod { + /// Client doesn't use any authentication + None, + + /// Client sends its `client_secret` in the request body + ClientSecretPost, + + /// Client sends its `client_secret` in the authorization header + ClientSecretBasic, + + /// Client uses its `client_secret` to sign a client assertion JWT + ClientSecretJwt, + + /// Client uses its private keys to sign a client assertion JWT + PrivateKeyJwt, +} + +impl AuthenticationMethod { + fn requires_client_secret(self) -> bool { + matches!( + self, + Self::ClientSecretJwt | Self::ClientSecretPost | Self::ClientSecretBasic + ) + } +} + +impl From for OAuthClientAuthenticationMethod { + fn from(val: AuthenticationMethod) -> Self { + (&val).into() + } +} + +impl From<&AuthenticationMethod> for OAuthClientAuthenticationMethod { + fn from(val: &AuthenticationMethod) -> Self { + match val { + AuthenticationMethod::None => OAuthClientAuthenticationMethod::None, + AuthenticationMethod::ClientSecretPost => { + OAuthClientAuthenticationMethod::ClientSecretPost + } + AuthenticationMethod::ClientSecretBasic => { + OAuthClientAuthenticationMethod::ClientSecretBasic + } + AuthenticationMethod::ClientSecretJwt => { + OAuthClientAuthenticationMethod::ClientSecretJwt + } + AuthenticationMethod::PrivateKeyJwt => OAuthClientAuthenticationMethod::PrivateKeyJwt, + } + } +} + +#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, ValueEnum)] +enum SigningAlgorithm { + #[value(name = "HS256")] + HS256, + #[value(name = "HS384")] + HS384, + #[value(name = "HS512")] + HS512, + #[value(name = "RS256")] + RS256, + #[value(name = "RS384")] + RS384, + #[value(name = "RS512")] + RS512, + #[value(name = "PS256")] + PS256, + #[value(name = "PS384")] + PS384, + #[value(name = "PS512")] + PS512, + #[value(name = "ES256")] + ES256, + #[value(name = "ES384")] + ES384, + #[value(name = "ES256K")] + ES256K, +} + +impl From for JsonWebSignatureAlg { + fn from(val: SigningAlgorithm) -> Self { + (&val).into() + } +} + +impl From<&SigningAlgorithm> for JsonWebSignatureAlg { + fn from(val: &SigningAlgorithm) -> Self { + match val { + SigningAlgorithm::HS256 => Self::Hs256, + SigningAlgorithm::HS384 => Self::Hs384, + SigningAlgorithm::HS512 => Self::Hs512, + SigningAlgorithm::RS256 => Self::Rs256, + SigningAlgorithm::RS384 => Self::Rs384, + SigningAlgorithm::RS512 => Self::Rs512, + SigningAlgorithm::PS256 => Self::Ps256, + SigningAlgorithm::PS384 => Self::Ps384, + SigningAlgorithm::PS512 => Self::Ps512, + SigningAlgorithm::ES256 => Self::Es256, + SigningAlgorithm::ES384 => Self::Es384, + SigningAlgorithm::ES256K => Self::Es256K, + } + } +} + #[derive(Parser, Debug)] enum Subcommand { /// Register a new user @@ -48,9 +155,38 @@ enum Subcommand { #[arg(long)] truncate: bool, }, + + /// Add an OAuth 2.0 upstream + #[command(name = "add-oauth-upstream")] + AddOAuthUpstream { + /// Issuer URL + issuer: String, + + /// Scope to ask for when authorizing with this upstream. + /// + /// This should include at least the `openid` scope. + scope: Scope, + + /// Client authentication method used when using the token endpoint. + #[arg(value_enum)] + token_endpoint_auth_method: AuthenticationMethod, + + /// Client ID + client_id: String, + + /// JWT signing algorithm used when authenticating for the token + /// endpoint. + #[arg(long, value_enum)] + signing_alg: Option, + + /// Client Secret + #[arg(long)] + client_secret: Option, + }, } impl Options { + #[allow(clippy::too_many_lines)] pub async fn run(&self, root: &super::Options) -> anyhow::Result<()> { use Subcommand as SC; let clock = Clock::default(); @@ -71,11 +207,13 @@ impl Options { Ok(()) } + SC::Users => { warn!("Not implemented yet"); Ok(()) } + SC::VerifyEmail { username, email } => { let config: DatabaseConfig = root.load_config()?; let pool = config.connect().await?; @@ -90,6 +228,7 @@ impl Options { Ok(()) } + SC::ImportClients { truncate } => { let config: RootConfig = root.load_config()?; let pool = config.database.connect().await?; @@ -144,6 +283,64 @@ impl Options { Ok(()) } + + SC::AddOAuthUpstream { + issuer, + scope, + token_endpoint_auth_method, + client_id, + client_secret, + signing_alg, + } => { + let config: RootConfig = root.load_config()?; + let encrypter = config.secrets.encrypter(); + let pool = config.database.connect().await?; + let url_builder = UrlBuilder::new(config.http.public_base); + let mut conn = pool.acquire().await?; + + let requires_client_secret = token_endpoint_auth_method.requires_client_secret(); + + let token_endpoint_auth_method: OAuthClientAuthenticationMethod = + token_endpoint_auth_method.into(); + + let token_endpoint_signing_alg: Option = + signing_alg.as_ref().map(Into::into); + + tracing::info!(%issuer, %scope, %token_endpoint_auth_method, %client_id, "Adding OAuth upstream"); + + if client_secret.is_none() && requires_client_secret { + tracing::warn!("Token endpoint auth method requires a client secret, but none were provided"); + } + + let encrypted_client_secret = client_secret + .as_deref() + .map(|client_secret| encrypter.encryt_to_string(client_secret.as_bytes())) + .transpose()?; + + let provider = mas_storage::upstream_oauth2::add_provider( + &mut conn, + &mut rng, + &clock, + issuer.clone(), + scope.clone(), + token_endpoint_auth_method, + token_endpoint_signing_alg, + client_id.clone(), + encrypted_client_secret, + ) + .await?; + + let redirect_uri = url_builder.upstream_oauth_callback(provider.id); + let auth_uri = url_builder.upstream_oauth_authorize(provider.id); + tracing::info!( + %provider.id, + %provider.client_id, + provider.redirect_uri = %redirect_uri, + "Test authorization by going to {auth_uri}" + ); + + Ok(()) + } } } } diff --git a/crates/config/src/sections/secrets.rs b/crates/config/src/sections/secrets.rs index 25d40bcc..b4c5e7d2 100644 --- a/crates/config/src/sections/secrets.rs +++ b/crates/config/src/sections/secrets.rs @@ -119,7 +119,9 @@ impl SecretsConfig { } }; - let key = JsonWebKey::new(key).with_kid(item.kid.clone()); + let key = JsonWebKey::new(key) + .with_kid(item.kid.clone()) + .with_use(mas_iana::jose::JsonWebKeyUse::Sig); keys.push(key); } diff --git a/crates/data-model/Cargo.toml b/crates/data-model/Cargo.toml index 6407bb7d..bac92881 100644 --- a/crates/data-model/Cargo.toml +++ b/crates/data-model/Cargo.toml @@ -12,6 +12,7 @@ serde = "1.0.148" url = { version = "2.3.1", features = ["serde"] } crc = "3.0.0" rand = "0.8.5" +ulid = "1.0.0" mas-iana = { path = "../iana" } mas-jose = { path = "../jose" } diff --git a/crates/data-model/src/lib.rs b/crates/data-model/src/lib.rs index 344fa708..9143369f 100644 --- a/crates/data-model/src/lib.rs +++ b/crates/data-model/src/lib.rs @@ -27,6 +27,7 @@ pub(crate) mod compat; pub(crate) mod oauth2; pub(crate) mod tokens; pub(crate) mod traits; +pub(crate) mod upstream_oauth2; pub(crate) mod users; pub use self::{ @@ -40,6 +41,9 @@ pub use self::{ }, tokens::{AccessToken, RefreshToken, TokenFormatError, TokenType}, traits::{StorageBackend, StorageBackendMarker}, + upstream_oauth2::{ + UpstreamOAuthAuthorizationSession, UpstreamOAuthLink, UpstreamOAuthProvider, + }, users::{ Authentication, BrowserSession, User, UserEmail, UserEmailVerification, UserEmailVerificationState, diff --git a/crates/data-model/src/upstream_oauth2/mod.rs b/crates/data-model/src/upstream_oauth2/mod.rs new file mode 100644 index 00000000..ace6f100 --- /dev/null +++ b/crates/data-model/src/upstream_oauth2/mod.rs @@ -0,0 +1,48 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use chrono::{DateTime, Utc}; +use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod}; +use oauth2_types::scope::Scope; +use serde::Serialize; +use ulid::Ulid; + +#[derive(Debug, Clone, PartialEq, Eq, Serialize)] +pub struct UpstreamOAuthProvider { + pub id: Ulid, + pub issuer: String, + pub scope: Scope, + pub client_id: String, + pub encrypted_client_secret: Option, + pub token_endpoint_signing_alg: Option, + pub token_endpoint_auth_method: OAuthClientAuthenticationMethod, + pub created_at: DateTime, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize)] +pub struct UpstreamOAuthLink { + pub id: Ulid, + pub subject: String, + pub created_at: DateTime, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize)] +pub struct UpstreamOAuthAuthorizationSession { + pub id: Ulid, + pub state: String, + pub code_challenge_verifier: Option, + pub nonce: String, + pub created_at: DateTime, + pub completed_at: Option>, +} diff --git a/crates/handlers/Cargo.toml b/crates/handlers/Cargo.toml index c00befe8..260816bf 100644 --- a/crates/handlers/Cargo.toml +++ b/crates/handlers/Cargo.toml @@ -52,7 +52,6 @@ rand_chacha = "0.3.1" headers = "0.3.8" ulid = "1.0.0" -oauth2-types = { path = "../oauth2-types" } mas-axum-utils = { path = "../axum-utils", default-features = false } mas-data-model = { path = "../data-model" } mas-email = { path = "../email" } @@ -61,10 +60,12 @@ mas-http = { path = "../http", default-features = false } mas-iana = { path = "../iana" } mas-jose = { path = "../jose" } mas-keystore = { path = "../keystore" } +mas-oidc-client = { path = "../oidc-client" } mas-policy = { path = "../policy" } mas-router = { path = "../router" } mas-storage = { path = "../storage" } mas-templates = { path = "../templates" } +oauth2-types = { path = "../oauth2-types" } [dev-dependencies] indoc = "1.0.7" diff --git a/crates/handlers/src/lib.rs b/crates/handlers/src/lib.rs index 8c6d2e5d..974d849b 100644 --- a/crates/handlers/src/lib.rs +++ b/crates/handlers/src/lib.rs @@ -52,6 +52,7 @@ mod compat; mod graphql; mod health; mod oauth2; +mod upstream_oauth2; mod views; pub use compat::MatrixHomeserver; @@ -233,6 +234,7 @@ where Encrypter: FromRef, Templates: FromRef, Mailer: FromRef, + Keystore: FromRef, { Router::new() .route( @@ -296,6 +298,14 @@ where mas_router::CompatLoginSsoComplete::route(), get(self::compat::login_sso_complete::get).post(self::compat::login_sso_complete::post), ) + .route( + mas_router::UpstreamOAuth2Authorize::route(), + get(self::upstream_oauth2::authorize::get), + ) + .route( + mas_router::UpstreamOAuth2Callback::route(), + get(self::upstream_oauth2::callback::get), + ) .layer(AndThenLayer::new( move |response: axum::response::Response| async move { if response.status().is_server_error() { @@ -315,43 +325,6 @@ where )) } -/* -#[must_use] -#[allow(clippy::trait_duplication_in_bounds)] -pub fn router(state: S) -> RouterService -where - B: HttpBody + Send + 'static, - ::Data: Into + Send, - ::Error: std::error::Error + Send + Sync, - S: Clone + Send + Sync + 'static, - Keystore: FromRef, - UrlBuilder: FromRef, - Arc: FromRef, - PgPool: FromRef, - Encrypter: FromRef, - Templates: FromRef, - Mailer: FromRef, - MatrixHomeserver: FromRef, - mas_graphql::Schema: FromRef, -{ - let healthcheck_router = healthcheck_router(); - let discovery_router = discovery_router(); - let api_router = api_router(); - let graphql_router = graphql_router(true); - let compat_router = compat_router(); - let human_router = human_router(Templates::from_ref(&state)); - - Router::new() - .merge(healthcheck_router) - .merge(discovery_router) - .merge(human_router) - .merge(api_router) - .merge(graphql_router) - .merge(compat_router) - .with_state(state) -} -*/ - #[cfg(test)] async fn test_state(pool: PgPool) -> Result { use mas_email::MailTransport; diff --git a/crates/handlers/src/upstream_oauth2/authorize.rs b/crates/handlers/src/upstream_oauth2/authorize.rs new file mode 100644 index 00000000..f473582b --- /dev/null +++ b/crates/handlers/src/upstream_oauth2/authorize.rs @@ -0,0 +1,149 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use axum::{ + extract::{Path, State}, + response::{IntoResponse, Redirect}, +}; +use axum_extra::extract::{cookie::Cookie, PrivateCookieJar}; +use hyper::StatusCode; +use mas_http::ClientInitError; +use mas_keystore::Encrypter; +use mas_oidc_client::{ + error::{AuthorizationError, DiscoveryError}, + requests::authorization_code::AuthorizationRequestData, +}; +use mas_router::UrlBuilder; +use mas_storage::{upstream_oauth2::lookup_provider, LookupResultExt}; +use sqlx::PgPool; +use thiserror::Error; +use ulid::Ulid; + +use super::http_service; + +#[derive(Debug, Error)] +pub(crate) enum RouteError { + #[error("Provider not found")] + ProviderNotFound, + + #[error(transparent)] + Authorization(#[from] AuthorizationError), + + #[error(transparent)] + InternalError(Box), + + #[error(transparent)] + Anyhow(#[from] anyhow::Error), +} + +impl From for RouteError { + fn from(e: sqlx::Error) -> Self { + Self::InternalError(Box::new(e)) + } +} + +impl From for RouteError { + fn from(e: DiscoveryError) -> Self { + Self::InternalError(Box::new(e)) + } +} + +impl From for RouteError { + fn from(e: mas_storage::upstream_oauth2::ProviderLookupError) -> Self { + Self::InternalError(Box::new(e)) + } +} + +impl From for RouteError { + fn from(e: ClientInitError) -> Self { + Self::InternalError(Box::new(e)) + } +} + +impl IntoResponse for RouteError { + fn into_response(self) -> axum::response::Response { + match self { + Self::ProviderNotFound => (StatusCode::NOT_FOUND, "Provider not found").into_response(), + Self::Authorization(e) => { + (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response() + } + Self::InternalError(e) => { + (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response() + } + Self::Anyhow(e) => { + (StatusCode::INTERNAL_SERVER_ERROR, format!("{e:?}")).into_response() + } + } + } +} + +pub(crate) async fn get( + State(pool): State, + State(url_builder): State, + cookie_jar: PrivateCookieJar, + Path(provider_id): Path, +) -> Result { + let (clock, mut rng) = crate::rng_and_clock()?; + + let mut txn = pool.begin().await?; + + let provider = lookup_provider(&mut txn, provider_id) + .await + .to_option()? + .ok_or(RouteError::ProviderNotFound)?; + + let http_service = http_service("upstream-discover").await?; + + // First, discover the provider + let metadata = + mas_oidc_client::requests::discovery::discover(&http_service, &provider.issuer).await?; + + let redirect_uri = url_builder.upstream_oauth_callback(provider.id); + + let data = AuthorizationRequestData { + client_id: &provider.client_id, + scope: &provider.scope, + prompt: None, + redirect_uri: &redirect_uri, + code_challenge_methods_supported: metadata.code_challenge_methods_supported.as_deref(), + }; + + // Build an authorization request for it + let (url, data) = mas_oidc_client::requests::authorization_code::build_authorization_url( + metadata.authorization_endpoint().clone(), + data, + &mut rng, + )?; + + let session = mas_storage::upstream_oauth2::add_session( + &mut txn, + &mut rng, + &clock, + &provider, + data.state, + data.code_challenge_verifier, + data.nonce, + ) + .await?; + + // TODO: handle that cookie somewhere else? + let mut cookie = Cookie::new("upstream-oauth2-session-id", session.id.to_string()); + cookie.set_path("/"); + cookie.set_http_only(true); + let cookie_jar = cookie_jar.add(cookie); + + txn.commit().await?; + + Ok((cookie_jar, Redirect::temporary(url.as_str()))) +} diff --git a/crates/handlers/src/upstream_oauth2/callback.rs b/crates/handlers/src/upstream_oauth2/callback.rs new file mode 100644 index 00000000..37745cab --- /dev/null +++ b/crates/handlers/src/upstream_oauth2/callback.rs @@ -0,0 +1,290 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use anyhow::Context; +use axum::{ + extract::{Path, Query, State}, + response::IntoResponse, + Json, +}; +use axum_extra::extract::PrivateCookieJar; +use hyper::StatusCode; +use mas_http::ClientInitError; +use mas_iana::oauth::OAuthClientAuthenticationMethod; +use mas_keystore::{Encrypter, Keystore}; +use mas_oidc_client::{ + error::{DiscoveryError, JwksError, TokenAuthorizationCodeError}, + requests::{authorization_code::AuthorizationValidationData, jose::JwtVerificationData}, + types::client_credentials::ClientCredentials, +}; +use mas_router::UrlBuilder; +use mas_storage::{upstream_oauth2::lookup_session, LookupResultExt}; +use oauth2_types::errors::ClientErrorCode; +use serde::Deserialize; +use sqlx::PgPool; +use thiserror::Error; +use ulid::Ulid; + +use super::http_service; + +#[derive(Deserialize)] +pub struct QueryParams { + state: String, + + #[serde(flatten)] + code_or_error: CodeOrError, +} + +#[derive(Deserialize)] +#[serde(untagged)] +enum CodeOrError { + Code { + code: String, + }, + Error { + error: ClientErrorCode, + error_description: Option, + #[allow(dead_code)] + error_uri: Option, + }, +} + +#[derive(Debug, Error)] +pub(crate) enum RouteError { + #[error("Session not found")] + SessionNotFound, + + #[error("Provider mismatch")] + ProviderMismatch, + + #[error("State parameter mismatch")] + StateMismatch, + + #[error("Error from the provider: {error}")] + ClientError { + error: ClientErrorCode, + error_description: Option, + }, + + #[error("Provider is missing a client secret")] + MissingClientSecret, + + #[error("Missing session cookie")] + MissingCookie, + + #[error("Invalid session cookie")] + InvalidCookie(#[source] ulid::DecodeError), + + #[error(transparent)] + InternalError(Box), + + #[error(transparent)] + Anyhow(#[from] anyhow::Error), +} + +impl From for RouteError { + fn from(e: sqlx::Error) -> Self { + Self::InternalError(Box::new(e)) + } +} + +impl From for RouteError { + fn from(e: DiscoveryError) -> Self { + Self::InternalError(Box::new(e)) + } +} + +impl From for RouteError { + fn from(e: JwksError) -> Self { + Self::InternalError(Box::new(e)) + } +} + +impl From for RouteError { + fn from(e: TokenAuthorizationCodeError) -> Self { + Self::InternalError(Box::new(e)) + } +} + +impl From for RouteError { + fn from(e: mas_storage::upstream_oauth2::SessionLookupError) -> Self { + Self::InternalError(Box::new(e)) + } +} + +impl From for RouteError { + fn from(e: ClientInitError) -> Self { + Self::InternalError(Box::new(e)) + } +} + +impl IntoResponse for RouteError { + fn into_response(self) -> axum::response::Response { + match self { + Self::SessionNotFound => (StatusCode::NOT_FOUND, "Session not found").into_response(), + Self::InternalError(e) => { + (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response() + } + Self::Anyhow(e) => { + (StatusCode::INTERNAL_SERVER_ERROR, format!("{e:?}")).into_response() + } + e => (StatusCode::BAD_REQUEST, e.to_string()).into_response(), + } + } +} + +#[allow(clippy::too_many_lines)] +pub(crate) async fn get( + State(pool): State, + State(url_builder): State, + State(encrypter): State, + State(keystore): State, + cookie_jar: PrivateCookieJar, + Path(provider_id): Path, + Query(params): Query, +) -> Result { + let (clock, mut rng) = crate::rng_and_clock()?; + + let mut txn = pool.begin().await?; + + // XXX: that cookie should be managed elsewhere + let cookie = cookie_jar + .get("upstream-oauth2-session-id") + .ok_or(RouteError::MissingCookie)?; + + let session_id: Ulid = cookie.value().parse().map_err(RouteError::InvalidCookie)?; + + let (provider, session) = lookup_session(&mut txn, session_id) + .await + .to_option()? + .ok_or(RouteError::SessionNotFound)?; + + if provider.id != provider_id { + // The provider in the session cookie should match the one from the URL + return Err(RouteError::ProviderMismatch); + } + + if params.state != session.state { + // The state in the session cookie should match the one from the params + return Err(RouteError::StateMismatch); + } + + // Let's extract the code from the params, and return if there was an error + let code = match params.code_or_error { + CodeOrError::Error { + error, + error_description, + .. + } => { + return Err(RouteError::ClientError { + error, + error_description, + }) + } + CodeOrError::Code { code } => code, + }; + + let http_service = http_service("upstream-code-exchange").await?; + + // XXX: we shouldn't discover on-the-fly + // Discover the provider + let metadata = + mas_oidc_client::requests::discovery::discover(&http_service, &provider.issuer).await?; + + // Fetch the JWKS + let jwks = + mas_oidc_client::requests::jose::fetch_jwks(&http_service, metadata.jwks_uri()).await?; + + // Figure out the client credentials + let client_id = provider.client_id.clone(); + // Decrypt the client secret + let client_secret = provider + .encrypted_client_secret + .map(|encrypted_client_secret| { + encrypter + .decrypt_string(&encrypted_client_secret) + .and_then(|client_secret| { + String::from_utf8(client_secret) + .context("Client secret contains non-UTF8 bytes") + }) + }) + .transpose()?; + + let token_endpoint = metadata.token_endpoint(); + + let client_credentials = match provider.token_endpoint_auth_method { + OAuthClientAuthenticationMethod::None => ClientCredentials::None { client_id }, + OAuthClientAuthenticationMethod::ClientSecretPost => ClientCredentials::ClientSecretPost { + client_id, + client_secret: client_secret.ok_or(RouteError::MissingClientSecret)?, + }, + OAuthClientAuthenticationMethod::ClientSecretBasic => { + ClientCredentials::ClientSecretBasic { + client_id, + client_secret: client_secret.ok_or(RouteError::MissingClientSecret)?, + } + } + OAuthClientAuthenticationMethod::ClientSecretJwt => ClientCredentials::ClientSecretJwt { + client_id, + client_secret: client_secret.ok_or(RouteError::MissingClientSecret)?, + signing_algorithm: provider + .token_endpoint_signing_alg + .unwrap_or(mas_iana::jose::JsonWebSignatureAlg::Rs256), + token_endpoint: token_endpoint.clone(), + }, + OAuthClientAuthenticationMethod::PrivateKeyJwt => ClientCredentials::PrivateKeyJwt { + client_id, + jwt_signing_method: + mas_oidc_client::types::client_credentials::JwtSigningMethod::Keystore(keystore), + signing_algorithm: provider + .token_endpoint_signing_alg + .unwrap_or(mas_iana::jose::JsonWebSignatureAlg::Rs256), + token_endpoint: token_endpoint.clone(), + }, + // XXX: The database should never have an unsupported method in it + _ => unreachable!(), + }; + + let redirect_uri = url_builder.upstream_oauth_callback(provider.id); + + let validation_data = AuthorizationValidationData { + state: session.state, + nonce: session.nonce, + code_challenge_verifier: session.code_challenge_verifier, + redirect_uri, + }; + + let id_token_verification_data = JwtVerificationData { + issuer: &provider.issuer, + jwks: &jwks, + // TODO: make that configurable + signing_algorithm: &mas_iana::jose::JsonWebSignatureAlg::Rs256, + client_id: &provider.client_id, + }; + + let (response, _id_token) = + mas_oidc_client::requests::authorization_code::access_token_with_authorization_code( + &http_service, + client_credentials, + token_endpoint, + code, + validation_data, + Some(id_token_verification_data), + clock.now(), + &mut rng, + ) + .await?; + + Ok(Json(response)) +} diff --git a/crates/handlers/src/upstream_oauth2/mod.rs b/crates/handlers/src/upstream_oauth2/mod.rs new file mode 100644 index 00000000..c6856025 --- /dev/null +++ b/crates/handlers/src/upstream_oauth2/mod.rs @@ -0,0 +1,35 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use axum::body::Full; +use mas_http::{BodyToBytesResponseLayer, ClientInitError, ClientLayer, HttpService}; +use tower::{ + util::{MapErrLayer, MapRequestLayer}, + BoxError, Layer, +}; + +pub(crate) mod authorize; +pub(crate) mod callback; + +async fn http_service(operation: &'static str) -> Result { + let client = ( + MapErrLayer::new(BoxError::from), + MapRequestLayer::new(|req: hyper::Request<_>| req.map(Full::new)), + BodyToBytesResponseLayer::default(), + ClientLayer::new(operation), + ) + .layer(mas_http::make_untraced_client().await?); + + Ok(HttpService::new(client)) +} diff --git a/crates/http/Cargo.toml b/crates/http/Cargo.toml index 06515dd8..2e2acbff 100644 --- a/crates/http/Cargo.toml +++ b/crates/http/Cargo.toml @@ -27,8 +27,8 @@ serde_json = "1.0.89" serde_urlencoded = "0.7.1" thiserror = "1.0.37" tokio = { version = "1.22.0", features = ["sync", "parking_lot"], optional = true } -tower = { version = "0.4.13", features = ["timeout", "limit"] } -tower-http = { version = "0.3.5", features = ["follow-redirect", "decompression-full", "set-header", "compression-full", "cors", "util"] } +tower = { version = "0.4.13", features = ["limit"] } +tower-http = { version = "0.3.5", features = ["timeout", "follow-redirect", "decompression-full", "set-header", "compression-full", "cors", "util"] } tracing = "0.1.37" tracing-opentelemetry = "0.18.0" webpki = { version = "0.22.0", optional = true } diff --git a/crates/http/src/client.rs b/crates/http/src/client.rs index a8e89aff..c80a3995 100644 --- a/crates/http/src/client.rs +++ b/crates/http/src/client.rs @@ -16,7 +16,6 @@ use std::{convert::Infallible, net::SocketAddr}; use bytes::Bytes; use http::{Request, Response}; -use http_body::{combinators::BoxBody, Body}; use hyper::{ client::{ connect::dns::{GaiResolver, Name}, @@ -26,14 +25,11 @@ use hyper::{ }; use hyper_rustls::{HttpsConnector, HttpsConnectorBuilder}; use thiserror::Error; -use tower::{ - util::{MapErrLayer, MapResponseLayer}, - Layer, Service, -}; +use tower::{Layer, Service}; use crate::{ layers::{ - client::{ClientLayer, ClientResponse}, + client::ClientLayer, otel::{TraceDns, TraceLayer}, }, BoxCloneSyncService, BoxError, @@ -229,32 +225,20 @@ where } /// Create a traced HTTP client, with a default timeout, which follows redirects -/// and handles compression /// /// # Errors /// /// Returns an error if it failed to initialize pub async fn client( operation: &'static str, -) -> Result< - BoxCloneSyncService, Response>, ClientError>, - ClientInitError, -> +) -> Result, Response, hyper::Error>, ClientInitError> where B: http_body::Body + Default + Send + 'static, E: Into + 'static, { let client = make_traced_client().await?; - let layer = ( - // Convert the errors to ClientError to help dealing with them - MapErrLayer::new(ClientError::from), - MapResponseLayer::new(|r: ClientResponse| { - r.map(|body| body.map_err(ClientError::from).boxed()) - }), - ClientLayer::new(operation), - ); - let client = BoxCloneSyncService::new(layer.layer(client)); + let client = ClientLayer::new(operation).layer(client); - Ok(client) + Ok(BoxCloneSyncService::new(client)) } diff --git a/crates/http/src/layers/body_to_bytes_response.rs b/crates/http/src/layers/body_to_bytes_response.rs index 648a71ed..1ebb3d93 100644 --- a/crates/http/src/layers/body_to_bytes_response.rs +++ b/crates/http/src/layers/body_to_bytes_response.rs @@ -38,6 +38,14 @@ impl Error { } } +impl Error { + pub fn unify(self) -> E { + match self { + Self::Service { inner } | Self::Body { inner } => inner, + } + } +} + #[derive(Clone)] pub struct BodyToBytesResponse { inner: S, diff --git a/crates/http/src/layers/client.rs b/crates/http/src/layers/client.rs index 89ff5b8a..eb10dd5a 100644 --- a/crates/http/src/layers/client.rs +++ b/crates/http/src/layers/client.rs @@ -17,13 +17,12 @@ use std::{marker::PhantomData, time::Duration}; use http::{header::USER_AGENT, HeaderValue, Request, Response}; use tower::{ limit::{ConcurrencyLimit, ConcurrencyLimitLayer}, - timeout::{Timeout, TimeoutLayer}, Layer, Service, }; use tower_http::{ - decompression::{Decompression, DecompressionBody, DecompressionLayer}, follow_redirect::{FollowRedirect, FollowRedirectLayer}, set_header::{SetRequestHeader, SetRequestHeaderLayer}, + timeout::{Timeout, TimeoutLayer}, }; use super::otel::TraceLayer; @@ -48,9 +47,6 @@ impl ClientLayer { } } -#[allow(dead_code)] -pub type ClientResponse = Response>; - impl Layer for ClientLayer where S: Service, Response = Response, Error = E> @@ -63,21 +59,14 @@ where S::Future: Send + 'static, E: Into, { - type Service = Decompression< - SetRequestHeader< - TraceHttpClient>>>>, - HeaderValue, - >, + type Service = SetRequestHeader< + TraceHttpClient>>>>, + HeaderValue, >; fn layer(&self, inner: S) -> Self::Service { - // Note that most layers here just forward the error type. Two notables - // exceptions are: - // - the TimeoutLayer - // - the DecompressionLayer - // Those layers do type erasure of the error. + // Note that all layers here just forward the error type. ( - DecompressionLayer::new(), SetRequestHeaderLayer::overriding(USER_AGENT, MAS_USER_AGENT.clone()), // A trace that has the whole operation, with all the redirects, timeouts and rate // limits in it diff --git a/crates/oauth2-types/Cargo.toml b/crates/oauth2-types/Cargo.toml index 837b29ea..fa226f6e 100644 --- a/crates/oauth2-types/Cargo.toml +++ b/crates/oauth2-types/Cargo.toml @@ -18,7 +18,6 @@ chrono = "0.4.23" sha2 = "0.10.6" data-encoding = "2.3.2" thiserror = "1.0.37" -itertools = "0.10.5" mas-iana = { path = "../iana" } mas-jose = { path = "../jose" } diff --git a/crates/oauth2-types/src/response_type.rs b/crates/oauth2-types/src/response_type.rs index a3d287a5..1a7dc00a 100644 --- a/crates/oauth2-types/src/response_type.rs +++ b/crates/oauth2-types/src/response_type.rs @@ -20,7 +20,6 @@ use std::{collections::BTreeSet, fmt, iter::FromIterator, str::FromStr}; -use itertools::Itertools; use mas_iana::oauth::OAuthAuthorizationEndpointResponseType; use parse_display::{Display, FromStr}; use serde_with::{DeserializeFromStr, SerializeDisplay}; @@ -127,14 +126,23 @@ impl FromStr for ResponseType { impl fmt::Display for ResponseType { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let res = Itertools::intersperse(self.iter().map(ToString::to_string), ' '.to_string()) - .collect::(); + let mut iter = self.iter(); - if res.is_empty() { - write!(f, "none") + // First item shouldn't have a leading space + if let Some(first) = iter.next() { + first.fmt(f)?; } else { - f.write_str(&res) + // If the whole iterator is empty, write 'none' instead + write!(f, "none")?; + return Ok(()); } + + // Write the other items with a leading space + for item in iter { + write!(f, " {item}")?; + } + + Ok(()) } } diff --git a/crates/oauth2-types/src/scope.rs b/crates/oauth2-types/src/scope.rs index f4c4a631..ab6299b0 100644 --- a/crates/oauth2-types/src/scope.rs +++ b/crates/oauth2-types/src/scope.rs @@ -20,7 +20,6 @@ use std::{borrow::Cow, collections::BTreeSet, iter::FromIterator, ops::Deref, str::FromStr}; -use itertools::Itertools; use serde::{Deserialize, Serialize}; use thiserror::Error; @@ -106,9 +105,9 @@ impl Deref for ScopeToken { } } -impl ToString for ScopeToken { - fn to_string(&self) -> String { - self.0.to_string() +impl std::fmt::Display for ScopeToken { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.0.fmt(f) } } @@ -169,10 +168,17 @@ impl Scope { } } -impl ToString for Scope { - fn to_string(&self) -> String { - let it = self.0.iter().map(ScopeToken::to_string); - Itertools::intersperse(it, ' '.to_string()).collect() +impl std::fmt::Display for Scope { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + for (index, token) in self.0.iter().enumerate() { + if index == 0 { + write!(f, "{token}")?; + } else { + write!(f, " {token}")?; + } + } + + Ok(()) } } diff --git a/crates/oidc-client/src/types/client_credentials.rs b/crates/oidc-client/src/types/client_credentials.rs index 4fd7b9a9..1009c320 100644 --- a/crates/oidc-client/src/types/client_credentials.rs +++ b/crates/oidc-client/src/types/client_credentials.rs @@ -23,6 +23,7 @@ use http::Request; use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod}; use mas_jose::{ claims::{self, ClaimError}, + constraints::Constrainable, jwa::SymmetricKey, jwt::{JsonWebSignatureHeader, Jwt}, }; @@ -338,7 +339,12 @@ impl RequestClientCredentials { .signing_key_for_algorithm(&signing_algorithm) .ok_or(CredentialsError::NoPrivateKeyFound)?; let signer = key.params().signing_key_for_alg(&signing_algorithm)?; - let header = JsonWebSignatureHeader::new(signing_algorithm); + let mut header = JsonWebSignatureHeader::new(signing_algorithm); + + if let Some(kid) = key.kid() { + header = header.with_kid(kid); + } + Jwt::sign(header, claims, &signer)?.to_string() } JwtSigningMethod::Custom(jwt_signing_fn) => { diff --git a/crates/router/src/endpoints.rs b/crates/router/src/endpoints.rs index 3cb390b3..08e558ff 100644 --- a/crates/router/src/endpoints.rs +++ b/crates/router/src/endpoints.rs @@ -524,6 +524,52 @@ impl Route for CompatLoginSsoComplete { } } +/// `GET /upstream/authorize/:id` +pub struct UpstreamOAuth2Authorize { + id: Ulid, +} + +impl UpstreamOAuth2Authorize { + #[must_use] + pub const fn new(id: Ulid) -> Self { + Self { id } + } +} + +impl Route for UpstreamOAuth2Authorize { + type Query = (); + fn route() -> &'static str { + "/upstream/authorize/:provider_id" + } + + fn path(&self) -> std::borrow::Cow<'static, str> { + format!("/upstream/authorize/{}", self.id).into() + } +} + +/// `GET /upstream/callback/:id` +pub struct UpstreamOAuth2Callback { + id: Ulid, +} + +impl UpstreamOAuth2Callback { + #[must_use] + pub const fn new(id: Ulid) -> Self { + Self { id } + } +} + +impl Route for UpstreamOAuth2Callback { + type Query = (); + fn route() -> &'static str { + "/upstream/callback/:provider_id" + } + + fn path(&self) -> std::borrow::Cow<'static, str> { + format!("/upstream/callback/{}", self.id).into() + } +} + /// `GET /assets` pub struct StaticAsset { path: String, diff --git a/crates/router/src/url_builder.rs b/crates/router/src/url_builder.rs index ce2abd81..35888a34 100644 --- a/crates/router/src/url_builder.rs +++ b/crates/router/src/url_builder.rs @@ -14,6 +14,7 @@ //! Utility to build URLs +use ulid::Ulid; use url::Url; use crate::traits::Route; @@ -97,4 +98,16 @@ impl UrlBuilder { pub fn static_asset(&self, path: String) -> Url { self.url_for(&crate::endpoints::StaticAsset::new(path)) } + + /// Upstream redirect URI + #[must_use] + pub fn upstream_oauth_callback(&self, id: Ulid) -> Url { + self.url_for(&crate::endpoints::UpstreamOAuth2Callback::new(id)) + } + + /// Upstream authorize URI + #[must_use] + pub fn upstream_oauth_authorize(&self, id: Ulid) -> Url { + self.url_for(&crate::endpoints::UpstreamOAuth2Authorize::new(id)) + } } diff --git a/crates/storage/migrations/20221121151402_upstream_oauth.sql b/crates/storage/migrations/20221121151402_upstream_oauth.sql new file mode 100644 index 00000000..1ed10029 --- /dev/null +++ b/crates/storage/migrations/20221121151402_upstream_oauth.sql @@ -0,0 +1,84 @@ +-- Copyright 2022 The Matrix.org Foundation C.I.C. +-- +-- Licensed under the Apache License, Version 2.0 (the "License"); +-- you may not use this file except in compliance with the License. +-- You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. + +CREATE TABLE "upstream_oauth_providers" ( + "upstream_oauth_provider_id" UUID NOT NULL + CONSTRAINT "upstream_oauth_providers_pkey" + PRIMARY KEY, + + "issuer" TEXT NOT NULL, + + "scope" TEXT NOT NULL, + + "client_id" TEXT NOT NULL, + + -- Used for client_secret_basic, client_secret_post and client_secret_jwt auth methods + "encrypted_client_secret" TEXT, + + -- Used for client_secret_jwt and private_key_jwt auth methods + "token_endpoint_signing_alg" TEXT, + + "token_endpoint_auth_method" TEXT NOT NULL, + + "created_at" TIMESTAMP WITH TIME ZONE NOT NULL +); + +CREATE TABLE "upstream_oauth_links" ( + "upstream_oauth_link_id" UUID NOT NULL + CONSTRAINT "upstream_oauth_links_pkey" + PRIMARY KEY, + + "upstream_oauth_provider_id" UUID NOT NULL + CONSTRAINT "upstream_oauth_links_upstream_oauth_provider_fkey" + REFERENCES "upstream_oauth_providers" ("upstream_oauth_provider_id"), + + -- The user is initially NULL when logging in the first time. + -- It then either links to an existing account, or creates a new one from scratch. + "user_id" UUID + CONSTRAINT "upstream_oauth_link_user_fkey" + REFERENCES "users" ("user_id"), + + "subject" TEXT NOT NULL, + + "created_at" TIMESTAMP WITH TIME ZONE NOT NULL, + + -- There should only be one entry per subject/provider tuple + CONSTRAINT "upstream_oauth_links_subject_unique" + UNIQUE ("upstream_oauth_provider_id", "subject") +); + +CREATE TABLE "upstream_oauth_authorization_sessions" ( + "upstream_oauth_authorization_session_id" UUID NOT NULL + CONSTRAINT "upstream_oauth_authorization_sessions_pkey" + PRIMARY KEY, + + "upstream_oauth_provider_id" UUID NOT NULL + CONSTRAINT "upstream_oauth_authorization_sessions_upstream_oauth_provider_fkey" + REFERENCES "upstream_oauth_providers" ("upstream_oauth_provider_id"), + + -- The link it resolves to at the end of the authorization grant + "upstream_oauth_link_id" UUID + CONSTRAINT "upstream_oauth_authorization_sessions_upstream_oauth_link_fkey" + REFERENCES "upstream_oauth_links" ("upstream_oauth_link_id"), + + "state" TEXT NOT NULL + CONSTRAINT "upstream_oauth_authorization_sessions_state_unique" + UNIQUE, + + "code_challenge_verifier" TEXT, + "nonce" TEXT NOT NULL, + + "created_at" TIMESTAMP WITH TIME ZONE NOT NULL, + "completed_at" TIMESTAMP WITH TIME ZONE +); diff --git a/crates/storage/sqlx-data.json b/crates/storage/sqlx-data.json index 05ccab02..26ab9744 100644 --- a/crates/storage/sqlx-data.json +++ b/crates/storage/sqlx-data.json @@ -214,6 +214,68 @@ }, "query": "\n SELECT\n c.oauth2_client_id,\n c.encrypted_client_secret,\n ARRAY(\n SELECT redirect_uri\n FROM oauth2_client_redirect_uris r\n WHERE r.oauth2_client_id = c.oauth2_client_id\n ) AS \"redirect_uris!\",\n c.grant_type_authorization_code,\n c.grant_type_refresh_token,\n c.client_name,\n c.logo_uri,\n c.client_uri,\n c.policy_uri,\n c.tos_uri,\n c.jwks_uri,\n c.jwks,\n c.id_token_signed_response_alg,\n c.userinfo_signed_response_alg,\n c.token_endpoint_auth_method,\n c.token_endpoint_auth_signing_alg,\n c.initiate_login_uri\n FROM oauth2_clients c\n\n WHERE c.oauth2_client_id = $1\n " }, + "0af182315b36766eca8e232280986bade0202d1b1d64160a99cd14eadcbfc25b": { + "describe": { + "columns": [ + { + "name": "upstream_oauth_provider_id", + "ordinal": 0, + "type_info": "Uuid" + }, + { + "name": "issuer", + "ordinal": 1, + "type_info": "Text" + }, + { + "name": "scope", + "ordinal": 2, + "type_info": "Text" + }, + { + "name": "client_id", + "ordinal": 3, + "type_info": "Text" + }, + { + "name": "encrypted_client_secret", + "ordinal": 4, + "type_info": "Text" + }, + { + "name": "token_endpoint_signing_alg", + "ordinal": 5, + "type_info": "Text" + }, + { + "name": "token_endpoint_auth_method", + "ordinal": 6, + "type_info": "Text" + }, + { + "name": "created_at", + "ordinal": 7, + "type_info": "Timestamptz" + } + ], + "nullable": [ + false, + false, + false, + false, + true, + true, + false, + false + ], + "parameters": { + "Left": [ + "Uuid" + ] + } + }, + "query": "\n SELECT\n upstream_oauth_provider_id,\n issuer,\n scope,\n client_id,\n encrypted_client_secret,\n token_endpoint_signing_alg,\n token_endpoint_auth_method,\n created_at\n FROM upstream_oauth_providers\n WHERE upstream_oauth_provider_id = $1\n " + }, "0b49cde0b7b79f79ec261502ab89bcffa81f9f5ed2f922a41b1718274b9e3073": { "describe": { "columns": [ @@ -291,6 +353,25 @@ }, "query": "\n INSERT INTO user_session_authentications\n (user_session_authentication_id, user_session_id, created_at)\n VALUES ($1, $2, $3)\n " }, + "1ee5cecfafd4726a4ebc08da8a34c09178e6e1e072581c8fca9d3d76967792cb": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "Uuid", + "Text", + "Text", + "Text", + "Text", + "Text", + "Text", + "Timestamptz" + ] + } + }, + "query": "\n INSERT INTO upstream_oauth_providers (\n upstream_oauth_provider_id,\n issuer,\n scope,\n token_endpoint_auth_method,\n token_endpoint_signing_alg,\n client_id,\n encrypted_client_secret,\n created_at\n ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)\n " + }, "2153118b364a33582e7f598acce3789fcb8d938948a819b15cf0b6d37edf58b2": { "describe": { "columns": [], @@ -1010,6 +1091,23 @@ }, "query": "\n SELECT scope_token\n FROM oauth2_consents\n WHERE user_id = $1 AND oauth2_client_id = $2\n " }, + "53a652f0892d25654fe937962913f2f964463fd09f518066fbc83808edc5b394": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "Uuid", + "Uuid", + "Text", + "Text", + "Text", + "Timestamptz" + ] + } + }, + "query": "\n INSERT INTO upstream_oauth_authorization_sessions (\n upstream_oauth_authorization_session_id,\n upstream_oauth_provider_id,\n state,\n code_challenge_verifier,\n nonce,\n created_at,\n completed_at\n ) VALUES ($1, $2, $3, $4, $5, $6, NULL)\n " + }, "559a486756d08d101eb7188ef6637b9d24c024d056795b8121f7f04a7f9db6a3": { "describe": { "columns": [ @@ -1157,6 +1255,104 @@ }, "query": "\n UPDATE oauth2_access_tokens\n SET revoked_at = $2\n WHERE oauth2_access_token_id = $1\n " }, + "6c8816b2618db8d04ab9393429866d9af59ad280949947fc025c89baffe6a455": { + "describe": { + "columns": [ + { + "name": "upstream_oauth_authorization_session_id", + "ordinal": 0, + "type_info": "Uuid" + }, + { + "name": "upstream_oauth_provider_id", + "ordinal": 1, + "type_info": "Uuid" + }, + { + "name": "state", + "ordinal": 2, + "type_info": "Text" + }, + { + "name": "code_challenge_verifier", + "ordinal": 3, + "type_info": "Text" + }, + { + "name": "nonce", + "ordinal": 4, + "type_info": "Text" + }, + { + "name": "created_at", + "ordinal": 5, + "type_info": "Timestamptz" + }, + { + "name": "completed_at", + "ordinal": 6, + "type_info": "Timestamptz" + }, + { + "name": "provider_issuer", + "ordinal": 7, + "type_info": "Text" + }, + { + "name": "provider_scope", + "ordinal": 8, + "type_info": "Text" + }, + { + "name": "provider_client_id", + "ordinal": 9, + "type_info": "Text" + }, + { + "name": "provider_encrypted_client_secret", + "ordinal": 10, + "type_info": "Text" + }, + { + "name": "provider_token_endpoint_auth_method", + "ordinal": 11, + "type_info": "Text" + }, + { + "name": "provider_token_endpoint_signing_alg", + "ordinal": 12, + "type_info": "Text" + }, + { + "name": "provider_created_at", + "ordinal": 13, + "type_info": "Timestamptz" + } + ], + "nullable": [ + false, + false, + false, + true, + false, + false, + true, + false, + false, + false, + true, + false, + true, + false + ], + "parameters": { + "Left": [ + "Uuid" + ] + } + }, + "query": "\n SELECT\n ua.upstream_oauth_authorization_session_id,\n ua.upstream_oauth_provider_id,\n ua.state,\n ua.code_challenge_verifier,\n ua.nonce,\n ua.created_at,\n ua.completed_at,\n up.issuer AS \"provider_issuer\",\n up.scope AS \"provider_scope\",\n up.client_id AS \"provider_client_id\",\n up.encrypted_client_secret AS \"provider_encrypted_client_secret\",\n up.token_endpoint_auth_method AS \"provider_token_endpoint_auth_method\",\n up.token_endpoint_signing_alg AS \"provider_token_endpoint_signing_alg\",\n up.created_at AS \"provider_created_at\"\n FROM upstream_oauth_authorization_sessions ua\n INNER JOIN upstream_oauth_providers up\n USING (upstream_oauth_provider_id)\n WHERE upstream_oauth_authorization_session_id = $1\n " + }, "7262f81a335a984c4051383d2ede7455ff65ed90fbd3151d625f8a21fd26cb05": { "describe": { "columns": [], diff --git a/crates/storage/src/lib.rs b/crates/storage/src/lib.rs index bd0db87c..b47ea744 100644 --- a/crates/storage/src/lib.rs +++ b/crates/storage/src/lib.rs @@ -126,6 +126,7 @@ impl StorageBackendMarker for PostgresqlBackend {} pub mod compat; pub mod oauth2; pub(crate) mod pagination; +pub mod upstream_oauth2; pub mod user; /// Embedded migrations, allowing them to run on startup diff --git a/crates/storage/src/upstream_oauth2/mod.rs b/crates/storage/src/upstream_oauth2/mod.rs new file mode 100644 index 00000000..8acb8229 --- /dev/null +++ b/crates/storage/src/upstream_oauth2/mod.rs @@ -0,0 +1,21 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +mod provider; +mod session; + +pub use self::{ + provider::{add_provider, lookup_provider, ProviderLookupError}, + session::{add_session, lookup_session, SessionLookupError}, +}; diff --git a/crates/storage/src/upstream_oauth2/provider.rs b/crates/storage/src/upstream_oauth2/provider.rs new file mode 100644 index 00000000..f6da34fa --- /dev/null +++ b/crates/storage/src/upstream_oauth2/provider.rs @@ -0,0 +1,159 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use chrono::{DateTime, Utc}; +use mas_data_model::UpstreamOAuthProvider; +use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod}; +use oauth2_types::scope::Scope; +use rand::Rng; +use sqlx::PgExecutor; +use thiserror::Error; +use ulid::Ulid; +use uuid::Uuid; + +use crate::{Clock, DatabaseInconsistencyError, LookupError}; + +#[derive(Debug, Error)] +#[error("Failed to lookup upstream OAuth 2.0 provider")] +pub enum ProviderLookupError { + Driver(#[from] sqlx::Error), + Inconcistency(#[from] DatabaseInconsistencyError), +} + +impl LookupError for ProviderLookupError { + fn not_found(&self) -> bool { + matches!(self, Self::Driver(sqlx::Error::RowNotFound)) + } +} + +struct ProviderLookup { + upstream_oauth_provider_id: Uuid, + issuer: String, + scope: String, + client_id: String, + encrypted_client_secret: Option, + token_endpoint_signing_alg: Option, + token_endpoint_auth_method: String, + created_at: DateTime, +} + +#[tracing::instrument( + skip_all, + fields(upstream_oauth_provider.id = %id), + err, +)] +pub async fn lookup_provider( + executor: impl PgExecutor<'_>, + id: Ulid, +) -> Result { + let res = sqlx::query_as!( + ProviderLookup, + r#" + SELECT + upstream_oauth_provider_id, + issuer, + scope, + client_id, + encrypted_client_secret, + token_endpoint_signing_alg, + token_endpoint_auth_method, + created_at + FROM upstream_oauth_providers + WHERE upstream_oauth_provider_id = $1 + "#, + Uuid::from(id), + ) + .fetch_one(executor) + .await?; + + Ok(UpstreamOAuthProvider { + id: res.upstream_oauth_provider_id.into(), + issuer: res.issuer, + scope: res.scope.parse().map_err(|_| DatabaseInconsistencyError)?, + client_id: res.client_id, + encrypted_client_secret: res.encrypted_client_secret, + token_endpoint_auth_method: res + .token_endpoint_auth_method + .parse() + .map_err(|_| DatabaseInconsistencyError)?, + token_endpoint_signing_alg: res + .token_endpoint_signing_alg + .map(|x| x.parse()) + .transpose() + .map_err(|_| DatabaseInconsistencyError)?, + created_at: res.created_at, + }) +} + +#[tracing::instrument( + skip_all, + fields( + upstream_oauth_provider.id, + upstream_oauth_provider.issuer = %issuer, + upstream_oauth_provider.client_id = %client_id, + ), + err, +)] +#[allow(clippy::too_many_arguments)] +pub async fn add_provider( + executor: impl PgExecutor<'_>, + mut rng: impl Rng + Send, + clock: &Clock, + issuer: String, + scope: Scope, + token_endpoint_auth_method: OAuthClientAuthenticationMethod, + token_endpoint_signing_alg: Option, + client_id: String, + encrypted_client_secret: Option, +) -> Result { + let created_at = clock.now(); + let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng); + tracing::Span::current().record("upstream_oauth_provider.id", tracing::field::display(id)); + + sqlx::query!( + r#" + INSERT INTO upstream_oauth_providers ( + upstream_oauth_provider_id, + issuer, + scope, + token_endpoint_auth_method, + token_endpoint_signing_alg, + client_id, + encrypted_client_secret, + created_at + ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8) + "#, + Uuid::from(id), + &issuer, + scope.to_string(), + token_endpoint_auth_method.to_string(), + token_endpoint_signing_alg.as_ref().map(ToString::to_string), + &client_id, + encrypted_client_secret.as_deref(), + created_at, + ) + .execute(executor) + .await?; + + Ok(UpstreamOAuthProvider { + id, + issuer, + scope, + client_id, + encrypted_client_secret, + token_endpoint_signing_alg, + token_endpoint_auth_method, + created_at, + }) +} diff --git a/crates/storage/src/upstream_oauth2/session.rs b/crates/storage/src/upstream_oauth2/session.rs new file mode 100644 index 00000000..43be7a99 --- /dev/null +++ b/crates/storage/src/upstream_oauth2/session.rs @@ -0,0 +1,184 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use chrono::{DateTime, Utc}; +use mas_data_model::{UpstreamOAuthAuthorizationSession, UpstreamOAuthProvider}; +use rand::Rng; +use sqlx::PgExecutor; +use thiserror::Error; +use ulid::Ulid; +use uuid::Uuid; + +use crate::{Clock, DatabaseInconsistencyError, LookupError}; + +#[derive(Debug, Error)] +#[error("Failed to lookup upstream OAuth 2.0 authorization session")] +pub enum SessionLookupError { + Driver(#[from] sqlx::Error), + Inconcistency(#[from] DatabaseInconsistencyError), +} + +impl LookupError for SessionLookupError { + fn not_found(&self) -> bool { + matches!(self, Self::Driver(sqlx::Error::RowNotFound)) + } +} + +struct SessionLookup { + upstream_oauth_authorization_session_id: Uuid, + upstream_oauth_provider_id: Uuid, + state: String, + code_challenge_verifier: Option, + nonce: String, + created_at: DateTime, + completed_at: Option>, + provider_issuer: String, + provider_scope: String, + provider_client_id: String, + provider_encrypted_client_secret: Option, + provider_token_endpoint_auth_method: String, + provider_token_endpoint_signing_alg: Option, + provider_created_at: DateTime, +} + +#[tracing::instrument( + skip_all, + fields(upstream_oauth_authorization_session.id = %id), + err, +)] +pub async fn lookup_session( + executor: impl PgExecutor<'_>, + id: Ulid, +) -> Result<(UpstreamOAuthProvider, UpstreamOAuthAuthorizationSession), SessionLookupError> { + let res = sqlx::query_as!( + SessionLookup, + r#" + SELECT + ua.upstream_oauth_authorization_session_id, + ua.upstream_oauth_provider_id, + ua.state, + ua.code_challenge_verifier, + ua.nonce, + ua.created_at, + ua.completed_at, + up.issuer AS "provider_issuer", + up.scope AS "provider_scope", + up.client_id AS "provider_client_id", + up.encrypted_client_secret AS "provider_encrypted_client_secret", + up.token_endpoint_auth_method AS "provider_token_endpoint_auth_method", + up.token_endpoint_signing_alg AS "provider_token_endpoint_signing_alg", + up.created_at AS "provider_created_at" + FROM upstream_oauth_authorization_sessions ua + INNER JOIN upstream_oauth_providers up + USING (upstream_oauth_provider_id) + WHERE upstream_oauth_authorization_session_id = $1 + "#, + Uuid::from(id), + ) + .fetch_one(executor) + .await?; + + let provider = UpstreamOAuthProvider { + id: res.upstream_oauth_provider_id.into(), + issuer: res + .provider_issuer + .parse() + .map_err(|_| DatabaseInconsistencyError)?, + scope: res + .provider_scope + .parse() + .map_err(|_| DatabaseInconsistencyError)?, + client_id: res.provider_client_id, + encrypted_client_secret: res.provider_encrypted_client_secret, + token_endpoint_auth_method: res + .provider_token_endpoint_auth_method + .parse() + .map_err(|_| DatabaseInconsistencyError)?, + token_endpoint_signing_alg: res + .provider_token_endpoint_signing_alg + .map(|x| x.parse()) + .transpose() + .map_err(|_| DatabaseInconsistencyError)?, + created_at: res.provider_created_at, + }; + + let session = UpstreamOAuthAuthorizationSession { + id: res.upstream_oauth_authorization_session_id.into(), + state: res.state, + code_challenge_verifier: res.code_challenge_verifier, + nonce: res.nonce, + created_at: res.created_at, + completed_at: res.completed_at, + }; + + Ok((provider, session)) +} + +#[tracing::instrument( + skip_all, + fields( + upstream_oauth_provider.id = %provider.id, + upstream_oauth_provider.issuer = %provider.issuer, + upstream_oauth_provider.client_id = %provider.client_id, + upstream_oauth_authorization_session.id, + ), + err, +)] +pub async fn add_session( + executor: impl PgExecutor<'_>, + mut rng: impl Rng + Send, + clock: &Clock, + provider: &UpstreamOAuthProvider, + state: String, + code_challenge_verifier: Option, + nonce: String, +) -> Result { + let created_at = clock.now(); + let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng); + tracing::Span::current().record( + "upstream_oauth_authorization_session.id", + tracing::field::display(id), + ); + + sqlx::query!( + r#" + INSERT INTO upstream_oauth_authorization_sessions ( + upstream_oauth_authorization_session_id, + upstream_oauth_provider_id, + state, + code_challenge_verifier, + nonce, + created_at, + completed_at + ) VALUES ($1, $2, $3, $4, $5, $6, NULL) + "#, + Uuid::from(id), + Uuid::from(provider.id), + &state, + code_challenge_verifier.as_deref(), + nonce, + created_at, + ) + .execute(executor) + .await?; + + Ok(UpstreamOAuthAuthorizationSession { + id, + state, + code_challenge_verifier, + nonce, + created_at, + completed_at: None, + }) +}