You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-07-28 11:02:02 +03:00
Move clients to the database
This commit is contained in:
14
Cargo.lock
generated
14
Cargo.lock
generated
@ -1828,6 +1828,7 @@ dependencies = [
|
||||
"argon2",
|
||||
"atty",
|
||||
"clap",
|
||||
"data-encoding",
|
||||
"dotenv",
|
||||
"futures 0.3.21",
|
||||
"hyper",
|
||||
@ -1845,6 +1846,7 @@ dependencies = [
|
||||
"opentelemetry-otlp",
|
||||
"opentelemetry-semantic-conventions",
|
||||
"opentelemetry-zipkin",
|
||||
"rand",
|
||||
"reqwest",
|
||||
"schemars",
|
||||
"serde_json",
|
||||
@ -1870,12 +1872,9 @@ dependencies = [
|
||||
"chrono",
|
||||
"elliptic-curve",
|
||||
"figment",
|
||||
"futures-util",
|
||||
"http",
|
||||
"http-body",
|
||||
"indoc",
|
||||
"lettre",
|
||||
"mas-http",
|
||||
"mas-iana",
|
||||
"mas-jose",
|
||||
"p256",
|
||||
"pem-rfc7468",
|
||||
@ -1889,7 +1888,6 @@ dependencies = [
|
||||
"sqlx",
|
||||
"thiserror",
|
||||
"tokio",
|
||||
"tower",
|
||||
"tracing",
|
||||
"url",
|
||||
]
|
||||
@ -1901,6 +1899,7 @@ dependencies = [
|
||||
"chrono",
|
||||
"crc",
|
||||
"mas-iana",
|
||||
"mas-jose",
|
||||
"oauth2-types",
|
||||
"rand",
|
||||
"serde",
|
||||
@ -2068,10 +2067,12 @@ dependencies = [
|
||||
"chrono",
|
||||
"mas-data-model",
|
||||
"mas-iana",
|
||||
"mas-jose",
|
||||
"oauth2-types",
|
||||
"password-hash",
|
||||
"rand",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"sqlx",
|
||||
"thiserror",
|
||||
"tokio",
|
||||
@ -2123,9 +2124,12 @@ dependencies = [
|
||||
"crc",
|
||||
"data-encoding",
|
||||
"headers",
|
||||
"http",
|
||||
"http-body",
|
||||
"hyper",
|
||||
"mas-config",
|
||||
"mas-data-model",
|
||||
"mas-http",
|
||||
"mas-iana",
|
||||
"mas-jose",
|
||||
"mas-storage",
|
||||
|
@ -22,6 +22,8 @@ argon2 = { version = "0.3.4", features = ["password-hash"] }
|
||||
reqwest = { version = "0.11.9", features = ["rustls-tls"], default-features = false, optional = true }
|
||||
watchman_client = "0.7.1"
|
||||
atty = "0.2.14"
|
||||
rand = "0.8.5"
|
||||
data-encoding = "2.3.2"
|
||||
|
||||
tracing = "0.1.31"
|
||||
tracing-appender = "0.2.1"
|
||||
|
@ -14,9 +14,13 @@
|
||||
|
||||
use argon2::Argon2;
|
||||
use clap::Parser;
|
||||
use mas_config::DatabaseConfig;
|
||||
use mas_storage::user::{
|
||||
use data_encoding::BASE64;
|
||||
use mas_config::{DatabaseConfig, RootConfig};
|
||||
use mas_storage::{
|
||||
oauth2::client::{insert_client_from_config, lookup_client_by_client_id, truncate_clients},
|
||||
user::{
|
||||
lookup_user_by_username, lookup_user_email, mark_user_email_as_verified, register_user,
|
||||
},
|
||||
};
|
||||
use tracing::{info, warn};
|
||||
|
||||
@ -36,6 +40,13 @@ enum Subcommand {
|
||||
|
||||
/// Mark email address as verified
|
||||
VerifyEmail { username: String, email: String },
|
||||
|
||||
/// Import clients from config
|
||||
ImportClients {
|
||||
/// Remove all clients before importing
|
||||
#[clap(long)]
|
||||
truncate: bool,
|
||||
},
|
||||
}
|
||||
|
||||
impl Options {
|
||||
@ -71,6 +82,65 @@ impl Options {
|
||||
txn.commit().await?;
|
||||
info!(?email, "Email marked as verified");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
SC::ImportClients { truncate } => {
|
||||
let config: RootConfig = root.load_config()?;
|
||||
let pool = config.database.connect().await?;
|
||||
let encrypter = config.secrets.encrypter();
|
||||
|
||||
let mut txn = pool.begin().await?;
|
||||
|
||||
if *truncate {
|
||||
warn!("Removing all clients first");
|
||||
truncate_clients(&mut txn).await?;
|
||||
}
|
||||
|
||||
for client in config.clients.iter() {
|
||||
let client_id = &client.client_id;
|
||||
let res = lookup_client_by_client_id(&mut txn, client_id).await;
|
||||
match res {
|
||||
Ok(_) => {
|
||||
warn!(%client_id, "Skipping already imported client");
|
||||
continue;
|
||||
}
|
||||
Err(e) if e.not_found() => {}
|
||||
Err(e) => anyhow::bail!(e),
|
||||
}
|
||||
|
||||
info!(%client_id, "Importing client");
|
||||
let client_secret = client.client_secret();
|
||||
let client_auth_method = client.client_auth_method();
|
||||
let jwks = client.jwks();
|
||||
let jwks_uri = client.jwks_uri();
|
||||
let redirect_uris = &client.redirect_uris;
|
||||
|
||||
// TODO: should be moved somewhere else
|
||||
let encrypted_client_secret = client_secret
|
||||
.map(|client_secret| {
|
||||
let nonce: [u8; 12] = rand::random();
|
||||
let message = encrypter.encrypt(&nonce, client_secret.as_bytes())?;
|
||||
let concat = [&nonce[..], &message[..]].concat();
|
||||
let res = BASE64.encode(&concat);
|
||||
|
||||
anyhow::Ok(res)
|
||||
})
|
||||
.transpose()?;
|
||||
|
||||
insert_client_from_config(
|
||||
&mut txn,
|
||||
client_id,
|
||||
client_auth_method,
|
||||
encrypted_client_secret.as_deref(),
|
||||
jwks,
|
||||
jwks_uri,
|
||||
redirect_uris,
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
|
||||
txn.commit().await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
@ -35,8 +35,4 @@ pem-rfc7468 = "0.3.1"
|
||||
indoc = "1.0.4"
|
||||
|
||||
mas-jose = { path = "../jose" }
|
||||
mas-http = { path = "../http" }
|
||||
tower = { version = "0.4.12", features = ["util"] }
|
||||
http = "0.2.6"
|
||||
http-body = "0.4.4"
|
||||
futures-util = "0.3.21"
|
||||
mas-iana = { path = "../iana" }
|
||||
|
@ -15,15 +15,12 @@
|
||||
use std::ops::{Deref, DerefMut};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use futures_util::future::Either;
|
||||
use http::Request;
|
||||
use mas_http::HttpServiceExt;
|
||||
use mas_jose::{DynamicJwksStore, JsonWebKeySet, StaticJwksStore, VerifyingKeystore};
|
||||
use mas_iana::oauth::OAuthClientAuthenticationMethod;
|
||||
use mas_jose::JsonWebKeySet;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_with::skip_serializing_none;
|
||||
use thiserror::Error;
|
||||
use tower::{BoxError, ServiceExt};
|
||||
use url::Url;
|
||||
|
||||
use super::ConfigurationSection;
|
||||
@ -35,41 +32,6 @@ pub enum JwksOrJwksUri {
|
||||
JwksUri(Url),
|
||||
}
|
||||
|
||||
impl JwksOrJwksUri {
|
||||
pub fn key_store(&self) -> Either<StaticJwksStore, DynamicJwksStore> {
|
||||
// Assert that the output is both a VerifyingKeystore and Send
|
||||
fn assert<T: Send + VerifyingKeystore>(t: T) -> T {
|
||||
t
|
||||
}
|
||||
|
||||
let inner = match self {
|
||||
Self::Jwks(jwks) => Either::Left(StaticJwksStore::new(jwks.clone())),
|
||||
Self::JwksUri(uri) => {
|
||||
let uri = uri.clone();
|
||||
|
||||
// TODO: get the client from somewhere else?
|
||||
let exporter = mas_http::client("fetch-jwks")
|
||||
.json::<JsonWebKeySet>()
|
||||
.map_request(move |_: ()| {
|
||||
Request::builder()
|
||||
.method("GET")
|
||||
// TODO: change the Uri type in config to avoid reparsing here
|
||||
.uri(uri.to_string())
|
||||
.body(http_body::Empty::new())
|
||||
.unwrap()
|
||||
})
|
||||
.map_response(http::Response::into_body)
|
||||
.map_err(BoxError::from)
|
||||
.boxed_clone();
|
||||
|
||||
Either::Right(DynamicJwksStore::new(exporter))
|
||||
}
|
||||
};
|
||||
|
||||
assert(inner)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<JsonWebKeySet> for JwksOrJwksUri {
|
||||
fn from(jwks: JsonWebKeySet) -> Self {
|
||||
Self::Jwks(jwks)
|
||||
@ -131,24 +93,53 @@ pub struct InvalidRedirectUriError;
|
||||
|
||||
impl ClientConfig {
|
||||
#[doc(hidden)]
|
||||
pub fn resolve_redirect_uri<'a>(
|
||||
&'a self,
|
||||
suggested_uri: &'a Option<Url>,
|
||||
) -> Result<&'a Url, InvalidRedirectUriError> {
|
||||
suggested_uri.as_ref().map_or_else(
|
||||
|| self.redirect_uris.get(0).ok_or(InvalidRedirectUriError),
|
||||
|suggested_uri| self.check_redirect_uri(suggested_uri),
|
||||
)
|
||||
#[must_use]
|
||||
pub fn client_secret(&self) -> Option<&str> {
|
||||
match &self.client_auth_method {
|
||||
ClientAuthMethodConfig::ClientSecretPost { client_secret }
|
||||
| ClientAuthMethodConfig::ClientSecretBasic { client_secret }
|
||||
| ClientAuthMethodConfig::ClientSecretJwt { client_secret } => Some(client_secret),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
fn check_redirect_uri<'a>(
|
||||
&self,
|
||||
redirect_uri: &'a Url,
|
||||
) -> Result<&'a Url, InvalidRedirectUriError> {
|
||||
if self.redirect_uris.contains(redirect_uri) {
|
||||
Ok(redirect_uri)
|
||||
} else {
|
||||
Err(InvalidRedirectUriError)
|
||||
#[doc(hidden)]
|
||||
#[must_use]
|
||||
pub fn client_auth_method(&self) -> OAuthClientAuthenticationMethod {
|
||||
match &self.client_auth_method {
|
||||
ClientAuthMethodConfig::None => OAuthClientAuthenticationMethod::None,
|
||||
ClientAuthMethodConfig::ClientSecretBasic { .. } => {
|
||||
OAuthClientAuthenticationMethod::ClientSecretBasic
|
||||
}
|
||||
ClientAuthMethodConfig::ClientSecretPost { .. } => {
|
||||
OAuthClientAuthenticationMethod::ClientSecretPost
|
||||
}
|
||||
ClientAuthMethodConfig::ClientSecretJwt { .. } => {
|
||||
OAuthClientAuthenticationMethod::ClientSecretJwt
|
||||
}
|
||||
ClientAuthMethodConfig::PrivateKeyJwt(_) => {
|
||||
OAuthClientAuthenticationMethod::PrivateKeyJwt
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[doc(hidden)]
|
||||
#[must_use]
|
||||
pub fn jwks(&self) -> Option<&JsonWebKeySet> {
|
||||
match &self.client_auth_method {
|
||||
ClientAuthMethodConfig::PrivateKeyJwt(JwksOrJwksUri::Jwks(jwks)) => Some(jwks),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
#[doc(hidden)]
|
||||
#[must_use]
|
||||
pub fn jwks_uri(&self) -> Option<&Url> {
|
||||
match &self.client_auth_method {
|
||||
ClientAuthMethodConfig::PrivateKeyJwt(JwksOrJwksUri::JwksUri(jwks_uri)) => {
|
||||
Some(jwks_uri)
|
||||
}
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -14,4 +14,5 @@ crc = "2.1.0"
|
||||
rand = "0.8.5"
|
||||
|
||||
mas-iana = { path = "../iana" }
|
||||
mas-jose = { path = "../jose" }
|
||||
oauth2-types = { path = "../oauth2-types" }
|
||||
|
@ -30,7 +30,8 @@ pub(crate) mod users;
|
||||
|
||||
pub use self::{
|
||||
oauth2::{
|
||||
AuthorizationCode, AuthorizationGrant, AuthorizationGrantStage, Client, Pkce, Session,
|
||||
AuthorizationCode, AuthorizationGrant, AuthorizationGrantStage, Client, JwksOrJwksUri,
|
||||
Pkce, Session,
|
||||
},
|
||||
tokens::{AccessToken, RefreshToken, TokenFormatError, TokenType},
|
||||
traits::{StorageBackend, StorageBackendMarker},
|
||||
|
@ -12,16 +12,87 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use mas_iana::{
|
||||
jose::JsonWebSignatureAlg,
|
||||
oauth::{OAuthAuthorizationEndpointResponseType, OAuthClientAuthenticationMethod},
|
||||
};
|
||||
use mas_jose::JsonWebKeySet;
|
||||
use oauth2_types::requests::GrantType;
|
||||
use serde::Serialize;
|
||||
use thiserror::Error;
|
||||
use url::Url;
|
||||
|
||||
use crate::traits::{StorageBackend, StorageBackendMarker};
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum JwksOrJwksUri {
|
||||
/// Client's JSON Web Key Set document, passed by value.
|
||||
Jwks(JsonWebKeySet),
|
||||
|
||||
/// URL for the Client's JSON Web Key Set document.
|
||||
JwksUri(Url),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize)]
|
||||
#[serde(bound = "T: StorageBackend")]
|
||||
pub struct Client<T: StorageBackend> {
|
||||
#[serde(skip_serializing)]
|
||||
pub data: T::ClientData,
|
||||
|
||||
/// Client identifier
|
||||
pub client_id: String,
|
||||
|
||||
pub encrypted_client_secret: Option<String>,
|
||||
|
||||
/// Array of Redirection URI values used by the Client
|
||||
pub redirect_uris: Vec<Url>,
|
||||
|
||||
/// Array containing a list of the OAuth 2.0 response_type values that the
|
||||
/// Client is declaring that it will restrict itself to using
|
||||
pub response_types: Vec<OAuthAuthorizationEndpointResponseType>,
|
||||
|
||||
/// Array containing a list of the OAuth 2.0 Grant Types that the Client is
|
||||
/// declaring that it will restrict itself to using.
|
||||
pub grant_types: Vec<GrantType>,
|
||||
|
||||
/// Array of e-mail addresses of people responsible for this Client
|
||||
pub contacts: Vec<String>,
|
||||
|
||||
/// Name of the Client to be presented to the End-User
|
||||
pub client_name: Option<String>, // TODO: translations
|
||||
|
||||
/// URL that references a logo for the Client application
|
||||
pub logo_uri: Option<Url>, // TODO: translations
|
||||
|
||||
/// URL of the home page of the Client
|
||||
pub client_uri: Option<Url>, // TODO: translations
|
||||
|
||||
/// URL that the Relying Party Client provides to the End-User to read about
|
||||
/// the how the profile data will be used
|
||||
pub policy_uri: Option<Url>, // TODO: translations
|
||||
|
||||
/// URL that the Relying Party Client provides to the End-User to read about
|
||||
/// the Relying Party's terms of service
|
||||
pub tos_uri: Option<Url>, // TODO: translations
|
||||
|
||||
pub jwks: Option<JwksOrJwksUri>,
|
||||
|
||||
/// JWS alg algorithm REQUIRED for signing the ID Token issued to this
|
||||
/// Client
|
||||
pub id_token_signed_response_alg: Option<JsonWebSignatureAlg>,
|
||||
|
||||
/// Requested authentication method for the token endpoint
|
||||
pub token_endpoint_auth_method: Option<OAuthClientAuthenticationMethod>,
|
||||
|
||||
/// JWS alg algorithm that MUST be used for signing the JWT used to
|
||||
/// authenticate the Client at the Token Endpoint for the private_key_jwt
|
||||
/// and client_secret_jwt authentication methods
|
||||
pub token_endpoint_auth_signing_alg: Option<JsonWebSignatureAlg>,
|
||||
|
||||
/// URI using the https scheme that a third party can use to initiate a
|
||||
/// login by the RP
|
||||
pub initiate_login_uri: Option<Url>,
|
||||
}
|
||||
|
||||
impl<S: StorageBackendMarker> From<Client<S>> for Client<()> {
|
||||
@ -29,6 +100,48 @@ impl<S: StorageBackendMarker> From<Client<S>> for Client<()> {
|
||||
Client {
|
||||
data: (),
|
||||
client_id: c.client_id,
|
||||
encrypted_client_secret: c.encrypted_client_secret,
|
||||
redirect_uris: c.redirect_uris,
|
||||
response_types: c.response_types,
|
||||
grant_types: c.grant_types,
|
||||
contacts: c.contacts,
|
||||
client_name: c.client_name,
|
||||
logo_uri: c.logo_uri,
|
||||
client_uri: c.client_uri,
|
||||
policy_uri: c.policy_uri,
|
||||
tos_uri: c.tos_uri,
|
||||
jwks: c.jwks,
|
||||
id_token_signed_response_alg: c.id_token_signed_response_alg,
|
||||
token_endpoint_auth_method: c.token_endpoint_auth_method,
|
||||
token_endpoint_auth_signing_alg: c.token_endpoint_auth_signing_alg,
|
||||
initiate_login_uri: c.initiate_login_uri,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum InvalidRedirectUriError {
|
||||
#[error("redirect_uri is not allowed for this client")]
|
||||
NotAllowed,
|
||||
|
||||
#[error("multiple redirect_uris registered for this client")]
|
||||
MultipleRegistered,
|
||||
|
||||
#[error("client has no redirect_uri registered")]
|
||||
NoneRegistered,
|
||||
}
|
||||
|
||||
impl<S: StorageBackend> Client<S> {
|
||||
pub fn resolve_redirect_uri<'a>(
|
||||
&'a self,
|
||||
redirect_uri: &'a Option<Url>,
|
||||
) -> Result<&'a Url, InvalidRedirectUriError> {
|
||||
match (&self.redirect_uris[..], redirect_uri) {
|
||||
([], _) => Err(InvalidRedirectUriError::NoneRegistered),
|
||||
([one], None) => Ok(one),
|
||||
(_, None) => Err(InvalidRedirectUriError::MultipleRegistered),
|
||||
(uris, Some(uri)) if uris.contains(uri) => Ok(uri),
|
||||
_ => Err(InvalidRedirectUriError::NotAllowed),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
// Copyright 2021, 2022 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
@ -18,6 +18,6 @@ pub(self) mod session;
|
||||
|
||||
pub use self::{
|
||||
authorization_grant::{AuthorizationCode, AuthorizationGrant, AuthorizationGrantStage, Pkce},
|
||||
client::Client,
|
||||
client::{Client, JwksOrJwksUri},
|
||||
session::Session,
|
||||
};
|
||||
|
@ -46,14 +46,7 @@ pub fn root(
|
||||
config: &RootConfig,
|
||||
) -> BoxedFilter<(impl Reply,)> {
|
||||
let health = health(pool);
|
||||
let oauth2 = oauth2(
|
||||
pool,
|
||||
templates,
|
||||
key_store,
|
||||
encrypter,
|
||||
&config.clients,
|
||||
&config.http,
|
||||
);
|
||||
let oauth2 = oauth2(pool, templates, key_store, encrypter, &config.http);
|
||||
let views = views(
|
||||
pool,
|
||||
templates,
|
||||
|
@ -20,7 +20,7 @@ use hyper::{
|
||||
http::uri::{Parts, PathAndQuery, Uri},
|
||||
StatusCode,
|
||||
};
|
||||
use mas_config::{ClientsConfig, Encrypter};
|
||||
use mas_config::Encrypter;
|
||||
use mas_data_model::{
|
||||
Authentication, AuthorizationCode, AuthorizationGrant, AuthorizationGrantStage, BrowserSession,
|
||||
Pkce, StorageBackend, TokenType,
|
||||
@ -32,6 +32,7 @@ use mas_storage::{
|
||||
authorization_grant::{
|
||||
derive_session, fulfill_grant, get_grant_by_id, new_authorization_grant,
|
||||
},
|
||||
client::lookup_client_by_client_id,
|
||||
refresh_token::add_refresh_token,
|
||||
},
|
||||
PostgresqlBackend,
|
||||
@ -41,7 +42,7 @@ use mas_warp_utils::{
|
||||
errors::WrapError,
|
||||
filters::{
|
||||
self,
|
||||
database::transaction,
|
||||
database::{connection, transaction},
|
||||
session::{optional_session, session},
|
||||
with_templates,
|
||||
},
|
||||
@ -49,19 +50,20 @@ use mas_warp_utils::{
|
||||
use oauth2_types::{
|
||||
errors::{
|
||||
ErrorResponse, InvalidGrant, InvalidRequest, LoginRequired, OAuth2Error,
|
||||
RegistrationNotSupported, RequestNotSupported, RequestUriNotSupported,
|
||||
RegistrationNotSupported, RequestNotSupported, RequestUriNotSupported, UnauthorizedClient,
|
||||
},
|
||||
pkce,
|
||||
prelude::*,
|
||||
requests::{
|
||||
AccessTokenResponse, AuthorizationRequest, AuthorizationResponse, Prompt, ResponseMode,
|
||||
AccessTokenResponse, AuthorizationRequest, AuthorizationResponse, GrantType, Prompt,
|
||||
ResponseMode,
|
||||
},
|
||||
scope::ScopeToken,
|
||||
};
|
||||
use rand::{distributions::Alphanumeric, thread_rng, Rng};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use sqlx::{PgExecutor, PgPool, Postgres, Transaction};
|
||||
use sqlx::{pool::PoolConnection, PgConnection, PgPool, Postgres, Transaction};
|
||||
use url::Url;
|
||||
use warp::{
|
||||
filters::BoxedFilter,
|
||||
@ -217,15 +219,10 @@ pub fn filter(
|
||||
pool: &PgPool,
|
||||
templates: &Templates,
|
||||
encrypter: &Encrypter,
|
||||
clients_config: &ClientsConfig,
|
||||
) -> BoxedFilter<(Box<dyn Reply>,)> {
|
||||
let clients_config = clients_config.clone();
|
||||
let clients_config_2 = clients_config.clone();
|
||||
|
||||
let authorize = warp::path!("oauth2" / "authorize")
|
||||
.and(filters::trace::name("GET /oauth2/authorize"))
|
||||
.and(warp::get())
|
||||
.map(move || clients_config.clone())
|
||||
.and(warp::query())
|
||||
.and(optional_session(pool, encrypter))
|
||||
.and(transaction(pool))
|
||||
@ -245,8 +242,8 @@ pub fn filter(
|
||||
.recover(recover)
|
||||
.unify()
|
||||
.and(warp::query())
|
||||
.and(warp::any().map(move || clients_config_2.clone()))
|
||||
.and(with_templates(templates))
|
||||
.and(connection(pool))
|
||||
.and_then(actually_reply)
|
||||
.boxed()
|
||||
}
|
||||
@ -262,8 +259,8 @@ async fn recover(rejection: Rejection) -> Result<ReplyOrBackToClient, Rejection>
|
||||
async fn actually_reply(
|
||||
rep: ReplyOrBackToClient,
|
||||
q: PartialParams,
|
||||
clients: ClientsConfig,
|
||||
templates: Templates,
|
||||
mut conn: PoolConnection<Postgres>,
|
||||
) -> Result<Box<dyn Reply>, Rejection> {
|
||||
let (redirect_uri, response_mode, state, params) = match rep {
|
||||
ReplyOrBackToClient::Reply(r) => return Ok(r),
|
||||
@ -281,15 +278,14 @@ async fn actually_reply(
|
||||
..
|
||||
} = q;
|
||||
|
||||
// First, disover the client
|
||||
let client = client_id
|
||||
.and_then(|client_id| clients.iter().find(|client| client.client_id == client_id));
|
||||
|
||||
let client = match client {
|
||||
Some(client) => client,
|
||||
None => return Ok(Box::new(html(templates.render_error(&error.into()).await?))),
|
||||
let client_id = if let Some(client_id) = client_id {
|
||||
client_id
|
||||
} else {
|
||||
return Ok(Box::new(html(templates.render_error(&error.into()).await?)));
|
||||
};
|
||||
|
||||
let client = lookup_client_by_client_id(&mut conn, &client_id).await?;
|
||||
|
||||
let redirect_uri: Result<Option<Url>, _> = redirect_uri.map(|r| r.parse()).transpose();
|
||||
let redirect_uri = match redirect_uri {
|
||||
Ok(r) => r,
|
||||
@ -315,7 +311,6 @@ async fn actually_reply(
|
||||
}
|
||||
|
||||
async fn get(
|
||||
clients: ClientsConfig,
|
||||
params: Params,
|
||||
maybe_session: Option<BrowserSession<PostgresqlBackend>>,
|
||||
mut txn: Transaction<'_, Postgres>,
|
||||
@ -337,15 +332,17 @@ async fn get(
|
||||
}
|
||||
|
||||
// First, find out what client it is
|
||||
let client = clients
|
||||
.iter()
|
||||
.find(|client| client.client_id == params.auth.client_id)
|
||||
.ok_or_else(|| anyhow::anyhow!("could not find client"))
|
||||
.wrap_error()?;
|
||||
let client = lookup_client_by_client_id(&mut txn, ¶ms.auth.client_id).await?;
|
||||
|
||||
// Check if it is allowed to use this grant type
|
||||
if !client.grant_types.contains(&GrantType::AuthorizationCode) {
|
||||
return Ok(ReplyOrBackToClient::Error(Box::new(UnauthorizedClient)));
|
||||
}
|
||||
|
||||
let redirect_uri = client
|
||||
.resolve_redirect_uri(¶ms.auth.redirect_uri)
|
||||
.wrap_error()?;
|
||||
.wrap_error()?
|
||||
.clone();
|
||||
let response_type = params.auth.response_type;
|
||||
let response_mode =
|
||||
resolve_response_mode(response_type, params.auth.response_mode).wrap_error()?;
|
||||
@ -392,8 +389,8 @@ async fn get(
|
||||
|
||||
let grant = new_authorization_grant(
|
||||
&mut txn,
|
||||
client.client_id.clone(),
|
||||
redirect_uri.clone(),
|
||||
client,
|
||||
redirect_uri,
|
||||
scope,
|
||||
code,
|
||||
params.auth.state,
|
||||
@ -471,10 +468,10 @@ impl ContinueAuthorizationGrant {
|
||||
|
||||
pub async fn fetch_authorization_grant(
|
||||
&self,
|
||||
executor: impl PgExecutor<'_>,
|
||||
conn: &mut PgConnection,
|
||||
) -> anyhow::Result<AuthorizationGrant<PostgresqlBackend>> {
|
||||
let data = self.data.parse()?;
|
||||
get_grant_by_id(executor, data).await
|
||||
get_grant_by_id(conn, data).await
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -12,11 +12,14 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use mas_config::{ClientConfig, ClientsConfig, HttpConfig};
|
||||
use mas_data_model::TokenType;
|
||||
use mas_config::{Encrypter, HttpConfig};
|
||||
use mas_data_model::{Client, TokenType};
|
||||
use mas_iana::oauth::{OAuthClientAuthenticationMethod, OAuthTokenTypeHint};
|
||||
use mas_storage::oauth2::{
|
||||
use mas_storage::{
|
||||
oauth2::{
|
||||
access_token::lookup_active_access_token, refresh_token::lookup_active_refresh_token,
|
||||
},
|
||||
PostgresqlBackend,
|
||||
};
|
||||
use mas_warp_utils::{
|
||||
errors::WrapError,
|
||||
@ -29,7 +32,7 @@ use warp::{filters::BoxedFilter, Filter, Rejection, Reply};
|
||||
|
||||
pub fn filter(
|
||||
pool: &PgPool,
|
||||
clients_config: &ClientsConfig,
|
||||
encrypter: &Encrypter,
|
||||
http_config: &HttpConfig,
|
||||
) -> BoxedFilter<(Box<dyn Reply>,)> {
|
||||
let audience = UrlBuilder::from(http_config)
|
||||
@ -41,7 +44,7 @@ pub fn filter(
|
||||
.and(
|
||||
warp::post()
|
||||
.and(connection(pool))
|
||||
.and(client_authentication(clients_config, audience))
|
||||
.and(client_authentication(pool, encrypter, audience))
|
||||
.and_then(introspect)
|
||||
.recover(recover)
|
||||
.unify(),
|
||||
@ -67,7 +70,7 @@ const INACTIVE: IntrospectionResponse = IntrospectionResponse {
|
||||
async fn introspect(
|
||||
mut conn: PoolConnection<Postgres>,
|
||||
auth: OAuthClientAuthenticationMethod,
|
||||
client: ClientConfig,
|
||||
client: Client<PostgresqlBackend>,
|
||||
params: IntrospectionRequest,
|
||||
) -> Result<Box<dyn Reply>, Rejection> {
|
||||
// Token introspection is only allowed by confidential clients
|
||||
|
@ -15,7 +15,7 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use hyper::Method;
|
||||
use mas_config::{ClientsConfig, Encrypter, HttpConfig};
|
||||
use mas_config::{Encrypter, HttpConfig};
|
||||
use mas_jose::StaticKeystore;
|
||||
use mas_templates::Templates;
|
||||
use mas_warp_utils::filters::cors::cors;
|
||||
@ -41,15 +41,14 @@ pub fn filter(
|
||||
templates: &Templates,
|
||||
key_store: &Arc<StaticKeystore>,
|
||||
encrypter: &Encrypter,
|
||||
clients_config: &ClientsConfig,
|
||||
http_config: &HttpConfig,
|
||||
) -> BoxedFilter<(impl Reply,)> {
|
||||
let discovery = discovery(key_store.as_ref(), http_config);
|
||||
let keys = keys(key_store);
|
||||
let authorization = authorization(pool, templates, encrypter, clients_config);
|
||||
let authorization = authorization(pool, templates, encrypter);
|
||||
let userinfo = userinfo(pool);
|
||||
let introspection = introspection(pool, clients_config, http_config);
|
||||
let token = token(pool, key_store, clients_config, http_config);
|
||||
let introspection = introspection(pool, encrypter, http_config);
|
||||
let token = token(pool, encrypter, key_store, http_config);
|
||||
|
||||
let filter = discovery
|
||||
.or(keys)
|
||||
|
@ -19,8 +19,8 @@ use chrono::{DateTime, Duration, Utc};
|
||||
use data_encoding::BASE64URL_NOPAD;
|
||||
use headers::{CacheControl, Pragma};
|
||||
use hyper::StatusCode;
|
||||
use mas_config::{ClientConfig, ClientsConfig, HttpConfig};
|
||||
use mas_data_model::{AuthorizationGrantStage, TokenType};
|
||||
use mas_config::{Encrypter, HttpConfig};
|
||||
use mas_data_model::{AuthorizationGrantStage, Client, TokenType};
|
||||
use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod};
|
||||
use mas_jose::{claims, DecodedJsonWebToken, SigningKeystore, StaticKeystore};
|
||||
use mas_storage::{
|
||||
@ -33,7 +33,7 @@ use mas_storage::{
|
||||
RefreshTokenLookupError,
|
||||
},
|
||||
},
|
||||
DatabaseInconsistencyError,
|
||||
DatabaseInconsistencyError, PostgresqlBackend,
|
||||
};
|
||||
use mas_warp_utils::{
|
||||
errors::WrapError,
|
||||
@ -99,8 +99,8 @@ where
|
||||
|
||||
pub fn filter(
|
||||
pool: &PgPool,
|
||||
encrypter: &Encrypter,
|
||||
key_store: &Arc<StaticKeystore>,
|
||||
clients_config: &ClientsConfig,
|
||||
http_config: &HttpConfig,
|
||||
) -> BoxedFilter<(Box<dyn Reply>,)> {
|
||||
let key_store = key_store.clone();
|
||||
@ -113,7 +113,7 @@ pub fn filter(
|
||||
.and(filters::trace::name("POST /oauth2/token"))
|
||||
.and(
|
||||
warp::post()
|
||||
.and(client_authentication(clients_config, audience))
|
||||
.and(client_authentication(pool, encrypter, audience))
|
||||
.and(warp::any().map(move || key_store.clone()))
|
||||
.and(warp::any().map(move || issuer.clone()))
|
||||
.and(connection(pool))
|
||||
@ -145,7 +145,7 @@ async fn recover(rejection: Rejection) -> Result<Box<dyn Reply>, Infallible> {
|
||||
|
||||
async fn token(
|
||||
_auth: OAuthClientAuthenticationMethod,
|
||||
client: ClientConfig,
|
||||
client: Client<PostgresqlBackend>,
|
||||
req: AccessTokenRequest,
|
||||
key_store: Arc<StaticKeystore>,
|
||||
issuer: Url,
|
||||
@ -185,7 +185,7 @@ fn hash<H: Digest>(mut hasher: H, token: &str) -> anyhow::Result<String> {
|
||||
#[allow(clippy::too_many_lines)]
|
||||
async fn authorization_code_grant(
|
||||
grant: &AuthorizationCodeGrant,
|
||||
client: &ClientConfig,
|
||||
client: &Client<PostgresqlBackend>,
|
||||
key_store: &StaticKeystore,
|
||||
issuer: Url,
|
||||
conn: &mut PoolConnection<Postgres>,
|
||||
@ -349,7 +349,7 @@ async fn authorization_code_grant(
|
||||
|
||||
async fn refresh_token_grant(
|
||||
grant: &RefreshTokenGrant,
|
||||
client: &ClientConfig,
|
||||
client: &Client<PostgresqlBackend>,
|
||||
conn: &mut PoolConnection<Postgres>,
|
||||
) -> Result<AccessTokenResponse, Rejection> {
|
||||
let mut txn = conn.begin().await.wrap_error()?;
|
||||
|
@ -17,7 +17,7 @@
|
||||
use hyper::Uri;
|
||||
use mas_templates::PostAuthContext;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sqlx::PgExecutor;
|
||||
use sqlx::PgConnection;
|
||||
|
||||
use super::super::oauth2::ContinueAuthorizationGrant;
|
||||
|
||||
@ -36,11 +36,11 @@ impl PostAuthAction {
|
||||
|
||||
pub async fn load_context<'e>(
|
||||
&self,
|
||||
executor: impl PgExecutor<'e>,
|
||||
conn: &mut PgConnection,
|
||||
) -> anyhow::Result<PostAuthContext> {
|
||||
match self {
|
||||
Self::ContinueAuthorizationGrant(c) => {
|
||||
let grant = c.fetch_authorization_grant(executor).await?;
|
||||
let grant = c.fetch_authorization_grant(conn).await?;
|
||||
let grant = grant.into();
|
||||
Ok(PostAuthContext::ContinueAuthorizationGrant { grant })
|
||||
}
|
||||
|
@ -21,3 +21,5 @@
|
||||
|
||||
pub mod jose;
|
||||
pub mod oauth;
|
||||
|
||||
pub use parse_display::ParseError;
|
||||
|
@ -31,7 +31,7 @@ sha2 = "0.10.2"
|
||||
signature = "1.4.0"
|
||||
thiserror = "1.0.30"
|
||||
tokio = { version = "1.17.0", features = ["macros", "rt", "sync"] }
|
||||
tower = "0.4.12"
|
||||
tower = { version = "0.4.12", features = ["util"] }
|
||||
url = { version = "2.2.2", features = ["serde"] }
|
||||
|
||||
mas-iana = { path = "../iana" }
|
||||
|
@ -22,6 +22,8 @@ pub(crate) mod jwk;
|
||||
pub(crate) mod jwt;
|
||||
mod keystore;
|
||||
|
||||
pub use futures_util::future::Either;
|
||||
|
||||
pub use self::{
|
||||
jwk::{JsonWebKey, JsonWebKeySet},
|
||||
jwt::{DecodedJsonWebToken, JsonWebTokenParts, JwtHeader},
|
||||
|
@ -27,14 +27,21 @@ use url::Url;
|
||||
|
||||
use crate::requests::{Display, GrantType, ResponseMode};
|
||||
|
||||
#[derive(Serialize, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
#[derive(Serialize, Clone, Copy, PartialEq, Eq, Hash, Debug)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum ApplicationType {
|
||||
Web,
|
||||
Native,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Clone, Copy, PartialEq, Eq, Hash, Debug)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum SubjectType {
|
||||
Public,
|
||||
Pairwise,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
#[derive(Serialize, Clone, Copy, PartialEq, Eq, Hash, Debug)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum ClaimType {
|
||||
Normal,
|
||||
|
@ -192,6 +192,7 @@ pub struct ClientCredentialsGrant {
|
||||
pub enum GrantType {
|
||||
AuthorizationCode,
|
||||
RefreshToken,
|
||||
Implicit,
|
||||
ClientCredentials,
|
||||
}
|
||||
|
||||
|
@ -7,9 +7,10 @@ license = "Apache-2.0"
|
||||
|
||||
[dependencies]
|
||||
tokio = "1.17.0"
|
||||
sqlx = { version = "0.5.11", features = ["runtime-tokio-rustls", "postgres", "migrate", "chrono", "offline"] }
|
||||
sqlx = { version = "0.5.11", features = ["runtime-tokio-rustls", "postgres", "migrate", "chrono", "offline", "json"] }
|
||||
chrono = { version = "0.4.19", features = ["serde"] }
|
||||
serde = { version = "1.0.136", features = ["derive"] }
|
||||
serde_json = "1.0.79"
|
||||
thiserror = "1.0.30"
|
||||
anyhow = "1.0.55"
|
||||
tracing = "0.1.31"
|
||||
@ -24,3 +25,4 @@ url = { version = "2.2.2", features = ["serde"] }
|
||||
oauth2-types = { path = "../oauth2-types" }
|
||||
mas-data-model = { path = "../data-model" }
|
||||
mas-iana = { path = "../iana" }
|
||||
mas-jose = { path = "../jose" }
|
||||
|
@ -0,0 +1,17 @@
|
||||
-- Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
--
|
||||
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||
-- you may not use this file except in compliance with the License.
|
||||
-- You may obtain a copy of the License at
|
||||
--
|
||||
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||
--
|
||||
-- Unless required by applicable law or agreed to in writing, software
|
||||
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
-- See the License for the specific language governing permissions and
|
||||
-- limitations under the License.
|
||||
|
||||
DROP TABLE oauth2_client_redirect_uris;
|
||||
DROP TRIGGER set_timestamp ON oauth2_clients;
|
||||
DROP TABLE oauth2_clients;
|
@ -0,0 +1,51 @@
|
||||
-- Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
--
|
||||
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||
-- you may not use this file except in compliance with the License.
|
||||
-- You may obtain a copy of the License at
|
||||
--
|
||||
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||
--
|
||||
-- Unless required by applicable law or agreed to in writing, software
|
||||
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
-- See the License for the specific language governing permissions and
|
||||
-- limitations under the License.
|
||||
|
||||
CREATE TABLE oauth2_clients (
|
||||
"id" BIGSERIAL PRIMARY KEY,
|
||||
"client_id" TEXT NOT NULL UNIQUE,
|
||||
"encrypted_client_secret" TEXT,
|
||||
"response_types" TEXT[] NOT NULL,
|
||||
"grant_type_authorization_code" BOOL NOT NULL,
|
||||
"grant_type_refresh_token" BOOL NOT NULL,
|
||||
"contacts" TEXT[] NOT NULL,
|
||||
"client_name" TEXT,
|
||||
"logo_uri" TEXT,
|
||||
"client_uri" TEXT,
|
||||
"policy_uri" TEXT,
|
||||
"tos_uri" TEXT,
|
||||
"jwks_uri" TEXT,
|
||||
"jwks" JSONB,
|
||||
"id_token_signed_response_alg" TEXT,
|
||||
"token_endpoint_auth_method" TEXT,
|
||||
"token_endpoint_auth_signing_alg" TEXT,
|
||||
"initiate_login_uri" TEXT,
|
||||
|
||||
"created_at" TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT now(),
|
||||
"updated_at" TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT now(),
|
||||
|
||||
-- jwks and jwks_uri can't be set at the same time
|
||||
CHECK ("jwks" IS NULL OR "jwks_uri" IS NULL)
|
||||
);
|
||||
|
||||
CREATE TRIGGER set_timestamp
|
||||
BEFORE UPDATE ON oauth2_clients
|
||||
FOR EACH ROW
|
||||
EXECUTE PROCEDURE trigger_set_timestamp();
|
||||
|
||||
CREATE TABLE oauth2_client_redirect_uris (
|
||||
"id" BIGSERIAL PRIMARY KEY,
|
||||
"oauth2_client_id" BIGINT NOT NULL REFERENCES oauth2_clients (id) ON DELETE CASCADE,
|
||||
"redirect_uri" TEXT NOT NULL
|
||||
);
|
File diff suppressed because it is too large
Load Diff
@ -41,7 +41,7 @@ impl StorageBackend for PostgresqlBackend {
|
||||
type AuthenticationData = i64;
|
||||
type AuthorizationGrantData = i64;
|
||||
type BrowserSessionData = i64;
|
||||
type ClientData = ();
|
||||
type ClientData = i64;
|
||||
type RefreshTokenData = i64;
|
||||
type SessionData = i64;
|
||||
type UserData = i64;
|
||||
|
@ -14,12 +14,11 @@
|
||||
|
||||
use anyhow::Context;
|
||||
use chrono::{DateTime, Duration, Utc};
|
||||
use mas_data_model::{
|
||||
AccessToken, Authentication, BrowserSession, Client, Session, User, UserEmail,
|
||||
};
|
||||
use sqlx::PgExecutor;
|
||||
use mas_data_model::{AccessToken, Authentication, BrowserSession, Session, User, UserEmail};
|
||||
use sqlx::{PgConnection, PgExecutor};
|
||||
use thiserror::Error;
|
||||
|
||||
use super::client::{lookup_client_by_client_id, ClientFetchError};
|
||||
use crate::{DatabaseInconsistencyError, IdAndCreationTime, PostgresqlBackend};
|
||||
|
||||
pub async fn add_access_token(
|
||||
@ -83,6 +82,7 @@ pub struct OAuth2AccessTokenLookup {
|
||||
#[error("failed to lookup access token")]
|
||||
pub enum AccessTokenLookupError {
|
||||
Database(#[from] sqlx::Error),
|
||||
ClientFetch(#[from] ClientFetchError),
|
||||
Inconsistency(#[from] DatabaseInconsistencyError),
|
||||
}
|
||||
|
||||
@ -95,7 +95,7 @@ impl AccessTokenLookupError {
|
||||
|
||||
#[allow(clippy::too_many_lines)]
|
||||
pub async fn lookup_active_access_token(
|
||||
executor: impl PgExecutor<'_>,
|
||||
conn: &mut PgConnection,
|
||||
token: &str,
|
||||
) -> Result<(AccessToken<PostgresqlBackend>, Session<PostgresqlBackend>), AccessTokenLookupError> {
|
||||
let res = sqlx::query_as!(
|
||||
@ -142,7 +142,7 @@ pub async fn lookup_active_access_token(
|
||||
"#,
|
||||
token,
|
||||
)
|
||||
.fetch_one(executor)
|
||||
.fetch_one(&mut *conn)
|
||||
.await?;
|
||||
|
||||
let access_token = AccessToken {
|
||||
@ -153,10 +153,7 @@ pub async fn lookup_active_access_token(
|
||||
expires_after: Duration::seconds(res.access_token_expires_after.into()),
|
||||
};
|
||||
|
||||
let client = Client {
|
||||
data: (),
|
||||
client_id: res.client_id,
|
||||
};
|
||||
let client = lookup_client_by_client_id(&mut *conn, &res.client_id).await?;
|
||||
|
||||
let primary_email = match (
|
||||
res.user_email_id,
|
||||
|
@ -24,15 +24,16 @@ use mas_data_model::{
|
||||
};
|
||||
use mas_iana::oauth::PkceCodeChallengeMethod;
|
||||
use oauth2_types::{requests::ResponseMode, scope::Scope};
|
||||
use sqlx::PgExecutor;
|
||||
use sqlx::{PgConnection, PgExecutor};
|
||||
use url::Url;
|
||||
|
||||
use super::client::lookup_client_by_client_id;
|
||||
use crate::{DatabaseInconsistencyError, IdAndCreationTime, PostgresqlBackend};
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub async fn new_authorization_grant(
|
||||
executor: impl PgExecutor<'_>,
|
||||
client_id: String,
|
||||
client: Client<PostgresqlBackend>,
|
||||
redirect_uri: Url,
|
||||
scope: Scope,
|
||||
code: Option<AuthorizationCode>,
|
||||
@ -65,7 +66,7 @@ pub async fn new_authorization_grant(
|
||||
($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14)
|
||||
RETURNING id, created_at
|
||||
"#,
|
||||
&client_id,
|
||||
&client.client_id,
|
||||
redirect_uri.to_string(),
|
||||
scope.to_string(),
|
||||
state,
|
||||
@ -85,11 +86,6 @@ pub async fn new_authorization_grant(
|
||||
.await
|
||||
.context("could not insert oauth2 authorization grant")?;
|
||||
|
||||
let client = Client {
|
||||
data: (),
|
||||
client_id,
|
||||
};
|
||||
|
||||
Ok(AuthorizationGrant {
|
||||
data: res.id,
|
||||
stage: AuthorizationGrantStage::Pending,
|
||||
@ -141,20 +137,21 @@ struct GrantLookup {
|
||||
user_email_confirmed_at: Option<DateTime<Utc>>,
|
||||
}
|
||||
|
||||
impl TryInto<AuthorizationGrant<PostgresqlBackend>> for GrantLookup {
|
||||
type Error = DatabaseInconsistencyError;
|
||||
|
||||
impl GrantLookup {
|
||||
#[allow(clippy::too_many_lines)]
|
||||
fn try_into(self) -> Result<AuthorizationGrant<PostgresqlBackend>, Self::Error> {
|
||||
async fn into_authorization_grant(
|
||||
self,
|
||||
executor: impl PgExecutor<'_>,
|
||||
) -> Result<AuthorizationGrant<PostgresqlBackend>, DatabaseInconsistencyError> {
|
||||
let scope: Scope = self
|
||||
.grant_scope
|
||||
.parse()
|
||||
.map_err(|_e| DatabaseInconsistencyError)?;
|
||||
|
||||
let client = Client {
|
||||
data: (),
|
||||
client_id: self.client_id,
|
||||
};
|
||||
// TODO: don't unwrap
|
||||
let client = lookup_client_by_client_id(executor, &self.client_id)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let last_authentication = match (
|
||||
self.user_session_last_authentication_id,
|
||||
@ -323,7 +320,7 @@ impl TryInto<AuthorizationGrant<PostgresqlBackend>> for GrantLookup {
|
||||
}
|
||||
|
||||
pub async fn get_grant_by_id(
|
||||
executor: impl PgExecutor<'_>,
|
||||
conn: &mut PgConnection,
|
||||
id: i64,
|
||||
) -> anyhow::Result<AuthorizationGrant<PostgresqlBackend>> {
|
||||
// TODO: handle "not found" cases
|
||||
@ -381,17 +378,17 @@ pub async fn get_grant_by_id(
|
||||
"#,
|
||||
id,
|
||||
)
|
||||
.fetch_one(executor)
|
||||
.fetch_one(&mut *conn)
|
||||
.await
|
||||
.context("failed to get grant by id")?;
|
||||
|
||||
let grant = res.try_into()?;
|
||||
let grant = res.into_authorization_grant(&mut *conn).await?;
|
||||
|
||||
Ok(grant)
|
||||
}
|
||||
|
||||
pub async fn lookup_grant_by_code(
|
||||
executor: impl PgExecutor<'_>,
|
||||
conn: &mut PgConnection,
|
||||
code: &str,
|
||||
) -> anyhow::Result<AuthorizationGrant<PostgresqlBackend>> {
|
||||
// TODO: handle "not found" cases
|
||||
@ -449,11 +446,11 @@ pub async fn lookup_grant_by_code(
|
||||
"#,
|
||||
code,
|
||||
)
|
||||
.fetch_one(executor)
|
||||
.fetch_one(&mut *conn)
|
||||
.await
|
||||
.context("failed to lookup grant by code")?;
|
||||
|
||||
let grant = res.try_into()?;
|
||||
let grant = res.into_authorization_grant(&mut *conn).await?;
|
||||
|
||||
Ok(grant)
|
||||
}
|
||||
|
380
crates/storage/src/oauth2/client.rs
Normal file
380
crates/storage/src/oauth2/client.rs
Normal file
@ -0,0 +1,380 @@
|
||||
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::string::ToString;
|
||||
|
||||
use mas_data_model::{Client, JwksOrJwksUri};
|
||||
use mas_iana::oauth::{OAuthAuthorizationEndpointResponseType, OAuthClientAuthenticationMethod};
|
||||
use mas_jose::JsonWebKeySet;
|
||||
use oauth2_types::requests::GrantType;
|
||||
use sqlx::{PgConnection, PgExecutor};
|
||||
use thiserror::Error;
|
||||
use url::Url;
|
||||
use warp::reject::Reject;
|
||||
|
||||
use crate::PostgresqlBackend;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct OAuth2ClientLookup {
|
||||
id: i64,
|
||||
client_id: String,
|
||||
encrypted_client_secret: Option<String>,
|
||||
redirect_uris: Vec<String>,
|
||||
response_types: Vec<String>,
|
||||
grant_type_authorization_code: bool,
|
||||
grant_type_refresh_token: bool,
|
||||
contacts: Vec<String>,
|
||||
client_name: Option<String>,
|
||||
logo_uri: Option<String>,
|
||||
client_uri: Option<String>,
|
||||
policy_uri: Option<String>,
|
||||
tos_uri: Option<String>,
|
||||
jwks_uri: Option<String>,
|
||||
jwks: Option<serde_json::Value>,
|
||||
id_token_signed_response_alg: Option<String>,
|
||||
token_endpoint_auth_method: Option<String>,
|
||||
token_endpoint_auth_signing_alg: Option<String>,
|
||||
initiate_login_uri: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum ClientFetchError {
|
||||
#[error("malformed jwks column")]
|
||||
MalformedJwks(#[source] serde_json::Error),
|
||||
|
||||
#[error("entry has both a jwks and a jwks_uri")]
|
||||
BothJwksAndJwksUri,
|
||||
|
||||
#[error("could not parse URL in field {field:?}")]
|
||||
ParseUrl {
|
||||
field: &'static str,
|
||||
source: url::ParseError,
|
||||
},
|
||||
|
||||
#[error("could not parse field {field:?}")]
|
||||
ParseField {
|
||||
field: &'static str,
|
||||
source: mas_iana::ParseError,
|
||||
},
|
||||
|
||||
#[error(transparent)]
|
||||
Database(#[from] sqlx::Error),
|
||||
}
|
||||
|
||||
impl ClientFetchError {
|
||||
#[must_use]
|
||||
pub fn not_found(&self) -> bool {
|
||||
matches!(self, Self::Database(sqlx::Error::RowNotFound))
|
||||
}
|
||||
}
|
||||
|
||||
impl Reject for ClientFetchError {}
|
||||
|
||||
impl TryInto<Client<PostgresqlBackend>> for OAuth2ClientLookup {
|
||||
type Error = ClientFetchError;
|
||||
|
||||
#[allow(clippy::too_many_lines)] // TODO: refactor some of the field parsing
|
||||
fn try_into(self) -> Result<Client<PostgresqlBackend>, Self::Error> {
|
||||
let redirect_uris: Result<Vec<Url>, _> =
|
||||
self.redirect_uris.iter().map(|s| s.parse()).collect();
|
||||
let redirect_uris = redirect_uris.map_err(|source| ClientFetchError::ParseUrl {
|
||||
field: "redirect_uris",
|
||||
source,
|
||||
})?;
|
||||
|
||||
let response_types: Result<Vec<OAuthAuthorizationEndpointResponseType>, _> =
|
||||
self.response_types.iter().map(|s| s.parse()).collect();
|
||||
let response_types = response_types.map_err(|source| ClientFetchError::ParseField {
|
||||
field: "response_types",
|
||||
source,
|
||||
})?;
|
||||
|
||||
let mut grant_types = Vec::new();
|
||||
if self.grant_type_authorization_code {
|
||||
grant_types.push(GrantType::AuthorizationCode);
|
||||
}
|
||||
if self.grant_type_refresh_token {
|
||||
grant_types.push(GrantType::RefreshToken);
|
||||
}
|
||||
|
||||
let logo_uri = self
|
||||
.logo_uri
|
||||
.map(|s| s.parse())
|
||||
.transpose()
|
||||
.map_err(|source| ClientFetchError::ParseUrl {
|
||||
field: "logo_uri",
|
||||
source,
|
||||
})?;
|
||||
|
||||
let client_uri = self
|
||||
.client_uri
|
||||
.map(|s| s.parse())
|
||||
.transpose()
|
||||
.map_err(|source| ClientFetchError::ParseUrl {
|
||||
field: "client_uri",
|
||||
source,
|
||||
})?;
|
||||
|
||||
let policy_uri = self
|
||||
.policy_uri
|
||||
.map(|s| s.parse())
|
||||
.transpose()
|
||||
.map_err(|source| ClientFetchError::ParseUrl {
|
||||
field: "policy_uri",
|
||||
source,
|
||||
})?;
|
||||
|
||||
let tos_uri = self
|
||||
.tos_uri
|
||||
.map(|s| s.parse())
|
||||
.transpose()
|
||||
.map_err(|source| ClientFetchError::ParseUrl {
|
||||
field: "tos_uri",
|
||||
source,
|
||||
})?;
|
||||
|
||||
let id_token_signed_response_alg = self
|
||||
.id_token_signed_response_alg
|
||||
.map(|s| s.parse())
|
||||
.transpose()
|
||||
.map_err(|source| ClientFetchError::ParseField {
|
||||
field: "id_token_signed_response_alg",
|
||||
source,
|
||||
})?;
|
||||
|
||||
let token_endpoint_auth_method = self
|
||||
.token_endpoint_auth_method
|
||||
.map(|s| s.parse())
|
||||
.transpose()
|
||||
.map_err(|source| ClientFetchError::ParseField {
|
||||
field: "token_endpoint_auth_method",
|
||||
source,
|
||||
})?;
|
||||
|
||||
let token_endpoint_auth_signing_alg = self
|
||||
.token_endpoint_auth_signing_alg
|
||||
.map(|s| s.parse())
|
||||
.transpose()
|
||||
.map_err(|source| ClientFetchError::ParseField {
|
||||
field: "token_endpoint_auth_signing_alg",
|
||||
source,
|
||||
})?;
|
||||
|
||||
let initiate_login_uri = self
|
||||
.initiate_login_uri
|
||||
.map(|s| s.parse())
|
||||
.transpose()
|
||||
.map_err(|source| ClientFetchError::ParseUrl {
|
||||
field: "initiate_login_uri",
|
||||
source,
|
||||
})?;
|
||||
|
||||
let jwks = match (self.jwks, self.jwks_uri) {
|
||||
(None, None) => None,
|
||||
(Some(jwks), None) => {
|
||||
let jwks = serde_json::from_value(jwks).map_err(ClientFetchError::MalformedJwks)?;
|
||||
Some(JwksOrJwksUri::Jwks(jwks))
|
||||
}
|
||||
(None, Some(jwks_uri)) => {
|
||||
let jwks_uri = jwks_uri
|
||||
.parse()
|
||||
.map_err(|source| ClientFetchError::ParseUrl {
|
||||
field: "jwks_uri",
|
||||
source,
|
||||
})?;
|
||||
|
||||
Some(JwksOrJwksUri::JwksUri(jwks_uri))
|
||||
}
|
||||
_ => return Err(ClientFetchError::BothJwksAndJwksUri),
|
||||
};
|
||||
|
||||
Ok(Client {
|
||||
data: self.id,
|
||||
client_id: self.client_id,
|
||||
encrypted_client_secret: self.encrypted_client_secret,
|
||||
redirect_uris,
|
||||
response_types,
|
||||
grant_types,
|
||||
contacts: self.contacts,
|
||||
client_name: self.client_name,
|
||||
logo_uri,
|
||||
client_uri,
|
||||
policy_uri,
|
||||
tos_uri,
|
||||
jwks,
|
||||
id_token_signed_response_alg,
|
||||
token_endpoint_auth_method,
|
||||
token_endpoint_auth_signing_alg,
|
||||
initiate_login_uri,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn lookup_client(
|
||||
executor: impl PgExecutor<'_>,
|
||||
id: i64,
|
||||
) -> Result<Client<PostgresqlBackend>, ClientFetchError> {
|
||||
let res = sqlx::query_as!(
|
||||
OAuth2ClientLookup,
|
||||
r#"
|
||||
SELECT
|
||||
c.id,
|
||||
c.client_id,
|
||||
c.encrypted_client_secret,
|
||||
ARRAY(SELECT redirect_uri FROM oauth2_client_redirect_uris r WHERE r.oauth2_client_id = c.id) AS "redirect_uris!",
|
||||
c.response_types,
|
||||
c.grant_type_authorization_code,
|
||||
c.grant_type_refresh_token,
|
||||
c.contacts,
|
||||
c.client_name,
|
||||
c.logo_uri,
|
||||
c.client_uri,
|
||||
c.policy_uri,
|
||||
c.tos_uri,
|
||||
c.jwks_uri,
|
||||
c.jwks,
|
||||
c.id_token_signed_response_alg,
|
||||
c.token_endpoint_auth_method,
|
||||
c.token_endpoint_auth_signing_alg,
|
||||
c.initiate_login_uri
|
||||
FROM oauth2_clients c
|
||||
|
||||
WHERE c.id = $1
|
||||
"#,
|
||||
id,
|
||||
)
|
||||
.fetch_one(executor)
|
||||
.await?;
|
||||
|
||||
let client = res.try_into()?;
|
||||
|
||||
Ok(client)
|
||||
}
|
||||
|
||||
pub async fn lookup_client_by_client_id(
|
||||
executor: impl PgExecutor<'_>,
|
||||
client_id: &str,
|
||||
) -> Result<Client<PostgresqlBackend>, ClientFetchError> {
|
||||
let res = sqlx::query_as!(
|
||||
OAuth2ClientLookup,
|
||||
r#"
|
||||
SELECT
|
||||
c.id,
|
||||
c.client_id,
|
||||
c.encrypted_client_secret,
|
||||
ARRAY(SELECT redirect_uri FROM oauth2_client_redirect_uris r WHERE r.oauth2_client_id = c.id) AS "redirect_uris!",
|
||||
c.response_types,
|
||||
c.grant_type_authorization_code,
|
||||
c.grant_type_refresh_token,
|
||||
c.contacts,
|
||||
c.client_name,
|
||||
c.logo_uri,
|
||||
c.client_uri,
|
||||
c.policy_uri,
|
||||
c.tos_uri,
|
||||
c.jwks_uri,
|
||||
c.jwks,
|
||||
c.id_token_signed_response_alg,
|
||||
c.token_endpoint_auth_method,
|
||||
c.token_endpoint_auth_signing_alg,
|
||||
c.initiate_login_uri
|
||||
FROM oauth2_clients c
|
||||
|
||||
WHERE c.client_id = $1
|
||||
"#,
|
||||
client_id,
|
||||
)
|
||||
.fetch_one(executor)
|
||||
.await?;
|
||||
|
||||
let client = res.try_into()?;
|
||||
|
||||
Ok(client)
|
||||
}
|
||||
|
||||
pub async fn insert_client_from_config(
|
||||
conn: &mut PgConnection,
|
||||
client_id: &str,
|
||||
client_auth_method: OAuthClientAuthenticationMethod,
|
||||
encrypted_client_secret: Option<&str>,
|
||||
jwks: Option<&JsonWebKeySet>,
|
||||
jwks_uri: Option<&Url>,
|
||||
redirect_uris: &[Url],
|
||||
) -> anyhow::Result<()> {
|
||||
let response_types = vec![
|
||||
OAuthAuthorizationEndpointResponseType::Code.to_string(),
|
||||
OAuthAuthorizationEndpointResponseType::CodeIdToken.to_string(),
|
||||
OAuthAuthorizationEndpointResponseType::CodeIdTokenToken.to_string(),
|
||||
OAuthAuthorizationEndpointResponseType::CodeToken.to_string(),
|
||||
OAuthAuthorizationEndpointResponseType::IdToken.to_string(),
|
||||
OAuthAuthorizationEndpointResponseType::IdTokenToken.to_string(),
|
||||
OAuthAuthorizationEndpointResponseType::None.to_string(),
|
||||
OAuthAuthorizationEndpointResponseType::Token.to_string(),
|
||||
];
|
||||
|
||||
let jwks = jwks.map(serde_json::to_value).transpose()?;
|
||||
let jwks_uri = jwks_uri.map(Url::as_str);
|
||||
|
||||
let client_auth_method = client_auth_method.to_string();
|
||||
|
||||
let id = sqlx::query_scalar!(
|
||||
r#"
|
||||
INSERT INTO oauth2_clients
|
||||
(client_id,
|
||||
encrypted_client_secret,
|
||||
response_types,
|
||||
grant_type_authorization_code,
|
||||
grant_type_refresh_token,
|
||||
token_endpoint_auth_method,
|
||||
jwks,
|
||||
jwks_uri,
|
||||
contacts)
|
||||
VALUES
|
||||
($1, $2, $3, $4, $5, $6, $7, $8, '{}')
|
||||
RETURNING id
|
||||
"#,
|
||||
client_id,
|
||||
encrypted_client_secret,
|
||||
&response_types,
|
||||
true,
|
||||
true,
|
||||
client_auth_method,
|
||||
jwks,
|
||||
jwks_uri,
|
||||
)
|
||||
.fetch_one(&mut *conn)
|
||||
.await?;
|
||||
|
||||
let redirect_uris: Vec<String> = redirect_uris.iter().map(ToString::to_string).collect();
|
||||
|
||||
sqlx::query!(
|
||||
r#"
|
||||
INSERT INTO oauth2_client_redirect_uris (oauth2_client_id, redirect_uri)
|
||||
SELECT $1, uri FROM UNNEST($2::text[]) uri
|
||||
"#,
|
||||
id,
|
||||
&redirect_uris,
|
||||
)
|
||||
.execute(&mut *conn)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn truncate_clients(executor: impl PgExecutor<'_>) -> anyhow::Result<()> {
|
||||
sqlx::query!("TRUNCATE oauth2_client_redirect_uris, oauth2_clients")
|
||||
.execute(executor)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
@ -19,6 +19,7 @@ use crate::PostgresqlBackend;
|
||||
|
||||
pub mod access_token;
|
||||
pub mod authorization_grant;
|
||||
pub mod client;
|
||||
pub mod refresh_token;
|
||||
|
||||
pub async fn end_oauth_session(
|
||||
|
@ -15,12 +15,13 @@
|
||||
use anyhow::Context;
|
||||
use chrono::{DateTime, Duration, Utc};
|
||||
use mas_data_model::{
|
||||
AccessToken, Authentication, BrowserSession, Client, RefreshToken, Session, User, UserEmail,
|
||||
AccessToken, Authentication, BrowserSession, RefreshToken, Session, User, UserEmail,
|
||||
};
|
||||
use sqlx::PgExecutor;
|
||||
use sqlx::{PgConnection, PgExecutor};
|
||||
use thiserror::Error;
|
||||
use warp::reject::Reject;
|
||||
|
||||
use super::client::{lookup_client_by_client_id, ClientFetchError};
|
||||
use crate::{DatabaseInconsistencyError, IdAndCreationTime, PostgresqlBackend};
|
||||
|
||||
pub async fn add_refresh_token(
|
||||
@ -82,6 +83,7 @@ struct OAuth2RefreshTokenLookup {
|
||||
#[error("could not lookup refresh token")]
|
||||
pub enum RefreshTokenLookupError {
|
||||
Fetch(#[from] sqlx::Error),
|
||||
ClientFetch(#[from] ClientFetchError),
|
||||
Conversion(#[from] DatabaseInconsistencyError),
|
||||
}
|
||||
|
||||
@ -96,7 +98,7 @@ impl RefreshTokenLookupError {
|
||||
|
||||
#[allow(clippy::too_many_lines)]
|
||||
pub async fn lookup_active_refresh_token(
|
||||
executor: impl PgExecutor<'_>,
|
||||
conn: &mut PgConnection,
|
||||
token: &str,
|
||||
) -> Result<(RefreshToken<PostgresqlBackend>, Session<PostgresqlBackend>), RefreshTokenLookupError>
|
||||
{
|
||||
@ -148,7 +150,7 @@ pub async fn lookup_active_refresh_token(
|
||||
"#,
|
||||
token,
|
||||
)
|
||||
.fetch_one(executor)
|
||||
.fetch_one(&mut *conn)
|
||||
.await?;
|
||||
|
||||
let access_token = match (
|
||||
@ -175,10 +177,7 @@ pub async fn lookup_active_refresh_token(
|
||||
access_token,
|
||||
};
|
||||
|
||||
let client = Client {
|
||||
data: (),
|
||||
client_id: res.client_id,
|
||||
};
|
||||
let client = lookup_client_by_client_id(&mut *conn, &res.client_id).await?;
|
||||
|
||||
let primary_email = match (
|
||||
res.user_email_id,
|
||||
|
@ -28,6 +28,9 @@ mime = "0.3.16"
|
||||
bincode = "1.3.3"
|
||||
crc = "2.1.0"
|
||||
url = "2.2.2"
|
||||
http = "0.2.6"
|
||||
http-body = "0.4.4"
|
||||
tower = { version = "0.4.12", features = ["util"] }
|
||||
|
||||
oauth2-types = { path = "../oauth2-types" }
|
||||
mas-config = { path = "../config" }
|
||||
@ -36,6 +39,4 @@ mas-data-model = { path = "../data-model" }
|
||||
mas-storage = { path = "../storage" }
|
||||
mas-jose = { path = "../jose" }
|
||||
mas-iana = { path = "../iana" }
|
||||
|
||||
[dev-dependencies]
|
||||
tower = { version = "0.4.12", features = ["util"] }
|
||||
mas-http = { path = "../http" }
|
||||
|
@ -87,6 +87,10 @@ pub fn authentication(
|
||||
.untuple_one()
|
||||
}
|
||||
|
||||
fn ensure<T: Clone + Send + Sync + 'static>(t: T) -> T {
|
||||
t
|
||||
}
|
||||
|
||||
async fn authenticate(
|
||||
mut conn: PoolConnection<Postgres>,
|
||||
auth: Authorization<Bearer>,
|
||||
@ -110,6 +114,9 @@ async fn authenticate(
|
||||
}
|
||||
})?;
|
||||
|
||||
let session = ensure(session);
|
||||
let token = ensure(token);
|
||||
|
||||
Ok((token, session))
|
||||
}
|
||||
|
||||
|
@ -16,30 +16,49 @@
|
||||
|
||||
use std::collections::HashMap;
|
||||
|
||||
use data_encoding::BASE64;
|
||||
use headers::{authorization::Basic, Authorization};
|
||||
use mas_config::{ClientAuthMethodConfig, ClientConfig, ClientsConfig};
|
||||
use mas_config::Encrypter;
|
||||
use mas_data_model::{Client, JwksOrJwksUri, StorageBackend};
|
||||
use mas_http::HttpServiceExt;
|
||||
use mas_iana::oauth::OAuthClientAuthenticationMethod;
|
||||
use mas_jose::{
|
||||
claims::{TimeOptions, AUD, EXP, IAT, ISS, JTI, NBF, SUB},
|
||||
DecodedJsonWebToken, JsonWebTokenParts, SharedSecret,
|
||||
DecodedJsonWebToken, DynamicJwksStore, Either, JsonWebKeySet, JsonWebTokenParts, SharedSecret,
|
||||
StaticJwksStore, VerifyingKeystore,
|
||||
};
|
||||
use mas_storage::{
|
||||
oauth2::client::{lookup_client_by_client_id, ClientFetchError},
|
||||
PostgresqlBackend,
|
||||
};
|
||||
use serde::{de::DeserializeOwned, Deserialize};
|
||||
use sqlx::{pool::PoolConnection, PgPool, Postgres};
|
||||
use thiserror::Error;
|
||||
use tower::{BoxError, ServiceExt};
|
||||
use warp::{reject::Reject, Filter, Rejection};
|
||||
|
||||
use super::headers::typed_header;
|
||||
use super::{database::connection, headers::typed_header};
|
||||
use crate::errors::WrapError;
|
||||
|
||||
/// Protect an enpoint with client authentication
|
||||
#[must_use]
|
||||
pub fn client_authentication<T: DeserializeOwned + Send + 'static>(
|
||||
clients_config: &ClientsConfig,
|
||||
pool: &PgPool,
|
||||
encrypter: &Encrypter,
|
||||
audience: String,
|
||||
) -> impl Filter<Extract = (OAuthClientAuthenticationMethod, ClientConfig, T), Error = Rejection>
|
||||
+ Clone
|
||||
) -> impl Filter<
|
||||
Extract = (
|
||||
OAuthClientAuthenticationMethod,
|
||||
Client<PostgresqlBackend>,
|
||||
T,
|
||||
),
|
||||
Error = Rejection,
|
||||
> + Clone
|
||||
+ Send
|
||||
+ Sync
|
||||
+ 'static {
|
||||
let encrypter = encrypter.clone();
|
||||
|
||||
// First, extract the client credentials
|
||||
let credentials = typed_header()
|
||||
.and(warp::body::form())
|
||||
@ -65,9 +84,9 @@ pub fn client_authentication<T: DeserializeOwned + Send + 'static>(
|
||||
.unify()
|
||||
.untuple_one();
|
||||
|
||||
let clients_config = clients_config.clone();
|
||||
warp::any()
|
||||
.map(move || clients_config.clone())
|
||||
.and(connection(pool))
|
||||
.and(warp::any().map(move || encrypter.clone()))
|
||||
.and(warp::any().map(move || audience.clone()))
|
||||
.and(credentials)
|
||||
.and_then(authenticate_client)
|
||||
@ -79,8 +98,20 @@ enum ClientAuthenticationError {
|
||||
#[error("wrong client secret for client {client_id:?}")]
|
||||
ClientSecretMismatch { client_id: String },
|
||||
|
||||
#[error("could not find client {client_id:?}")]
|
||||
ClientNotFound { client_id: String },
|
||||
#[error("could not fetch client {client_id:?}")]
|
||||
ClientFetch {
|
||||
client_id: String,
|
||||
source: ClientFetchError,
|
||||
},
|
||||
|
||||
#[error("client {client_id:?} has an invalid client secret")]
|
||||
InvalidClientSecret {
|
||||
client_id: String,
|
||||
source: anyhow::Error,
|
||||
},
|
||||
|
||||
#[error("client {client_id:?} has an invalid JWKS")]
|
||||
InvalidJwks { client_id: String },
|
||||
|
||||
#[error("wrong client authentication method for client {client_id:?}")]
|
||||
WrongAuthenticationMethod { client_id: String },
|
||||
@ -94,68 +125,136 @@ enum ClientAuthenticationError {
|
||||
|
||||
impl Reject for ClientAuthenticationError {}
|
||||
|
||||
fn decrypt_client_secret<T: StorageBackend>(
|
||||
client: &Client<T>,
|
||||
encrypter: &Encrypter,
|
||||
) -> anyhow::Result<Vec<u8>> {
|
||||
let encrypted_client_secret = client
|
||||
.encrypted_client_secret
|
||||
.as_ref()
|
||||
.ok_or_else(|| anyhow::anyhow!("missing encrypted_client_secret field"))?;
|
||||
|
||||
let encrypted_client_secret = BASE64.decode(encrypted_client_secret.as_bytes())?;
|
||||
|
||||
let nonce: &[u8; 12] = encrypted_client_secret
|
||||
.get(0..12)
|
||||
.ok_or_else(|| anyhow::anyhow!("invalid payload serialization"))?
|
||||
.try_into()?;
|
||||
|
||||
let payload = encrypted_client_secret
|
||||
.get(12..)
|
||||
.ok_or_else(|| anyhow::anyhow!("invalid payload serialization"))?;
|
||||
|
||||
let decrypted_client_secret = encrypter.decrypt(nonce, payload)?;
|
||||
|
||||
Ok(decrypted_client_secret)
|
||||
}
|
||||
|
||||
fn jwks_key_store(jwks: &JwksOrJwksUri) -> Either<StaticJwksStore, DynamicJwksStore> {
|
||||
// Assert that the output is both a VerifyingKeystore and Send
|
||||
fn assert<T: Send + VerifyingKeystore>(t: T) -> T {
|
||||
t
|
||||
}
|
||||
|
||||
let inner = match jwks {
|
||||
JwksOrJwksUri::Jwks(jwks) => Either::Left(StaticJwksStore::new(jwks.clone())),
|
||||
JwksOrJwksUri::JwksUri(uri) => {
|
||||
let uri = uri.clone();
|
||||
|
||||
// TODO: get the client from somewhere else?
|
||||
let exporter = mas_http::client("fetch-jwks")
|
||||
.json::<JsonWebKeySet>()
|
||||
.map_request(move |_: ()| {
|
||||
http::Request::builder()
|
||||
.method("GET")
|
||||
// TODO: change the Uri type in config to avoid reparsing here
|
||||
.uri(uri.to_string())
|
||||
.body(http_body::Empty::new())
|
||||
.unwrap()
|
||||
})
|
||||
.map_response(http::Response::into_body)
|
||||
.map_err(BoxError::from)
|
||||
.boxed_clone();
|
||||
|
||||
Either::Right(DynamicJwksStore::new(exporter))
|
||||
}
|
||||
};
|
||||
|
||||
assert(inner)
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_lines)]
|
||||
#[tracing::instrument(skip_all, fields(enduser.id), err(Debug))]
|
||||
async fn authenticate_client<T>(
|
||||
clients_config: ClientsConfig,
|
||||
mut conn: PoolConnection<Postgres>,
|
||||
encrypter: Encrypter,
|
||||
audience: String,
|
||||
credentials: ClientCredentials,
|
||||
body: T,
|
||||
) -> Result<(OAuthClientAuthenticationMethod, ClientConfig, T), Rejection> {
|
||||
) -> Result<
|
||||
(
|
||||
OAuthClientAuthenticationMethod,
|
||||
Client<PostgresqlBackend>,
|
||||
T,
|
||||
),
|
||||
Rejection,
|
||||
> {
|
||||
let (auth_method, client) = match credentials {
|
||||
ClientCredentials::Pair {
|
||||
client_id,
|
||||
client_secret,
|
||||
via,
|
||||
} => {
|
||||
let client = clients_config
|
||||
.iter()
|
||||
.find(|client| client.client_id == client_id)
|
||||
.ok_or_else(|| ClientAuthenticationError::ClientNotFound {
|
||||
client_id: client_id.to_string(),
|
||||
let client = lookup_client_by_client_id(&mut *conn, &client_id)
|
||||
.await
|
||||
.map_err(|source| ClientAuthenticationError::ClientFetch {
|
||||
client_id: client_id.clone(),
|
||||
source,
|
||||
})?;
|
||||
|
||||
let auth_method = match (&client.client_auth_method, client_secret, via) {
|
||||
(ClientAuthMethodConfig::None, None, _) => OAuthClientAuthenticationMethod::None,
|
||||
|
||||
(
|
||||
ClientAuthMethodConfig::ClientSecretBasic {
|
||||
client_secret: ref expected_client_secret,
|
||||
let auth_method = client.token_endpoint_auth_method.ok_or(
|
||||
ClientAuthenticationError::WrongAuthenticationMethod {
|
||||
client_id: client.client_id.clone(),
|
||||
},
|
||||
Some(ref given_client_secret),
|
||||
)?;
|
||||
|
||||
// Let's match the authentication method
|
||||
match (auth_method, client_secret, via) {
|
||||
(OAuthClientAuthenticationMethod::None, None, _) => {}
|
||||
(
|
||||
OAuthClientAuthenticationMethod::ClientSecretBasic,
|
||||
Some(client_secret),
|
||||
CredentialsVia::AuthorizationHeader,
|
||||
) => {
|
||||
if expected_client_secret != given_client_secret {
|
||||
return Err(
|
||||
ClientAuthenticationError::ClientSecretMismatch { client_id }.into(),
|
||||
);
|
||||
}
|
||||
|
||||
OAuthClientAuthenticationMethod::ClientSecretBasic
|
||||
}
|
||||
|
||||
(
|
||||
ClientAuthMethodConfig::ClientSecretPost {
|
||||
client_secret: ref expected_client_secret,
|
||||
},
|
||||
Some(ref given_client_secret),
|
||||
)
|
||||
| (
|
||||
OAuthClientAuthenticationMethod::ClientSecretPost,
|
||||
Some(client_secret),
|
||||
CredentialsVia::FormBody,
|
||||
) => {
|
||||
if expected_client_secret != given_client_secret {
|
||||
return Err(
|
||||
ClientAuthenticationError::ClientSecretMismatch { client_id }.into(),
|
||||
);
|
||||
let decrypted =
|
||||
decrypt_client_secret(&client, &encrypter).map_err(|source| {
|
||||
ClientAuthenticationError::InvalidClientSecret {
|
||||
client_id: client.client_id.clone(),
|
||||
source,
|
||||
}
|
||||
})?;
|
||||
|
||||
OAuthClientAuthenticationMethod::ClientSecretPost
|
||||
if client_secret.as_bytes() != decrypted {
|
||||
return Err(warp::reject::custom(
|
||||
ClientAuthenticationError::ClientSecretMismatch {
|
||||
client_id: client.client_id,
|
||||
},
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
_ => {
|
||||
return Err(
|
||||
ClientAuthenticationError::WrongAuthenticationMethod { client_id }.into(),
|
||||
)
|
||||
return Err(warp::reject::custom(
|
||||
ClientAuthenticationError::WrongAuthenticationMethod {
|
||||
client_id: client.client_id,
|
||||
},
|
||||
));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
(auth_method, client)
|
||||
}
|
||||
@ -195,34 +294,52 @@ async fn authenticate_client<T>(
|
||||
// from the token, as per rfc7521 sec. 4.2
|
||||
let client_id = client_id.as_ref().unwrap_or(&sub);
|
||||
|
||||
let client = clients_config
|
||||
.iter()
|
||||
.find(|client| &client.client_id == client_id)
|
||||
.ok_or_else(|| ClientAuthenticationError::ClientNotFound {
|
||||
let client = lookup_client_by_client_id(&mut *conn, client_id)
|
||||
.await
|
||||
.map_err(|source| ClientAuthenticationError::ClientFetch {
|
||||
client_id: client_id.to_string(),
|
||||
source,
|
||||
})?;
|
||||
|
||||
let auth_method = match &client.client_auth_method {
|
||||
ClientAuthMethodConfig::PrivateKeyJwt(jwks) => {
|
||||
let store = jwks.key_store();
|
||||
let auth_method = client.token_endpoint_auth_method.ok_or(
|
||||
ClientAuthenticationError::WrongAuthenticationMethod {
|
||||
client_id: client.client_id.clone(),
|
||||
},
|
||||
)?;
|
||||
|
||||
match auth_method {
|
||||
OAuthClientAuthenticationMethod::ClientSecretJwt => {
|
||||
let client_secret =
|
||||
decrypt_client_secret(&client, &encrypter).map_err(|source| {
|
||||
ClientAuthenticationError::InvalidClientSecret {
|
||||
client_id: client.client_id.clone(),
|
||||
source,
|
||||
}
|
||||
})?;
|
||||
|
||||
let store = SharedSecret::new(&client_secret);
|
||||
let fut = token.verify(&decoded, &store);
|
||||
fut.await.wrap_error()?;
|
||||
OAuthClientAuthenticationMethod::PrivateKeyJwt
|
||||
}
|
||||
|
||||
ClientAuthMethodConfig::ClientSecretJwt { client_secret } => {
|
||||
let store = SharedSecret::new(client_secret);
|
||||
token.verify(&decoded, &store).await.wrap_error()?;
|
||||
OAuthClientAuthenticationMethod::ClientSecretJwt
|
||||
OAuthClientAuthenticationMethod::PrivateKeyJwt => {
|
||||
let jwks = client.jwks.as_ref().ok_or_else(|| {
|
||||
ClientAuthenticationError::InvalidJwks {
|
||||
client_id: client.client_id.clone(),
|
||||
}
|
||||
})?;
|
||||
|
||||
let store = jwks_key_store(jwks);
|
||||
let fut = token.verify(&decoded, &store);
|
||||
fut.await.wrap_error()?;
|
||||
}
|
||||
_ => {
|
||||
return Err(ClientAuthenticationError::WrongAuthenticationMethod {
|
||||
client_id: client_id.clone(),
|
||||
return Err(warp::reject::custom(
|
||||
ClientAuthenticationError::WrongAuthenticationMethod {
|
||||
client_id: client.client_id,
|
||||
},
|
||||
));
|
||||
}
|
||||
.into())
|
||||
}
|
||||
};
|
||||
|
||||
// rfc7523 sec. 3.3: the audience is the URL being called
|
||||
if !aud.contains(&audience) {
|
||||
@ -243,7 +360,7 @@ async fn authenticate_client<T>(
|
||||
|
||||
tracing::Span::current().record("enduser.id", &client.client_id.as_str());
|
||||
|
||||
Ok((auth_method, client.clone(), body))
|
||||
Ok((auth_method, client, body))
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
@ -291,6 +408,7 @@ struct ClientAuthForm<T> {
|
||||
body: T,
|
||||
}
|
||||
|
||||
/* TODO: all secrets are broken because there is no way to mock the DB yet
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use headers::authorization::Credentials;
|
||||
@ -651,3 +769,4 @@ mod tests {
|
||||
assert_eq!(body.bar, "foobar");
|
||||
}
|
||||
}
|
||||
*/
|
||||
|
Reference in New Issue
Block a user