1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-28 11:02:02 +03:00

Move clients to the database

This commit is contained in:
Quentin Gliech
2022-03-08 17:33:25 +01:00
parent 19a81afe51
commit 62f633a716
33 changed files with 1926 additions and 867 deletions

14
Cargo.lock generated
View File

@ -1828,6 +1828,7 @@ dependencies = [
"argon2", "argon2",
"atty", "atty",
"clap", "clap",
"data-encoding",
"dotenv", "dotenv",
"futures 0.3.21", "futures 0.3.21",
"hyper", "hyper",
@ -1845,6 +1846,7 @@ dependencies = [
"opentelemetry-otlp", "opentelemetry-otlp",
"opentelemetry-semantic-conventions", "opentelemetry-semantic-conventions",
"opentelemetry-zipkin", "opentelemetry-zipkin",
"rand",
"reqwest", "reqwest",
"schemars", "schemars",
"serde_json", "serde_json",
@ -1870,12 +1872,9 @@ dependencies = [
"chrono", "chrono",
"elliptic-curve", "elliptic-curve",
"figment", "figment",
"futures-util",
"http",
"http-body",
"indoc", "indoc",
"lettre", "lettre",
"mas-http", "mas-iana",
"mas-jose", "mas-jose",
"p256", "p256",
"pem-rfc7468", "pem-rfc7468",
@ -1889,7 +1888,6 @@ dependencies = [
"sqlx", "sqlx",
"thiserror", "thiserror",
"tokio", "tokio",
"tower",
"tracing", "tracing",
"url", "url",
] ]
@ -1901,6 +1899,7 @@ dependencies = [
"chrono", "chrono",
"crc", "crc",
"mas-iana", "mas-iana",
"mas-jose",
"oauth2-types", "oauth2-types",
"rand", "rand",
"serde", "serde",
@ -2068,10 +2067,12 @@ dependencies = [
"chrono", "chrono",
"mas-data-model", "mas-data-model",
"mas-iana", "mas-iana",
"mas-jose",
"oauth2-types", "oauth2-types",
"password-hash", "password-hash",
"rand", "rand",
"serde", "serde",
"serde_json",
"sqlx", "sqlx",
"thiserror", "thiserror",
"tokio", "tokio",
@ -2123,9 +2124,12 @@ dependencies = [
"crc", "crc",
"data-encoding", "data-encoding",
"headers", "headers",
"http",
"http-body",
"hyper", "hyper",
"mas-config", "mas-config",
"mas-data-model", "mas-data-model",
"mas-http",
"mas-iana", "mas-iana",
"mas-jose", "mas-jose",
"mas-storage", "mas-storage",

View File

@ -22,6 +22,8 @@ argon2 = { version = "0.3.4", features = ["password-hash"] }
reqwest = { version = "0.11.9", features = ["rustls-tls"], default-features = false, optional = true } reqwest = { version = "0.11.9", features = ["rustls-tls"], default-features = false, optional = true }
watchman_client = "0.7.1" watchman_client = "0.7.1"
atty = "0.2.14" atty = "0.2.14"
rand = "0.8.5"
data-encoding = "2.3.2"
tracing = "0.1.31" tracing = "0.1.31"
tracing-appender = "0.2.1" tracing-appender = "0.2.1"

View File

@ -14,9 +14,13 @@
use argon2::Argon2; use argon2::Argon2;
use clap::Parser; use clap::Parser;
use mas_config::DatabaseConfig; use data_encoding::BASE64;
use mas_storage::user::{ use mas_config::{DatabaseConfig, RootConfig};
lookup_user_by_username, lookup_user_email, mark_user_email_as_verified, register_user, use mas_storage::{
oauth2::client::{insert_client_from_config, lookup_client_by_client_id, truncate_clients},
user::{
lookup_user_by_username, lookup_user_email, mark_user_email_as_verified, register_user,
},
}; };
use tracing::{info, warn}; use tracing::{info, warn};
@ -36,6 +40,13 @@ enum Subcommand {
/// Mark email address as verified /// Mark email address as verified
VerifyEmail { username: String, email: String }, VerifyEmail { username: String, email: String },
/// Import clients from config
ImportClients {
/// Remove all clients before importing
#[clap(long)]
truncate: bool,
},
} }
impl Options { impl Options {
@ -71,6 +82,65 @@ impl Options {
txn.commit().await?; txn.commit().await?;
info!(?email, "Email marked as verified"); info!(?email, "Email marked as verified");
Ok(())
}
SC::ImportClients { truncate } => {
let config: RootConfig = root.load_config()?;
let pool = config.database.connect().await?;
let encrypter = config.secrets.encrypter();
let mut txn = pool.begin().await?;
if *truncate {
warn!("Removing all clients first");
truncate_clients(&mut txn).await?;
}
for client in config.clients.iter() {
let client_id = &client.client_id;
let res = lookup_client_by_client_id(&mut txn, client_id).await;
match res {
Ok(_) => {
warn!(%client_id, "Skipping already imported client");
continue;
}
Err(e) if e.not_found() => {}
Err(e) => anyhow::bail!(e),
}
info!(%client_id, "Importing client");
let client_secret = client.client_secret();
let client_auth_method = client.client_auth_method();
let jwks = client.jwks();
let jwks_uri = client.jwks_uri();
let redirect_uris = &client.redirect_uris;
// TODO: should be moved somewhere else
let encrypted_client_secret = client_secret
.map(|client_secret| {
let nonce: [u8; 12] = rand::random();
let message = encrypter.encrypt(&nonce, client_secret.as_bytes())?;
let concat = [&nonce[..], &message[..]].concat();
let res = BASE64.encode(&concat);
anyhow::Ok(res)
})
.transpose()?;
insert_client_from_config(
&mut txn,
client_id,
client_auth_method,
encrypted_client_secret.as_deref(),
jwks,
jwks_uri,
redirect_uris,
)
.await?;
}
txn.commit().await?;
Ok(()) Ok(())
} }
} }

View File

@ -35,8 +35,4 @@ pem-rfc7468 = "0.3.1"
indoc = "1.0.4" indoc = "1.0.4"
mas-jose = { path = "../jose" } mas-jose = { path = "../jose" }
mas-http = { path = "../http" } mas-iana = { path = "../iana" }
tower = { version = "0.4.12", features = ["util"] }
http = "0.2.6"
http-body = "0.4.4"
futures-util = "0.3.21"

View File

@ -15,15 +15,12 @@
use std::ops::{Deref, DerefMut}; use std::ops::{Deref, DerefMut};
use async_trait::async_trait; use async_trait::async_trait;
use futures_util::future::Either; use mas_iana::oauth::OAuthClientAuthenticationMethod;
use http::Request; use mas_jose::JsonWebKeySet;
use mas_http::HttpServiceExt;
use mas_jose::{DynamicJwksStore, JsonWebKeySet, StaticJwksStore, VerifyingKeystore};
use schemars::JsonSchema; use schemars::JsonSchema;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_with::skip_serializing_none; use serde_with::skip_serializing_none;
use thiserror::Error; use thiserror::Error;
use tower::{BoxError, ServiceExt};
use url::Url; use url::Url;
use super::ConfigurationSection; use super::ConfigurationSection;
@ -35,41 +32,6 @@ pub enum JwksOrJwksUri {
JwksUri(Url), JwksUri(Url),
} }
impl JwksOrJwksUri {
pub fn key_store(&self) -> Either<StaticJwksStore, DynamicJwksStore> {
// Assert that the output is both a VerifyingKeystore and Send
fn assert<T: Send + VerifyingKeystore>(t: T) -> T {
t
}
let inner = match self {
Self::Jwks(jwks) => Either::Left(StaticJwksStore::new(jwks.clone())),
Self::JwksUri(uri) => {
let uri = uri.clone();
// TODO: get the client from somewhere else?
let exporter = mas_http::client("fetch-jwks")
.json::<JsonWebKeySet>()
.map_request(move |_: ()| {
Request::builder()
.method("GET")
// TODO: change the Uri type in config to avoid reparsing here
.uri(uri.to_string())
.body(http_body::Empty::new())
.unwrap()
})
.map_response(http::Response::into_body)
.map_err(BoxError::from)
.boxed_clone();
Either::Right(DynamicJwksStore::new(exporter))
}
};
assert(inner)
}
}
impl From<JsonWebKeySet> for JwksOrJwksUri { impl From<JsonWebKeySet> for JwksOrJwksUri {
fn from(jwks: JsonWebKeySet) -> Self { fn from(jwks: JsonWebKeySet) -> Self {
Self::Jwks(jwks) Self::Jwks(jwks)
@ -131,24 +93,53 @@ pub struct InvalidRedirectUriError;
impl ClientConfig { impl ClientConfig {
#[doc(hidden)] #[doc(hidden)]
pub fn resolve_redirect_uri<'a>( #[must_use]
&'a self, pub fn client_secret(&self) -> Option<&str> {
suggested_uri: &'a Option<Url>, match &self.client_auth_method {
) -> Result<&'a Url, InvalidRedirectUriError> { ClientAuthMethodConfig::ClientSecretPost { client_secret }
suggested_uri.as_ref().map_or_else( | ClientAuthMethodConfig::ClientSecretBasic { client_secret }
|| self.redirect_uris.get(0).ok_or(InvalidRedirectUriError), | ClientAuthMethodConfig::ClientSecretJwt { client_secret } => Some(client_secret),
|suggested_uri| self.check_redirect_uri(suggested_uri), _ => None,
) }
} }
fn check_redirect_uri<'a>( #[doc(hidden)]
&self, #[must_use]
redirect_uri: &'a Url, pub fn client_auth_method(&self) -> OAuthClientAuthenticationMethod {
) -> Result<&'a Url, InvalidRedirectUriError> { match &self.client_auth_method {
if self.redirect_uris.contains(redirect_uri) { ClientAuthMethodConfig::None => OAuthClientAuthenticationMethod::None,
Ok(redirect_uri) ClientAuthMethodConfig::ClientSecretBasic { .. } => {
} else { OAuthClientAuthenticationMethod::ClientSecretBasic
Err(InvalidRedirectUriError) }
ClientAuthMethodConfig::ClientSecretPost { .. } => {
OAuthClientAuthenticationMethod::ClientSecretPost
}
ClientAuthMethodConfig::ClientSecretJwt { .. } => {
OAuthClientAuthenticationMethod::ClientSecretJwt
}
ClientAuthMethodConfig::PrivateKeyJwt(_) => {
OAuthClientAuthenticationMethod::PrivateKeyJwt
}
}
}
#[doc(hidden)]
#[must_use]
pub fn jwks(&self) -> Option<&JsonWebKeySet> {
match &self.client_auth_method {
ClientAuthMethodConfig::PrivateKeyJwt(JwksOrJwksUri::Jwks(jwks)) => Some(jwks),
_ => None,
}
}
#[doc(hidden)]
#[must_use]
pub fn jwks_uri(&self) -> Option<&Url> {
match &self.client_auth_method {
ClientAuthMethodConfig::PrivateKeyJwt(JwksOrJwksUri::JwksUri(jwks_uri)) => {
Some(jwks_uri)
}
_ => None,
} }
} }
} }

