1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-28 11:02:02 +03:00

Move clients to the database

This commit is contained in:
Quentin Gliech
2022-03-08 17:33:25 +01:00
parent 19a81afe51
commit 62f633a716
33 changed files with 1926 additions and 867 deletions

14
Cargo.lock generated
View File

@ -1828,6 +1828,7 @@ dependencies = [
"argon2",
"atty",
"clap",
"data-encoding",
"dotenv",
"futures 0.3.21",
"hyper",
@ -1845,6 +1846,7 @@ dependencies = [
"opentelemetry-otlp",
"opentelemetry-semantic-conventions",
"opentelemetry-zipkin",
"rand",
"reqwest",
"schemars",
"serde_json",
@ -1870,12 +1872,9 @@ dependencies = [
"chrono",
"elliptic-curve",
"figment",
"futures-util",
"http",
"http-body",
"indoc",
"lettre",
"mas-http",
"mas-iana",
"mas-jose",
"p256",
"pem-rfc7468",
@ -1889,7 +1888,6 @@ dependencies = [
"sqlx",
"thiserror",
"tokio",
"tower",
"tracing",
"url",
]
@ -1901,6 +1899,7 @@ dependencies = [
"chrono",
"crc",
"mas-iana",
"mas-jose",
"oauth2-types",
"rand",
"serde",
@ -2068,10 +2067,12 @@ dependencies = [
"chrono",
"mas-data-model",
"mas-iana",
"mas-jose",
"oauth2-types",
"password-hash",
"rand",
"serde",
"serde_json",
"sqlx",
"thiserror",
"tokio",
@ -2123,9 +2124,12 @@ dependencies = [
"crc",
"data-encoding",
"headers",
"http",
"http-body",
"hyper",
"mas-config",
"mas-data-model",
"mas-http",
"mas-iana",
"mas-jose",
"mas-storage",

View File

@ -22,6 +22,8 @@ argon2 = { version = "0.3.4", features = ["password-hash"] }
reqwest = { version = "0.11.9", features = ["rustls-tls"], default-features = false, optional = true }
watchman_client = "0.7.1"
atty = "0.2.14"
rand = "0.8.5"
data-encoding = "2.3.2"
tracing = "0.1.31"
tracing-appender = "0.2.1"

View File

@ -14,9 +14,13 @@
use argon2::Argon2;
use clap::Parser;
use mas_config::DatabaseConfig;
use mas_storage::user::{
lookup_user_by_username, lookup_user_email, mark_user_email_as_verified, register_user,
use data_encoding::BASE64;
use mas_config::{DatabaseConfig, RootConfig};
use mas_storage::{
oauth2::client::{insert_client_from_config, lookup_client_by_client_id, truncate_clients},
user::{
lookup_user_by_username, lookup_user_email, mark_user_email_as_verified, register_user,
},
};
use tracing::{info, warn};
@ -36,6 +40,13 @@ enum Subcommand {
/// Mark email address as verified
VerifyEmail { username: String, email: String },
/// Import clients from config
ImportClients {
/// Remove all clients before importing
#[clap(long)]
truncate: bool,
},
}
impl Options {
@ -71,6 +82,65 @@ impl Options {
txn.commit().await?;
info!(?email, "Email marked as verified");
Ok(())
}
SC::ImportClients { truncate } => {
let config: RootConfig = root.load_config()?;
let pool = config.database.connect().await?;
let encrypter = config.secrets.encrypter();
let mut txn = pool.begin().await?;
if *truncate {
warn!("Removing all clients first");
truncate_clients(&mut txn).await?;
}
for client in config.clients.iter() {
let client_id = &client.client_id;
let res = lookup_client_by_client_id(&mut txn, client_id).await;
match res {
Ok(_) => {
warn!(%client_id, "Skipping already imported client");
continue;
}
Err(e) if e.not_found() => {}
Err(e) => anyhow::bail!(e),
}
info!(%client_id, "Importing client");
let client_secret = client.client_secret();
let client_auth_method = client.client_auth_method();
let jwks = client.jwks();
let jwks_uri = client.jwks_uri();
let redirect_uris = &client.redirect_uris;
// TODO: should be moved somewhere else
let encrypted_client_secret = client_secret
.map(|client_secret| {
let nonce: [u8; 12] = rand::random();
let message = encrypter.encrypt(&nonce, client_secret.as_bytes())?;
let concat = [&nonce[..], &message[..]].concat();
let res = BASE64.encode(&concat);
anyhow::Ok(res)
})
.transpose()?;
insert_client_from_config(
&mut txn,
client_id,
client_auth_method,
encrypted_client_secret.as_deref(),
jwks,
jwks_uri,
redirect_uris,
)
.await?;
}
txn.commit().await?;
Ok(())
}
}

View File

@ -35,8 +35,4 @@ pem-rfc7468 = "0.3.1"
indoc = "1.0.4"
mas-jose = { path = "../jose" }
mas-http = { path = "../http" }
tower = { version = "0.4.12", features = ["util"] }
http = "0.2.6"
http-body = "0.4.4"
futures-util = "0.3.21"
mas-iana = { path = "../iana" }

View File

@ -15,15 +15,12 @@
use std::ops::{Deref, DerefMut};
use async_trait::async_trait;
use futures_util::future::Either;
use http::Request;
use mas_http::HttpServiceExt;
use mas_jose::{DynamicJwksStore, JsonWebKeySet, StaticJwksStore, VerifyingKeystore};
use mas_iana::oauth::OAuthClientAuthenticationMethod;
use mas_jose::JsonWebKeySet;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use serde_with::skip_serializing_none;
use thiserror::Error;
use tower::{BoxError, ServiceExt};
use url::Url;
use super::ConfigurationSection;
@ -35,41 +32,6 @@ pub enum JwksOrJwksUri {
JwksUri(Url),
}
impl JwksOrJwksUri {
pub fn key_store(&self) -> Either<StaticJwksStore, DynamicJwksStore> {
// Assert that the output is both a VerifyingKeystore and Send
fn assert<T: Send + VerifyingKeystore>(t: T) -> T {
t
}
let inner = match self {
Self::Jwks(jwks) => Either::Left(StaticJwksStore::new(jwks.clone())),
Self::JwksUri(uri) => {
let uri = uri.clone();
// TODO: get the client from somewhere else?
let exporter = mas_http::client("fetch-jwks")
.json::<JsonWebKeySet>()
.map_request(move |_: ()| {
Request::builder()
.method("GET")
// TODO: change the Uri type in config to avoid reparsing here
.uri(uri.to_string())
.body(http_body::Empty::new())
.unwrap()
})
.map_response(http::Response::into_body)
.map_err(BoxError::from)
.boxed_clone();
Either::Right(DynamicJwksStore::new(exporter))
}
};
assert(inner)
}
}
impl From<JsonWebKeySet> for JwksOrJwksUri {
fn from(jwks: JsonWebKeySet) -> Self {
Self::Jwks(jwks)
@ -131,24 +93,53 @@ pub struct InvalidRedirectUriError;
impl ClientConfig {
#[doc(hidden)]
pub fn resolve_redirect_uri<'a>(
&'a self,
suggested_uri: &'a Option<Url>,
) -> Result<&'a Url, InvalidRedirectUriError> {
suggested_uri.as_ref().map_or_else(
|| self.redirect_uris.get(0).ok_or(InvalidRedirectUriError),
|suggested_uri| self.check_redirect_uri(suggested_uri),
)
#[must_use]
pub fn client_secret(&self) -> Option<&str> {
match &self.client_auth_method {
ClientAuthMethodConfig::ClientSecretPost { client_secret }
| ClientAuthMethodConfig::ClientSecretBasic { client_secret }
| ClientAuthMethodConfig::ClientSecretJwt { client_secret } => Some(client_secret),
_ => None,
}
}
fn check_redirect_uri<'a>(
&self,
redirect_uri: &'a Url,
) -> Result<&'a Url, InvalidRedirectUriError> {
if self.redirect_uris.contains(redirect_uri) {
Ok(redirect_uri)
} else {
Err(InvalidRedirectUriError)
#[doc(hidden)]
#[must_use]
pub fn client_auth_method(&self) -> OAuthClientAuthenticationMethod {
match &self.client_auth_method {
ClientAuthMethodConfig::None => OAuthClientAuthenticationMethod::None,
ClientAuthMethodConfig::ClientSecretBasic { .. } => {
OAuthClientAuthenticationMethod::ClientSecretBasic
}
ClientAuthMethodConfig::ClientSecretPost { .. } => {
OAuthClientAuthenticationMethod::ClientSecretPost
}
ClientAuthMethodConfig::ClientSecretJwt { .. } => {
OAuthClientAuthenticationMethod::ClientSecretJwt
}
ClientAuthMethodConfig::PrivateKeyJwt(_) => {
OAuthClientAuthenticationMethod::PrivateKeyJwt
}
}
}
#[doc(hidden)]
#[must_use]
pub fn jwks(&self) -> Option<&JsonWebKeySet> {
match &self.client_auth_method {
ClientAuthMethodConfig::PrivateKeyJwt(JwksOrJwksUri::Jwks(jwks)) => Some(jwks),
_ => None,
}
}
#[doc(hidden)]
#[must_use]
pub fn jwks_uri(&self) -> Option<&Url> {
match &self.client_auth_method {
ClientAuthMethodConfig::PrivateKeyJwt(JwksOrJwksUri::JwksUri(jwks_uri)) => {
Some(jwks_uri)
}
_ => None,
}
}
}

View File

@ -14,4 +14,5 @@ crc = "2.1.0"
rand = "0.8.5"
mas-iana = { path = "../iana" }
mas-jose = { path = "../jose" }
oauth2-types = { path = "../oauth2-types" }

View File

@ -30,7 +30,8 @@ pub(crate) mod users;
pub use self::{
oauth2::{
AuthorizationCode, AuthorizationGrant, AuthorizationGrantStage, Client, Pkce, Session,
AuthorizationCode, AuthorizationGrant, AuthorizationGrantStage, Client, JwksOrJwksUri,
Pkce, Session,
},
tokens::{AccessToken, RefreshToken, TokenFormatError, TokenType},
traits::{StorageBackend, StorageBackendMarker},

View File

@ -12,16 +12,87 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use mas_iana::{
jose::JsonWebSignatureAlg,
oauth::{OAuthAuthorizationEndpointResponseType, OAuthClientAuthenticationMethod},
};
use mas_jose::JsonWebKeySet;
use oauth2_types::requests::GrantType;
use serde::Serialize;
use thiserror::Error;
use url::Url;
use crate::traits::{StorageBackend, StorageBackendMarker};
#[derive(Debug, Clone, PartialEq, Serialize)]
#[serde(rename_all = "snake_case")]
pub enum JwksOrJwksUri {
/// Client's JSON Web Key Set document, passed by value.
Jwks(JsonWebKeySet),
/// URL for the Client's JSON Web Key Set document.
JwksUri(Url),
}
#[derive(Debug, Clone, PartialEq, Serialize)]
#[serde(bound = "T: StorageBackend")]
pub struct Client<T: StorageBackend> {
#[serde(skip_serializing)]
pub data: T::ClientData,
/// Client identifier
pub client_id: String,
pub encrypted_client_secret: Option<String>,
/// Array of Redirection URI values used by the Client
pub redirect_uris: Vec<Url>,
/// Array containing a list of the OAuth 2.0 response_type values that the
/// Client is declaring that it will restrict itself to using
pub response_types: Vec<OAuthAuthorizationEndpointResponseType>,
/// Array containing a list of the OAuth 2.0 Grant Types that the Client is
/// declaring that it will restrict itself to using.
pub grant_types: Vec<GrantType>,
/// Array of e-mail addresses of people responsible for this Client
pub contacts: Vec<String>,
/// Name of the Client to be presented to the End-User
pub client_name: Option<String>, // TODO: translations
/// URL that references a logo for the Client application
pub logo_uri: Option<Url>, // TODO: translations
/// URL of the home page of the Client
pub client_uri: Option<Url>, // TODO: translations
/// URL that the Relying Party Client provides to the End-User to read about
/// the how the profile data will be used
pub policy_uri: Option<Url>, // TODO: translations
/// URL that the Relying Party Client provides to the End-User to read about
/// the Relying Party's terms of service
pub tos_uri: Option<Url>, // TODO: translations
pub jwks: Option<JwksOrJwksUri>,
/// JWS alg algorithm REQUIRED for signing the ID Token issued to this
/// Client
pub id_token_signed_response_alg: Option<JsonWebSignatureAlg>,
/// Requested authentication method for the token endpoint
pub token_endpoint_auth_method: Option<OAuthClientAuthenticationMethod>,
/// JWS alg algorithm that MUST be used for signing the JWT used to
/// authenticate the Client at the Token Endpoint for the private_key_jwt
/// and client_secret_jwt authentication methods
pub token_endpoint_auth_signing_alg: Option<JsonWebSignatureAlg>,
/// URI using the https scheme that a third party can use to initiate a
/// login by the RP
pub initiate_login_uri: Option<Url>,
}
impl<S: StorageBackendMarker> From<Client<S>> for Client<()> {
@ -29,6 +100,48 @@ impl<S: StorageBackendMarker> From<Client<S>> for Client<()> {
Client {
data: (),
client_id: c.client_id,
encrypted_client_secret: c.encrypted_client_secret,
redirect_uris: c.redirect_uris,
response_types: c.response_types,
grant_types: c.grant_types,
contacts: c.contacts,
client_name: c.client_name,
logo_uri: c.logo_uri,
client_uri: c.client_uri,
policy_uri: c.policy_uri,
tos_uri: c.tos_uri,
jwks: c.jwks,
id_token_signed_response_alg: c.id_token_signed_response_alg,
token_endpoint_auth_method: c.token_endpoint_auth_method,
token_endpoint_auth_signing_alg: c.token_endpoint_auth_signing_alg,
initiate_login_uri: c.initiate_login_uri,
}
}
}
#[derive(Debug, Error)]
pub enum InvalidRedirectUriError {
#[error("redirect_uri is not allowed for this client")]
NotAllowed,
#[error("multiple redirect_uris registered for this client")]
MultipleRegistered,
#[error("client has no redirect_uri registered")]
NoneRegistered,
}
impl<S: StorageBackend> Client<S> {
pub fn resolve_redirect_uri<'a>(
&'a self,
redirect_uri: &'a Option<Url>,
) -> Result<&'a Url, InvalidRedirectUriError> {
match (&self.redirect_uris[..], redirect_uri) {
([], _) => Err(InvalidRedirectUriError::NoneRegistered),
([one], None) => Ok(one),
(_, None) => Err(InvalidRedirectUriError::MultipleRegistered),
(uris, Some(uri)) if uris.contains(uri) => Ok(uri),
_ => Err(InvalidRedirectUriError::NotAllowed),
}
}
}

View File

@ -1,4 +1,4 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
// Copyright 2021, 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@ -18,6 +18,6 @@ pub(self) mod session;
pub use self::{
authorization_grant::{AuthorizationCode, AuthorizationGrant, AuthorizationGrantStage, Pkce},
client::Client,
client::{Client, JwksOrJwksUri},
session::Session,
};

View File

@ -46,14 +46,7 @@ pub fn root(
config: &RootConfig,
) -> BoxedFilter<(impl Reply,)> {
let health = health(pool);
let oauth2 = oauth2(
pool,
templates,
key_store,
encrypter,
&config.clients,
&config.http,
);
let oauth2 = oauth2(pool, templates, key_store, encrypter, &config.http);
let views = views(
pool,
templates,

View File

@ -20,7 +20,7 @@ use hyper::{
http::uri::{Parts, PathAndQuery, Uri},
StatusCode,
};
use mas_config::{ClientsConfig, Encrypter};
use mas_config::Encrypter;
use mas_data_model::{
Authentication, AuthorizationCode, AuthorizationGrant, AuthorizationGrantStage, BrowserSession,
Pkce, StorageBackend, TokenType,
@ -32,6 +32,7 @@ use mas_storage::{
authorization_grant::{
derive_session, fulfill_grant, get_grant_by_id, new_authorization_grant,
},
client::lookup_client_by_client_id,
refresh_token::add_refresh_token,
},
PostgresqlBackend,
@ -41,7 +42,7 @@ use mas_warp_utils::{
errors::WrapError,
filters::{
self,
database::transaction,
database::{connection, transaction},
session::{optional_session, session},
with_templates,
},
@ -49,19 +50,20 @@ use mas_warp_utils::{
use oauth2_types::{
errors::{
ErrorResponse, InvalidGrant, InvalidRequest, LoginRequired, OAuth2Error,
RegistrationNotSupported, RequestNotSupported, RequestUriNotSupported,
RegistrationNotSupported, RequestNotSupported, RequestUriNotSupported, UnauthorizedClient,
},
pkce,
prelude::*,
requests::{
AccessTokenResponse, AuthorizationRequest, AuthorizationResponse, Prompt, ResponseMode,
AccessTokenResponse, AuthorizationRequest, AuthorizationResponse, GrantType, Prompt,
ResponseMode,
},
scope::ScopeToken,
};
use rand::{distributions::Alphanumeric, thread_rng, Rng};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use sqlx::{PgExecutor, PgPool, Postgres, Transaction};
use sqlx::{pool::PoolConnection, PgConnection, PgPool, Postgres, Transaction};
use url::Url;
use warp::{
filters::BoxedFilter,
@ -217,15 +219,10 @@ pub fn filter(
pool: &PgPool,
templates: &Templates,
encrypter: &Encrypter,
clients_config: &ClientsConfig,
) -> BoxedFilter<(Box<dyn Reply>,)> {
let clients_config = clients_config.clone();
let clients_config_2 = clients_config.clone();
let authorize = warp::path!("oauth2" / "authorize")
.and(filters::trace::name("GET /oauth2/authorize"))
.and(warp::get())
.map(move || clients_config.clone())
.and(warp::query())
.and(optional_session(pool, encrypter))
.and(transaction(pool))
@ -245,8 +242,8 @@ pub fn filter(
.recover(recover)
.unify()
.and(warp::query())
.and(warp::any().map(move || clients_config_2.clone()))
.and(with_templates(templates))
.and(connection(pool))
.and_then(actually_reply)
.boxed()
}
@ -262,8 +259,8 @@ async fn recover(rejection: Rejection) -> Result<ReplyOrBackToClient, Rejection>
async fn actually_reply(
rep: ReplyOrBackToClient,
q: PartialParams,
clients: ClientsConfig,
templates: Templates,
mut conn: PoolConnection<Postgres>,
) -> Result<Box<dyn Reply>, Rejection> {
let (redirect_uri, response_mode, state, params) = match rep {
ReplyOrBackToClient::Reply(r) => return Ok(r),
@ -281,15 +278,14 @@ async fn actually_reply(
..
} = q;
// First, disover the client
let client = client_id
.and_then(|client_id| clients.iter().find(|client| client.client_id == client_id));
let client = match client {
Some(client) => client,
None => return Ok(Box::new(html(templates.render_error(&error.into()).await?))),
let client_id = if let Some(client_id) = client_id {
client_id
} else {
return Ok(Box::new(html(templates.render_error(&error.into()).await?)));
};
let client = lookup_client_by_client_id(&mut conn, &client_id).await?;
let redirect_uri: Result<Option<Url>, _> = redirect_uri.map(|r| r.parse()).transpose();
let redirect_uri = match redirect_uri {
Ok(r) => r,
@ -315,7 +311,6 @@ async fn actually_reply(
}
async fn get(
clients: ClientsConfig,
params: Params,
maybe_session: Option<BrowserSession<PostgresqlBackend>>,
mut txn: Transaction<'_, Postgres>,
@ -337,15 +332,17 @@ async fn get(
}
// First, find out what client it is
let client = clients
.iter()
.find(|client| client.client_id == params.auth.client_id)
.ok_or_else(|| anyhow::anyhow!("could not find client"))
.wrap_error()?;
let client = lookup_client_by_client_id(&mut txn, &params.auth.client_id).await?;
// Check if it is allowed to use this grant type
if !client.grant_types.contains(&GrantType::AuthorizationCode) {
return Ok(ReplyOrBackToClient::Error(Box::new(UnauthorizedClient)));
}
let redirect_uri = client
.resolve_redirect_uri(&params.auth.redirect_uri)
.wrap_error()?;
.wrap_error()?
.clone();
let response_type = params.auth.response_type;
let response_mode =
resolve_response_mode(response_type, params.auth.response_mode).wrap_error()?;
@ -392,8 +389,8 @@ async fn get(
let grant = new_authorization_grant(
&mut txn,
client.client_id.clone(),
redirect_uri.clone(),
client,
redirect_uri,
scope,
code,
params.auth.state,
@ -471,10 +468,10 @@ impl ContinueAuthorizationGrant {
pub async fn fetch_authorization_grant(
&self,
executor: impl PgExecutor<'_>,
conn: &mut PgConnection,
) -> anyhow::Result<AuthorizationGrant<PostgresqlBackend>> {
let data = self.data.parse()?;
get_grant_by_id(executor, data).await
get_grant_by_id(conn, data).await
}
}

View File

@ -12,11 +12,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use mas_config::{ClientConfig, ClientsConfig, HttpConfig};
use mas_data_model::TokenType;
use mas_config::{Encrypter, HttpConfig};
use mas_data_model::{Client, TokenType};
use mas_iana::oauth::{OAuthClientAuthenticationMethod, OAuthTokenTypeHint};
use mas_storage::oauth2::{
access_token::lookup_active_access_token, refresh_token::lookup_active_refresh_token,
use mas_storage::{
oauth2::{
access_token::lookup_active_access_token, refresh_token::lookup_active_refresh_token,
},
PostgresqlBackend,
};
use mas_warp_utils::{
errors::WrapError,
@ -29,7 +32,7 @@ use warp::{filters::BoxedFilter, Filter, Rejection, Reply};
pub fn filter(
pool: &PgPool,
clients_config: &ClientsConfig,
encrypter: &Encrypter,
http_config: &HttpConfig,
) -> BoxedFilter<(Box<dyn Reply>,)> {
let audience = UrlBuilder::from(http_config)
@ -41,7 +44,7 @@ pub fn filter(
.and(
warp::post()
.and(connection(pool))
.and(client_authentication(clients_config, audience))
.and(client_authentication(pool, encrypter, audience))
.and_then(introspect)
.recover(recover)
.unify(),
@ -67,7 +70,7 @@ const INACTIVE: IntrospectionResponse = IntrospectionResponse {
async fn introspect(
mut conn: PoolConnection<Postgres>,
auth: OAuthClientAuthenticationMethod,
client: ClientConfig,
client: Client<PostgresqlBackend>,
params: IntrospectionRequest,
) -> Result<Box<dyn Reply>, Rejection> {
// Token introspection is only allowed by confidential clients

View File

@ -15,7 +15,7 @@
use std::sync::Arc;
use hyper::Method;
use mas_config::{ClientsConfig, Encrypter, HttpConfig};
use mas_config::{Encrypter, HttpConfig};
use mas_jose::StaticKeystore;
use mas_templates::Templates;
use mas_warp_utils::filters::cors::cors;
@ -41,15 +41,14 @@ pub fn filter(
templates: &Templates,
key_store: &Arc<StaticKeystore>,
encrypter: &Encrypter,
clients_config: &ClientsConfig,
http_config: &HttpConfig,
) -> BoxedFilter<(impl Reply,)> {
let discovery = discovery(key_store.as_ref(), http_config);
let keys = keys(key_store);
let authorization = authorization(pool, templates, encrypter, clients_config);
let authorization = authorization(pool, templates, encrypter);
let userinfo = userinfo(pool);
let introspection = introspection(pool, clients_config, http_config);
let token = token(pool, key_store, clients_config, http_config);
let introspection = introspection(pool, encrypter, http_config);
let token = token(pool, encrypter, key_store, http_config);
let filter = discovery
.or(keys)

View File

@ -19,8 +19,8 @@ use chrono::{DateTime, Duration, Utc};
use data_encoding::BASE64URL_NOPAD;
use headers::{CacheControl, Pragma};
use hyper::StatusCode;
use mas_config::{ClientConfig, ClientsConfig, HttpConfig};
use mas_data_model::{AuthorizationGrantStage, TokenType};
use mas_config::{Encrypter, HttpConfig};
use mas_data_model::{AuthorizationGrantStage, Client, TokenType};
use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod};
use mas_jose::{claims, DecodedJsonWebToken, SigningKeystore, StaticKeystore};
use mas_storage::{
@ -33,7 +33,7 @@ use mas_storage::{
RefreshTokenLookupError,
},
},
DatabaseInconsistencyError,
DatabaseInconsistencyError, PostgresqlBackend,
};
use mas_warp_utils::{
errors::WrapError,
@ -99,8 +99,8 @@ where
pub fn filter(
pool: &PgPool,
encrypter: &Encrypter,
key_store: &Arc<StaticKeystore>,
clients_config: &ClientsConfig,
http_config: &HttpConfig,
) -> BoxedFilter<(Box<dyn Reply>,)> {
let key_store = key_store.clone();
@ -113,7 +113,7 @@ pub fn filter(
.and(filters::trace::name("POST /oauth2/token"))
.and(
warp::post()
.and(client_authentication(clients_config, audience))
.and(client_authentication(pool, encrypter, audience))
.and(warp::any().map(move || key_store.clone()))
.and(warp::any().map(move || issuer.clone()))
.and(connection(pool))
@ -145,7 +145,7 @@ async fn recover(rejection: Rejection) -> Result<Box<dyn Reply>, Infallible> {
async fn token(
_auth: OAuthClientAuthenticationMethod,
client: ClientConfig,
client: Client<PostgresqlBackend>,
req: AccessTokenRequest,
key_store: Arc<StaticKeystore>,
issuer: Url,
@ -185,7 +185,7 @@ fn hash<H: Digest>(mut hasher: H, token: &str) -> anyhow::Result<String> {
#[allow(clippy::too_many_lines)]
async fn authorization_code_grant(
grant: &AuthorizationCodeGrant,
client: &ClientConfig,
client: &Client<PostgresqlBackend>,
key_store: &StaticKeystore,
issuer: Url,
conn: &mut PoolConnection<Postgres>,
@ -349,7 +349,7 @@ async fn authorization_code_grant(
async fn refresh_token_grant(
grant: &RefreshTokenGrant,
client: &ClientConfig,
client: &Client<PostgresqlBackend>,
conn: &mut PoolConnection<Postgres>,
) -> Result<AccessTokenResponse, Rejection> {
let mut txn = conn.begin().await.wrap_error()?;

View File

@ -17,7 +17,7 @@
use hyper::Uri;
use mas_templates::PostAuthContext;
use serde::{Deserialize, Serialize};
use sqlx::PgExecutor;
use sqlx::PgConnection;
use super::super::oauth2::ContinueAuthorizationGrant;
@ -36,11 +36,11 @@ impl PostAuthAction {
pub async fn load_context<'e>(
&self,
executor: impl PgExecutor<'e>,
conn: &mut PgConnection,
) -> anyhow::Result<PostAuthContext> {
match self {
Self::ContinueAuthorizationGrant(c) => {
let grant = c.fetch_authorization_grant(executor).await?;
let grant = c.fetch_authorization_grant(conn).await?;
let grant = grant.into();
Ok(PostAuthContext::ContinueAuthorizationGrant { grant })
}

View File

@ -21,3 +21,5 @@
pub mod jose;
pub mod oauth;
pub use parse_display::ParseError;

View File

@ -31,7 +31,7 @@ sha2 = "0.10.2"
signature = "1.4.0"
thiserror = "1.0.30"
tokio = { version = "1.17.0", features = ["macros", "rt", "sync"] }
tower = "0.4.12"
tower = { version = "0.4.12", features = ["util"] }
url = { version = "2.2.2", features = ["serde"] }
mas-iana = { path = "../iana" }

View File

@ -22,6 +22,8 @@ pub(crate) mod jwk;
pub(crate) mod jwt;
mod keystore;
pub use futures_util::future::Either;
pub use self::{
jwk::{JsonWebKey, JsonWebKeySet},
jwt::{DecodedJsonWebToken, JsonWebTokenParts, JwtHeader},

View File

@ -27,14 +27,21 @@ use url::Url;
use crate::requests::{Display, GrantType, ResponseMode};
#[derive(Serialize, Clone, Copy, PartialEq, Eq, Hash)]
#[derive(Serialize, Clone, Copy, PartialEq, Eq, Hash, Debug)]
#[serde(rename_all = "lowercase")]
pub enum ApplicationType {
Web,
Native,
}
#[derive(Serialize, Clone, Copy, PartialEq, Eq, Hash, Debug)]
#[serde(rename_all = "lowercase")]
pub enum SubjectType {
Public,
Pairwise,
}
#[derive(Serialize, Clone, Copy, PartialEq, Eq, Hash)]
#[derive(Serialize, Clone, Copy, PartialEq, Eq, Hash, Debug)]
#[serde(rename_all = "lowercase")]
pub enum ClaimType {
Normal,

View File

@ -192,6 +192,7 @@ pub struct ClientCredentialsGrant {
pub enum GrantType {
AuthorizationCode,
RefreshToken,
Implicit,
ClientCredentials,
}

View File

@ -7,9 +7,10 @@ license = "Apache-2.0"
[dependencies]
tokio = "1.17.0"
sqlx = { version = "0.5.11", features = ["runtime-tokio-rustls", "postgres", "migrate", "chrono", "offline"] }
sqlx = { version = "0.5.11", features = ["runtime-tokio-rustls", "postgres", "migrate", "chrono", "offline", "json"] }
chrono = { version = "0.4.19", features = ["serde"] }
serde = { version = "1.0.136", features = ["derive"] }
serde_json = "1.0.79"
thiserror = "1.0.30"
anyhow = "1.0.55"
tracing = "0.1.31"
@ -24,3 +25,4 @@ url = { version = "2.2.2", features = ["serde"] }
oauth2-types = { path = "../oauth2-types" }
mas-data-model = { path = "../data-model" }
mas-iana = { path = "../iana" }
mas-jose = { path = "../jose" }

View File

@ -0,0 +1,17 @@
-- Copyright 2021 The Matrix.org Foundation C.I.C.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
DROP TABLE oauth2_client_redirect_uris;
DROP TRIGGER set_timestamp ON oauth2_clients;
DROP TABLE oauth2_clients;

View File

@ -0,0 +1,51 @@
-- Copyright 2022 The Matrix.org Foundation C.I.C.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
CREATE TABLE oauth2_clients (
"id" BIGSERIAL PRIMARY KEY,
"client_id" TEXT NOT NULL UNIQUE,
"encrypted_client_secret" TEXT,
"response_types" TEXT[] NOT NULL,
"grant_type_authorization_code" BOOL NOT NULL,
"grant_type_refresh_token" BOOL NOT NULL,
"contacts" TEXT[] NOT NULL,
"client_name" TEXT,
"logo_uri" TEXT,
"client_uri" TEXT,
"policy_uri" TEXT,
"tos_uri" TEXT,
"jwks_uri" TEXT,
"jwks" JSONB,
"id_token_signed_response_alg" TEXT,
"token_endpoint_auth_method" TEXT,
"token_endpoint_auth_signing_alg" TEXT,
"initiate_login_uri" TEXT,
"created_at" TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT now(),
"updated_at" TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT now(),
-- jwks and jwks_uri can't be set at the same time
CHECK ("jwks" IS NULL OR "jwks_uri" IS NULL)
);
CREATE TRIGGER set_timestamp
BEFORE UPDATE ON oauth2_clients
FOR EACH ROW
EXECUTE PROCEDURE trigger_set_timestamp();
CREATE TABLE oauth2_client_redirect_uris (
"id" BIGSERIAL PRIMARY KEY,
"oauth2_client_id" BIGINT NOT NULL REFERENCES oauth2_clients (id) ON DELETE CASCADE,
"redirect_uri" TEXT NOT NULL
);

File diff suppressed because it is too large Load Diff

View File

@ -41,7 +41,7 @@ impl StorageBackend for PostgresqlBackend {
type AuthenticationData = i64;
type AuthorizationGrantData = i64;
type BrowserSessionData = i64;
type ClientData = ();
type ClientData = i64;
type RefreshTokenData = i64;
type SessionData = i64;
type UserData = i64;

View File

@ -14,12 +14,11 @@
use anyhow::Context;
use chrono::{DateTime, Duration, Utc};
use mas_data_model::{
AccessToken, Authentication, BrowserSession, Client, Session, User, UserEmail,
};
use sqlx::PgExecutor;
use mas_data_model::{AccessToken, Authentication, BrowserSession, Session, User, UserEmail};
use sqlx::{PgConnection, PgExecutor};
use thiserror::Error;
use super::client::{lookup_client_by_client_id, ClientFetchError};
use crate::{DatabaseInconsistencyError, IdAndCreationTime, PostgresqlBackend};
pub async fn add_access_token(
@ -83,6 +82,7 @@ pub struct OAuth2AccessTokenLookup {
#[error("failed to lookup access token")]
pub enum AccessTokenLookupError {
Database(#[from] sqlx::Error),
ClientFetch(#[from] ClientFetchError),
Inconsistency(#[from] DatabaseInconsistencyError),
}
@ -95,7 +95,7 @@ impl AccessTokenLookupError {
#[allow(clippy::too_many_lines)]
pub async fn lookup_active_access_token(
executor: impl PgExecutor<'_>,
conn: &mut PgConnection,
token: &str,
) -> Result<(AccessToken<PostgresqlBackend>, Session<PostgresqlBackend>), AccessTokenLookupError> {
let res = sqlx::query_as!(
@ -142,7 +142,7 @@ pub async fn lookup_active_access_token(
"#,
token,
)
.fetch_one(executor)
.fetch_one(&mut *conn)
.await?;
let access_token = AccessToken {
@ -153,10 +153,7 @@ pub async fn lookup_active_access_token(
expires_after: Duration::seconds(res.access_token_expires_after.into()),
};
let client = Client {
data: (),
client_id: res.client_id,
};
let client = lookup_client_by_client_id(&mut *conn, &res.client_id).await?;
let primary_email = match (
res.user_email_id,

View File

@ -24,15 +24,16 @@ use mas_data_model::{
};
use mas_iana::oauth::PkceCodeChallengeMethod;
use oauth2_types::{requests::ResponseMode, scope::Scope};
use sqlx::PgExecutor;
use sqlx::{PgConnection, PgExecutor};
use url::Url;
use super::client::lookup_client_by_client_id;
use crate::{DatabaseInconsistencyError, IdAndCreationTime, PostgresqlBackend};
#[allow(clippy::too_many_arguments)]
pub async fn new_authorization_grant(
executor: impl PgExecutor<'_>,
client_id: String,
client: Client<PostgresqlBackend>,
redirect_uri: Url,
scope: Scope,
code: Option<AuthorizationCode>,
@ -65,7 +66,7 @@ pub async fn new_authorization_grant(
($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14)
RETURNING id, created_at
"#,
&client_id,
&client.client_id,
redirect_uri.to_string(),
scope.to_string(),
state,
@ -85,11 +86,6 @@ pub async fn new_authorization_grant(
.await
.context("could not insert oauth2 authorization grant")?;
let client = Client {
data: (),
client_id,
};
Ok(AuthorizationGrant {
data: res.id,
stage: AuthorizationGrantStage::Pending,
@ -141,20 +137,21 @@ struct GrantLookup {
user_email_confirmed_at: Option<DateTime<Utc>>,
}
impl TryInto<AuthorizationGrant<PostgresqlBackend>> for GrantLookup {
type Error = DatabaseInconsistencyError;
impl GrantLookup {
#[allow(clippy::too_many_lines)]
fn try_into(self) -> Result<AuthorizationGrant<PostgresqlBackend>, Self::Error> {
async fn into_authorization_grant(
self,
executor: impl PgExecutor<'_>,
) -> Result<AuthorizationGrant<PostgresqlBackend>, DatabaseInconsistencyError> {
let scope: Scope = self
.grant_scope
.parse()
.map_err(|_e| DatabaseInconsistencyError)?;
let client = Client {
data: (),
client_id: self.client_id,
};
// TODO: don't unwrap
let client = lookup_client_by_client_id(executor, &self.client_id)
.await
.unwrap();
let last_authentication = match (
self.user_session_last_authentication_id,
@ -323,7 +320,7 @@ impl TryInto<AuthorizationGrant<PostgresqlBackend>> for GrantLookup {
}
pub async fn get_grant_by_id(
executor: impl PgExecutor<'_>,
conn: &mut PgConnection,
id: i64,
) -> anyhow::Result<AuthorizationGrant<PostgresqlBackend>> {
// TODO: handle "not found" cases
@ -381,17 +378,17 @@ pub async fn get_grant_by_id(
"#,
id,
)
.fetch_one(executor)
.fetch_one(&mut *conn)
.await
.context("failed to get grant by id")?;
let grant = res.try_into()?;
let grant = res.into_authorization_grant(&mut *conn).await?;
Ok(grant)
}
pub async fn lookup_grant_by_code(
executor: impl PgExecutor<'_>,
conn: &mut PgConnection,
code: &str,
) -> anyhow::Result<AuthorizationGrant<PostgresqlBackend>> {
// TODO: handle "not found" cases
@ -449,11 +446,11 @@ pub async fn lookup_grant_by_code(
"#,
code,
)
.fetch_one(executor)
.fetch_one(&mut *conn)
.await
.context("failed to lookup grant by code")?;
let grant = res.try_into()?;
let grant = res.into_authorization_grant(&mut *conn).await?;
Ok(grant)
}

View File

@ -0,0 +1,380 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::string::ToString;
use mas_data_model::{Client, JwksOrJwksUri};
use mas_iana::oauth::{OAuthAuthorizationEndpointResponseType, OAuthClientAuthenticationMethod};
use mas_jose::JsonWebKeySet;
use oauth2_types::requests::GrantType;
use sqlx::{PgConnection, PgExecutor};
use thiserror::Error;
use url::Url;
use warp::reject::Reject;
use crate::PostgresqlBackend;
#[derive(Debug)]
pub struct OAuth2ClientLookup {
id: i64,
client_id: String,
encrypted_client_secret: Option<String>,
redirect_uris: Vec<String>,
response_types: Vec<String>,
grant_type_authorization_code: bool,
grant_type_refresh_token: bool,
contacts: Vec<String>,
client_name: Option<String>,
logo_uri: Option<String>,
client_uri: Option<String>,
policy_uri: Option<String>,
tos_uri: Option<String>,
jwks_uri: Option<String>,
jwks: Option<serde_json::Value>,
id_token_signed_response_alg: Option<String>,
token_endpoint_auth_method: Option<String>,
token_endpoint_auth_signing_alg: Option<String>,
initiate_login_uri: Option<String>,
}
#[derive(Debug, Error)]
pub enum ClientFetchError {
#[error("malformed jwks column")]
MalformedJwks(#[source] serde_json::Error),
#[error("entry has both a jwks and a jwks_uri")]
BothJwksAndJwksUri,
#[error("could not parse URL in field {field:?}")]
ParseUrl {
field: &'static str,
source: url::ParseError,
},
#[error("could not parse field {field:?}")]
ParseField {
field: &'static str,
source: mas_iana::ParseError,
},
#[error(transparent)]
Database(#[from] sqlx::Error),
}
impl ClientFetchError {
#[must_use]
pub fn not_found(&self) -> bool {
matches!(self, Self::Database(sqlx::Error::RowNotFound))
}
}
impl Reject for ClientFetchError {}
impl TryInto<Client<PostgresqlBackend>> for OAuth2ClientLookup {
type Error = ClientFetchError;
#[allow(clippy::too_many_lines)] // TODO: refactor some of the field parsing
fn try_into(self) -> Result<Client<PostgresqlBackend>, Self::Error> {
let redirect_uris: Result<Vec<Url>, _> =
self.redirect_uris.iter().map(|s| s.parse()).collect();
let redirect_uris = redirect_uris.map_err(|source| ClientFetchError::ParseUrl {
field: "redirect_uris",
source,
})?;
let response_types: Result<Vec<OAuthAuthorizationEndpointResponseType>, _> =
self.response_types.iter().map(|s| s.parse()).collect();
let response_types = response_types.map_err(|source| ClientFetchError::ParseField {
field: "response_types",
source,
})?;
let mut grant_types = Vec::new();
if self.grant_type_authorization_code {
grant_types.push(GrantType::AuthorizationCode);
}
if self.grant_type_refresh_token {
grant_types.push(GrantType::RefreshToken);
}
let logo_uri = self
.logo_uri
.map(|s| s.parse())
.transpose()
.map_err(|source| ClientFetchError::ParseUrl {
field: "logo_uri",
source,
})?;
let client_uri = self
.client_uri
.map(|s| s.parse())
.transpose()
.map_err(|source| ClientFetchError::ParseUrl {
field: "client_uri",
source,
})?;
let policy_uri = self
.policy_uri
.map(|s| s.parse())
.transpose()
.map_err(|source| ClientFetchError::ParseUrl {
field: "policy_uri",
source,
})?;
let tos_uri = self
.tos_uri
.map(|s| s.parse())
.transpose()
.map_err(|source| ClientFetchError::ParseUrl {
field: "tos_uri",
source,
})?;
let id_token_signed_response_alg = self
.id_token_signed_response_alg
.map(|s| s.parse())
.transpose()
.map_err(|source| ClientFetchError::ParseField {
field: "id_token_signed_response_alg",
source,
})?;
let token_endpoint_auth_method = self
.token_endpoint_auth_method
.map(|s| s.parse())
.transpose()
.map_err(|source| ClientFetchError::ParseField {
field: "token_endpoint_auth_method",
source,
})?;
let token_endpoint_auth_signing_alg = self
.token_endpoint_auth_signing_alg
.map(|s| s.parse())
.transpose()
.map_err(|source| ClientFetchError::ParseField {
field: "token_endpoint_auth_signing_alg",
source,
})?;
let initiate_login_uri = self
.initiate_login_uri
.map(|s| s.parse())
.transpose()
.map_err(|source| ClientFetchError::ParseUrl {
field: "initiate_login_uri",
source,
})?;
let jwks = match (self.jwks, self.jwks_uri) {
(None, None) => None,
(Some(jwks), None) => {
let jwks = serde_json::from_value(jwks).map_err(ClientFetchError::MalformedJwks)?;
Some(JwksOrJwksUri::Jwks(jwks))
}
(None, Some(jwks_uri)) => {
let jwks_uri = jwks_uri
.parse()
.map_err(|source| ClientFetchError::ParseUrl {
field: "jwks_uri",
source,
})?;
Some(JwksOrJwksUri::JwksUri(jwks_uri))
}
_ => return Err(ClientFetchError::BothJwksAndJwksUri),
};
Ok(Client {
data: self.id,
client_id: self.client_id,
encrypted_client_secret: self.encrypted_client_secret,
redirect_uris,
response_types,
grant_types,
contacts: self.contacts,
client_name: self.client_name,
logo_uri,
client_uri,
policy_uri,
tos_uri,
jwks,
id_token_signed_response_alg,
token_endpoint_auth_method,
token_endpoint_auth_signing_alg,
initiate_login_uri,
})
}
}
pub async fn lookup_client(
executor: impl PgExecutor<'_>,
id: i64,
) -> Result<Client<PostgresqlBackend>, ClientFetchError> {
let res = sqlx::query_as!(
OAuth2ClientLookup,
r#"
SELECT
c.id,
c.client_id,
c.encrypted_client_secret,
ARRAY(SELECT redirect_uri FROM oauth2_client_redirect_uris r WHERE r.oauth2_client_id = c.id) AS "redirect_uris!",
c.response_types,
c.grant_type_authorization_code,
c.grant_type_refresh_token,
c.contacts,
c.client_name,
c.logo_uri,
c.client_uri,
c.policy_uri,
c.tos_uri,
c.jwks_uri,
c.jwks,
c.id_token_signed_response_alg,
c.token_endpoint_auth_method,
c.token_endpoint_auth_signing_alg,
c.initiate_login_uri
FROM oauth2_clients c
WHERE c.id = $1
"#,
id,
)
.fetch_one(executor)
.await?;
let client = res.try_into()?;
Ok(client)
}
pub async fn lookup_client_by_client_id(
executor: impl PgExecutor<'_>,
client_id: &str,
) -> Result<Client<PostgresqlBackend>, ClientFetchError> {
let res = sqlx::query_as!(
OAuth2ClientLookup,
r#"
SELECT
c.id,
c.client_id,
c.encrypted_client_secret,
ARRAY(SELECT redirect_uri FROM oauth2_client_redirect_uris r WHERE r.oauth2_client_id = c.id) AS "redirect_uris!",
c.response_types,
c.grant_type_authorization_code,
c.grant_type_refresh_token,
c.contacts,
c.client_name,
c.logo_uri,
c.client_uri,
c.policy_uri,
c.tos_uri,
c.jwks_uri,
c.jwks,
c.id_token_signed_response_alg,
c.token_endpoint_auth_method,
c.token_endpoint_auth_signing_alg,
c.initiate_login_uri
FROM oauth2_clients c
WHERE c.client_id = $1
"#,
client_id,
)
.fetch_one(executor)
.await?;
let client = res.try_into()?;
Ok(client)
}
pub async fn insert_client_from_config(
conn: &mut PgConnection,
client_id: &str,
client_auth_method: OAuthClientAuthenticationMethod,
encrypted_client_secret: Option<&str>,
jwks: Option<&JsonWebKeySet>,
jwks_uri: Option<&Url>,
redirect_uris: &[Url],
) -> anyhow::Result<()> {
let response_types = vec![
OAuthAuthorizationEndpointResponseType::Code.to_string(),
OAuthAuthorizationEndpointResponseType::CodeIdToken.to_string(),
OAuthAuthorizationEndpointResponseType::CodeIdTokenToken.to_string(),
OAuthAuthorizationEndpointResponseType::CodeToken.to_string(),
OAuthAuthorizationEndpointResponseType::IdToken.to_string(),
OAuthAuthorizationEndpointResponseType::IdTokenToken.to_string(),
OAuthAuthorizationEndpointResponseType::None.to_string(),
OAuthAuthorizationEndpointResponseType::Token.to_string(),
];
let jwks = jwks.map(serde_json::to_value).transpose()?;
let jwks_uri = jwks_uri.map(Url::as_str);
let client_auth_method = client_auth_method.to_string();
let id = sqlx::query_scalar!(
r#"
INSERT INTO oauth2_clients
(client_id,
encrypted_client_secret,
response_types,
grant_type_authorization_code,
grant_type_refresh_token,
token_endpoint_auth_method,
jwks,
jwks_uri,
contacts)
VALUES
($1, $2, $3, $4, $5, $6, $7, $8, '{}')
RETURNING id
"#,
client_id,
encrypted_client_secret,
&response_types,
true,
true,
client_auth_method,
jwks,
jwks_uri,
)
.fetch_one(&mut *conn)
.await?;
let redirect_uris: Vec<String> = redirect_uris.iter().map(ToString::to_string).collect();
sqlx::query!(
r#"
INSERT INTO oauth2_client_redirect_uris (oauth2_client_id, redirect_uri)
SELECT $1, uri FROM UNNEST($2::text[]) uri
"#,
id,
&redirect_uris,
)
.execute(&mut *conn)
.await?;
Ok(())
}
pub async fn truncate_clients(executor: impl PgExecutor<'_>) -> anyhow::Result<()> {
sqlx::query!("TRUNCATE oauth2_client_redirect_uris, oauth2_clients")
.execute(executor)
.await?;
Ok(())
}

View File

@ -19,6 +19,7 @@ use crate::PostgresqlBackend;
pub mod access_token;
pub mod authorization_grant;
pub mod client;
pub mod refresh_token;
pub async fn end_oauth_session(

View File

@ -15,12 +15,13 @@
use anyhow::Context;
use chrono::{DateTime, Duration, Utc};
use mas_data_model::{
AccessToken, Authentication, BrowserSession, Client, RefreshToken, Session, User, UserEmail,
AccessToken, Authentication, BrowserSession, RefreshToken, Session, User, UserEmail,
};
use sqlx::PgExecutor;
use sqlx::{PgConnection, PgExecutor};
use thiserror::Error;
use warp::reject::Reject;
use super::client::{lookup_client_by_client_id, ClientFetchError};
use crate::{DatabaseInconsistencyError, IdAndCreationTime, PostgresqlBackend};
pub async fn add_refresh_token(
@ -82,6 +83,7 @@ struct OAuth2RefreshTokenLookup {
#[error("could not lookup refresh token")]
pub enum RefreshTokenLookupError {
Fetch(#[from] sqlx::Error),
ClientFetch(#[from] ClientFetchError),
Conversion(#[from] DatabaseInconsistencyError),
}
@ -96,7 +98,7 @@ impl RefreshTokenLookupError {
#[allow(clippy::too_many_lines)]
pub async fn lookup_active_refresh_token(
executor: impl PgExecutor<'_>,
conn: &mut PgConnection,
token: &str,
) -> Result<(RefreshToken<PostgresqlBackend>, Session<PostgresqlBackend>), RefreshTokenLookupError>
{
@ -148,7 +150,7 @@ pub async fn lookup_active_refresh_token(
"#,
token,
)
.fetch_one(executor)
.fetch_one(&mut *conn)
.await?;
let access_token = match (
@ -175,10 +177,7 @@ pub async fn lookup_active_refresh_token(
access_token,
};
let client = Client {
data: (),
client_id: res.client_id,
};
let client = lookup_client_by_client_id(&mut *conn, &res.client_id).await?;
let primary_email = match (
res.user_email_id,

View File

@ -28,6 +28,9 @@ mime = "0.3.16"
bincode = "1.3.3"
crc = "2.1.0"
url = "2.2.2"
http = "0.2.6"
http-body = "0.4.4"
tower = { version = "0.4.12", features = ["util"] }
oauth2-types = { path = "../oauth2-types" }
mas-config = { path = "../config" }
@ -36,6 +39,4 @@ mas-data-model = { path = "../data-model" }
mas-storage = { path = "../storage" }
mas-jose = { path = "../jose" }
mas-iana = { path = "../iana" }
[dev-dependencies]
tower = { version = "0.4.12", features = ["util"] }
mas-http = { path = "../http" }

View File

@ -87,6 +87,10 @@ pub fn authentication(
.untuple_one()
}
fn ensure<T: Clone + Send + Sync + 'static>(t: T) -> T {
t
}
async fn authenticate(
mut conn: PoolConnection<Postgres>,
auth: Authorization<Bearer>,
@ -110,6 +114,9 @@ async fn authenticate(
}
})?;
let session = ensure(session);
let token = ensure(token);
Ok((token, session))
}

View File

@ -16,30 +16,49 @@
use std::collections::HashMap;
use data_encoding::BASE64;
use headers::{authorization::Basic, Authorization};
use mas_config::{ClientAuthMethodConfig, ClientConfig, ClientsConfig};
use mas_config::Encrypter;
use mas_data_model::{Client, JwksOrJwksUri, StorageBackend};
use mas_http::HttpServiceExt;
use mas_iana::oauth::OAuthClientAuthenticationMethod;
use mas_jose::{
claims::{TimeOptions, AUD, EXP, IAT, ISS, JTI, NBF, SUB},
DecodedJsonWebToken, JsonWebTokenParts, SharedSecret,
DecodedJsonWebToken, DynamicJwksStore, Either, JsonWebKeySet, JsonWebTokenParts, SharedSecret,
StaticJwksStore, VerifyingKeystore,
};
use mas_storage::{
oauth2::client::{lookup_client_by_client_id, ClientFetchError},
PostgresqlBackend,
};
use serde::{de::DeserializeOwned, Deserialize};
use sqlx::{pool::PoolConnection, PgPool, Postgres};
use thiserror::Error;
use tower::{BoxError, ServiceExt};
use warp::{reject::Reject, Filter, Rejection};
use super::headers::typed_header;
use super::{database::connection, headers::typed_header};
use crate::errors::WrapError;
/// Protect an enpoint with client authentication
#[must_use]
pub fn client_authentication<T: DeserializeOwned + Send + 'static>(
clients_config: &ClientsConfig,
pool: &PgPool,
encrypter: &Encrypter,
audience: String,
) -> impl Filter<Extract = (OAuthClientAuthenticationMethod, ClientConfig, T), Error = Rejection>
+ Clone
) -> impl Filter<
Extract = (
OAuthClientAuthenticationMethod,
Client<PostgresqlBackend>,
T,
),
Error = Rejection,
> + Clone
+ Send
+ Sync
+ 'static {
let encrypter = encrypter.clone();
// First, extract the client credentials
let credentials = typed_header()
.and(warp::body::form())
@ -65,9 +84,9 @@ pub fn client_authentication<T: DeserializeOwned + Send + 'static>(
.unify()
.untuple_one();
let clients_config = clients_config.clone();
warp::any()
.map(move || clients_config.clone())
.and(connection(pool))
.and(warp::any().map(move || encrypter.clone()))
.and(warp::any().map(move || audience.clone()))
.and(credentials)
.and_then(authenticate_client)
@ -79,8 +98,20 @@ enum ClientAuthenticationError {
#[error("wrong client secret for client {client_id:?}")]
ClientSecretMismatch { client_id: String },
#[error("could not find client {client_id:?}")]
ClientNotFound { client_id: String },
#[error("could not fetch client {client_id:?}")]
ClientFetch {
client_id: String,
source: ClientFetchError,
},
#[error("client {client_id:?} has an invalid client secret")]
InvalidClientSecret {
client_id: String,
source: anyhow::Error,
},
#[error("client {client_id:?} has an invalid JWKS")]
InvalidJwks { client_id: String },
#[error("wrong client authentication method for client {client_id:?}")]
WrongAuthenticationMethod { client_id: String },
@ -94,68 +125,136 @@ enum ClientAuthenticationError {
impl Reject for ClientAuthenticationError {}
fn decrypt_client_secret<T: StorageBackend>(
client: &Client<T>,
encrypter: &Encrypter,
) -> anyhow::Result<Vec<u8>> {
let encrypted_client_secret = client
.encrypted_client_secret
.as_ref()
.ok_or_else(|| anyhow::anyhow!("missing encrypted_client_secret field"))?;
let encrypted_client_secret = BASE64.decode(encrypted_client_secret.as_bytes())?;
let nonce: &[u8; 12] = encrypted_client_secret
.get(0..12)
.ok_or_else(|| anyhow::anyhow!("invalid payload serialization"))?
.try_into()?;
let payload = encrypted_client_secret
.get(12..)
.ok_or_else(|| anyhow::anyhow!("invalid payload serialization"))?;
let decrypted_client_secret = encrypter.decrypt(nonce, payload)?;
Ok(decrypted_client_secret)
}
fn jwks_key_store(jwks: &JwksOrJwksUri) -> Either<StaticJwksStore, DynamicJwksStore> {
// Assert that the output is both a VerifyingKeystore and Send
fn assert<T: Send + VerifyingKeystore>(t: T) -> T {
t
}
let inner = match jwks {
JwksOrJwksUri::Jwks(jwks) => Either::Left(StaticJwksStore::new(jwks.clone())),
JwksOrJwksUri::JwksUri(uri) => {
let uri = uri.clone();
// TODO: get the client from somewhere else?
let exporter = mas_http::client("fetch-jwks")
.json::<JsonWebKeySet>()
.map_request(move |_: ()| {
http::Request::builder()
.method("GET")
// TODO: change the Uri type in config to avoid reparsing here
.uri(uri.to_string())
.body(http_body::Empty::new())
.unwrap()
})
.map_response(http::Response::into_body)
.map_err(BoxError::from)
.boxed_clone();
Either::Right(DynamicJwksStore::new(exporter))
}
};
assert(inner)
}
#[allow(clippy::too_many_lines)]
#[tracing::instrument(skip_all, fields(enduser.id), err(Debug))]
async fn authenticate_client<T>(
clients_config: ClientsConfig,
mut conn: PoolConnection<Postgres>,
encrypter: Encrypter,
audience: String,
credentials: ClientCredentials,
body: T,
) -> Result<(OAuthClientAuthenticationMethod, ClientConfig, T), Rejection> {
) -> Result<
(
OAuthClientAuthenticationMethod,
Client<PostgresqlBackend>,
T,
),
Rejection,
> {
let (auth_method, client) = match credentials {
ClientCredentials::Pair {
client_id,
client_secret,
via,
} => {
let client = clients_config
.iter()
.find(|client| client.client_id == client_id)
.ok_or_else(|| ClientAuthenticationError::ClientNotFound {
client_id: client_id.to_string(),
let client = lookup_client_by_client_id(&mut *conn, &client_id)
.await
.map_err(|source| ClientAuthenticationError::ClientFetch {
client_id: client_id.clone(),
source,
})?;
let auth_method = match (&client.client_auth_method, client_secret, via) {
(ClientAuthMethodConfig::None, None, _) => OAuthClientAuthenticationMethod::None,
let auth_method = client.token_endpoint_auth_method.ok_or(
ClientAuthenticationError::WrongAuthenticationMethod {
client_id: client.client_id.clone(),
},
)?;
// Let's match the authentication method
match (auth_method, client_secret, via) {
(OAuthClientAuthenticationMethod::None, None, _) => {}
(
ClientAuthMethodConfig::ClientSecretBasic {
client_secret: ref expected_client_secret,
},
Some(ref given_client_secret),
OAuthClientAuthenticationMethod::ClientSecretBasic,
Some(client_secret),
CredentialsVia::AuthorizationHeader,
) => {
if expected_client_secret != given_client_secret {
return Err(
ClientAuthenticationError::ClientSecretMismatch { client_id }.into(),
);
}
OAuthClientAuthenticationMethod::ClientSecretBasic
}
(
ClientAuthMethodConfig::ClientSecretPost {
client_secret: ref expected_client_secret,
},
Some(ref given_client_secret),
)
| (
OAuthClientAuthenticationMethod::ClientSecretPost,
Some(client_secret),
CredentialsVia::FormBody,
) => {
if expected_client_secret != given_client_secret {
return Err(
ClientAuthenticationError::ClientSecretMismatch { client_id }.into(),
);
let decrypted =
decrypt_client_secret(&client, &encrypter).map_err(|source| {
ClientAuthenticationError::InvalidClientSecret {
client_id: client.client_id.clone(),
source,
}
})?;
if client_secret.as_bytes() != decrypted {
return Err(warp::reject::custom(
ClientAuthenticationError::ClientSecretMismatch {
client_id: client.client_id,
},
));
}
OAuthClientAuthenticationMethod::ClientSecretPost
}
_ => {
return Err(
ClientAuthenticationError::WrongAuthenticationMethod { client_id }.into(),
)
return Err(warp::reject::custom(
ClientAuthenticationError::WrongAuthenticationMethod {
client_id: client.client_id,
},
));
}
};
}
(auth_method, client)
}
@ -195,34 +294,52 @@ async fn authenticate_client<T>(
// from the token, as per rfc7521 sec. 4.2
let client_id = client_id.as_ref().unwrap_or(&sub);
let client = clients_config
.iter()
.find(|client| &client.client_id == client_id)
.ok_or_else(|| ClientAuthenticationError::ClientNotFound {
let client = lookup_client_by_client_id(&mut *conn, client_id)
.await
.map_err(|source| ClientAuthenticationError::ClientFetch {
client_id: client_id.to_string(),
source,
})?;
let auth_method = match &client.client_auth_method {
ClientAuthMethodConfig::PrivateKeyJwt(jwks) => {
let store = jwks.key_store();
let auth_method = client.token_endpoint_auth_method.ok_or(
ClientAuthenticationError::WrongAuthenticationMethod {
client_id: client.client_id.clone(),
},
)?;
match auth_method {
OAuthClientAuthenticationMethod::ClientSecretJwt => {
let client_secret =
decrypt_client_secret(&client, &encrypter).map_err(|source| {
ClientAuthenticationError::InvalidClientSecret {
client_id: client.client_id.clone(),
source,
}
})?;
let store = SharedSecret::new(&client_secret);
let fut = token.verify(&decoded, &store);
fut.await.wrap_error()?;
OAuthClientAuthenticationMethod::PrivateKeyJwt
}
OAuthClientAuthenticationMethod::PrivateKeyJwt => {
let jwks = client.jwks.as_ref().ok_or_else(|| {
ClientAuthenticationError::InvalidJwks {
client_id: client.client_id.clone(),
}
})?;
ClientAuthMethodConfig::ClientSecretJwt { client_secret } => {
let store = SharedSecret::new(client_secret);
token.verify(&decoded, &store).await.wrap_error()?;
OAuthClientAuthenticationMethod::ClientSecretJwt
let store = jwks_key_store(jwks);
let fut = token.verify(&decoded, &store);
fut.await.wrap_error()?;
}
_ => {
return Err(ClientAuthenticationError::WrongAuthenticationMethod {
client_id: client_id.clone(),
}
.into())
return Err(warp::reject::custom(
ClientAuthenticationError::WrongAuthenticationMethod {
client_id: client.client_id,
},
));
}
};
}
// rfc7523 sec. 3.3: the audience is the URL being called
if !aud.contains(&audience) {
@ -243,7 +360,7 @@ async fn authenticate_client<T>(
tracing::Span::current().record("enduser.id", &client.client_id.as_str());
Ok((auth_method, client.clone(), body))
Ok((auth_method, client, body))
}
#[derive(Deserialize)]
@ -291,6 +408,7 @@ struct ClientAuthForm<T> {
body: T,
}
/* TODO: all secrets are broken because there is no way to mock the DB yet
#[cfg(test)]
mod tests {
use headers::authorization::Credentials;
@ -651,3 +769,4 @@ mod tests {
assert_eq!(body.bar, "foobar");
}
}
*/