You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-07-28 11:02:02 +03:00
WIP: upstream OIDC provider support
This commit is contained in:
5
Cargo.lock
generated
5
Cargo.lock
generated
@ -2687,6 +2687,7 @@ dependencies = [
|
|||||||
"mas-email",
|
"mas-email",
|
||||||
"mas-handlers",
|
"mas-handlers",
|
||||||
"mas-http",
|
"mas-http",
|
||||||
|
"mas-iana",
|
||||||
"mas-listener",
|
"mas-listener",
|
||||||
"mas-policy",
|
"mas-policy",
|
||||||
"mas-router",
|
"mas-router",
|
||||||
@ -2694,6 +2695,7 @@ dependencies = [
|
|||||||
"mas-storage",
|
"mas-storage",
|
||||||
"mas-tasks",
|
"mas-tasks",
|
||||||
"mas-templates",
|
"mas-templates",
|
||||||
|
"oauth2-types",
|
||||||
"opentelemetry",
|
"opentelemetry",
|
||||||
"opentelemetry-http",
|
"opentelemetry-http",
|
||||||
"opentelemetry-jaeger",
|
"opentelemetry-jaeger",
|
||||||
@ -2761,6 +2763,7 @@ dependencies = [
|
|||||||
"rand 0.8.5",
|
"rand 0.8.5",
|
||||||
"serde",
|
"serde",
|
||||||
"thiserror",
|
"thiserror",
|
||||||
|
"ulid",
|
||||||
"url",
|
"url",
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -2825,6 +2828,7 @@ dependencies = [
|
|||||||
"mas-iana",
|
"mas-iana",
|
||||||
"mas-jose",
|
"mas-jose",
|
||||||
"mas-keystore",
|
"mas-keystore",
|
||||||
|
"mas-oidc-client",
|
||||||
"mas-policy",
|
"mas-policy",
|
||||||
"mas-router",
|
"mas-router",
|
||||||
"mas-storage",
|
"mas-storage",
|
||||||
@ -3345,7 +3349,6 @@ dependencies = [
|
|||||||
"data-encoding",
|
"data-encoding",
|
||||||
"http",
|
"http",
|
||||||
"indoc",
|
"indoc",
|
||||||
"itertools",
|
|
||||||
"language-tags",
|
"language-tags",
|
||||||
"mas-iana",
|
"mas-iana",
|
||||||
"mas-jose",
|
"mas-jose",
|
||||||
|
@ -45,6 +45,7 @@ mas-config = { path = "../config" }
|
|||||||
mas-email = { path = "../email" }
|
mas-email = { path = "../email" }
|
||||||
mas-handlers = { path = "../handlers", default-features = false }
|
mas-handlers = { path = "../handlers", default-features = false }
|
||||||
mas-http = { path = "../http", default-features = false, features = ["axum", "client"] }
|
mas-http = { path = "../http", default-features = false, features = ["axum", "client"] }
|
||||||
|
mas-iana = { path = "../iana" }
|
||||||
mas-listener = { path = "../listener" }
|
mas-listener = { path = "../listener" }
|
||||||
mas-policy = { path = "../policy" }
|
mas-policy = { path = "../policy" }
|
||||||
mas-router = { path = "../router" }
|
mas-router = { path = "../router" }
|
||||||
@ -52,6 +53,7 @@ mas-spa = { path = "../spa" }
|
|||||||
mas-storage = { path = "../storage" }
|
mas-storage = { path = "../storage" }
|
||||||
mas-tasks = { path = "../tasks" }
|
mas-tasks = { path = "../tasks" }
|
||||||
mas-templates = { path = "../templates" }
|
mas-templates = { path = "../templates" }
|
||||||
|
oauth2-types = { path = "../oauth2-types" }
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
indoc = "1.0.7"
|
indoc = "1.0.7"
|
||||||
|
@ -13,8 +13,10 @@
|
|||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
use argon2::Argon2;
|
use argon2::Argon2;
|
||||||
use clap::Parser;
|
use clap::{Parser, ValueEnum};
|
||||||
use mas_config::{DatabaseConfig, RootConfig};
|
use mas_config::{DatabaseConfig, RootConfig};
|
||||||
|
use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod};
|
||||||
|
use mas_router::UrlBuilder;
|
||||||
use mas_storage::{
|
use mas_storage::{
|
||||||
oauth2::client::{insert_client_from_config, lookup_client, truncate_clients},
|
oauth2::client::{insert_client_from_config, lookup_client, truncate_clients},
|
||||||
user::{
|
user::{
|
||||||
@ -22,6 +24,7 @@ use mas_storage::{
|
|||||||
},
|
},
|
||||||
Clock, LookupError,
|
Clock, LookupError,
|
||||||
};
|
};
|
||||||
|
use oauth2_types::scope::Scope;
|
||||||
use rand::SeedableRng;
|
use rand::SeedableRng;
|
||||||
use tracing::{info, warn};
|
use tracing::{info, warn};
|
||||||
|
|
||||||
@ -31,6 +34,110 @@ pub(super) struct Options {
|
|||||||
subcommand: Subcommand,
|
subcommand: Subcommand,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, ValueEnum)]
|
||||||
|
enum AuthenticationMethod {
|
||||||
|
/// Client doesn't use any authentication
|
||||||
|
None,
|
||||||
|
|
||||||
|
/// Client sends its `client_secret` in the request body
|
||||||
|
ClientSecretPost,
|
||||||
|
|
||||||
|
/// Client sends its `client_secret` in the authorization header
|
||||||
|
ClientSecretBasic,
|
||||||
|
|
||||||
|
/// Client uses its `client_secret` to sign a client assertion JWT
|
||||||
|
ClientSecretJwt,
|
||||||
|
|
||||||
|
/// Client uses its private keys to sign a client assertion JWT
|
||||||
|
PrivateKeyJwt,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AuthenticationMethod {
|
||||||
|
fn requires_client_secret(self) -> bool {
|
||||||
|
matches!(
|
||||||
|
self,
|
||||||
|
Self::ClientSecretJwt | Self::ClientSecretPost | Self::ClientSecretBasic
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<AuthenticationMethod> for OAuthClientAuthenticationMethod {
|
||||||
|
fn from(val: AuthenticationMethod) -> Self {
|
||||||
|
(&val).into()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<&AuthenticationMethod> for OAuthClientAuthenticationMethod {
|
||||||
|
fn from(val: &AuthenticationMethod) -> Self {
|
||||||
|
match val {
|
||||||
|
AuthenticationMethod::None => OAuthClientAuthenticationMethod::None,
|
||||||
|
AuthenticationMethod::ClientSecretPost => {
|
||||||
|
OAuthClientAuthenticationMethod::ClientSecretPost
|
||||||
|
}
|
||||||
|
AuthenticationMethod::ClientSecretBasic => {
|
||||||
|
OAuthClientAuthenticationMethod::ClientSecretBasic
|
||||||
|
}
|
||||||
|
AuthenticationMethod::ClientSecretJwt => {
|
||||||
|
OAuthClientAuthenticationMethod::ClientSecretJwt
|
||||||
|
}
|
||||||
|
AuthenticationMethod::PrivateKeyJwt => OAuthClientAuthenticationMethod::PrivateKeyJwt,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, ValueEnum)]
|
||||||
|
enum SigningAlgorithm {
|
||||||
|
#[value(name = "HS256")]
|
||||||
|
HS256,
|
||||||
|
#[value(name = "HS384")]
|
||||||
|
HS384,
|
||||||
|
#[value(name = "HS512")]
|
||||||
|
HS512,
|
||||||
|
#[value(name = "RS256")]
|
||||||
|
RS256,
|
||||||
|
#[value(name = "RS384")]
|
||||||
|
RS384,
|
||||||
|
#[value(name = "RS512")]
|
||||||
|
RS512,
|
||||||
|
#[value(name = "PS256")]
|
||||||
|
PS256,
|
||||||
|
#[value(name = "PS384")]
|
||||||
|
PS384,
|
||||||
|
#[value(name = "PS512")]
|
||||||
|
PS512,
|
||||||
|
#[value(name = "ES256")]
|
||||||
|
ES256,
|
||||||
|
#[value(name = "ES384")]
|
||||||
|
ES384,
|
||||||
|
#[value(name = "ES256K")]
|
||||||
|
ES256K,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<SigningAlgorithm> for JsonWebSignatureAlg {
|
||||||
|
fn from(val: SigningAlgorithm) -> Self {
|
||||||
|
(&val).into()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<&SigningAlgorithm> for JsonWebSignatureAlg {
|
||||||
|
fn from(val: &SigningAlgorithm) -> Self {
|
||||||
|
match val {
|
||||||
|
SigningAlgorithm::HS256 => Self::Hs256,
|
||||||
|
SigningAlgorithm::HS384 => Self::Hs384,
|
||||||
|
SigningAlgorithm::HS512 => Self::Hs512,
|
||||||
|
SigningAlgorithm::RS256 => Self::Rs256,
|
||||||
|
SigningAlgorithm::RS384 => Self::Rs384,
|
||||||
|
SigningAlgorithm::RS512 => Self::Rs512,
|
||||||
|
SigningAlgorithm::PS256 => Self::Ps256,
|
||||||
|
SigningAlgorithm::PS384 => Self::Ps384,
|
||||||
|
SigningAlgorithm::PS512 => Self::Ps512,
|
||||||
|
SigningAlgorithm::ES256 => Self::Es256,
|
||||||
|
SigningAlgorithm::ES384 => Self::Es384,
|
||||||
|
SigningAlgorithm::ES256K => Self::Es256K,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Parser, Debug)]
|
#[derive(Parser, Debug)]
|
||||||
enum Subcommand {
|
enum Subcommand {
|
||||||
/// Register a new user
|
/// Register a new user
|
||||||
@ -48,9 +155,38 @@ enum Subcommand {
|
|||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
truncate: bool,
|
truncate: bool,
|
||||||
},
|
},
|
||||||
|
|
||||||
|
/// Add an OAuth 2.0 upstream
|
||||||
|
#[command(name = "add-oauth-upstream")]
|
||||||
|
AddOAuthUpstream {
|
||||||
|
/// Issuer URL
|
||||||
|
issuer: String,
|
||||||
|
|
||||||
|
/// Scope to ask for when authorizing with this upstream.
|
||||||
|
///
|
||||||
|
/// This should include at least the `openid` scope.
|
||||||
|
scope: Scope,
|
||||||
|
|
||||||
|
/// Client authentication method used when using the token endpoint.
|
||||||
|
#[arg(value_enum)]
|
||||||
|
token_endpoint_auth_method: AuthenticationMethod,
|
||||||
|
|
||||||
|
/// Client ID
|
||||||
|
client_id: String,
|
||||||
|
|
||||||
|
/// JWT signing algorithm used when authenticating for the token
|
||||||
|
/// endpoint.
|
||||||
|
#[arg(long, value_enum)]
|
||||||
|
signing_alg: Option<SigningAlgorithm>,
|
||||||
|
|
||||||
|
/// Client Secret
|
||||||
|
#[arg(long)]
|
||||||
|
client_secret: Option<String>,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Options {
|
impl Options {
|
||||||
|
#[allow(clippy::too_many_lines)]
|
||||||
pub async fn run(&self, root: &super::Options) -> anyhow::Result<()> {
|
pub async fn run(&self, root: &super::Options) -> anyhow::Result<()> {
|
||||||
use Subcommand as SC;
|
use Subcommand as SC;
|
||||||
let clock = Clock::default();
|
let clock = Clock::default();
|
||||||
@ -71,11 +207,13 @@ impl Options {
|
|||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
SC::Users => {
|
SC::Users => {
|
||||||
warn!("Not implemented yet");
|
warn!("Not implemented yet");
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
SC::VerifyEmail { username, email } => {
|
SC::VerifyEmail { username, email } => {
|
||||||
let config: DatabaseConfig = root.load_config()?;
|
let config: DatabaseConfig = root.load_config()?;
|
||||||
let pool = config.connect().await?;
|
let pool = config.connect().await?;
|
||||||
@ -90,6 +228,7 @@ impl Options {
|
|||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
SC::ImportClients { truncate } => {
|
SC::ImportClients { truncate } => {
|
||||||
let config: RootConfig = root.load_config()?;
|
let config: RootConfig = root.load_config()?;
|
||||||
let pool = config.database.connect().await?;
|
let pool = config.database.connect().await?;
|
||||||
@ -144,6 +283,64 @@ impl Options {
|
|||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
SC::AddOAuthUpstream {
|
||||||
|
issuer,
|
||||||
|
scope,
|
||||||
|
token_endpoint_auth_method,
|
||||||
|
client_id,
|
||||||
|
client_secret,
|
||||||
|
signing_alg,
|
||||||
|
} => {
|
||||||
|
let config: RootConfig = root.load_config()?;
|
||||||
|
let encrypter = config.secrets.encrypter();
|
||||||
|
let pool = config.database.connect().await?;
|
||||||
|
let url_builder = UrlBuilder::new(config.http.public_base);
|
||||||
|
let mut conn = pool.acquire().await?;
|
||||||
|
|
||||||
|
let requires_client_secret = token_endpoint_auth_method.requires_client_secret();
|
||||||
|
|
||||||
|
let token_endpoint_auth_method: OAuthClientAuthenticationMethod =
|
||||||
|
token_endpoint_auth_method.into();
|
||||||
|
|
||||||
|
let token_endpoint_signing_alg: Option<JsonWebSignatureAlg> =
|
||||||
|
signing_alg.as_ref().map(Into::into);
|
||||||
|
|
||||||
|
tracing::info!(%issuer, %scope, %token_endpoint_auth_method, %client_id, "Adding OAuth upstream");
|
||||||
|
|
||||||
|
if client_secret.is_none() && requires_client_secret {
|
||||||
|
tracing::warn!("Token endpoint auth method requires a client secret, but none were provided");
|
||||||
|
}
|
||||||
|
|
||||||
|
let encrypted_client_secret = client_secret
|
||||||
|
.as_deref()
|
||||||
|
.map(|client_secret| encrypter.encryt_to_string(client_secret.as_bytes()))
|
||||||
|
.transpose()?;
|
||||||
|
|
||||||
|
let provider = mas_storage::upstream_oauth2::add_provider(
|
||||||
|
&mut conn,
|
||||||
|
&mut rng,
|
||||||
|
&clock,
|
||||||
|
issuer.clone(),
|
||||||
|
scope.clone(),
|
||||||
|
token_endpoint_auth_method,
|
||||||
|
token_endpoint_signing_alg,
|
||||||
|
client_id.clone(),
|
||||||
|
encrypted_client_secret,
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let redirect_uri = url_builder.upstream_oauth_callback(provider.id);
|
||||||
|
let auth_uri = url_builder.upstream_oauth_authorize(provider.id);
|
||||||
|
tracing::info!(
|
||||||
|
%provider.id,
|
||||||
|
%provider.client_id,
|
||||||
|
provider.redirect_uri = %redirect_uri,
|
||||||
|
"Test authorization by going to {auth_uri}"
|
||||||
|
);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -119,7 +119,9 @@ impl SecretsConfig {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let key = JsonWebKey::new(key).with_kid(item.kid.clone());
|
let key = JsonWebKey::new(key)
|
||||||
|
.with_kid(item.kid.clone())
|
||||||
|
.with_use(mas_iana::jose::JsonWebKeyUse::Sig);
|
||||||
keys.push(key);
|
keys.push(key);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -12,6 +12,7 @@ serde = "1.0.148"
|
|||||||
url = { version = "2.3.1", features = ["serde"] }
|
url = { version = "2.3.1", features = ["serde"] }
|
||||||
crc = "3.0.0"
|
crc = "3.0.0"
|
||||||
rand = "0.8.5"
|
rand = "0.8.5"
|
||||||
|
ulid = "1.0.0"
|
||||||
|
|
||||||
mas-iana = { path = "../iana" }
|
mas-iana = { path = "../iana" }
|
||||||
mas-jose = { path = "../jose" }
|
mas-jose = { path = "../jose" }
|
||||||
|
@ -27,6 +27,7 @@ pub(crate) mod compat;
|
|||||||
pub(crate) mod oauth2;
|
pub(crate) mod oauth2;
|
||||||
pub(crate) mod tokens;
|
pub(crate) mod tokens;
|
||||||
pub(crate) mod traits;
|
pub(crate) mod traits;
|
||||||
|
pub(crate) mod upstream_oauth2;
|
||||||
pub(crate) mod users;
|
pub(crate) mod users;
|
||||||
|
|
||||||
pub use self::{
|
pub use self::{
|
||||||
@ -40,6 +41,9 @@ pub use self::{
|
|||||||
},
|
},
|
||||||
tokens::{AccessToken, RefreshToken, TokenFormatError, TokenType},
|
tokens::{AccessToken, RefreshToken, TokenFormatError, TokenType},
|
||||||
traits::{StorageBackend, StorageBackendMarker},
|
traits::{StorageBackend, StorageBackendMarker},
|
||||||
|
upstream_oauth2::{
|
||||||
|
UpstreamOAuthAuthorizationSession, UpstreamOAuthLink, UpstreamOAuthProvider,
|
||||||
|
},
|
||||||
users::{
|
users::{
|
||||||
Authentication, BrowserSession, User, UserEmail, UserEmailVerification,
|
Authentication, BrowserSession, User, UserEmail, UserEmailVerification,
|
||||||
UserEmailVerificationState,
|
UserEmailVerificationState,
|
||||||
|
48
crates/data-model/src/upstream_oauth2/mod.rs
Normal file
48
crates/data-model/src/upstream_oauth2/mod.rs
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
use chrono::{DateTime, Utc};
|
||||||
|
use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod};
|
||||||
|
use oauth2_types::scope::Scope;
|
||||||
|
use serde::Serialize;
|
||||||
|
use ulid::Ulid;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
|
||||||
|
pub struct UpstreamOAuthProvider {
|
||||||
|
pub id: Ulid,
|
||||||
|
pub issuer: String,
|
||||||
|
pub scope: Scope,
|
||||||
|
pub client_id: String,
|
||||||
|
pub encrypted_client_secret: Option<String>,
|
||||||
|
pub token_endpoint_signing_alg: Option<JsonWebSignatureAlg>,
|
||||||
|
pub token_endpoint_auth_method: OAuthClientAuthenticationMethod,
|
||||||
|
pub created_at: DateTime<Utc>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
|
||||||
|
pub struct UpstreamOAuthLink {
|
||||||
|
pub id: Ulid,
|
||||||
|
pub subject: String,
|
||||||
|
pub created_at: DateTime<Utc>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
|
||||||
|
pub struct UpstreamOAuthAuthorizationSession {
|
||||||
|
pub id: Ulid,
|
||||||
|
pub state: String,
|
||||||
|
pub code_challenge_verifier: Option<String>,
|
||||||
|
pub nonce: String,
|
||||||
|
pub created_at: DateTime<Utc>,
|
||||||
|
pub completed_at: Option<DateTime<Utc>>,
|
||||||
|
}
|
@ -52,7 +52,6 @@ rand_chacha = "0.3.1"
|
|||||||
headers = "0.3.8"
|
headers = "0.3.8"
|
||||||
ulid = "1.0.0"
|
ulid = "1.0.0"
|
||||||
|
|
||||||
oauth2-types = { path = "../oauth2-types" }
|
|
||||||
mas-axum-utils = { path = "../axum-utils", default-features = false }
|
mas-axum-utils = { path = "../axum-utils", default-features = false }
|
||||||
mas-data-model = { path = "../data-model" }
|
mas-data-model = { path = "../data-model" }
|
||||||
mas-email = { path = "../email" }
|
mas-email = { path = "../email" }
|
||||||
@ -61,10 +60,12 @@ mas-http = { path = "../http", default-features = false }
|
|||||||
mas-iana = { path = "../iana" }
|
mas-iana = { path = "../iana" }
|
||||||
mas-jose = { path = "../jose" }
|
mas-jose = { path = "../jose" }
|
||||||
mas-keystore = { path = "../keystore" }
|
mas-keystore = { path = "../keystore" }
|
||||||
|
mas-oidc-client = { path = "../oidc-client" }
|
||||||
mas-policy = { path = "../policy" }
|
mas-policy = { path = "../policy" }
|
||||||
mas-router = { path = "../router" }
|
mas-router = { path = "../router" }
|
||||||
mas-storage = { path = "../storage" }
|
mas-storage = { path = "../storage" }
|
||||||
mas-templates = { path = "../templates" }
|
mas-templates = { path = "../templates" }
|
||||||
|
oauth2-types = { path = "../oauth2-types" }
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
indoc = "1.0.7"
|
indoc = "1.0.7"
|
||||||
|
@ -52,6 +52,7 @@ mod compat;
|
|||||||
mod graphql;
|
mod graphql;
|
||||||
mod health;
|
mod health;
|
||||||
mod oauth2;
|
mod oauth2;
|
||||||
|
mod upstream_oauth2;
|
||||||
mod views;
|
mod views;
|
||||||
|
|
||||||
pub use compat::MatrixHomeserver;
|
pub use compat::MatrixHomeserver;
|
||||||
@ -233,6 +234,7 @@ where
|
|||||||
Encrypter: FromRef<S>,
|
Encrypter: FromRef<S>,
|
||||||
Templates: FromRef<S>,
|
Templates: FromRef<S>,
|
||||||
Mailer: FromRef<S>,
|
Mailer: FromRef<S>,
|
||||||
|
Keystore: FromRef<S>,
|
||||||
{
|
{
|
||||||
Router::new()
|
Router::new()
|
||||||
.route(
|
.route(
|
||||||
@ -296,6 +298,14 @@ where
|
|||||||
mas_router::CompatLoginSsoComplete::route(),
|
mas_router::CompatLoginSsoComplete::route(),
|
||||||
get(self::compat::login_sso_complete::get).post(self::compat::login_sso_complete::post),
|
get(self::compat::login_sso_complete::get).post(self::compat::login_sso_complete::post),
|
||||||
)
|
)
|
||||||
|
.route(
|
||||||
|
mas_router::UpstreamOAuth2Authorize::route(),
|
||||||
|
get(self::upstream_oauth2::authorize::get),
|
||||||
|
)
|
||||||
|
.route(
|
||||||
|
mas_router::UpstreamOAuth2Callback::route(),
|
||||||
|
get(self::upstream_oauth2::callback::get),
|
||||||
|
)
|
||||||
.layer(AndThenLayer::new(
|
.layer(AndThenLayer::new(
|
||||||
move |response: axum::response::Response| async move {
|
move |response: axum::response::Response| async move {
|
||||||
if response.status().is_server_error() {
|
if response.status().is_server_error() {
|
||||||
@ -315,43 +325,6 @@ where
|
|||||||
))
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
|
||||||
#[must_use]
|
|
||||||
#[allow(clippy::trait_duplication_in_bounds)]
|
|
||||||
pub fn router<S, B>(state: S) -> RouterService<B>
|
|
||||||
where
|
|
||||||
B: HttpBody + Send + 'static,
|
|
||||||
<B as HttpBody>::Data: Into<Bytes> + Send,
|
|
||||||
<B as HttpBody>::Error: std::error::Error + Send + Sync,
|
|
||||||
S: Clone + Send + Sync + 'static,
|
|
||||||
Keystore: FromRef<S>,
|
|
||||||
UrlBuilder: FromRef<S>,
|
|
||||||
Arc<PolicyFactory>: FromRef<S>,
|
|
||||||
PgPool: FromRef<S>,
|
|
||||||
Encrypter: FromRef<S>,
|
|
||||||
Templates: FromRef<S>,
|
|
||||||
Mailer: FromRef<S>,
|
|
||||||
MatrixHomeserver: FromRef<S>,
|
|
||||||
mas_graphql::Schema: FromRef<S>,
|
|
||||||
{
|
|
||||||
let healthcheck_router = healthcheck_router();
|
|
||||||
let discovery_router = discovery_router();
|
|
||||||
let api_router = api_router();
|
|
||||||
let graphql_router = graphql_router(true);
|
|
||||||
let compat_router = compat_router();
|
|
||||||
let human_router = human_router(Templates::from_ref(&state));
|
|
||||||
|
|
||||||
Router::new()
|
|
||||||
.merge(healthcheck_router)
|
|
||||||
.merge(discovery_router)
|
|
||||||
.merge(human_router)
|
|
||||||
.merge(api_router)
|
|
||||||
.merge(graphql_router)
|
|
||||||
.merge(compat_router)
|
|
||||||
.with_state(state)
|
|
||||||
}
|
|
||||||
*/
|
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
async fn test_state(pool: PgPool) -> Result<AppState, anyhow::Error> {
|
async fn test_state(pool: PgPool) -> Result<AppState, anyhow::Error> {
|
||||||
use mas_email::MailTransport;
|
use mas_email::MailTransport;
|
||||||
|
149
crates/handlers/src/upstream_oauth2/authorize.rs
Normal file
149
crates/handlers/src/upstream_oauth2/authorize.rs
Normal file
@ -0,0 +1,149 @@
|
|||||||
|
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
use axum::{
|
||||||
|
extract::{Path, State},
|
||||||
|
response::{IntoResponse, Redirect},
|
||||||
|
};
|
||||||
|
use axum_extra::extract::{cookie::Cookie, PrivateCookieJar};
|
||||||
|
use hyper::StatusCode;
|
||||||
|
use mas_http::ClientInitError;
|
||||||
|
use mas_keystore::Encrypter;
|
||||||
|
use mas_oidc_client::{
|
||||||
|
error::{AuthorizationError, DiscoveryError},
|
||||||
|
requests::authorization_code::AuthorizationRequestData,
|
||||||
|
};
|
||||||
|
use mas_router::UrlBuilder;
|
||||||
|
use mas_storage::{upstream_oauth2::lookup_provider, LookupResultExt};
|
||||||
|
use sqlx::PgPool;
|
||||||
|
use thiserror::Error;
|
||||||
|
use ulid::Ulid;
|
||||||
|
|
||||||
|
use super::http_service;
|
||||||
|
|
||||||
|
#[derive(Debug, Error)]
|
||||||
|
pub(crate) enum RouteError {
|
||||||
|
#[error("Provider not found")]
|
||||||
|
ProviderNotFound,
|
||||||
|
|
||||||
|
#[error(transparent)]
|
||||||
|
Authorization(#[from] AuthorizationError),
|
||||||
|
|
||||||
|
#[error(transparent)]
|
||||||
|
InternalError(Box<dyn std::error::Error>),
|
||||||
|
|
||||||
|
#[error(transparent)]
|
||||||
|
Anyhow(#[from] anyhow::Error),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<sqlx::Error> for RouteError {
|
||||||
|
fn from(e: sqlx::Error) -> Self {
|
||||||
|
Self::InternalError(Box::new(e))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<DiscoveryError> for RouteError {
|
||||||
|
fn from(e: DiscoveryError) -> Self {
|
||||||
|
Self::InternalError(Box::new(e))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<mas_storage::upstream_oauth2::ProviderLookupError> for RouteError {
|
||||||
|
fn from(e: mas_storage::upstream_oauth2::ProviderLookupError) -> Self {
|
||||||
|
Self::InternalError(Box::new(e))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<ClientInitError> for RouteError {
|
||||||
|
fn from(e: ClientInitError) -> Self {
|
||||||
|
Self::InternalError(Box::new(e))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl IntoResponse for RouteError {
|
||||||
|
fn into_response(self) -> axum::response::Response {
|
||||||
|
match self {
|
||||||
|
Self::ProviderNotFound => (StatusCode::NOT_FOUND, "Provider not found").into_response(),
|
||||||
|
Self::Authorization(e) => {
|
||||||
|
(StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response()
|
||||||
|
}
|
||||||
|
Self::InternalError(e) => {
|
||||||
|
(StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response()
|
||||||
|
}
|
||||||
|
Self::Anyhow(e) => {
|
||||||
|
(StatusCode::INTERNAL_SERVER_ERROR, format!("{e:?}")).into_response()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) async fn get(
|
||||||
|
State(pool): State<PgPool>,
|
||||||
|
State(url_builder): State<UrlBuilder>,
|
||||||
|
cookie_jar: PrivateCookieJar<Encrypter>,
|
||||||
|
Path(provider_id): Path<Ulid>,
|
||||||
|
) -> Result<impl IntoResponse, RouteError> {
|
||||||
|
let (clock, mut rng) = crate::rng_and_clock()?;
|
||||||
|
|
||||||
|
let mut txn = pool.begin().await?;
|
||||||
|
|
||||||
|
let provider = lookup_provider(&mut txn, provider_id)
|
||||||
|
.await
|
||||||
|
.to_option()?
|
||||||
|
.ok_or(RouteError::ProviderNotFound)?;
|
||||||
|
|
||||||
|
let http_service = http_service("upstream-discover").await?;
|
||||||
|
|
||||||
|
// First, discover the provider
|
||||||
|
let metadata =
|
||||||
|
mas_oidc_client::requests::discovery::discover(&http_service, &provider.issuer).await?;
|
||||||
|
|
||||||
|
let redirect_uri = url_builder.upstream_oauth_callback(provider.id);
|
||||||
|
|
||||||
|
let data = AuthorizationRequestData {
|
||||||
|
client_id: &provider.client_id,
|
||||||
|
scope: &provider.scope,
|
||||||
|
prompt: None,
|
||||||
|
redirect_uri: &redirect_uri,
|
||||||
|
code_challenge_methods_supported: metadata.code_challenge_methods_supported.as_deref(),
|
||||||
|
};
|
||||||
|
|
||||||
|
// Build an authorization request for it
|
||||||
|
let (url, data) = mas_oidc_client::requests::authorization_code::build_authorization_url(
|
||||||
|
metadata.authorization_endpoint().clone(),
|
||||||
|
data,
|
||||||
|
&mut rng,
|
||||||
|
)?;
|
||||||
|
|
||||||
|
let session = mas_storage::upstream_oauth2::add_session(
|
||||||
|
&mut txn,
|
||||||
|
&mut rng,
|
||||||
|
&clock,
|
||||||
|
&provider,
|
||||||
|
data.state,
|
||||||
|
data.code_challenge_verifier,
|
||||||
|
data.nonce,
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
// TODO: handle that cookie somewhere else?
|
||||||
|
let mut cookie = Cookie::new("upstream-oauth2-session-id", session.id.to_string());
|
||||||
|
cookie.set_path("/");
|
||||||
|
cookie.set_http_only(true);
|
||||||
|
let cookie_jar = cookie_jar.add(cookie);
|
||||||
|
|
||||||
|
txn.commit().await?;
|
||||||
|
|
||||||
|
Ok((cookie_jar, Redirect::temporary(url.as_str())))
|
||||||
|
}
|
290
crates/handlers/src/upstream_oauth2/callback.rs
Normal file
290
crates/handlers/src/upstream_oauth2/callback.rs
Normal file
@ -0,0 +1,290 @@
|
|||||||
|
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
use anyhow::Context;
|
||||||
|
use axum::{
|
||||||
|
extract::{Path, Query, State},
|
||||||
|
response::IntoResponse,
|
||||||
|
Json,
|
||||||
|
};
|
||||||
|
use axum_extra::extract::PrivateCookieJar;
|
||||||
|
use hyper::StatusCode;
|
||||||
|
use mas_http::ClientInitError;
|
||||||
|
use mas_iana::oauth::OAuthClientAuthenticationMethod;
|
||||||
|
use mas_keystore::{Encrypter, Keystore};
|
||||||
|
use mas_oidc_client::{
|
||||||
|
error::{DiscoveryError, JwksError, TokenAuthorizationCodeError},
|
||||||
|
requests::{authorization_code::AuthorizationValidationData, jose::JwtVerificationData},
|
||||||
|
types::client_credentials::ClientCredentials,
|
||||||
|
};
|
||||||
|
use mas_router::UrlBuilder;
|
||||||
|
use mas_storage::{upstream_oauth2::lookup_session, LookupResultExt};
|
||||||
|
use oauth2_types::errors::ClientErrorCode;
|
||||||
|
use serde::Deserialize;
|
||||||
|
use sqlx::PgPool;
|
||||||
|
use thiserror::Error;
|
||||||
|
use ulid::Ulid;
|
||||||
|
|
||||||
|
use super::http_service;
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
pub struct QueryParams {
|
||||||
|
state: String,
|
||||||
|
|
||||||
|
#[serde(flatten)]
|
||||||
|
code_or_error: CodeOrError,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
#[serde(untagged)]
|
||||||
|
enum CodeOrError {
|
||||||
|
Code {
|
||||||
|
code: String,
|
||||||
|
},
|
||||||
|
Error {
|
||||||
|
error: ClientErrorCode,
|
||||||
|
error_description: Option<String>,
|
||||||
|
#[allow(dead_code)]
|
||||||
|
error_uri: Option<String>,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Error)]
|
||||||
|
pub(crate) enum RouteError {
|
||||||
|
#[error("Session not found")]
|
||||||
|
SessionNotFound,
|
||||||
|
|
||||||
|
#[error("Provider mismatch")]
|
||||||
|
ProviderMismatch,
|
||||||
|
|
||||||
|
#[error("State parameter mismatch")]
|
||||||
|
StateMismatch,
|
||||||
|
|
||||||
|
#[error("Error from the provider: {error}")]
|
||||||
|
ClientError {
|
||||||
|
error: ClientErrorCode,
|
||||||
|
error_description: Option<String>,
|
||||||
|
},
|
||||||
|
|
||||||
|
#[error("Provider is missing a client secret")]
|
||||||
|
MissingClientSecret,
|
||||||
|
|
||||||
|
#[error("Missing session cookie")]
|
||||||
|
MissingCookie,
|
||||||
|
|
||||||
|
#[error("Invalid session cookie")]
|
||||||
|
InvalidCookie(#[source] ulid::DecodeError),
|
||||||
|
|
||||||
|
#[error(transparent)]
|
||||||
|
InternalError(Box<dyn std::error::Error>),
|
||||||
|
|
||||||
|
#[error(transparent)]
|
||||||
|
Anyhow(#[from] anyhow::Error),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<sqlx::Error> for RouteError {
|
||||||
|
fn from(e: sqlx::Error) -> Self {
|
||||||
|
Self::InternalError(Box::new(e))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<DiscoveryError> for RouteError {
|
||||||
|
fn from(e: DiscoveryError) -> Self {
|
||||||
|
Self::InternalError(Box::new(e))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<JwksError> for RouteError {
|
||||||
|
fn from(e: JwksError) -> Self {
|
||||||
|
Self::InternalError(Box::new(e))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<TokenAuthorizationCodeError> for RouteError {
|
||||||
|
fn from(e: TokenAuthorizationCodeError) -> Self {
|
||||||
|
Self::InternalError(Box::new(e))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<mas_storage::upstream_oauth2::SessionLookupError> for RouteError {
|
||||||
|
fn from(e: mas_storage::upstream_oauth2::SessionLookupError) -> Self {
|
||||||
|
Self::InternalError(Box::new(e))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<ClientInitError> for RouteError {
|
||||||
|
fn from(e: ClientInitError) -> Self {
|
||||||
|
Self::InternalError(Box::new(e))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl IntoResponse for RouteError {
|
||||||
|
fn into_response(self) -> axum::response::Response {
|
||||||
|
match self {
|
||||||
|
Self::SessionNotFound => (StatusCode::NOT_FOUND, "Session not found").into_response(),
|
||||||
|
Self::InternalError(e) => {
|
||||||
|
(StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response()
|
||||||
|
}
|
||||||
|
Self::Anyhow(e) => {
|
||||||
|
(StatusCode::INTERNAL_SERVER_ERROR, format!("{e:?}")).into_response()
|
||||||
|
}
|
||||||
|
e => (StatusCode::BAD_REQUEST, e.to_string()).into_response(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::too_many_lines)]
|
||||||
|
pub(crate) async fn get(
|
||||||
|
State(pool): State<PgPool>,
|
||||||
|
State(url_builder): State<UrlBuilder>,
|
||||||
|
State(encrypter): State<Encrypter>,
|
||||||
|
State(keystore): State<Keystore>,
|
||||||
|
cookie_jar: PrivateCookieJar<Encrypter>,
|
||||||
|
Path(provider_id): Path<Ulid>,
|
||||||
|
Query(params): Query<QueryParams>,
|
||||||
|
) -> Result<impl IntoResponse, RouteError> {
|
||||||
|
let (clock, mut rng) = crate::rng_and_clock()?;
|
||||||
|
|
||||||
|
let mut txn = pool.begin().await?;
|
||||||
|
|
||||||
|
// XXX: that cookie should be managed elsewhere
|
||||||
|
let cookie = cookie_jar
|
||||||
|
.get("upstream-oauth2-session-id")
|
||||||
|
.ok_or(RouteError::MissingCookie)?;
|
||||||
|
|
||||||
|
let session_id: Ulid = cookie.value().parse().map_err(RouteError::InvalidCookie)?;
|
||||||
|
|
||||||
|
let (provider, session) = lookup_session(&mut txn, session_id)
|
||||||
|
.await
|
||||||
|
.to_option()?
|
||||||
|
.ok_or(RouteError::SessionNotFound)?;
|
||||||
|
|
||||||
|
if provider.id != provider_id {
|
||||||
|
// The provider in the session cookie should match the one from the URL
|
||||||
|
return Err(RouteError::ProviderMismatch);
|
||||||
|
}
|
||||||
|
|
||||||
|
if params.state != session.state {
|
||||||
|
// The state in the session cookie should match the one from the params
|
||||||
|
return Err(RouteError::StateMismatch);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Let's extract the code from the params, and return if there was an error
|
||||||
|
let code = match params.code_or_error {
|
||||||
|
CodeOrError::Error {
|
||||||
|
error,
|
||||||
|
error_description,
|
||||||
|
..
|
||||||
|
} => {
|
||||||
|
return Err(RouteError::ClientError {
|
||||||
|
error,
|
||||||
|
error_description,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
CodeOrError::Code { code } => code,
|
||||||
|
};
|
||||||
|
|
||||||
|
let http_service = http_service("upstream-code-exchange").await?;
|
||||||
|
|
||||||
|
// XXX: we shouldn't discover on-the-fly
|
||||||
|
// Discover the provider
|
||||||
|
let metadata =
|
||||||
|
mas_oidc_client::requests::discovery::discover(&http_service, &provider.issuer).await?;
|
||||||
|
|
||||||
|
// Fetch the JWKS
|
||||||
|
let jwks =
|
||||||
|
mas_oidc_client::requests::jose::fetch_jwks(&http_service, metadata.jwks_uri()).await?;
|
||||||
|
|
||||||
|
// Figure out the client credentials
|
||||||
|
let client_id = provider.client_id.clone();
|
||||||
|
// Decrypt the client secret
|
||||||
|
let client_secret = provider
|
||||||
|
.encrypted_client_secret
|
||||||
|
.map(|encrypted_client_secret| {
|
||||||
|
encrypter
|
||||||
|
.decrypt_string(&encrypted_client_secret)
|
||||||
|
.and_then(|client_secret| {
|
||||||
|
String::from_utf8(client_secret)
|
||||||
|
.context("Client secret contains non-UTF8 bytes")
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.transpose()?;
|
||||||
|
|
||||||
|
let token_endpoint = metadata.token_endpoint();
|
||||||
|
|
||||||
|
let client_credentials = match provider.token_endpoint_auth_method {
|
||||||
|
OAuthClientAuthenticationMethod::None => ClientCredentials::None { client_id },
|
||||||
|
OAuthClientAuthenticationMethod::ClientSecretPost => ClientCredentials::ClientSecretPost {
|
||||||
|
client_id,
|
||||||
|
client_secret: client_secret.ok_or(RouteError::MissingClientSecret)?,
|
||||||
|
},
|
||||||
|
OAuthClientAuthenticationMethod::ClientSecretBasic => {
|
||||||
|
ClientCredentials::ClientSecretBasic {
|
||||||
|
client_id,
|
||||||
|
client_secret: client_secret.ok_or(RouteError::MissingClientSecret)?,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
OAuthClientAuthenticationMethod::ClientSecretJwt => ClientCredentials::ClientSecretJwt {
|
||||||
|
client_id,
|
||||||
|
client_secret: client_secret.ok_or(RouteError::MissingClientSecret)?,
|
||||||
|
signing_algorithm: provider
|
||||||
|
.token_endpoint_signing_alg
|
||||||
|
.unwrap_or(mas_iana::jose::JsonWebSignatureAlg::Rs256),
|
||||||
|
token_endpoint: token_endpoint.clone(),
|
||||||
|
},
|
||||||
|
OAuthClientAuthenticationMethod::PrivateKeyJwt => ClientCredentials::PrivateKeyJwt {
|
||||||
|
client_id,
|
||||||
|
jwt_signing_method:
|
||||||
|
mas_oidc_client::types::client_credentials::JwtSigningMethod::Keystore(keystore),
|
||||||
|
signing_algorithm: provider
|
||||||
|
.token_endpoint_signing_alg
|
||||||
|
.unwrap_or(mas_iana::jose::JsonWebSignatureAlg::Rs256),
|
||||||
|
token_endpoint: token_endpoint.clone(),
|
||||||
|
},
|
||||||
|
// XXX: The database should never have an unsupported method in it
|
||||||
|
_ => unreachable!(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let redirect_uri = url_builder.upstream_oauth_callback(provider.id);
|
||||||
|
|
||||||
|
let validation_data = AuthorizationValidationData {
|
||||||
|
state: session.state,
|
||||||
|
nonce: session.nonce,
|
||||||
|
code_challenge_verifier: session.code_challenge_verifier,
|
||||||
|
redirect_uri,
|
||||||
|
};
|
||||||
|
|
||||||
|
let id_token_verification_data = JwtVerificationData {
|
||||||
|
issuer: &provider.issuer,
|
||||||
|
jwks: &jwks,
|
||||||
|
// TODO: make that configurable
|
||||||
|
signing_algorithm: &mas_iana::jose::JsonWebSignatureAlg::Rs256,
|
||||||
|
client_id: &provider.client_id,
|
||||||
|
};
|
||||||
|
|
||||||
|
let (response, _id_token) =
|
||||||
|
mas_oidc_client::requests::authorization_code::access_token_with_authorization_code(
|
||||||
|
&http_service,
|
||||||
|
client_credentials,
|
||||||
|
token_endpoint,
|
||||||
|
code,
|
||||||
|
validation_data,
|
||||||
|
Some(id_token_verification_data),
|
||||||
|
clock.now(),
|
||||||
|
&mut rng,
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
Ok(Json(response))
|
||||||
|
}
|
35
crates/handlers/src/upstream_oauth2/mod.rs
Normal file
35
crates/handlers/src/upstream_oauth2/mod.rs
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
use axum::body::Full;
|
||||||
|
use mas_http::{BodyToBytesResponseLayer, ClientInitError, ClientLayer, HttpService};
|
||||||
|
use tower::{
|
||||||
|
util::{MapErrLayer, MapRequestLayer},
|
||||||
|
BoxError, Layer,
|
||||||
|
};
|
||||||
|
|
||||||
|
pub(crate) mod authorize;
|
||||||
|
pub(crate) mod callback;
|
||||||
|
|
||||||
|
async fn http_service(operation: &'static str) -> Result<HttpService, ClientInitError> {
|
||||||
|
let client = (
|
||||||
|
MapErrLayer::new(BoxError::from),
|
||||||
|
MapRequestLayer::new(|req: hyper::Request<_>| req.map(Full::new)),
|
||||||
|
BodyToBytesResponseLayer::default(),
|
||||||
|
ClientLayer::new(operation),
|
||||||
|
)
|
||||||
|
.layer(mas_http::make_untraced_client().await?);
|
||||||
|
|
||||||
|
Ok(HttpService::new(client))
|
||||||
|
}
|
@ -27,8 +27,8 @@ serde_json = "1.0.89"
|
|||||||
serde_urlencoded = "0.7.1"
|
serde_urlencoded = "0.7.1"
|
||||||
thiserror = "1.0.37"
|
thiserror = "1.0.37"
|
||||||
tokio = { version = "1.22.0", features = ["sync", "parking_lot"], optional = true }
|
tokio = { version = "1.22.0", features = ["sync", "parking_lot"], optional = true }
|
||||||
tower = { version = "0.4.13", features = ["timeout", "limit"] }
|
tower = { version = "0.4.13", features = ["limit"] }
|
||||||
tower-http = { version = "0.3.5", features = ["follow-redirect", "decompression-full", "set-header", "compression-full", "cors", "util"] }
|
tower-http = { version = "0.3.5", features = ["timeout", "follow-redirect", "decompression-full", "set-header", "compression-full", "cors", "util"] }
|
||||||
tracing = "0.1.37"
|
tracing = "0.1.37"
|
||||||
tracing-opentelemetry = "0.18.0"
|
tracing-opentelemetry = "0.18.0"
|
||||||
webpki = { version = "0.22.0", optional = true }
|
webpki = { version = "0.22.0", optional = true }
|
||||||
|
@ -16,7 +16,6 @@ use std::{convert::Infallible, net::SocketAddr};
|
|||||||
|
|
||||||
use bytes::Bytes;
|
use bytes::Bytes;
|
||||||
use http::{Request, Response};
|
use http::{Request, Response};
|
||||||
use http_body::{combinators::BoxBody, Body};
|
|
||||||
use hyper::{
|
use hyper::{
|
||||||
client::{
|
client::{
|
||||||
connect::dns::{GaiResolver, Name},
|
connect::dns::{GaiResolver, Name},
|
||||||
@ -26,14 +25,11 @@ use hyper::{
|
|||||||
};
|
};
|
||||||
use hyper_rustls::{HttpsConnector, HttpsConnectorBuilder};
|
use hyper_rustls::{HttpsConnector, HttpsConnectorBuilder};
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use tower::{
|
use tower::{Layer, Service};
|
||||||
util::{MapErrLayer, MapResponseLayer},
|
|
||||||
Layer, Service,
|
|
||||||
};
|
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
layers::{
|
layers::{
|
||||||
client::{ClientLayer, ClientResponse},
|
client::ClientLayer,
|
||||||
otel::{TraceDns, TraceLayer},
|
otel::{TraceDns, TraceLayer},
|
||||||
},
|
},
|
||||||
BoxCloneSyncService, BoxError,
|
BoxCloneSyncService, BoxError,
|
||||||
@ -229,32 +225,20 @@ where
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Create a traced HTTP client, with a default timeout, which follows redirects
|
/// Create a traced HTTP client, with a default timeout, which follows redirects
|
||||||
/// and handles compression
|
|
||||||
///
|
///
|
||||||
/// # Errors
|
/// # Errors
|
||||||
///
|
///
|
||||||
/// Returns an error if it failed to initialize
|
/// Returns an error if it failed to initialize
|
||||||
pub async fn client<B, E>(
|
pub async fn client<B, E>(
|
||||||
operation: &'static str,
|
operation: &'static str,
|
||||||
) -> Result<
|
) -> Result<BoxCloneSyncService<Request<B>, Response<hyper::Body>, hyper::Error>, ClientInitError>
|
||||||
BoxCloneSyncService<Request<B>, Response<BoxBody<bytes::Bytes, ClientError>>, ClientError>,
|
|
||||||
ClientInitError,
|
|
||||||
>
|
|
||||||
where
|
where
|
||||||
B: http_body::Body<Data = Bytes, Error = E> + Default + Send + 'static,
|
B: http_body::Body<Data = Bytes, Error = E> + Default + Send + 'static,
|
||||||
E: Into<BoxError> + 'static,
|
E: Into<BoxError> + 'static,
|
||||||
{
|
{
|
||||||
let client = make_traced_client().await?;
|
let client = make_traced_client().await?;
|
||||||
|
|
||||||
let layer = (
|
let client = ClientLayer::new(operation).layer(client);
|
||||||
// Convert the errors to ClientError to help dealing with them
|
|
||||||
MapErrLayer::new(ClientError::from),
|
|
||||||
MapResponseLayer::new(|r: ClientResponse<hyper::Body>| {
|
|
||||||
r.map(|body| body.map_err(ClientError::from).boxed())
|
|
||||||
}),
|
|
||||||
ClientLayer::new(operation),
|
|
||||||
);
|
|
||||||
let client = BoxCloneSyncService::new(layer.layer(client));
|
|
||||||
|
|
||||||
Ok(client)
|
Ok(BoxCloneSyncService::new(client))
|
||||||
}
|
}
|
||||||
|
@ -38,6 +38,14 @@ impl<S, B> Error<S, B> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<E> Error<E, E> {
|
||||||
|
pub fn unify(self) -> E {
|
||||||
|
match self {
|
||||||
|
Self::Service { inner } | Self::Body { inner } => inner,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct BodyToBytesResponse<S> {
|
pub struct BodyToBytesResponse<S> {
|
||||||
inner: S,
|
inner: S,
|
||||||
|
@ -17,13 +17,12 @@ use std::{marker::PhantomData, time::Duration};
|
|||||||
use http::{header::USER_AGENT, HeaderValue, Request, Response};
|
use http::{header::USER_AGENT, HeaderValue, Request, Response};
|
||||||
use tower::{
|
use tower::{
|
||||||
limit::{ConcurrencyLimit, ConcurrencyLimitLayer},
|
limit::{ConcurrencyLimit, ConcurrencyLimitLayer},
|
||||||
timeout::{Timeout, TimeoutLayer},
|
|
||||||
Layer, Service,
|
Layer, Service,
|
||||||
};
|
};
|
||||||
use tower_http::{
|
use tower_http::{
|
||||||
decompression::{Decompression, DecompressionBody, DecompressionLayer},
|
|
||||||
follow_redirect::{FollowRedirect, FollowRedirectLayer},
|
follow_redirect::{FollowRedirect, FollowRedirectLayer},
|
||||||
set_header::{SetRequestHeader, SetRequestHeaderLayer},
|
set_header::{SetRequestHeader, SetRequestHeaderLayer},
|
||||||
|
timeout::{Timeout, TimeoutLayer},
|
||||||
};
|
};
|
||||||
|
|
||||||
use super::otel::TraceLayer;
|
use super::otel::TraceLayer;
|
||||||
@ -48,9 +47,6 @@ impl<B> ClientLayer<B> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(dead_code)]
|
|
||||||
pub type ClientResponse<B> = Response<DecompressionBody<B>>;
|
|
||||||
|
|
||||||
impl<ReqBody, ResBody, S, E> Layer<S> for ClientLayer<ReqBody>
|
impl<ReqBody, ResBody, S, E> Layer<S> for ClientLayer<ReqBody>
|
||||||
where
|
where
|
||||||
S: Service<Request<ReqBody>, Response = Response<ResBody>, Error = E>
|
S: Service<Request<ReqBody>, Response = Response<ResBody>, Error = E>
|
||||||
@ -63,21 +59,14 @@ where
|
|||||||
S::Future: Send + 'static,
|
S::Future: Send + 'static,
|
||||||
E: Into<BoxError>,
|
E: Into<BoxError>,
|
||||||
{
|
{
|
||||||
type Service = Decompression<
|
type Service = SetRequestHeader<
|
||||||
SetRequestHeader<
|
TraceHttpClient<ConcurrencyLimit<FollowRedirect<TraceHttpClient<Timeout<S>>>>>,
|
||||||
TraceHttpClient<ConcurrencyLimit<FollowRedirect<TraceHttpClient<Timeout<S>>>>>,
|
HeaderValue,
|
||||||
HeaderValue,
|
|
||||||
>,
|
|
||||||
>;
|
>;
|
||||||
|
|
||||||
fn layer(&self, inner: S) -> Self::Service {
|
fn layer(&self, inner: S) -> Self::Service {
|
||||||
// Note that most layers here just forward the error type. Two notables
|
// Note that all layers here just forward the error type.
|
||||||
// exceptions are:
|
|
||||||
// - the TimeoutLayer
|
|
||||||
// - the DecompressionLayer
|
|
||||||
// Those layers do type erasure of the error.
|
|
||||||
(
|
(
|
||||||
DecompressionLayer::new(),
|
|
||||||
SetRequestHeaderLayer::overriding(USER_AGENT, MAS_USER_AGENT.clone()),
|
SetRequestHeaderLayer::overriding(USER_AGENT, MAS_USER_AGENT.clone()),
|
||||||
// A trace that has the whole operation, with all the redirects, timeouts and rate
|
// A trace that has the whole operation, with all the redirects, timeouts and rate
|
||||||
// limits in it
|
// limits in it
|
||||||
|
@ -18,7 +18,6 @@ chrono = "0.4.23"
|
|||||||
sha2 = "0.10.6"
|
sha2 = "0.10.6"
|
||||||
data-encoding = "2.3.2"
|
data-encoding = "2.3.2"
|
||||||
thiserror = "1.0.37"
|
thiserror = "1.0.37"
|
||||||
itertools = "0.10.5"
|
|
||||||
|
|
||||||
mas-iana = { path = "../iana" }
|
mas-iana = { path = "../iana" }
|
||||||
mas-jose = { path = "../jose" }
|
mas-jose = { path = "../jose" }
|
||||||
|
@ -20,7 +20,6 @@
|
|||||||
|
|
||||||
use std::{collections::BTreeSet, fmt, iter::FromIterator, str::FromStr};
|
use std::{collections::BTreeSet, fmt, iter::FromIterator, str::FromStr};
|
||||||
|
|
||||||
use itertools::Itertools;
|
|
||||||
use mas_iana::oauth::OAuthAuthorizationEndpointResponseType;
|
use mas_iana::oauth::OAuthAuthorizationEndpointResponseType;
|
||||||
use parse_display::{Display, FromStr};
|
use parse_display::{Display, FromStr};
|
||||||
use serde_with::{DeserializeFromStr, SerializeDisplay};
|
use serde_with::{DeserializeFromStr, SerializeDisplay};
|
||||||
@ -127,14 +126,23 @@ impl FromStr for ResponseType {
|
|||||||
|
|
||||||
impl fmt::Display for ResponseType {
|
impl fmt::Display for ResponseType {
|
||||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
let res = Itertools::intersperse(self.iter().map(ToString::to_string), ' '.to_string())
|
let mut iter = self.iter();
|
||||||
.collect::<String>();
|
|
||||||
|
|
||||||
if res.is_empty() {
|
// First item shouldn't have a leading space
|
||||||
write!(f, "none")
|
if let Some(first) = iter.next() {
|
||||||
|
first.fmt(f)?;
|
||||||
} else {
|
} else {
|
||||||
f.write_str(&res)
|
// If the whole iterator is empty, write 'none' instead
|
||||||
|
write!(f, "none")?;
|
||||||
|
return Ok(());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Write the other items with a leading space
|
||||||
|
for item in iter {
|
||||||
|
write!(f, " {item}")?;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -20,7 +20,6 @@
|
|||||||
|
|
||||||
use std::{borrow::Cow, collections::BTreeSet, iter::FromIterator, ops::Deref, str::FromStr};
|
use std::{borrow::Cow, collections::BTreeSet, iter::FromIterator, ops::Deref, str::FromStr};
|
||||||
|
|
||||||
use itertools::Itertools;
|
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
|
|
||||||
@ -106,9 +105,9 @@ impl Deref for ScopeToken {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ToString for ScopeToken {
|
impl std::fmt::Display for ScopeToken {
|
||||||
fn to_string(&self) -> String {
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
self.0.to_string()
|
self.0.fmt(f)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -169,10 +168,17 @@ impl Scope {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ToString for Scope {
|
impl std::fmt::Display for Scope {
|
||||||
fn to_string(&self) -> String {
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
let it = self.0.iter().map(ScopeToken::to_string);
|
for (index, token) in self.0.iter().enumerate() {
|
||||||
Itertools::intersperse(it, ' '.to_string()).collect()
|
if index == 0 {
|
||||||
|
write!(f, "{token}")?;
|
||||||
|
} else {
|
||||||
|
write!(f, " {token}")?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -23,6 +23,7 @@ use http::Request;
|
|||||||
use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod};
|
use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod};
|
||||||
use mas_jose::{
|
use mas_jose::{
|
||||||
claims::{self, ClaimError},
|
claims::{self, ClaimError},
|
||||||
|
constraints::Constrainable,
|
||||||
jwa::SymmetricKey,
|
jwa::SymmetricKey,
|
||||||
jwt::{JsonWebSignatureHeader, Jwt},
|
jwt::{JsonWebSignatureHeader, Jwt},
|
||||||
};
|
};
|
||||||
@ -338,7 +339,12 @@ impl RequestClientCredentials {
|
|||||||
.signing_key_for_algorithm(&signing_algorithm)
|
.signing_key_for_algorithm(&signing_algorithm)
|
||||||
.ok_or(CredentialsError::NoPrivateKeyFound)?;
|
.ok_or(CredentialsError::NoPrivateKeyFound)?;
|
||||||
let signer = key.params().signing_key_for_alg(&signing_algorithm)?;
|
let signer = key.params().signing_key_for_alg(&signing_algorithm)?;
|
||||||
let header = JsonWebSignatureHeader::new(signing_algorithm);
|
let mut header = JsonWebSignatureHeader::new(signing_algorithm);
|
||||||
|
|
||||||
|
if let Some(kid) = key.kid() {
|
||||||
|
header = header.with_kid(kid);
|
||||||
|
}
|
||||||
|
|
||||||
Jwt::sign(header, claims, &signer)?.to_string()
|
Jwt::sign(header, claims, &signer)?.to_string()
|
||||||
}
|
}
|
||||||
JwtSigningMethod::Custom(jwt_signing_fn) => {
|
JwtSigningMethod::Custom(jwt_signing_fn) => {
|
||||||
|
@ -524,6 +524,52 @@ impl Route for CompatLoginSsoComplete {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// `GET /upstream/authorize/:id`
|
||||||
|
pub struct UpstreamOAuth2Authorize {
|
||||||
|
id: Ulid,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl UpstreamOAuth2Authorize {
|
||||||
|
#[must_use]
|
||||||
|
pub const fn new(id: Ulid) -> Self {
|
||||||
|
Self { id }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Route for UpstreamOAuth2Authorize {
|
||||||
|
type Query = ();
|
||||||
|
fn route() -> &'static str {
|
||||||
|
"/upstream/authorize/:provider_id"
|
||||||
|
}
|
||||||
|
|
||||||
|
fn path(&self) -> std::borrow::Cow<'static, str> {
|
||||||
|
format!("/upstream/authorize/{}", self.id).into()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// `GET /upstream/callback/:id`
|
||||||
|
pub struct UpstreamOAuth2Callback {
|
||||||
|
id: Ulid,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl UpstreamOAuth2Callback {
|
||||||
|
#[must_use]
|
||||||
|
pub const fn new(id: Ulid) -> Self {
|
||||||
|
Self { id }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Route for UpstreamOAuth2Callback {
|
||||||
|
type Query = ();
|
||||||
|
fn route() -> &'static str {
|
||||||
|
"/upstream/callback/:provider_id"
|
||||||
|
}
|
||||||
|
|
||||||
|
fn path(&self) -> std::borrow::Cow<'static, str> {
|
||||||
|
format!("/upstream/callback/{}", self.id).into()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// `GET /assets`
|
/// `GET /assets`
|
||||||
pub struct StaticAsset {
|
pub struct StaticAsset {
|
||||||
path: String,
|
path: String,
|
||||||
|
@ -14,6 +14,7 @@
|
|||||||
|
|
||||||
//! Utility to build URLs
|
//! Utility to build URLs
|
||||||
|
|
||||||
|
use ulid::Ulid;
|
||||||
use url::Url;
|
use url::Url;
|
||||||
|
|
||||||
use crate::traits::Route;
|
use crate::traits::Route;
|
||||||
@ -97,4 +98,16 @@ impl UrlBuilder {
|
|||||||
pub fn static_asset(&self, path: String) -> Url {
|
pub fn static_asset(&self, path: String) -> Url {
|
||||||
self.url_for(&crate::endpoints::StaticAsset::new(path))
|
self.url_for(&crate::endpoints::StaticAsset::new(path))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Upstream redirect URI
|
||||||
|
#[must_use]
|
||||||
|
pub fn upstream_oauth_callback(&self, id: Ulid) -> Url {
|
||||||
|
self.url_for(&crate::endpoints::UpstreamOAuth2Callback::new(id))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Upstream authorize URI
|
||||||
|
#[must_use]
|
||||||
|
pub fn upstream_oauth_authorize(&self, id: Ulid) -> Url {
|
||||||
|
self.url_for(&crate::endpoints::UpstreamOAuth2Authorize::new(id))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
84
crates/storage/migrations/20221121151402_upstream_oauth.sql
Normal file
84
crates/storage/migrations/20221121151402_upstream_oauth.sql
Normal file
@ -0,0 +1,84 @@
|
|||||||
|
-- Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||||
|
--
|
||||||
|
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
-- you may not use this file except in compliance with the License.
|
||||||
|
-- You may obtain a copy of the License at
|
||||||
|
--
|
||||||
|
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
--
|
||||||
|
-- Unless required by applicable law or agreed to in writing, software
|
||||||
|
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
-- See the License for the specific language governing permissions and
|
||||||
|
-- limitations under the License.
|
||||||
|
|
||||||
|
CREATE TABLE "upstream_oauth_providers" (
|
||||||
|
"upstream_oauth_provider_id" UUID NOT NULL
|
||||||
|
CONSTRAINT "upstream_oauth_providers_pkey"
|
||||||
|
PRIMARY KEY,
|
||||||
|
|
||||||
|
"issuer" TEXT NOT NULL,
|
||||||
|
|
||||||
|
"scope" TEXT NOT NULL,
|
||||||
|
|
||||||
|
"client_id" TEXT NOT NULL,
|
||||||
|
|
||||||
|
-- Used for client_secret_basic, client_secret_post and client_secret_jwt auth methods
|
||||||
|
"encrypted_client_secret" TEXT,
|
||||||
|
|
||||||
|
-- Used for client_secret_jwt and private_key_jwt auth methods
|
||||||
|
"token_endpoint_signing_alg" TEXT,
|
||||||
|
|
||||||
|
"token_endpoint_auth_method" TEXT NOT NULL,
|
||||||
|
|
||||||
|
"created_at" TIMESTAMP WITH TIME ZONE NOT NULL
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE "upstream_oauth_links" (
|
||||||
|
"upstream_oauth_link_id" UUID NOT NULL
|
||||||
|
CONSTRAINT "upstream_oauth_links_pkey"
|
||||||
|
PRIMARY KEY,
|
||||||
|
|
||||||
|
"upstream_oauth_provider_id" UUID NOT NULL
|
||||||
|
CONSTRAINT "upstream_oauth_links_upstream_oauth_provider_fkey"
|
||||||
|
REFERENCES "upstream_oauth_providers" ("upstream_oauth_provider_id"),
|
||||||
|
|
||||||
|
-- The user is initially NULL when logging in the first time.
|
||||||
|
-- It then either links to an existing account, or creates a new one from scratch.
|
||||||
|
"user_id" UUID
|
||||||
|
CONSTRAINT "upstream_oauth_link_user_fkey"
|
||||||
|
REFERENCES "users" ("user_id"),
|
||||||
|
|
||||||
|
"subject" TEXT NOT NULL,
|
||||||
|
|
||||||
|
"created_at" TIMESTAMP WITH TIME ZONE NOT NULL,
|
||||||
|
|
||||||
|
-- There should only be one entry per subject/provider tuple
|
||||||
|
CONSTRAINT "upstream_oauth_links_subject_unique"
|
||||||
|
UNIQUE ("upstream_oauth_provider_id", "subject")
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE "upstream_oauth_authorization_sessions" (
|
||||||
|
"upstream_oauth_authorization_session_id" UUID NOT NULL
|
||||||
|
CONSTRAINT "upstream_oauth_authorization_sessions_pkey"
|
||||||
|
PRIMARY KEY,
|
||||||
|
|
||||||
|
"upstream_oauth_provider_id" UUID NOT NULL
|
||||||
|
CONSTRAINT "upstream_oauth_authorization_sessions_upstream_oauth_provider_fkey"
|
||||||
|
REFERENCES "upstream_oauth_providers" ("upstream_oauth_provider_id"),
|
||||||
|
|
||||||
|
-- The link it resolves to at the end of the authorization grant
|
||||||
|
"upstream_oauth_link_id" UUID
|
||||||
|
CONSTRAINT "upstream_oauth_authorization_sessions_upstream_oauth_link_fkey"
|
||||||
|
REFERENCES "upstream_oauth_links" ("upstream_oauth_link_id"),
|
||||||
|
|
||||||
|
"state" TEXT NOT NULL
|
||||||
|
CONSTRAINT "upstream_oauth_authorization_sessions_state_unique"
|
||||||
|
UNIQUE,
|
||||||
|
|
||||||
|
"code_challenge_verifier" TEXT,
|
||||||
|
"nonce" TEXT NOT NULL,
|
||||||
|
|
||||||
|
"created_at" TIMESTAMP WITH TIME ZONE NOT NULL,
|
||||||
|
"completed_at" TIMESTAMP WITH TIME ZONE
|
||||||
|
);
|
@ -214,6 +214,68 @@
|
|||||||
},
|
},
|
||||||
"query": "\n SELECT\n c.oauth2_client_id,\n c.encrypted_client_secret,\n ARRAY(\n SELECT redirect_uri\n FROM oauth2_client_redirect_uris r\n WHERE r.oauth2_client_id = c.oauth2_client_id\n ) AS \"redirect_uris!\",\n c.grant_type_authorization_code,\n c.grant_type_refresh_token,\n c.client_name,\n c.logo_uri,\n c.client_uri,\n c.policy_uri,\n c.tos_uri,\n c.jwks_uri,\n c.jwks,\n c.id_token_signed_response_alg,\n c.userinfo_signed_response_alg,\n c.token_endpoint_auth_method,\n c.token_endpoint_auth_signing_alg,\n c.initiate_login_uri\n FROM oauth2_clients c\n\n WHERE c.oauth2_client_id = $1\n "
|
"query": "\n SELECT\n c.oauth2_client_id,\n c.encrypted_client_secret,\n ARRAY(\n SELECT redirect_uri\n FROM oauth2_client_redirect_uris r\n WHERE r.oauth2_client_id = c.oauth2_client_id\n ) AS \"redirect_uris!\",\n c.grant_type_authorization_code,\n c.grant_type_refresh_token,\n c.client_name,\n c.logo_uri,\n c.client_uri,\n c.policy_uri,\n c.tos_uri,\n c.jwks_uri,\n c.jwks,\n c.id_token_signed_response_alg,\n c.userinfo_signed_response_alg,\n c.token_endpoint_auth_method,\n c.token_endpoint_auth_signing_alg,\n c.initiate_login_uri\n FROM oauth2_clients c\n\n WHERE c.oauth2_client_id = $1\n "
|
||||||
},
|
},
|
||||||
|
"0af182315b36766eca8e232280986bade0202d1b1d64160a99cd14eadcbfc25b": {
|
||||||
|
"describe": {
|
||||||
|
"columns": [
|
||||||
|
{
|
||||||
|
"name": "upstream_oauth_provider_id",
|
||||||
|
"ordinal": 0,
|
||||||
|
"type_info": "Uuid"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "issuer",
|
||||||
|
"ordinal": 1,
|
||||||
|
"type_info": "Text"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "scope",
|
||||||
|
"ordinal": 2,
|
||||||
|
"type_info": "Text"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "client_id",
|
||||||
|
"ordinal": 3,
|
||||||
|
"type_info": "Text"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "encrypted_client_secret",
|
||||||
|
"ordinal": 4,
|
||||||
|
"type_info": "Text"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "token_endpoint_signing_alg",
|
||||||
|
"ordinal": 5,
|
||||||
|
"type_info": "Text"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "token_endpoint_auth_method",
|
||||||
|
"ordinal": 6,
|
||||||
|
"type_info": "Text"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "created_at",
|
||||||
|
"ordinal": 7,
|
||||||
|
"type_info": "Timestamptz"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"nullable": [
|
||||||
|
false,
|
||||||
|
false,
|
||||||
|
false,
|
||||||
|
false,
|
||||||
|
true,
|
||||||
|
true,
|
||||||
|
false,
|
||||||
|
false
|
||||||
|
],
|
||||||
|
"parameters": {
|
||||||
|
"Left": [
|
||||||
|
"Uuid"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"query": "\n SELECT\n upstream_oauth_provider_id,\n issuer,\n scope,\n client_id,\n encrypted_client_secret,\n token_endpoint_signing_alg,\n token_endpoint_auth_method,\n created_at\n FROM upstream_oauth_providers\n WHERE upstream_oauth_provider_id = $1\n "
|
||||||
|
},
|
||||||
"0b49cde0b7b79f79ec261502ab89bcffa81f9f5ed2f922a41b1718274b9e3073": {
|
"0b49cde0b7b79f79ec261502ab89bcffa81f9f5ed2f922a41b1718274b9e3073": {
|
||||||
"describe": {
|
"describe": {
|
||||||
"columns": [
|
"columns": [
|
||||||
@ -291,6 +353,25 @@
|
|||||||
},
|
},
|
||||||
"query": "\n INSERT INTO user_session_authentications\n (user_session_authentication_id, user_session_id, created_at)\n VALUES ($1, $2, $3)\n "
|
"query": "\n INSERT INTO user_session_authentications\n (user_session_authentication_id, user_session_id, created_at)\n VALUES ($1, $2, $3)\n "
|
||||||
},
|
},
|
||||||
|
"1ee5cecfafd4726a4ebc08da8a34c09178e6e1e072581c8fca9d3d76967792cb": {
|
||||||
|
"describe": {
|
||||||
|
"columns": [],
|
||||||
|
"nullable": [],
|
||||||
|
"parameters": {
|
||||||
|
"Left": [
|
||||||
|
"Uuid",
|
||||||
|
"Text",
|
||||||
|
"Text",
|
||||||
|
"Text",
|
||||||
|
"Text",
|
||||||
|
"Text",
|
||||||
|
"Text",
|
||||||
|
"Timestamptz"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"query": "\n INSERT INTO upstream_oauth_providers (\n upstream_oauth_provider_id,\n issuer,\n scope,\n token_endpoint_auth_method,\n token_endpoint_signing_alg,\n client_id,\n encrypted_client_secret,\n created_at\n ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)\n "
|
||||||
|
},
|
||||||
"2153118b364a33582e7f598acce3789fcb8d938948a819b15cf0b6d37edf58b2": {
|
"2153118b364a33582e7f598acce3789fcb8d938948a819b15cf0b6d37edf58b2": {
|
||||||
"describe": {
|
"describe": {
|
||||||
"columns": [],
|
"columns": [],
|
||||||
@ -1010,6 +1091,23 @@
|
|||||||
},
|
},
|
||||||
"query": "\n SELECT scope_token\n FROM oauth2_consents\n WHERE user_id = $1 AND oauth2_client_id = $2\n "
|
"query": "\n SELECT scope_token\n FROM oauth2_consents\n WHERE user_id = $1 AND oauth2_client_id = $2\n "
|
||||||
},
|
},
|
||||||
|
"53a652f0892d25654fe937962913f2f964463fd09f518066fbc83808edc5b394": {
|
||||||
|
"describe": {
|
||||||
|
"columns": [],
|
||||||
|
"nullable": [],
|
||||||
|
"parameters": {
|
||||||
|
"Left": [
|
||||||
|
"Uuid",
|
||||||
|
"Uuid",
|
||||||
|
"Text",
|
||||||
|
"Text",
|
||||||
|
"Text",
|
||||||
|
"Timestamptz"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"query": "\n INSERT INTO upstream_oauth_authorization_sessions (\n upstream_oauth_authorization_session_id,\n upstream_oauth_provider_id,\n state,\n code_challenge_verifier,\n nonce,\n created_at,\n completed_at\n ) VALUES ($1, $2, $3, $4, $5, $6, NULL)\n "
|
||||||
|
},
|
||||||
"559a486756d08d101eb7188ef6637b9d24c024d056795b8121f7f04a7f9db6a3": {
|
"559a486756d08d101eb7188ef6637b9d24c024d056795b8121f7f04a7f9db6a3": {
|
||||||
"describe": {
|
"describe": {
|
||||||
"columns": [
|
"columns": [
|
||||||
@ -1157,6 +1255,104 @@
|
|||||||
},
|
},
|
||||||
"query": "\n UPDATE oauth2_access_tokens\n SET revoked_at = $2\n WHERE oauth2_access_token_id = $1\n "
|
"query": "\n UPDATE oauth2_access_tokens\n SET revoked_at = $2\n WHERE oauth2_access_token_id = $1\n "
|
||||||
},
|
},
|
||||||
|
"6c8816b2618db8d04ab9393429866d9af59ad280949947fc025c89baffe6a455": {
|
||||||
|
"describe": {
|
||||||
|
"columns": [
|
||||||
|
{
|
||||||
|
"name": "upstream_oauth_authorization_session_id",
|
||||||
|
"ordinal": 0,
|
||||||
|
"type_info": "Uuid"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "upstream_oauth_provider_id",
|
||||||
|
"ordinal": 1,
|
||||||
|
"type_info": "Uuid"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "state",
|
||||||
|
"ordinal": 2,
|
||||||
|
"type_info": "Text"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "code_challenge_verifier",
|
||||||
|
"ordinal": 3,
|
||||||
|
"type_info": "Text"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "nonce",
|
||||||
|
"ordinal": 4,
|
||||||
|
"type_info": "Text"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "created_at",
|
||||||
|
"ordinal": 5,
|
||||||
|
"type_info": "Timestamptz"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "completed_at",
|
||||||
|
"ordinal": 6,
|
||||||
|
"type_info": "Timestamptz"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "provider_issuer",
|
||||||
|
"ordinal": 7,
|
||||||
|
"type_info": "Text"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "provider_scope",
|
||||||
|
"ordinal": 8,
|
||||||
|
"type_info": "Text"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "provider_client_id",
|
||||||
|
"ordinal": 9,
|
||||||
|
"type_info": "Text"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "provider_encrypted_client_secret",
|
||||||
|
"ordinal": 10,
|
||||||
|
"type_info": "Text"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "provider_token_endpoint_auth_method",
|
||||||
|
"ordinal": 11,
|
||||||
|
"type_info": "Text"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "provider_token_endpoint_signing_alg",
|
||||||
|
"ordinal": 12,
|
||||||
|
"type_info": "Text"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "provider_created_at",
|
||||||
|
"ordinal": 13,
|
||||||
|
"type_info": "Timestamptz"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"nullable": [
|
||||||
|
false,
|
||||||
|
false,
|
||||||
|
false,
|
||||||
|
true,
|
||||||
|
false,
|
||||||
|
false,
|
||||||
|
true,
|
||||||
|
false,
|
||||||
|
false,
|
||||||
|
false,
|
||||||
|
true,
|
||||||
|
false,
|
||||||
|
true,
|
||||||
|
false
|
||||||
|
],
|
||||||
|
"parameters": {
|
||||||
|
"Left": [
|
||||||
|
"Uuid"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"query": "\n SELECT\n ua.upstream_oauth_authorization_session_id,\n ua.upstream_oauth_provider_id,\n ua.state,\n ua.code_challenge_verifier,\n ua.nonce,\n ua.created_at,\n ua.completed_at,\n up.issuer AS \"provider_issuer\",\n up.scope AS \"provider_scope\",\n up.client_id AS \"provider_client_id\",\n up.encrypted_client_secret AS \"provider_encrypted_client_secret\",\n up.token_endpoint_auth_method AS \"provider_token_endpoint_auth_method\",\n up.token_endpoint_signing_alg AS \"provider_token_endpoint_signing_alg\",\n up.created_at AS \"provider_created_at\"\n FROM upstream_oauth_authorization_sessions ua\n INNER JOIN upstream_oauth_providers up\n USING (upstream_oauth_provider_id)\n WHERE upstream_oauth_authorization_session_id = $1\n "
|
||||||
|
},
|
||||||
"7262f81a335a984c4051383d2ede7455ff65ed90fbd3151d625f8a21fd26cb05": {
|
"7262f81a335a984c4051383d2ede7455ff65ed90fbd3151d625f8a21fd26cb05": {
|
||||||
"describe": {
|
"describe": {
|
||||||
"columns": [],
|
"columns": [],
|
||||||
|
@ -126,6 +126,7 @@ impl StorageBackendMarker for PostgresqlBackend {}
|
|||||||
pub mod compat;
|
pub mod compat;
|
||||||
pub mod oauth2;
|
pub mod oauth2;
|
||||||
pub(crate) mod pagination;
|
pub(crate) mod pagination;
|
||||||
|
pub mod upstream_oauth2;
|
||||||
pub mod user;
|
pub mod user;
|
||||||
|
|
||||||
/// Embedded migrations, allowing them to run on startup
|
/// Embedded migrations, allowing them to run on startup
|
||||||
|
21
crates/storage/src/upstream_oauth2/mod.rs
Normal file
21
crates/storage/src/upstream_oauth2/mod.rs
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
mod provider;
|
||||||
|
mod session;
|
||||||
|
|
||||||
|
pub use self::{
|
||||||
|
provider::{add_provider, lookup_provider, ProviderLookupError},
|
||||||
|
session::{add_session, lookup_session, SessionLookupError},
|
||||||
|
};
|
159
crates/storage/src/upstream_oauth2/provider.rs
Normal file
159
crates/storage/src/upstream_oauth2/provider.rs
Normal file
@ -0,0 +1,159 @@
|
|||||||
|
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
use chrono::{DateTime, Utc};
|
||||||
|
use mas_data_model::UpstreamOAuthProvider;
|
||||||
|
use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod};
|
||||||
|
use oauth2_types::scope::Scope;
|
||||||
|
use rand::Rng;
|
||||||
|
use sqlx::PgExecutor;
|
||||||
|
use thiserror::Error;
|
||||||
|
use ulid::Ulid;
|
||||||
|
use uuid::Uuid;
|
||||||
|
|
||||||
|
use crate::{Clock, DatabaseInconsistencyError, LookupError};
|
||||||
|
|
||||||
|
#[derive(Debug, Error)]
|
||||||
|
#[error("Failed to lookup upstream OAuth 2.0 provider")]
|
||||||
|
pub enum ProviderLookupError {
|
||||||
|
Driver(#[from] sqlx::Error),
|
||||||
|
Inconcistency(#[from] DatabaseInconsistencyError),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl LookupError for ProviderLookupError {
|
||||||
|
fn not_found(&self) -> bool {
|
||||||
|
matches!(self, Self::Driver(sqlx::Error::RowNotFound))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ProviderLookup {
|
||||||
|
upstream_oauth_provider_id: Uuid,
|
||||||
|
issuer: String,
|
||||||
|
scope: String,
|
||||||
|
client_id: String,
|
||||||
|
encrypted_client_secret: Option<String>,
|
||||||
|
token_endpoint_signing_alg: Option<String>,
|
||||||
|
token_endpoint_auth_method: String,
|
||||||
|
created_at: DateTime<Utc>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tracing::instrument(
|
||||||
|
skip_all,
|
||||||
|
fields(upstream_oauth_provider.id = %id),
|
||||||
|
err,
|
||||||
|
)]
|
||||||
|
pub async fn lookup_provider(
|
||||||
|
executor: impl PgExecutor<'_>,
|
||||||
|
id: Ulid,
|
||||||
|
) -> Result<UpstreamOAuthProvider, ProviderLookupError> {
|
||||||
|
let res = sqlx::query_as!(
|
||||||
|
ProviderLookup,
|
||||||
|
r#"
|
||||||
|
SELECT
|
||||||
|
upstream_oauth_provider_id,
|
||||||
|
issuer,
|
||||||
|
scope,
|
||||||
|
client_id,
|
||||||
|
encrypted_client_secret,
|
||||||
|
token_endpoint_signing_alg,
|
||||||
|
token_endpoint_auth_method,
|
||||||
|
created_at
|
||||||
|
FROM upstream_oauth_providers
|
||||||
|
WHERE upstream_oauth_provider_id = $1
|
||||||
|
"#,
|
||||||
|
Uuid::from(id),
|
||||||
|
)
|
||||||
|
.fetch_one(executor)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
Ok(UpstreamOAuthProvider {
|
||||||
|
id: res.upstream_oauth_provider_id.into(),
|
||||||
|
issuer: res.issuer,
|
||||||
|
scope: res.scope.parse().map_err(|_| DatabaseInconsistencyError)?,
|
||||||
|
client_id: res.client_id,
|
||||||
|
encrypted_client_secret: res.encrypted_client_secret,
|
||||||
|
token_endpoint_auth_method: res
|
||||||
|
.token_endpoint_auth_method
|
||||||
|
.parse()
|
||||||
|
.map_err(|_| DatabaseInconsistencyError)?,
|
||||||
|
token_endpoint_signing_alg: res
|
||||||
|
.token_endpoint_signing_alg
|
||||||
|
.map(|x| x.parse())
|
||||||
|
.transpose()
|
||||||
|
.map_err(|_| DatabaseInconsistencyError)?,
|
||||||
|
created_at: res.created_at,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tracing::instrument(
|
||||||
|
skip_all,
|
||||||
|
fields(
|
||||||
|
upstream_oauth_provider.id,
|
||||||
|
upstream_oauth_provider.issuer = %issuer,
|
||||||
|
upstream_oauth_provider.client_id = %client_id,
|
||||||
|
),
|
||||||
|
err,
|
||||||
|
)]
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
pub async fn add_provider(
|
||||||
|
executor: impl PgExecutor<'_>,
|
||||||
|
mut rng: impl Rng + Send,
|
||||||
|
clock: &Clock,
|
||||||
|
issuer: String,
|
||||||
|
scope: Scope,
|
||||||
|
token_endpoint_auth_method: OAuthClientAuthenticationMethod,
|
||||||
|
token_endpoint_signing_alg: Option<JsonWebSignatureAlg>,
|
||||||
|
client_id: String,
|
||||||
|
encrypted_client_secret: Option<String>,
|
||||||
|
) -> Result<UpstreamOAuthProvider, sqlx::Error> {
|
||||||
|
let created_at = clock.now();
|
||||||
|
let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng);
|
||||||
|
tracing::Span::current().record("upstream_oauth_provider.id", tracing::field::display(id));
|
||||||
|
|
||||||
|
sqlx::query!(
|
||||||
|
r#"
|
||||||
|
INSERT INTO upstream_oauth_providers (
|
||||||
|
upstream_oauth_provider_id,
|
||||||
|
issuer,
|
||||||
|
scope,
|
||||||
|
token_endpoint_auth_method,
|
||||||
|
token_endpoint_signing_alg,
|
||||||
|
client_id,
|
||||||
|
encrypted_client_secret,
|
||||||
|
created_at
|
||||||
|
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
|
||||||
|
"#,
|
||||||
|
Uuid::from(id),
|
||||||
|
&issuer,
|
||||||
|
scope.to_string(),
|
||||||
|
token_endpoint_auth_method.to_string(),
|
||||||
|
token_endpoint_signing_alg.as_ref().map(ToString::to_string),
|
||||||
|
&client_id,
|
||||||
|
encrypted_client_secret.as_deref(),
|
||||||
|
created_at,
|
||||||
|
)
|
||||||
|
.execute(executor)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
Ok(UpstreamOAuthProvider {
|
||||||
|
id,
|
||||||
|
issuer,
|
||||||
|
scope,
|
||||||
|
client_id,
|
||||||
|
encrypted_client_secret,
|
||||||
|
token_endpoint_signing_alg,
|
||||||
|
token_endpoint_auth_method,
|
||||||
|
created_at,
|
||||||
|
})
|
||||||
|
}
|
184
crates/storage/src/upstream_oauth2/session.rs
Normal file
184
crates/storage/src/upstream_oauth2/session.rs
Normal file
@ -0,0 +1,184 @@
|
|||||||
|
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
use chrono::{DateTime, Utc};
|
||||||
|
use mas_data_model::{UpstreamOAuthAuthorizationSession, UpstreamOAuthProvider};
|
||||||
|
use rand::Rng;
|
||||||
|
use sqlx::PgExecutor;
|
||||||
|
use thiserror::Error;
|
||||||
|
use ulid::Ulid;
|
||||||
|
use uuid::Uuid;
|
||||||
|
|
||||||
|
use crate::{Clock, DatabaseInconsistencyError, LookupError};
|
||||||
|
|
||||||
|
#[derive(Debug, Error)]
|
||||||
|
#[error("Failed to lookup upstream OAuth 2.0 authorization session")]
|
||||||
|
pub enum SessionLookupError {
|
||||||
|
Driver(#[from] sqlx::Error),
|
||||||
|
Inconcistency(#[from] DatabaseInconsistencyError),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl LookupError for SessionLookupError {
|
||||||
|
fn not_found(&self) -> bool {
|
||||||
|
matches!(self, Self::Driver(sqlx::Error::RowNotFound))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct SessionLookup {
|
||||||
|
upstream_oauth_authorization_session_id: Uuid,
|
||||||
|
upstream_oauth_provider_id: Uuid,
|
||||||
|
state: String,
|
||||||
|
code_challenge_verifier: Option<String>,
|
||||||
|
nonce: String,
|
||||||
|
created_at: DateTime<Utc>,
|
||||||
|
completed_at: Option<DateTime<Utc>>,
|
||||||
|
provider_issuer: String,
|
||||||
|
provider_scope: String,
|
||||||
|
provider_client_id: String,
|
||||||
|
provider_encrypted_client_secret: Option<String>,
|
||||||
|
provider_token_endpoint_auth_method: String,
|
||||||
|
provider_token_endpoint_signing_alg: Option<String>,
|
||||||
|
provider_created_at: DateTime<Utc>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tracing::instrument(
|
||||||
|
skip_all,
|
||||||
|
fields(upstream_oauth_authorization_session.id = %id),
|
||||||
|
err,
|
||||||
|
)]
|
||||||
|
pub async fn lookup_session(
|
||||||
|
executor: impl PgExecutor<'_>,
|
||||||
|
id: Ulid,
|
||||||
|
) -> Result<(UpstreamOAuthProvider, UpstreamOAuthAuthorizationSession), SessionLookupError> {
|
||||||
|
let res = sqlx::query_as!(
|
||||||
|
SessionLookup,
|
||||||
|
r#"
|
||||||
|
SELECT
|
||||||
|
ua.upstream_oauth_authorization_session_id,
|
||||||
|
ua.upstream_oauth_provider_id,
|
||||||
|
ua.state,
|
||||||
|
ua.code_challenge_verifier,
|
||||||
|
ua.nonce,
|
||||||
|
ua.created_at,
|
||||||
|
ua.completed_at,
|
||||||
|
up.issuer AS "provider_issuer",
|
||||||
|
up.scope AS "provider_scope",
|
||||||
|
up.client_id AS "provider_client_id",
|
||||||
|
up.encrypted_client_secret AS "provider_encrypted_client_secret",
|
||||||
|
up.token_endpoint_auth_method AS "provider_token_endpoint_auth_method",
|
||||||
|
up.token_endpoint_signing_alg AS "provider_token_endpoint_signing_alg",
|
||||||
|
up.created_at AS "provider_created_at"
|
||||||
|
FROM upstream_oauth_authorization_sessions ua
|
||||||
|
INNER JOIN upstream_oauth_providers up
|
||||||
|
USING (upstream_oauth_provider_id)
|
||||||
|
WHERE upstream_oauth_authorization_session_id = $1
|
||||||
|
"#,
|
||||||
|
Uuid::from(id),
|
||||||
|
)
|
||||||
|
.fetch_one(executor)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let provider = UpstreamOAuthProvider {
|
||||||
|
id: res.upstream_oauth_provider_id.into(),
|
||||||
|
issuer: res
|
||||||
|
.provider_issuer
|
||||||
|
.parse()
|
||||||
|
.map_err(|_| DatabaseInconsistencyError)?,
|
||||||
|
scope: res
|
||||||
|
.provider_scope
|
||||||
|
.parse()
|
||||||
|
.map_err(|_| DatabaseInconsistencyError)?,
|
||||||
|
client_id: res.provider_client_id,
|
||||||
|
encrypted_client_secret: res.provider_encrypted_client_secret,
|
||||||
|
token_endpoint_auth_method: res
|
||||||
|
.provider_token_endpoint_auth_method
|
||||||
|
.parse()
|
||||||
|
.map_err(|_| DatabaseInconsistencyError)?,
|
||||||
|
token_endpoint_signing_alg: res
|
||||||
|
.provider_token_endpoint_signing_alg
|
||||||
|
.map(|x| x.parse())
|
||||||
|
.transpose()
|
||||||
|
.map_err(|_| DatabaseInconsistencyError)?,
|
||||||
|
created_at: res.provider_created_at,
|
||||||
|
};
|
||||||
|
|
||||||
|
let session = UpstreamOAuthAuthorizationSession {
|
||||||
|
id: res.upstream_oauth_authorization_session_id.into(),
|
||||||
|
state: res.state,
|
||||||
|
code_challenge_verifier: res.code_challenge_verifier,
|
||||||
|
nonce: res.nonce,
|
||||||
|
created_at: res.created_at,
|
||||||
|
completed_at: res.completed_at,
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok((provider, session))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tracing::instrument(
|
||||||
|
skip_all,
|
||||||
|
fields(
|
||||||
|
upstream_oauth_provider.id = %provider.id,
|
||||||
|
upstream_oauth_provider.issuer = %provider.issuer,
|
||||||
|
upstream_oauth_provider.client_id = %provider.client_id,
|
||||||
|
upstream_oauth_authorization_session.id,
|
||||||
|
),
|
||||||
|
err,
|
||||||
|
)]
|
||||||
|
pub async fn add_session(
|
||||||
|
executor: impl PgExecutor<'_>,
|
||||||
|
mut rng: impl Rng + Send,
|
||||||
|
clock: &Clock,
|
||||||
|
provider: &UpstreamOAuthProvider,
|
||||||
|
state: String,
|
||||||
|
code_challenge_verifier: Option<String>,
|
||||||
|
nonce: String,
|
||||||
|
) -> Result<UpstreamOAuthAuthorizationSession, sqlx::Error> {
|
||||||
|
let created_at = clock.now();
|
||||||
|
let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng);
|
||||||
|
tracing::Span::current().record(
|
||||||
|
"upstream_oauth_authorization_session.id",
|
||||||
|
tracing::field::display(id),
|
||||||
|
);
|
||||||
|
|
||||||
|
sqlx::query!(
|
||||||
|
r#"
|
||||||
|
INSERT INTO upstream_oauth_authorization_sessions (
|
||||||
|
upstream_oauth_authorization_session_id,
|
||||||
|
upstream_oauth_provider_id,
|
||||||
|
state,
|
||||||
|
code_challenge_verifier,
|
||||||
|
nonce,
|
||||||
|
created_at,
|
||||||
|
completed_at
|
||||||
|
) VALUES ($1, $2, $3, $4, $5, $6, NULL)
|
||||||
|
"#,
|
||||||
|
Uuid::from(id),
|
||||||
|
Uuid::from(provider.id),
|
||||||
|
&state,
|
||||||
|
code_challenge_verifier.as_deref(),
|
||||||
|
nonce,
|
||||||
|
created_at,
|
||||||
|
)
|
||||||
|
.execute(executor)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
Ok(UpstreamOAuthAuthorizationSession {
|
||||||
|
id,
|
||||||
|
state,
|
||||||
|
code_challenge_verifier,
|
||||||
|
nonce,
|
||||||
|
created_at,
|
||||||
|
completed_at: None,
|
||||||
|
})
|
||||||
|
}
|
Reference in New Issue
Block a user