View File

@ -14,4 +14,5 @@ crc = "2.1.0"
rand = "0.8.5" rand = "0.8.5"
mas-iana = { path = "../iana" } mas-iana = { path = "../iana" }
mas-jose = { path = "../jose" }
oauth2-types = { path = "../oauth2-types" } oauth2-types = { path = "../oauth2-types" }

View File

@ -30,7 +30,8 @@ pub(crate) mod users;
pub use self::{ pub use self::{
oauth2::{ oauth2::{
AuthorizationCode, AuthorizationGrant, AuthorizationGrantStage, Client, Pkce, Session, AuthorizationCode, AuthorizationGrant, AuthorizationGrantStage, Client, JwksOrJwksUri,
Pkce, Session,
}, },
tokens::{AccessToken, RefreshToken, TokenFormatError, TokenType}, tokens::{AccessToken, RefreshToken, TokenFormatError, TokenType},
traits::{StorageBackend, StorageBackendMarker}, traits::{StorageBackend, StorageBackendMarker},

View File

@ -12,16 +12,87 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use mas_iana::{
jose::JsonWebSignatureAlg,
oauth::{OAuthAuthorizationEndpointResponseType, OAuthClientAuthenticationMethod},
};
use mas_jose::JsonWebKeySet;
use oauth2_types::requests::GrantType;
use serde::Serialize; use serde::Serialize;
use thiserror::Error;
use url::Url;
use crate::traits::{StorageBackend, StorageBackendMarker}; use crate::traits::{StorageBackend, StorageBackendMarker};
#[derive(Debug, Clone, PartialEq, Serialize)]
#[serde(rename_all = "snake_case")]
pub enum JwksOrJwksUri {
/// Client's JSON Web Key Set document, passed by value.
Jwks(JsonWebKeySet),
/// URL for the Client's JSON Web Key Set document.
JwksUri(Url),
}
#[derive(Debug, Clone, PartialEq, Serialize)] #[derive(Debug, Clone, PartialEq, Serialize)]
#[serde(bound = "T: StorageBackend")] #[serde(bound = "T: StorageBackend")]
pub struct Client<T: StorageBackend> { pub struct Client<T: StorageBackend> {
#[serde(skip_serializing)] #[serde(skip_serializing)]
pub data: T::ClientData, pub data: T::ClientData,
/// Client identifier
pub client_id: String, pub client_id: String,
pub encrypted_client_secret: Option<String>,
/// Array of Redirection URI values used by the Client
pub redirect_uris: Vec<Url>,
/// Array containing a list of the OAuth 2.0 response_type values that the
/// Client is declaring that it will restrict itself to using
pub response_types: Vec<OAuthAuthorizationEndpointResponseType>,
/// Array containing a list of the OAuth 2.0 Grant Types that the Client is
/// declaring that it will restrict itself to using.
pub grant_types: Vec<GrantType>,
/// Array of e-mail addresses of people responsible for this Client
pub contacts: Vec<String>,
/// Name of the Client to be presented to the End-User
pub client_name: Option<String>, // TODO: translations
/// URL that references a logo for the Client application
pub logo_uri: Option<Url>, // TODO: translations
/// URL of the home page of the Client
pub client_uri: Option<Url>, // TODO: translations
/// URL that the Relying Party Client provides to the End-User to read about
/// the how the profile data will be used
pub policy_uri: Option<Url>, // TODO: translations
/// URL that the Relying Party Client provides to the End-User to read about
/// the Relying Party's terms of service
pub tos_uri: Option<Url>, // TODO: translations
pub jwks: Option<JwksOrJwksUri>,
/// JWS alg algorithm REQUIRED for signing the ID Token issued to this
/// Client
pub id_token_signed_response_alg: Option<JsonWebSignatureAlg>,
/// Requested authentication method for the token endpoint
pub token_endpoint_auth_method: Option<OAuthClientAuthenticationMethod>,
/// JWS alg algorithm that MUST be used for signing the JWT used to
/// authenticate the Client at the Token Endpoint for the private_key_jwt
/// and client_secret_jwt authentication methods
pub token_endpoint_auth_signing_alg: Option<JsonWebSignatureAlg>,
/// URI using the https scheme that a third party can use to initiate a
/// login by the RP
pub initiate_login_uri: Option<Url>,
} }
impl<S: StorageBackendMarker> From<Client<S>> for Client<()> { impl<S: StorageBackendMarker> From<Client<S>> for Client<()> {
@ -29,6 +100,48 @@ impl<S: StorageBackendMarker> From<Client<S>> for Client<()> {
Client { Client {
data: (), data: (),
client_id: c.client_id, client_id: c.client_id,
encrypted_client_secret: c.encrypted_client_secret,
redirect_uris: c.redirect_uris,
response_types: c.response_types,
grant_types: c.grant_types,
contacts: c.contacts,
client_name: c.client_name,
logo_uri: c.logo_uri,
client_uri: c.client_uri,
policy_uri: c.policy_uri,
tos_uri: c.tos_uri,
jwks: c.jwks,
id_token_signed_response_alg: c.id_token_signed_response_alg,
token_endpoint_auth_method: c.token_endpoint_auth_method,
token_endpoint_auth_signing_alg: c.token_endpoint_auth_signing_alg,
initiate_login_uri: c.initiate_login_uri,
}
}
}
#[derive(Debug, Error)]
pub enum InvalidRedirectUriError {
#[error("redirect_uri is not allowed for this client")]
NotAllowed,
#[error("multiple redirect_uris registered for this client")]
MultipleRegistered,
#[error("client has no redirect_uri registered")]
NoneRegistered,
}
impl<S: StorageBackend> Client<S> {
pub fn resolve_redirect_uri<'a>(
&'a self,
redirect_uri: &'a Option<Url>,
) -> Result<&'a Url, InvalidRedirectUriError> {
match (&self.redirect_uris[..], redirect_uri) {
([], _) => Err(InvalidRedirectUriError::NoneRegistered),
([one], None) => Ok(one),
(_, None) => Err(InvalidRedirectUriError::MultipleRegistered),
(uris, Some(uri)) if uris.contains(uri) => Ok(uri),
_ => Err(InvalidRedirectUriError::NotAllowed),
} }
} }
} }

View File

@ -1,4 +1,4 @@
// Copyright 2021 The Matrix.org Foundation C.I.C. // Copyright 2021, 2022 The Matrix.org Foundation C.I.C.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
@ -18,6 +18,6 @@ pub(self) mod session;
pub use self::{ pub use self::{
authorization_grant::{AuthorizationCode, AuthorizationGrant, AuthorizationGrantStage, Pkce}, authorization_grant::{AuthorizationCode, AuthorizationGrant, AuthorizationGrantStage, Pkce},
client::Client, client::{Client, JwksOrJwksUri},
session::Session, session::Session,
}; };

View File

@ -46,14 +46,7 @@ pub fn root(
config: &RootConfig, config: &RootConfig,
) -> BoxedFilter<(impl Reply,)> { ) -> BoxedFilter<(impl Reply,)> {
let health = health(pool); let health = health(pool);
let oauth2 = oauth2( let oauth2 = oauth2(pool, templates, key_store, encrypter, &config.http);
pool,
templates,
key_store,
encrypter,
&config.clients,
&config.http,
);
let views = views( let views = views(
pool, pool,
templates, templates,

View File

@ -20,7 +20,7 @@ use hyper::{
http::uri::{Parts, PathAndQuery, Uri}, http::uri::{Parts, PathAndQuery, Uri},
StatusCode, StatusCode,
}; };
use mas_config::{ClientsConfig, Encrypter}; use mas_config::Encrypter;
use mas_data_model::{ use mas_data_model::{
Authentication, AuthorizationCode, AuthorizationGrant, AuthorizationGrantStage, BrowserSession, Authentication, AuthorizationCode, AuthorizationGrant, AuthorizationGrantStage, BrowserSession,
Pkce, StorageBackend, TokenType, Pkce, StorageBackend, TokenType,
@ -32,6 +32,7 @@ use mas_storage::{
authorization_grant::{ authorization_grant::{
derive_session, fulfill_grant, get_grant_by_id, new_authorization_grant, derive_session, fulfill_grant, get_grant_by_id, new_authorization_grant,
}, },
client::lookup_client_by_client_id,
refresh_token::add_refresh_token, refresh_token::add_refresh_token,
}, },
PostgresqlBackend, PostgresqlBackend,
@ -41,7 +42,7 @@ use mas_warp_utils::{
errors::WrapError, errors::WrapError,
filters::{ filters::{
self, self,
database::transaction, database::{connection, transaction},
session::{optional_session, session}, session::{optional_session, session},
with_templates, with_templates,
}, },
@ -49,19 +50,20 @@ use mas_warp_utils::{
use oauth2_types::{ use oauth2_types::{
errors::{ errors::{
ErrorResponse, InvalidGrant, InvalidRequest, LoginRequired, OAuth2Error, ErrorResponse, InvalidGrant, InvalidRequest, LoginRequired, OAuth2Error,
RegistrationNotSupported, RequestNotSupported, RequestUriNotSupported, RegistrationNotSupported, RequestNotSupported, RequestUriNotSupported, UnauthorizedClient,
}, },
pkce, pkce,
prelude::*, prelude::*,
requests::{ requests::{
AccessTokenResponse, AuthorizationRequest, AuthorizationResponse, Prompt, ResponseMode, AccessTokenResponse, AuthorizationRequest, AuthorizationResponse, GrantType, Prompt,
ResponseMode,
}, },
scope::ScopeToken, scope::ScopeToken,
}; };
use rand::{distributions::Alphanumeric, thread_rng, Rng}; use rand::{distributions::Alphanumeric, thread_rng, Rng};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::Value; use serde_json::Value;
use sqlx::{PgExecutor, PgPool, Postgres, Transaction}; use sqlx::{pool::PoolConnection, PgConnection, PgPool, Postgres, Transaction};
use url::Url; use url::Url;
use warp::{ use warp::{
filters::BoxedFilter, filters::BoxedFilter,
@ -217,15 +219,10 @@ pub fn filter(
pool: &PgPool, pool: &PgPool,
templates: &Templates, templates: &Templates,
encrypter: &Encrypter, encrypter: &Encrypter,
clients_config: &ClientsConfig,
) -> BoxedFilter<(Box<dyn Reply>,)> { ) -> BoxedFilter<(Box<dyn Reply>,)> {
let clients_config = clients_config.clone();
let clients_config_2 = clients_config.clone();
let authorize = warp::path!("oauth2" / "authorize") let authorize = warp::path!("oauth2" / "authorize")
.and(filters::trace::name("GET /oauth2/authorize")) .and(filters::trace::name("GET /oauth2/authorize"))
.and(warp::get()) .and(warp::get())
.map(move || clients_config.clone())
.and(warp::query()) .and(warp::query())
.and(optional_session(pool, encrypter)) .and(optional_session(pool, encrypter))
.and(transaction(pool)) .and(transaction(pool))
@ -245,8 +242,8 @@ pub fn filter(
.recover(recover) .recover(recover)
.unify() .unify()
.and(warp::query()) .and(warp::query())
.and(warp::any().map(move || clients_config_2.clone()))
.and(with_templates(templates)) .and(with_templates(templates))
.and(connection(pool))
.and_then(actually_reply) .and_then(actually_reply)
.boxed() .boxed()
} }
@ -262,8 +259,8 @@ async fn recover(rejection: Rejection) -> Result<ReplyOrBackToClient, Rejection>
async fn actually_reply( async fn actually_reply(
rep: ReplyOrBackToClient, rep: ReplyOrBackToClient,
q: PartialParams, q: PartialParams,
clients: ClientsConfig,
templates: Templates, templates: Templates,
mut conn: PoolConnection<Postgres>,
) -> Result<Box<dyn Reply>, Rejection> { ) -> Result<Box<dyn Reply>, Rejection> {
let (redirect_uri, response_mode, state, params) = match rep { let (redirect_uri, response_mode, state, params) = match rep {
ReplyOrBackToClient::Reply(r) => return Ok(r), ReplyOrBackToClient::Reply(r) => return Ok(r),
@ -281,15 +278,14 @@ async fn actually_reply(
.. ..
} = q; } = q;
// First, disover the client let client_id = if let Some(client_id) = client_id {
let client = client_id client_id
.and_then(|client_id| clients.iter().find(|client| client.client_id == client_id)); } else {
return Ok(Box::new(html(templates.render_error(&error.into()).await?)));
let client = match client {
Some(client) => client,
None => return Ok(Box::new(html(templates.render_error(&error.into()).await?))),
}; };
let client = lookup_client_by_client_id(&mut conn, &client_id).await?;
let redirect_uri: Result<Option<Url>, _> = redirect_uri.map(|r| r.parse()).transpose(); let redirect_uri: Result<Option<Url>, _> = redirect_uri.map(|r| r.parse()).transpose();
let redirect_uri = match redirect_uri { let redirect_uri = match redirect_uri {
Ok(r) => r, Ok(r) => r,
@ -315,7 +311,6 @@ async fn actually_reply(
} }
async fn get( async fn get(
clients: ClientsConfig,
params: Params, params: Params,
maybe_session: Option<BrowserSession<PostgresqlBackend>>, maybe_session: Option<BrowserSession<PostgresqlBackend>>,
mut txn: Transaction<'_, Postgres>, mut txn: Transaction<'_, Postgres>,
@ -337,15 +332,17 @@ async fn get(
} }
// First, find out what client it is // First, find out what client it is
let client = clients let client = lookup_client_by_client_id(&mut txn, &params.auth.client_id).await?;
.iter()
.find(|client| client.client_id == params.auth.client_id) // Check if it is allowed to use this grant type
.ok_or_else(|| anyhow::anyhow!("could not find client")) if !client.grant_types.contains(&GrantType::AuthorizationCode) {
.wrap_error()?; return Ok(ReplyOrBackToClient::Error(Box::new(UnauthorizedClient)));
}
let redirect_uri = client let redirect_uri = client
.resolve_redirect_uri(&params.auth.redirect_uri) .resolve_redirect_uri(&params.auth.redirect_uri)
.wrap_error()?; .wrap_error()?
.clone();
let response_type = params.auth.response_type; let response_type = params.auth.response_type;
let response_mode = let response_mode =
resolve_response_mode(response_type, params.auth.response_mode).wrap_error()?; resolve_response_mode(response_type, params.auth.response_mode).wrap_error()?;
@ -392,8 +389,8 @@ async fn get(
let grant = new_authorization_grant( let grant = new_authorization_grant(
&mut txn, &mut txn,
client.client_id.clone(), client,
redirect_uri.clone(), redirect_uri,
scope, scope,
code, code,
params.auth.state, params.auth.state,
@ -471,10 +468,10 @@ impl ContinueAuthorizationGrant {
pub async fn fetch_authorization_grant( pub async fn fetch_authorization_grant(
&self, &self,
executor: impl PgExecutor<'_>, conn: &mut PgConnection,
) -> anyhow::Result<AuthorizationGrant<PostgresqlBackend>> { ) -> anyhow::Result<AuthorizationGrant<PostgresqlBackend>> {
let data = self.data.parse()?; let data = self.data.parse()?;
get_grant_by_id(executor, data).await get_grant_by_id(conn, data).await
} }
} }

View File

@ -12,11 +12,14 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use mas_config::{ClientConfig, ClientsConfig, HttpConfig}; use mas_config::{Encrypter, HttpConfig};
use mas_data_model::TokenType; use mas_data_model::{Client, TokenType};
use mas_iana::oauth::{OAuthClientAuthenticationMethod, OAuthTokenTypeHint}; use mas_iana::oauth::{OAuthClientAuthenticationMethod, OAuthTokenTypeHint};
use mas_storage::oauth2::{ use mas_storage::{
access_token::lookup_active_access_token, refresh_token::lookup_active_refresh_token, oauth2::{
access_token::lookup_active_access_token, refresh_token::lookup_active_refresh_token,
},
PostgresqlBackend,
}; };
use mas_warp_utils::{ use mas_warp_utils::{
errors::WrapError, errors::WrapError,
@ -29,7 +32,7 @@ use warp::{filters::BoxedFilter, Filter, Rejection, Reply};
pub fn filter( pub fn filter(
pool: &PgPool, pool: &PgPool,
clients_config: &ClientsConfig, encrypter: &Encrypter,
http_config: &HttpConfig, http_config: &HttpConfig,
) -> BoxedFilter<(Box<dyn Reply>,)> { ) -> BoxedFilter<(Box<dyn Reply>,)> {
let audience = UrlBuilder::from(http_config) let audience = UrlBuilder::from(http_config)
@ -41,7 +44,7 @@ pub fn filter(
.and( .and(
warp::post() warp::post()
.and(connection(pool)) .and(connection(pool))
.and(client_authentication(clients_config, audience)) .and(client_authentication(pool, encrypter, audience))
.and_then(introspect) .and_then(introspect)
.recover(recover) .recover(recover)
.unify(), .unify(),
@ -67,7 +70,7 @@ const INACTIVE: IntrospectionResponse = IntrospectionResponse {
async fn introspect( async fn introspect(
mut conn: PoolConnection<Postgres>, mut conn: PoolConnection<Postgres>,
auth: OAuthClientAuthenticationMethod, auth: OAuthClientAuthenticationMethod,
client: ClientConfig, client: Client<PostgresqlBackend>,
params: IntrospectionRequest, params: IntrospectionRequest,
) -> Result<Box<dyn Reply>, Rejection> { ) -> Result<Box<dyn Reply>, Rejection> {
// Token introspection is only allowed by confidential clients // Token introspection is only allowed by confidential clients

View File

@ -15,7 +15,7 @@
use std::sync::Arc; use std::sync::Arc;
use hyper::Method; use hyper::Method;
use mas_config::{ClientsConfig, Encrypter, HttpConfig}; use mas_config::{Encrypter, HttpConfig};
use mas_jose::StaticKeystore; use mas_jose::StaticKeystore;
use mas_templates::Templates; use mas_templates::Templates;
use mas_warp_utils::filters::cors::cors; use mas_warp_utils::filters::cors::cors;
@ -41,15 +41,14 @@ pub fn filter(
templates: &Templates, templates: &Templates,
key_store: &Arc<StaticKeystore>, key_store: &Arc<StaticKeystore>,
encrypter: &Encrypter, encrypter: &Encrypter,
clients_config: &ClientsConfig,
http_config: &HttpConfig, http_config: &HttpConfig,
) -> BoxedFilter<(impl Reply,)> { ) -> BoxedFilter<(impl Reply,)> {
let discovery = discovery(key_store.as_ref(), http_config); let discovery = discovery(key_store.as_ref(), http_config);
let keys = keys(key_store); let keys = keys(key_store);
let authorization = authorization(pool, templates, encrypter, clients_config); let authorization = authorization(pool, templates, encrypter);
let userinfo = userinfo(pool); let userinfo = userinfo(pool);
let introspection = introspection(pool, clients_config, http_config); let introspection = introspection(pool, encrypter, http_config);
let token = token(pool, key_store, clients_config, http_config); let token = token(pool, encrypter, key_store, http_config);
let filter = discovery let filter = discovery
.or(keys) .or(keys)

View File

@ -19,8 +19,8 @@ use chrono::{DateTime, Duration, Utc};
use data_encoding::BASE64URL_NOPAD; use data_encoding::BASE64URL_NOPAD;
use headers::{CacheControl, Pragma}; use headers::{CacheControl, Pragma};
use hyper::StatusCode; use hyper::StatusCode;
use mas_config::{ClientConfig, ClientsConfig, HttpConfig}; use mas_config::{Encrypter, HttpConfig};
use mas_data_model::{AuthorizationGrantStage, TokenType}; use mas_data_model::{AuthorizationGrantStage, Client, TokenType};
use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod}; use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod};
use mas_jose::{claims, DecodedJsonWebToken, SigningKeystore, StaticKeystore}; use mas_jose::{claims, DecodedJsonWebToken, SigningKeystore, StaticKeystore};
use mas_storage::{ use mas_storage::{
@ -33,7 +33,7 @@ use mas_storage::{
RefreshTokenLookupError, RefreshTokenLookupError,
}, },
}, },
DatabaseInconsistencyError, DatabaseInconsistencyError, PostgresqlBackend,
}; };
use mas_warp_utils::{ use mas_warp_utils::{
errors::WrapError, errors::WrapError,
@ -99,8 +99,8 @@ where
pub fn filter( pub fn filter(
pool: &PgPool, pool: &PgPool,
encrypter: &Encrypter,
key_store: &Arc<StaticKeystore>, key_store: &Arc<StaticKeystore>,
clients_config: &ClientsConfig,
http_config: &HttpConfig, http_config: &HttpConfig,
) -> BoxedFilter<(Box<dyn Reply>,)> { ) -> BoxedFilter<(Box<dyn Reply>,)> {
let key_store = key_store.clone(); let key_store = key_store.clone();
@ -113,7 +113,7 @@ pub fn filter(
.and(filters::trace::name("POST /oauth2/token")) .and(filters::trace::name("POST /oauth2/token"))
.and( .and(
warp::post() warp::post()
.and(client_authentication(clients_config, audience)) .and(client_authentication(pool, encrypter, audience))
.and(warp::any().map(move || key_store.clone())) .and(warp::any().map(move || key_store.clone()))
.and(warp::any().map(move || issuer.clone())) .and(warp::any().map(move || issuer.clone()))
.and(connection(pool)) .and(connection(pool))
@ -145,7 +145,7 @@ async fn recover(rejection: Rejection) -> Result<Box<dyn Reply>, Infallible> {
async fn token( async fn token(
_auth: OAuthClientAuthenticationMethod, _auth: OAuthClientAuthenticationMethod,
client: ClientConfig, client: Client<PostgresqlBackend>,
req: AccessTokenRequest, req: AccessTokenRequest,
key_store: Arc<StaticKeystore>, key_store: Arc<StaticKeystore>,
issuer: Url, issuer: Url,
@ -185,7 +185,7 @@ fn hash<H: Digest>(mut hasher: H, token: &str) -> anyhow::Result<String> {
#[allow(clippy::too_many_lines)] #[allow(clippy::too_many_lines)]
async fn authorization_code_grant( async fn authorization_code_grant(
grant: &AuthorizationCodeGrant, grant: &AuthorizationCodeGrant,
client: &ClientConfig, client: &Client<PostgresqlBackend>,
key_store: &StaticKeystore, key_store: &StaticKeystore,
issuer: Url, issuer: Url,
conn: &mut PoolConnection<Postgres>, conn: &mut PoolConnection<Postgres>,
@ -349,7 +349,7 @@ async fn authorization_code_grant(
async fn refresh_token_grant( async fn refresh_token_grant(
grant: &RefreshTokenGrant, grant: &RefreshTokenGrant,
client: &ClientConfig, client: &Client<PostgresqlBackend>,
conn: &mut PoolConnection<Postgres>, conn: &mut PoolConnection<Postgres>,
) -> Result<AccessTokenResponse, Rejection> { ) -> Result<AccessTokenResponse, Rejection> {
let mut txn = conn.begin().await.wrap_error()?; let mut txn = conn.begin().await.wrap_error()?;

View File

@ -17,7 +17,7 @@
use hyper::Uri; use hyper::Uri;
use mas_templates::PostAuthContext; use mas_templates::PostAuthContext;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use sqlx::PgExecutor; use sqlx::PgConnection;
use super::super::oauth2::ContinueAuthorizationGrant; use super::super::oauth2::ContinueAuthorizationGrant;
@ -36,11 +36,11 @@ impl PostAuthAction {
pub async fn load_context<'e>( pub async fn load_context<'e>(
&self, &self,
executor: impl PgExecutor<'e>, conn: &mut PgConnection,
) -> anyhow::Result<PostAuthContext> { ) -> anyhow::Result<PostAuthContext> {
match self { match self {
Self::ContinueAuthorizationGrant(c) => { Self::ContinueAuthorizationGrant(c) => {
let grant = c.fetch_authorization_grant(executor).await?; let grant = c.fetch_authorization_grant(conn).await?;
let grant = grant.into(); let grant = grant.into();
Ok(PostAuthContext::ContinueAuthorizationGrant { grant }) Ok(PostAuthContext::ContinueAuthorizationGrant { grant })
} }

View File

@ -21,3 +21,5 @@
pub mod jose; pub mod jose;
pub mod oauth; pub mod oauth;
pub use parse_display::ParseError;

View File

@ -31,7 +31,7 @@ sha2 = "0.10.2"
signature = "1.4.0" signature = "1.4.0"
thiserror = "1.0.30" thiserror = "1.0.30"
tokio = { version = "1.17.0", features = ["macros", "rt", "sync"] } tokio = { version = "1.17.0", features = ["macros", "rt", "sync"] }
tower = "0.4.12" tower = { version = "0.4.12", features = ["util"] }
url = { version = "2.2.2", features = ["serde"] } url = { version = "2.2.2", features = ["serde"] }
mas-iana = { path = "../iana" } mas-iana = { path = "../iana" }

View File

@ -22,6 +22,8 @@ pub(crate) mod jwk;
pub(crate) mod jwt; pub(crate) mod jwt;
mod keystore; mod keystore;
pub use futures_util::future::Either;
pub use self::{ pub use self::{
jwk::{JsonWebKey, JsonWebKeySet}, jwk::{JsonWebKey, JsonWebKeySet},
jwt::{DecodedJsonWebToken, JsonWebTokenParts, JwtHeader}, jwt::{DecodedJsonWebToken, JsonWebTokenParts, JwtHeader},

View File

@ -27,14 +27,21 @@ use url::Url;
use crate::requests::{Display, GrantType, ResponseMode}; use crate::requests::{Display, GrantType, ResponseMode};
#[derive(Serialize, Clone, Copy, PartialEq, Eq, Hash)] #[derive(Serialize, Clone, Copy, PartialEq, Eq, Hash, Debug)]
#[serde(rename_all = "lowercase")]
pub enum ApplicationType {
Web,
Native,
}
#[derive(Serialize, Clone, Copy, PartialEq, Eq, Hash, Debug)]
#[serde(rename_all = "lowercase")] #[serde(rename_all = "lowercase")]
pub enum SubjectType { pub enum SubjectType {
Public, Public,
Pairwise, Pairwise,
} }
#[derive(Serialize, Clone, Copy, PartialEq, Eq, Hash)] #[derive(Serialize, Clone, Copy, PartialEq, Eq, Hash, Debug)]
#[serde(rename_all = "lowercase")] #[serde(rename_all = "lowercase")]
pub enum ClaimType { pub enum ClaimType {
Normal, Normal,

View File

@ -192,6 +192,7 @@ pub struct ClientCredentialsGrant {
pub enum GrantType { pub enum GrantType {
AuthorizationCode, AuthorizationCode,
RefreshToken, RefreshToken,
Implicit,
ClientCredentials, ClientCredentials,
} }

View File

@ -7,9 +7,10 @@ license = "Apache-2.0"
[dependencies] [dependencies]
tokio = "1.17.0" tokio = "1.17.0"
sqlx = { version = "0.5.11", features = ["runtime-tokio-rustls", "postgres", "migrate", "chrono", "offline"] } sqlx = { version = "0.5.11", features = ["runtime-tokio-rustls", "postgres", "migrate", "chrono", "offline", "json"] }
chrono = { version = "0.4.19", features = ["serde"] } chrono = { version = "0.4.19", features = ["serde"] }
serde = { version = "1.0.136", features = ["derive"] } serde = { version = "1.0.136", features = ["derive"] }
serde_json = "1.0.79"
thiserror = "1.0.30" thiserror = "1.0.30"
anyhow = "1.0.55" anyhow = "1.0.55"
tracing = "0.1.31" tracing = "0.1.31"
@ -24,3 +25,4 @@ url = { version = "2.2.2", features = ["serde"] }
oauth2-types = { path = "../oauth2-types" } oauth2-types = { path = "../oauth2-types" }
mas-data-model = { path = "../data-model" } mas-data-model = { path = "../data-model" }
mas-iana = { path = "../iana" } mas-iana = { path = "../iana" }
mas-jose = { path = "../jose" }

View File

@ -0,0 +1,17 @@
-- Copyright 2021 The Matrix.org Foundation C.I.C.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
DROP TABLE oauth2_client_redirect_uris;
DROP TRIGGER set_timestamp ON oauth2_clients;
DROP TABLE oauth2_clients;

View File

@ -0,0 +1,51 @@
-- Copyright 2022 The Matrix.org Foundation C.I.C.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
CREATE TABLE oauth2_clients (
"id" BIGSERIAL PRIMARY KEY,
"client_id" TEXT NOT NULL UNIQUE,
"encrypted_client_secret" TEXT,
"response_types" TEXT[] NOT NULL,
"grant_type_authorization_code" BOOL NOT NULL,
"grant_type_refresh_token" BOOL NOT NULL,
"contacts" TEXT[] NOT NULL,
"client_name" TEXT,
"logo_uri" TEXT,
"client_uri" TEXT,
"policy_uri" TEXT,
"tos_uri" TEXT,
"jwks_uri" TEXT,
"jwks" JSONB,
"id_token_signed_response_alg" TEXT,
"token_endpoint_auth_method" TEXT,
"token_endpoint_auth_signing_alg" TEXT,
"initiate_login_uri" TEXT,
"created_at" TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT now(),
"updated_at" TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT now(),
-- jwks and jwks_uri can't be set at the same time
CHECK ("jwks" IS NULL OR "jwks_uri" IS NULL)
);
CREATE TRIGGER set_timestamp
BEFORE UPDATE ON oauth2_clients
FOR EACH ROW
EXECUTE PROCEDURE trigger_set_timestamp();
CREATE TABLE oauth2_client_redirect_uris (
"id" BIGSERIAL PRIMARY KEY,
"oauth2_client_id" BIGINT NOT NULL REFERENCES oauth2_clients (id) ON DELETE CASCADE,
"redirect_uri" TEXT NOT NULL
);

File diff suppressed because it is too large Load Diff

View File

@ -41,7 +41,7 @@ impl StorageBackend for PostgresqlBackend {
type AuthenticationData = i64; type AuthenticationData = i64;
type AuthorizationGrantData = i64; type AuthorizationGrantData = i64;
type BrowserSessionData = i64; type BrowserSessionData = i64;
type ClientData = (); type ClientData = i64;
type RefreshTokenData = i64; type RefreshTokenData = i64;
type SessionData = i64; type SessionData = i64;
type UserData = i64; type UserData = i64;

View File

@ -14,12 +14,11 @@
use anyhow::Context; use anyhow::Context;
use chrono::{DateTime, Duration, Utc}; use chrono::{DateTime, Duration, Utc};
use mas_data_model::{ use mas_data_model::{AccessToken, Authentication, BrowserSession, Session, User, UserEmail};
AccessToken, Authentication, BrowserSession, Client, Session, User, UserEmail, use sqlx::{PgConnection, PgExecutor};
};
use sqlx::PgExecutor;
use thiserror::Error; use thiserror::Error;
use super::client::{lookup_client_by_client_id, ClientFetchError};
use crate::{DatabaseInconsistencyError, IdAndCreationTime, PostgresqlBackend}; use crate::{DatabaseInconsistencyError, IdAndCreationTime, PostgresqlBackend};
pub async fn add_access_token( pub async fn add_access_token(
@ -83,6 +82,7 @@ pub struct OAuth2AccessTokenLookup {
#[error("failed to lookup access token")] #[error("failed to lookup access token")]
pub enum AccessTokenLookupError { pub enum AccessTokenLookupError {
Database(#[from] sqlx::Error), Database(#[from] sqlx::Error),
ClientFetch(#[from] ClientFetchError),
Inconsistency(#[from] DatabaseInconsistencyError), Inconsistency(#[from] DatabaseInconsistencyError),
} }
@ -95,7 +95,7 @@ impl AccessTokenLookupError {
#[allow(clippy::too_many_lines)] #[allow(clippy::too_many_lines)]
pub async fn lookup_active_access_token( pub async fn lookup_active_access_token(
executor: impl PgExecutor<'_>, conn: &mut PgConnection,
token: &str, token: &str,
) -> Result<(AccessToken<PostgresqlBackend>, Session<PostgresqlBackend>), AccessTokenLookupError> { ) -> Result<(AccessToken<PostgresqlBackend>, Session<PostgresqlBackend>), AccessTokenLookupError> {
let res = sqlx::query_as!( let res = sqlx::query_as!(
@ -142,7 +142,7 @@ pub async fn lookup_active_access_token(
"#, "#,
token, token,
) )
.fetch_one(executor) .fetch_one(&mut *conn)
.await?; .await?;
let access_token = AccessToken { let access_token = AccessToken {
@ -153,10 +153,7 @@ pub async fn lookup_active_access_token(
expires_after: Duration::seconds(res.access_token_expires_after.into()), expires_after: Duration::seconds(res.access_token_expires_after.into()),
}; };
let client = Client { let client = lookup_client_by_client_id(&mut *conn, &res.client_id).await?;
data: (),
client_id: res.client_id,
};
let primary_email = match ( let primary_email = match (
res.user_email_id, res.user_email_id,

View File

@ -24,15 +24,16 @@ use mas_data_model::{
}; };
use mas_iana::oauth::PkceCodeChallengeMethod; use mas_iana::oauth::PkceCodeChallengeMethod;
use oauth2_types::{requests::ResponseMode, scope::Scope}; use oauth2_types::{requests::ResponseMode, scope::Scope};
use sqlx::PgExecutor; use sqlx::{PgConnection, PgExecutor};
use url::Url; use url::Url;
use super::client::lookup_client_by_client_id;
use crate::{DatabaseInconsistencyError, IdAndCreationTime, PostgresqlBackend}; use crate::{DatabaseInconsistencyError, IdAndCreationTime, PostgresqlBackend};
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
pub async fn new_authorization_grant( pub async fn new_authorization_grant(
executor: impl PgExecutor<'_>, executor: impl PgExecutor<'_>,
client_id: String, client: Client<PostgresqlBackend>,
redirect_uri: Url, redirect_uri: Url,
scope: Scope, scope: Scope,
code: Option<AuthorizationCode>, code: Option<AuthorizationCode>,
@ -65,7 +66,7 @@ pub async fn new_authorization_grant(
($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14) ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14)
RETURNING id, created_at RETURNING id, created_at
"#, "#,
&client_id, &client.client_id,
redirect_uri.to_string(), redirect_uri.to_string(),
scope.to_string(), scope.to_string(),
state, state,
@ -85,11 +86,6 @@ pub async fn new_authorization_grant(
.await .await
.context("could not insert oauth2 authorization grant")?; .context("could not insert oauth2 authorization grant")?;
let client = Client {
data: (),
client_id,
};
Ok(AuthorizationGrant { Ok(AuthorizationGrant {
data: res.id, data: res.id,
stage: AuthorizationGrantStage::Pending, stage: AuthorizationGrantStage::Pending,
@ -141,20 +137,21 @@ struct GrantLookup {
user_email_confirmed_at: Option<DateTime<Utc>>, user_email_confirmed_at: Option<DateTime<Utc>>,
} }
impl TryInto<AuthorizationGrant<PostgresqlBackend>> for GrantLookup { impl GrantLookup {
type Error = DatabaseInconsistencyError;
#[allow(clippy::too_many_lines)] #[allow(clippy::too_many_lines)]
fn try_into(self) -> Result<AuthorizationGrant<PostgresqlBackend>, Self::Error> { async fn into_authorization_grant(
self,
executor: impl PgExecutor<'_>,
) -> Result<AuthorizationGrant<PostgresqlBackend>, DatabaseInconsistencyError> {
let scope: Scope = self let scope: Scope = self
.grant_scope .grant_scope
.parse() .parse()
.map_err(|_e| DatabaseInconsistencyError)?; .map_err(|_e| DatabaseInconsistencyError)?;
let client = Client { // TODO: don't unwrap
data: (), let client = lookup_client_by_client_id(executor, &self.client_id)
client_id: self.client_id, .await
}; .unwrap();
let last_authentication = match ( let last_authentication = match (
self.user_session_last_authentication_id, self.user_session_last_authentication_id,
@ -323,7 +320,7 @@ impl TryInto<AuthorizationGrant<PostgresqlBackend>> for GrantLookup {
} }
pub async fn get_grant_by_id( pub async fn get_grant_by_id(
executor: impl PgExecutor<'_>, conn: &mut PgConnection,
id: i64, id: i64,
) -> anyhow::Result<AuthorizationGrant<PostgresqlBackend>> { ) -> anyhow::Result<AuthorizationGrant<PostgresqlBackend>> {
// TODO: handle "not found" cases // TODO: handle "not found" cases
@ -381,17 +378,17 @@ pub async fn get_grant_by_id(
"#, "#,
id, id,
) )
.fetch_one(executor) .fetch_one(&mut *conn)
.await .await
.context("failed to get grant by id")?; .context("failed to get grant by id")?;
let grant = res.try_into()?; let grant = res.into_authorization_grant(&mut *conn).await?;
Ok(grant) Ok(grant)
} }
pub async fn lookup_grant_by_code( pub async fn lookup_grant_by_code(
executor: impl PgExecutor<'_>, conn: &mut PgConnection,
code: &str, code: &str,
) -> anyhow::Result<AuthorizationGrant<PostgresqlBackend>> { ) -> anyhow::Result<AuthorizationGrant<PostgresqlBackend>> {
// TODO: handle "not found" cases // TODO: handle "not found" cases
@ -449,11 +446,11 @@ pub async fn lookup_grant_by_code(
"#, "#,
code, code,
) )
.fetch_one(executor) .fetch_one(&mut *conn)
.await .await
.context("failed to lookup grant by code")?; .context("failed to lookup grant by code")?;
let grant = res.try_into()?; let grant = res.into_authorization_grant(&mut *conn).await?;
Ok(grant) Ok(grant)
} }

View File

@ -0,0 +1,380 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::string::ToString;
use mas_data_model::{Client, JwksOrJwksUri};
use mas_iana::oauth::{OAuthAuthorizationEndpointResponseType, OAuthClientAuthenticationMethod};
use mas_jose::JsonWebKeySet;
use oauth2_types::requests::GrantType;
use sqlx::{PgConnection, PgExecutor};
use thiserror::Error;
use url::Url;
use warp::reject::Reject;
use crate::PostgresqlBackend;
#[derive(Debug)]
pub struct OAuth2ClientLookup {
id: i64,
client_id: String,
encrypted_client_secret: Option<String>,
redirect_uris: Vec<String>,
response_types: Vec<String>,
grant_type_authorization_code: bool,
grant_type_refresh_token: bool,
contacts: Vec<String>,
client_name: Option<String>,
logo_uri: Option<String>,
client_uri: Option<String>,
policy_uri: Option<String>,
tos_uri: Option<String>,
jwks_uri: Option<String>,
jwks: Option<serde_json::Value>,
id_token_signed_response_alg: Option<String>,
token_endpoint_auth_method: Option<String>,
token_endpoint_auth_signing_alg: Option<String>,
initiate_login_uri: Option<String>,
}
#[derive(Debug, Error)]
pub enum ClientFetchError {
#[error("malformed jwks column")]
MalformedJwks(#[source] serde_json::Error),
#[error("entry has both a jwks and a jwks_uri")]
BothJwksAndJwksUri,
#[error("could not parse URL in field {field:?}")]
ParseUrl {
field: &'static str,
source: url::ParseError,
},
#[error("could not parse field {field:?}")]
ParseField {
field: &'static str,
source: mas_iana::ParseError,
},
#[error(transparent)]
Database(#[from] sqlx::Error),
}
impl ClientFetchError {
#[must_use]
pub fn not_found(&self) -> bool {
matches!(self, Self::Database(sqlx::Error::RowNotFound))
}
}
impl Reject for ClientFetchError {}
impl TryInto<Client<PostgresqlBackend>> for OAuth2ClientLookup {
type Error = ClientFetchError;
#[allow(clippy::too_many_lines)] // TODO: refactor some of the field parsing
fn try_into(self) -> Result<Client<PostgresqlBackend>, Self::Error> {
let redirect_uris: Result<Vec<Url>, _> =
self.redirect_uris.iter().map(|s| s.parse()).collect();
let redirect_uris = redirect_uris.map_err(|source| ClientFetchError::ParseUrl {
field: "redirect_uris",
source,
})?;
let response_types: Result<Vec<OAuthAuthorizationEndpointResponseType>, _> =
self.response_types.iter().map(|s| s.parse()).collect();
let response_types = response_types.map_err(|source| ClientFetchError::ParseField {
field: "response_types",
source,
})?;
let mut grant_types = Vec::new();
if self.grant_type_authorization_code {
grant_types.push(GrantType::AuthorizationCode);
}
if self.grant_type_refresh_token {
grant_types.push(GrantType::RefreshToken);
}
let logo_uri = self
.logo_uri
.map(|s| s.parse())
.transpose()
.map_err(|source| ClientFetchError::ParseUrl {
field: "logo_uri",
source,
})?;
let client_uri = self
.client_uri
.map(|s| s.parse())
.transpose()
.map_err(|source| ClientFetchError::ParseUrl {
field: "client_uri",
source,
})?;
let policy_uri = self
.policy_uri
.map(|s| s.parse())
.transpose()
.map_err(|source| ClientFetchError::ParseUrl {
field: "policy_uri",
source,
})?;
let tos_uri = self
.tos_uri
.map(|s| s.parse())
.transpose()
.map_err(|source| ClientFetchError::ParseUrl {
field: "tos_uri",
source,
})?;
let id_token_signed_response_alg = self
.id_token_signed_response_alg
.map(|s| s.parse())
.transpose()
.map_err(|source| ClientFetchError::ParseField {
field: "id_token_signed_response_alg",
source,
})?;
let token_endpoint_auth_method = self
.token_endpoint_auth_method
.map(|s| s.parse())
.transpose()
.map_err(|source| ClientFetchError::ParseField {
field: "token_endpoint_auth_method",
source,
})?;
let token_endpoint_auth_signing_alg = self
.token_endpoint_auth_signing_alg
.map(|s| s.parse())
.transpose()
.map_err(|source| ClientFetchError::ParseField {
field: "token_endpoint_auth_signing_alg",
source,
})?;
let initiate_login_uri = self
.initiate_login_uri
.map(|s| s.parse())
.transpose()
.map_err(|source| ClientFetchError::ParseUrl {
field: "initiate_login_uri",
source,
})?;
let jwks = match (self.jwks, self.jwks_uri) {
(None, None) => None,
(Some(jwks), None) => {
let jwks = serde_json::from_value(jwks).map_err(ClientFetchError::MalformedJwks)?;
Some(JwksOrJwksUri::Jwks(jwks))
}
(None, Some(jwks_uri)) => {
let jwks_uri = jwks_uri
.parse()
.map_err(|source| ClientFetchError::ParseUrl {
field: "jwks_uri",
source,
})?;
Some(JwksOrJwksUri::JwksUri(jwks_uri))
}
_ => return Err(ClientFetchError::BothJwksAndJwksUri),
};
Ok(Client {
data: self.id,
client_id: self.client_id,
encrypted_client_secret: self.encrypted_client_secret,
redirect_uris,
response_types,
grant_types,
contacts: self.contacts,
client_name: self.client_name,
logo_uri,
client_uri,
policy_uri,
tos_uri,
jwks,
id_token_signed_response_alg,
token_endpoint_auth_method,
token_endpoint_auth_signing_alg,
initiate_login_uri,
})
}
}
pub async fn lookup_client(
executor: impl PgExecutor<'_>,
id: i64,
) -> Result<Client<PostgresqlBackend>, ClientFetchError> {
let res = sqlx::query_as!(
OAuth2ClientLookup,
r#"
SELECT
c.id,
c.client_id,
c.encrypted_client_secret,
ARRAY(SELECT redirect_uri FROM oauth2_client_redirect_uris r WHERE r.oauth2_client_id = c.id) AS "redirect_uris!",
c.response_types,
c.grant_type_authorization_code,
c.grant_type_refresh_token,
c.contacts,
c.client_name,
c.logo_uri,
c.client_uri,
c.policy_uri,
c.tos_uri,
c.jwks_uri,
c.jwks,
c.id_token_signed_response_alg,
c.token_endpoint_auth_method,
c.token_endpoint_auth_signing_alg,
c.initiate_login_uri
FROM oauth2_clients c
WHERE c.id = $1
"#,
id,
)
.fetch_one(executor)
.await?;
let client = res.try_into()?;
Ok(client)
}
pub async fn lookup_client_by_client_id(
executor: impl PgExecutor<'_>,
client_id: &str,
) -> Result<Client<PostgresqlBackend>, ClientFetchError> {
let res = sqlx::query_as!(
OAuth2ClientLookup,
r#"
SELECT
c.id,
c.client_id,
c.encrypted_client_secret,
ARRAY(SELECT redirect_uri FROM oauth2_client_redirect_uris r WHERE r.oauth2_client_id = c.id) AS "redirect_uris!",
c.response_types,
c.grant_type_authorization_code,
c.grant_type_refresh_token,
c.contacts,
c.client_name,
c.logo_uri,
c.client_uri,
c.policy_uri,
c.tos_uri,
c.jwks_uri,
c.jwks,
c.id_token_signed_response_alg,
c.token_endpoint_auth_method,
c.token_endpoint_auth_signing_alg,
c.initiate_login_uri
FROM oauth2_clients c
WHERE c.client_id = $1
"#,
client_id,
)
.fetch_one(executor)
.await?;
let client = res.try_into()?;
Ok(client)
}
pub async fn insert_client_from_config(
conn: &mut PgConnection,
client_id: &str,
client_auth_method: OAuthClientAuthenticationMethod,
encrypted_client_secret: Option<&str>,
jwks: Option<&JsonWebKeySet>,
jwks_uri: Option<&Url>,
redirect_uris: &[Url],
) -> anyhow::Result<()> {
let response_types = vec![
OAuthAuthorizationEndpointResponseType::Code.to_string(),
OAuthAuthorizationEndpointResponseType::CodeIdToken.to_string(),
OAuthAuthorizationEndpointResponseType::CodeIdTokenToken.to_string(),
OAuthAuthorizationEndpointResponseType::CodeToken.to_string(),
OAuthAuthorizationEndpointResponseType::IdToken.to_string(),
OAuthAuthorizationEndpointResponseType::IdTokenToken.to_string(),
OAuthAuthorizationEndpointResponseType::None.to_string(),
OAuthAuthorizationEndpointResponseType::Token.to_string(),
];
let jwks = jwks.map(serde_json::to_value).transpose()?;
let jwks_uri = jwks_uri.map(Url::as_str);
let client_auth_method = client_auth_method.to_string();
let id = sqlx::query_scalar!(
r#"
INSERT INTO oauth2_clients
(client_id,
encrypted_client_secret,
response_types,
grant_type_authorization_code,
grant_type_refresh_token,
token_endpoint_auth_method,
jwks,
jwks_uri,
contacts)
VALUES
($1, $2, $3, $4, $5, $6, $7, $8, '{}')
RETURNING id
"#,
client_id,
encrypted_client_secret,
&response_types,
true,
true,
client_auth_method,
jwks,
jwks_uri,
)
.fetch_one(&mut *conn)
.await?;
let redirect_uris: Vec<String> = redirect_uris.iter().map(ToString::to_string).collect();
sqlx::query!(
r#"
INSERT INTO oauth2_client_redirect_uris (oauth2_client_id, redirect_uri)
SELECT $1, uri FROM UNNEST($2::text[]) uri
"#,
id,
&redirect_uris,
)
.execute(&mut *conn)
.await?;
Ok(())
}
pub async fn truncate_clients(executor: impl PgExecutor<'_>) -> anyhow::Result<()> {
sqlx::query!("TRUNCATE oauth2_client_redirect_uris, oauth2_clients")
.execute(executor)
.await?;
Ok(())
}

View File

@ -19,6 +19,7 @@ use crate::PostgresqlBackend;
pub mod access_token; pub mod access_token;
pub mod authorization_grant; pub mod authorization_grant;
pub mod client;
pub mod refresh_token; pub mod refresh_token;
pub async fn end_oauth_session( pub async fn end_oauth_session(

View File

@ -15,12 +15,13 @@
use anyhow::Context; use anyhow::Context;
use chrono::{DateTime, Duration, Utc}; use chrono::{DateTime, Duration, Utc};
use mas_data_model::{ use mas_data_model::{
AccessToken, Authentication, BrowserSession, Client, RefreshToken, Session, User, UserEmail, AccessToken, Authentication, BrowserSession, RefreshToken, Session, User, UserEmail,
}; };
use sqlx::PgExecutor; use sqlx::{PgConnection, PgExecutor};
use thiserror::Error; use thiserror::Error;
use warp::reject::Reject; use warp::reject::Reject;
use super::client::{lookup_client_by_client_id, ClientFetchError};
use crate::{DatabaseInconsistencyError, IdAndCreationTime, PostgresqlBackend}; use crate::{DatabaseInconsistencyError, IdAndCreationTime, PostgresqlBackend};
pub async fn add_refresh_token( pub async fn add_refresh_token(
@ -82,6 +83,7 @@ struct OAuth2RefreshTokenLookup {
#[error("could not lookup refresh token")] #[error("could not lookup refresh token")]
pub enum RefreshTokenLookupError { pub enum RefreshTokenLookupError {
Fetch(#[from] sqlx::Error), Fetch(#[from] sqlx::Error),
ClientFetch(#[from] ClientFetchError),
Conversion(#[from] DatabaseInconsistencyError), Conversion(#[from] DatabaseInconsistencyError),
} }
@ -96,7 +98,7 @@ impl RefreshTokenLookupError {
#[allow(clippy::too_many_lines)] #[allow(clippy::too_many_lines)]
pub async fn lookup_active_refresh_token( pub async fn lookup_active_refresh_token(
executor: impl PgExecutor<'_>, conn: &mut PgConnection,
token: &str, token: &str,
) -> Result<(RefreshToken<PostgresqlBackend>, Session<PostgresqlBackend>), RefreshTokenLookupError> ) -> Result<(RefreshToken<PostgresqlBackend>, Session<PostgresqlBackend>), RefreshTokenLookupError>
{ {
@ -148,7 +150,7 @@ pub async fn lookup_active_refresh_token(
"#, "#,
token, token,
) )
.fetch_one(executor) .fetch_one(&mut *conn)
.await?; .await?;
let access_token = match ( let access_token = match (
@ -175,10 +177,7 @@ pub async fn lookup_active_refresh_token(
access_token, access_token,
}; };
let client = Client { let client = lookup_client_by_client_id(&mut *conn, &res.client_id).await?;
data: (),
client_id: res.client_id,
};
let primary_email = match ( let primary_email = match (
res.user_email_id, res.user_email_id,

View File

@ -28,6 +28,9 @@ mime = "0.3.16"
bincode = "1.3.3" bincode = "1.3.3"
crc = "2.1.0" crc = "2.1.0"
url = "2.2.2" url = "2.2.2"
http = "0.2.6"
http-body = "0.4.4"
tower = { version = "0.4.12", features = ["util"] }
oauth2-types = { path = "../oauth2-types" } oauth2-types = { path = "../oauth2-types" }
mas-config = { path = "../config" } mas-config = { path = "../config" }
@ -36,6 +39,4 @@ mas-data-model = { path = "../data-model" }
mas-storage = { path = "../storage" } mas-storage = { path = "../storage" }
mas-jose = { path = "../jose" } mas-jose = { path = "../jose" }
mas-iana = { path = "../iana" } mas-iana = { path = "../iana" }
mas-http = { path = "../http" }
[dev-dependencies]
tower = { version = "0.4.12", features = ["util"] }

View File

@ -87,6 +87,10 @@ pub fn authentication(
.untuple_one() .untuple_one()
} }
fn ensure<T: Clone + Send + Sync + 'static>(t: T) -> T {
t
}
async fn authenticate( async fn authenticate(
mut conn: PoolConnection<Postgres>, mut conn: PoolConnection<Postgres>,
auth: Authorization<Bearer>, auth: Authorization<Bearer>,
@ -110,6 +114,9 @@ async fn authenticate(
} }
})?; })?;
let session = ensure(session);
let token = ensure(token);
Ok((token, session)) Ok((token, session))
} }

View File

@ -16,30 +16,49 @@
use std::collections::HashMap; use std::collections::HashMap;
use data_encoding::BASE64;
use headers::{authorization::Basic, Authorization}; use headers::{authorization::Basic, Authorization};
use mas_config::{ClientAuthMethodConfig, ClientConfig, ClientsConfig}; use mas_config::Encrypter;
use mas_data_model::{Client, JwksOrJwksUri, StorageBackend};
use mas_http::HttpServiceExt;
use mas_iana::oauth::OAuthClientAuthenticationMethod; use mas_iana::oauth::OAuthClientAuthenticationMethod;
use mas_jose::{ use mas_jose::{
claims::{TimeOptions, AUD, EXP, IAT, ISS, JTI, NBF, SUB}, claims::{TimeOptions, AUD, EXP, IAT, ISS, JTI, NBF, SUB},
DecodedJsonWebToken, JsonWebTokenParts, SharedSecret, DecodedJsonWebToken, DynamicJwksStore, Either, JsonWebKeySet, JsonWebTokenParts, SharedSecret,
StaticJwksStore, VerifyingKeystore,
};
use mas_storage::{
oauth2::client::{lookup_client_by_client_id, ClientFetchError},
PostgresqlBackend,
}; };
use serde::{de::DeserializeOwned, Deserialize}; use serde::{de::DeserializeOwned, Deserialize};
use sqlx::{pool::PoolConnection, PgPool, Postgres};
use thiserror::Error; use thiserror::Error;
use tower::{BoxError, ServiceExt};
use warp::{reject::Reject, Filter, Rejection}; use warp::{reject::Reject, Filter, Rejection};
use super::headers::typed_header; use super::{database::connection, headers::typed_header};
use crate::errors::WrapError; use crate::errors::WrapError;
/// Protect an enpoint with client authentication /// Protect an enpoint with client authentication
#[must_use] #[must_use]
pub fn client_authentication<T: DeserializeOwned + Send + 'static>( pub fn client_authentication<T: DeserializeOwned + Send + 'static>(
clients_config: &ClientsConfig, pool: &PgPool,
encrypter: &Encrypter,
audience: String, audience: String,
) -> impl Filter<Extract = (OAuthClientAuthenticationMethod, ClientConfig, T), Error = Rejection> ) -> impl Filter<
+ Clone Extract = (
OAuthClientAuthenticationMethod,
Client<PostgresqlBackend>,
T,
),
Error = Rejection,
> + Clone
+ Send + Send
+ Sync + Sync
+ 'static { + 'static {
let encrypter = encrypter.clone();
// First, extract the client credentials // First, extract the client credentials
let credentials = typed_header() let credentials = typed_header()
.and(warp::body::form()) .and(warp::body::form())
@ -65,9 +84,9 @@ pub fn client_authentication<T: DeserializeOwned + Send + 'static>(
.unify() .unify()
.untuple_one(); .untuple_one();
let clients_config = clients_config.clone();
warp::any() warp::any()
.map(move || clients_config.clone()) .and(connection(pool))
.and(warp::any().map(move || encrypter.clone()))
.and(warp::any().map(move || audience.clone())) .and(warp::any().map(move || audience.clone()))
.and(credentials) .and(credentials)
.and_then(authenticate_client) .and_then(authenticate_client)
@ -79,8 +98,20 @@ enum ClientAuthenticationError {
#[error("wrong client secret for client {client_id:?}")] #[error("wrong client secret for client {client_id:?}")]
ClientSecretMismatch { client_id: String }, ClientSecretMismatch { client_id: String },
#[error("could not find client {client_id:?}")] #[error("could not fetch client {client_id:?}")]
ClientNotFound { client_id: String }, ClientFetch {
client_id: String,
source: ClientFetchError,
},
#[error("client {client_id:?} has an invalid client secret")]
InvalidClientSecret {
client_id: String,
source: anyhow::Error,
},
#[error("client {client_id:?} has an invalid JWKS")]
InvalidJwks { client_id: String },
#[error("wrong client authentication method for client {client_id:?}")] #[error("wrong client authentication method for client {client_id:?}")]
WrongAuthenticationMethod { client_id: String }, WrongAuthenticationMethod { client_id: String },
@ -94,68 +125,136 @@ enum ClientAuthenticationError {
impl Reject for ClientAuthenticationError {} impl Reject for ClientAuthenticationError {}
fn decrypt_client_secret<T: StorageBackend>(
client: &Client<T>,
encrypter: &Encrypter,
) -> anyhow::Result<Vec<u8>> {
let encrypted_client_secret = client
.encrypted_client_secret
.as_ref()
.ok_or_else(|| anyhow::anyhow!("missing encrypted_client_secret field"))?;
let encrypted_client_secret = BASE64.decode(encrypted_client_secret.as_bytes())?;
let nonce: &[u8; 12] = encrypted_client_secret
.get(0..12)
.ok_or_else(|| anyhow::anyhow!("invalid payload serialization"))?
.try_into()?;
let payload = encrypted_client_secret
.get(12..)
.ok_or_else(|| anyhow::anyhow!("invalid payload serialization"))?;
let decrypted_client_secret = encrypter.decrypt(nonce, payload)?;
Ok(decrypted_client_secret)
}
fn jwks_key_store(jwks: &JwksOrJwksUri) -> Either<StaticJwksStore, DynamicJwksStore> {
// Assert that the output is both a VerifyingKeystore and Send
fn assert<T: Send + VerifyingKeystore>(t: T) -> T {
t
}
let inner = match jwks {
JwksOrJwksUri::Jwks(jwks) => Either::Left(StaticJwksStore::new(jwks.clone())),
JwksOrJwksUri::JwksUri(uri) => {
let uri = uri.clone();
// TODO: get the client from somewhere else?
let exporter = mas_http::client("fetch-jwks")
.json::<JsonWebKeySet>()
.map_request(move |_: ()| {
http::Request::builder()
.method("GET")
// TODO: change the Uri type in config to avoid reparsing here
.uri(uri.to_string())
.body(http_body::Empty::new())
.unwrap()
})
.map_response(http::Response::into_body)
.map_err(BoxError::from)
.boxed_clone();
Either::Right(DynamicJwksStore::new(exporter))
}
};
assert(inner)
}
#[allow(clippy::too_many_lines)] #[allow(clippy::too_many_lines)]
#[tracing::instrument(skip_all, fields(enduser.id), err(Debug))] #[tracing::instrument(skip_all, fields(enduser.id), err(Debug))]
async fn authenticate_client<T>( async fn authenticate_client<T>(
clients_config: ClientsConfig, mut conn: PoolConnection<Postgres>,
encrypter: Encrypter,
audience: String, audience: String,
credentials: ClientCredentials, credentials: ClientCredentials,
body: T, body: T,
) -> Result<(OAuthClientAuthenticationMethod, ClientConfig, T), Rejection> { ) -> Result<
(
OAuthClientAuthenticationMethod,
Client<PostgresqlBackend>,
T,
),
Rejection,
> {
let (auth_method, client) = match credentials { let (auth_method, client) = match credentials {
ClientCredentials::Pair { ClientCredentials::Pair {
client_id, client_id,
client_secret, client_secret,
via, via,
} => { } => {
let client = clients_config let client = lookup_client_by_client_id(&mut *conn, &client_id)
.iter() .await
.find(|client| client.client_id == client_id) .map_err(|source| ClientAuthenticationError::ClientFetch {
.ok_or_else(|| ClientAuthenticationError::ClientNotFound { client_id: client_id.clone(),
client_id: client_id.to_string(), source,
})?; })?;
let auth_method = match (&client.client_auth_method, client_secret, via) { let auth_method = client.token_endpoint_auth_method.ok_or(
(ClientAuthMethodConfig::None, None, _) => OAuthClientAuthenticationMethod::None, ClientAuthenticationError::WrongAuthenticationMethod {
client_id: client.client_id.clone(),
},
)?;
// Let's match the authentication method
match (auth_method, client_secret, via) {
(OAuthClientAuthenticationMethod::None, None, _) => {}
( (
ClientAuthMethodConfig::ClientSecretBasic { OAuthClientAuthenticationMethod::ClientSecretBasic,
client_secret: ref expected_client_secret, Some(client_secret),
},
Some(ref given_client_secret),
CredentialsVia::AuthorizationHeader, CredentialsVia::AuthorizationHeader,
) => { )
if expected_client_secret != given_client_secret { | (
return Err( OAuthClientAuthenticationMethod::ClientSecretPost,
ClientAuthenticationError::ClientSecretMismatch { client_id }.into(), Some(client_secret),
);
}
OAuthClientAuthenticationMethod::ClientSecretBasic
}
(
ClientAuthMethodConfig::ClientSecretPost {
client_secret: ref expected_client_secret,
},
Some(ref given_client_secret),
CredentialsVia::FormBody, CredentialsVia::FormBody,
) => { ) => {
if expected_client_secret != given_client_secret { let decrypted =
return Err( decrypt_client_secret(&client, &encrypter).map_err(|source| {
ClientAuthenticationError::ClientSecretMismatch { client_id }.into(), ClientAuthenticationError::InvalidClientSecret {
); client_id: client.client_id.clone(),
source,
}
})?;
if client_secret.as_bytes() != decrypted {
return Err(warp::reject::custom(
ClientAuthenticationError::ClientSecretMismatch {
client_id: client.client_id,
},
));
} }
OAuthClientAuthenticationMethod::ClientSecretPost
} }
_ => { _ => {
return Err( return Err(warp::reject::custom(
ClientAuthenticationError::WrongAuthenticationMethod { client_id }.into(), ClientAuthenticationError::WrongAuthenticationMethod {
) client_id: client.client_id,
},
));
} }
}; }
(auth_method, client) (auth_method, client)
} }
@ -195,34 +294,52 @@ async fn authenticate_client<T>(
// from the token, as per rfc7521 sec. 4.2 // from the token, as per rfc7521 sec. 4.2
let client_id = client_id.as_ref().unwrap_or(&sub); let client_id = client_id.as_ref().unwrap_or(&sub);
let client = clients_config let client = lookup_client_by_client_id(&mut *conn, client_id)
.iter() .await
.find(|client| &client.client_id == client_id) .map_err(|source| ClientAuthenticationError::ClientFetch {
.ok_or_else(|| ClientAuthenticationError::ClientNotFound {
client_id: client_id.to_string(), client_id: client_id.to_string(),
source,
})?; })?;
let auth_method = match &client.client_auth_method { let auth_method = client.token_endpoint_auth_method.ok_or(
ClientAuthMethodConfig::PrivateKeyJwt(jwks) => { ClientAuthenticationError::WrongAuthenticationMethod {
let store = jwks.key_store(); client_id: client.client_id.clone(),
},
)?;
match auth_method {
OAuthClientAuthenticationMethod::ClientSecretJwt => {
let client_secret =
decrypt_client_secret(&client, &encrypter).map_err(|source| {
ClientAuthenticationError::InvalidClientSecret {
client_id: client.client_id.clone(),
source,
}
})?;
let store = SharedSecret::new(&client_secret);
let fut = token.verify(&decoded, &store); let fut = token.verify(&decoded, &store);
fut.await.wrap_error()?; fut.await.wrap_error()?;
OAuthClientAuthenticationMethod::PrivateKeyJwt
} }
OAuthClientAuthenticationMethod::PrivateKeyJwt => {
let jwks = client.jwks.as_ref().ok_or_else(|| {
ClientAuthenticationError::InvalidJwks {
client_id: client.client_id.clone(),
}
})?;
ClientAuthMethodConfig::ClientSecretJwt { client_secret } => { let store = jwks_key_store(jwks);
let store = SharedSecret::new(client_secret); let fut = token.verify(&decoded, &store);
token.verify(&decoded, &store).await.wrap_error()?; fut.await.wrap_error()?;
OAuthClientAuthenticationMethod::ClientSecretJwt
} }
_ => { _ => {
return Err(ClientAuthenticationError::WrongAuthenticationMethod { return Err(warp::reject::custom(
client_id: client_id.clone(), ClientAuthenticationError::WrongAuthenticationMethod {
} client_id: client.client_id,
.into()) },
));
} }
}; }
// rfc7523 sec. 3.3: the audience is the URL being called // rfc7523 sec. 3.3: the audience is the URL being called
if !aud.contains(&audience) { if !aud.contains(&audience) {
@ -243,7 +360,7 @@ async fn authenticate_client<T>(
tracing::Span::current().record("enduser.id", &client.client_id.as_str()); tracing::Span::current().record("enduser.id", &client.client_id.as_str());
Ok((auth_method, client.clone(), body)) Ok((auth_method, client, body))
} }
#[derive(Deserialize)] #[derive(Deserialize)]
@ -291,6 +408,7 @@ struct ClientAuthForm<T> {
body: T, body: T,
} }
/* TODO: all secrets are broken because there is no way to mock the DB yet
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use headers::authorization::Credentials; use headers::authorization::Credentials;
@ -651,3 +769,4 @@ mod tests {
assert_eq!(body.bar, "foobar"); assert_eq!(body.bar, "foobar");
} }
} }
*/