You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-11-20 12:02:22 +03:00
WIP: upstream OIDC provider support
This commit is contained in:
149
crates/handlers/src/upstream_oauth2/authorize.rs
Normal file
149
crates/handlers/src/upstream_oauth2/authorize.rs
Normal file
@@ -0,0 +1,149 @@
|
||||
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use axum::{
|
||||
extract::{Path, State},
|
||||
response::{IntoResponse, Redirect},
|
||||
};
|
||||
use axum_extra::extract::{cookie::Cookie, PrivateCookieJar};
|
||||
use hyper::StatusCode;
|
||||
use mas_http::ClientInitError;
|
||||
use mas_keystore::Encrypter;
|
||||
use mas_oidc_client::{
|
||||
error::{AuthorizationError, DiscoveryError},
|
||||
requests::authorization_code::AuthorizationRequestData,
|
||||
};
|
||||
use mas_router::UrlBuilder;
|
||||
use mas_storage::{upstream_oauth2::lookup_provider, LookupResultExt};
|
||||
use sqlx::PgPool;
|
||||
use thiserror::Error;
|
||||
use ulid::Ulid;
|
||||
|
||||
use super::http_service;
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub(crate) enum RouteError {
|
||||
#[error("Provider not found")]
|
||||
ProviderNotFound,
|
||||
|
||||
#[error(transparent)]
|
||||
Authorization(#[from] AuthorizationError),
|
||||
|
||||
#[error(transparent)]
|
||||
InternalError(Box<dyn std::error::Error>),
|
||||
|
||||
#[error(transparent)]
|
||||
Anyhow(#[from] anyhow::Error),
|
||||
}
|
||||
|
||||
impl From<sqlx::Error> for RouteError {
|
||||
fn from(e: sqlx::Error) -> Self {
|
||||
Self::InternalError(Box::new(e))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<DiscoveryError> for RouteError {
|
||||
fn from(e: DiscoveryError) -> Self {
|
||||
Self::InternalError(Box::new(e))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<mas_storage::upstream_oauth2::ProviderLookupError> for RouteError {
|
||||
fn from(e: mas_storage::upstream_oauth2::ProviderLookupError) -> Self {
|
||||
Self::InternalError(Box::new(e))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<ClientInitError> for RouteError {
|
||||
fn from(e: ClientInitError) -> Self {
|
||||
Self::InternalError(Box::new(e))
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoResponse for RouteError {
|
||||
fn into_response(self) -> axum::response::Response {
|
||||
match self {
|
||||
Self::ProviderNotFound => (StatusCode::NOT_FOUND, "Provider not found").into_response(),
|
||||
Self::Authorization(e) => {
|
||||
(StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response()
|
||||
}
|
||||
Self::InternalError(e) => {
|
||||
(StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response()
|
||||
}
|
||||
Self::Anyhow(e) => {
|
||||
(StatusCode::INTERNAL_SERVER_ERROR, format!("{e:?}")).into_response()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn get(
|
||||
State(pool): State<PgPool>,
|
||||
State(url_builder): State<UrlBuilder>,
|
||||
cookie_jar: PrivateCookieJar<Encrypter>,
|
||||
Path(provider_id): Path<Ulid>,
|
||||
) -> Result<impl IntoResponse, RouteError> {
|
||||
let (clock, mut rng) = crate::rng_and_clock()?;
|
||||
|
||||
let mut txn = pool.begin().await?;
|
||||
|
||||
let provider = lookup_provider(&mut txn, provider_id)
|
||||
.await
|
||||
.to_option()?
|
||||
.ok_or(RouteError::ProviderNotFound)?;
|
||||
|
||||
let http_service = http_service("upstream-discover").await?;
|
||||
|
||||
// First, discover the provider
|
||||
let metadata =
|
||||
mas_oidc_client::requests::discovery::discover(&http_service, &provider.issuer).await?;
|
||||
|
||||
let redirect_uri = url_builder.upstream_oauth_callback(provider.id);
|
||||
|
||||
let data = AuthorizationRequestData {
|
||||
client_id: &provider.client_id,
|
||||
scope: &provider.scope,
|
||||
prompt: None,
|
||||
redirect_uri: &redirect_uri,
|
||||
code_challenge_methods_supported: metadata.code_challenge_methods_supported.as_deref(),
|
||||
};
|
||||
|
||||
// Build an authorization request for it
|
||||
let (url, data) = mas_oidc_client::requests::authorization_code::build_authorization_url(
|
||||
metadata.authorization_endpoint().clone(),
|
||||
data,
|
||||
&mut rng,
|
||||
)?;
|
||||
|
||||
let session = mas_storage::upstream_oauth2::add_session(
|
||||
&mut txn,
|
||||
&mut rng,
|
||||
&clock,
|
||||
&provider,
|
||||
data.state,
|
||||
data.code_challenge_verifier,
|
||||
data.nonce,
|
||||
)
|
||||
.await?;
|
||||
|
||||
// TODO: handle that cookie somewhere else?
|
||||
let mut cookie = Cookie::new("upstream-oauth2-session-id", session.id.to_string());
|
||||
cookie.set_path("/");
|
||||
cookie.set_http_only(true);
|
||||
let cookie_jar = cookie_jar.add(cookie);
|
||||
|
||||
txn.commit().await?;
|
||||
|
||||
Ok((cookie_jar, Redirect::temporary(url.as_str())))
|
||||
}
|
||||
290
crates/handlers/src/upstream_oauth2/callback.rs
Normal file
290
crates/handlers/src/upstream_oauth2/callback.rs
Normal file
@@ -0,0 +1,290 @@
|
||||
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use anyhow::Context;
|
||||
use axum::{
|
||||
extract::{Path, Query, State},
|
||||
response::IntoResponse,
|
||||
Json,
|
||||
};
|
||||
use axum_extra::extract::PrivateCookieJar;
|
||||
use hyper::StatusCode;
|
||||
use mas_http::ClientInitError;
|
||||
use mas_iana::oauth::OAuthClientAuthenticationMethod;
|
||||
use mas_keystore::{Encrypter, Keystore};
|
||||
use mas_oidc_client::{
|
||||
error::{DiscoveryError, JwksError, TokenAuthorizationCodeError},
|
||||
requests::{authorization_code::AuthorizationValidationData, jose::JwtVerificationData},
|
||||
types::client_credentials::ClientCredentials,
|
||||
};
|
||||
use mas_router::UrlBuilder;
|
||||
use mas_storage::{upstream_oauth2::lookup_session, LookupResultExt};
|
||||
use oauth2_types::errors::ClientErrorCode;
|
||||
use serde::Deserialize;
|
||||
use sqlx::PgPool;
|
||||
use thiserror::Error;
|
||||
use ulid::Ulid;
|
||||
|
||||
use super::http_service;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct QueryParams {
|
||||
state: String,
|
||||
|
||||
#[serde(flatten)]
|
||||
code_or_error: CodeOrError,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
#[serde(untagged)]
|
||||
enum CodeOrError {
|
||||
Code {
|
||||
code: String,
|
||||
},
|
||||
Error {
|
||||
error: ClientErrorCode,
|
||||
error_description: Option<String>,
|
||||
#[allow(dead_code)]
|
||||
error_uri: Option<String>,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub(crate) enum RouteError {
|
||||
#[error("Session not found")]
|
||||
SessionNotFound,
|
||||
|
||||
#[error("Provider mismatch")]
|
||||
ProviderMismatch,
|
||||
|
||||
#[error("State parameter mismatch")]
|
||||
StateMismatch,
|
||||
|
||||
#[error("Error from the provider: {error}")]
|
||||
ClientError {
|
||||
error: ClientErrorCode,
|
||||
error_description: Option<String>,
|
||||
},
|
||||
|
||||
#[error("Provider is missing a client secret")]
|
||||
MissingClientSecret,
|
||||
|
||||
#[error("Missing session cookie")]
|
||||
MissingCookie,
|
||||
|
||||
#[error("Invalid session cookie")]
|
||||
InvalidCookie(#[source] ulid::DecodeError),
|
||||
|
||||
#[error(transparent)]
|
||||
InternalError(Box<dyn std::error::Error>),
|
||||
|
||||
#[error(transparent)]
|
||||
Anyhow(#[from] anyhow::Error),
|
||||
}
|
||||
|
||||
impl From<sqlx::Error> for RouteError {
|
||||
fn from(e: sqlx::Error) -> Self {
|
||||
Self::InternalError(Box::new(e))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<DiscoveryError> for RouteError {
|
||||
fn from(e: DiscoveryError) -> Self {
|
||||
Self::InternalError(Box::new(e))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<JwksError> for RouteError {
|
||||
fn from(e: JwksError) -> Self {
|
||||
Self::InternalError(Box::new(e))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<TokenAuthorizationCodeError> for RouteError {
|
||||
fn from(e: TokenAuthorizationCodeError) -> Self {
|
||||
Self::InternalError(Box::new(e))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<mas_storage::upstream_oauth2::SessionLookupError> for RouteError {
|
||||
fn from(e: mas_storage::upstream_oauth2::SessionLookupError) -> Self {
|
||||
Self::InternalError(Box::new(e))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<ClientInitError> for RouteError {
|
||||
fn from(e: ClientInitError) -> Self {
|
||||
Self::InternalError(Box::new(e))
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoResponse for RouteError {
|
||||
fn into_response(self) -> axum::response::Response {
|
||||
match self {
|
||||
Self::SessionNotFound => (StatusCode::NOT_FOUND, "Session not found").into_response(),
|
||||
Self::InternalError(e) => {
|
||||
(StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response()
|
||||
}
|
||||
Self::Anyhow(e) => {
|
||||
(StatusCode::INTERNAL_SERVER_ERROR, format!("{e:?}")).into_response()
|
||||
}
|
||||
e => (StatusCode::BAD_REQUEST, e.to_string()).into_response(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_lines)]
|
||||
pub(crate) async fn get(
|
||||
State(pool): State<PgPool>,
|
||||
State(url_builder): State<UrlBuilder>,
|
||||
State(encrypter): State<Encrypter>,
|
||||
State(keystore): State<Keystore>,
|
||||
cookie_jar: PrivateCookieJar<Encrypter>,
|
||||
Path(provider_id): Path<Ulid>,
|
||||
Query(params): Query<QueryParams>,
|
||||
) -> Result<impl IntoResponse, RouteError> {
|
||||
let (clock, mut rng) = crate::rng_and_clock()?;
|
||||
|
||||
let mut txn = pool.begin().await?;
|
||||
|
||||
// XXX: that cookie should be managed elsewhere
|
||||
let cookie = cookie_jar
|
||||
.get("upstream-oauth2-session-id")
|
||||
.ok_or(RouteError::MissingCookie)?;
|
||||
|
||||
let session_id: Ulid = cookie.value().parse().map_err(RouteError::InvalidCookie)?;
|
||||
|
||||
let (provider, session) = lookup_session(&mut txn, session_id)
|
||||
.await
|
||||
.to_option()?
|
||||
.ok_or(RouteError::SessionNotFound)?;
|
||||
|
||||
if provider.id != provider_id {
|
||||
// The provider in the session cookie should match the one from the URL
|
||||
return Err(RouteError::ProviderMismatch);
|
||||
}
|
||||
|
||||
if params.state != session.state {
|
||||
// The state in the session cookie should match the one from the params
|
||||
return Err(RouteError::StateMismatch);
|
||||
}
|
||||
|
||||
// Let's extract the code from the params, and return if there was an error
|
||||
let code = match params.code_or_error {
|
||||
CodeOrError::Error {
|
||||
error,
|
||||
error_description,
|
||||
..
|
||||
} => {
|
||||
return Err(RouteError::ClientError {
|
||||
error,
|
||||
error_description,
|
||||
})
|
||||
}
|
||||
CodeOrError::Code { code } => code,
|
||||
};
|
||||
|
||||
let http_service = http_service("upstream-code-exchange").await?;
|
||||
|
||||
// XXX: we shouldn't discover on-the-fly
|
||||
// Discover the provider
|
||||
let metadata =
|
||||
mas_oidc_client::requests::discovery::discover(&http_service, &provider.issuer).await?;
|
||||
|
||||
// Fetch the JWKS
|
||||
let jwks =
|
||||
mas_oidc_client::requests::jose::fetch_jwks(&http_service, metadata.jwks_uri()).await?;
|
||||
|
||||
// Figure out the client credentials
|
||||
let client_id = provider.client_id.clone();
|
||||
// Decrypt the client secret
|
||||
let client_secret = provider
|
||||
.encrypted_client_secret
|
||||
.map(|encrypted_client_secret| {
|
||||
encrypter
|
||||
.decrypt_string(&encrypted_client_secret)
|
||||
.and_then(|client_secret| {
|
||||
String::from_utf8(client_secret)
|
||||
.context("Client secret contains non-UTF8 bytes")
|
||||
})
|
||||
})
|
||||
.transpose()?;
|
||||
|
||||
let token_endpoint = metadata.token_endpoint();
|
||||
|
||||
let client_credentials = match provider.token_endpoint_auth_method {
|
||||
OAuthClientAuthenticationMethod::None => ClientCredentials::None { client_id },
|
||||
OAuthClientAuthenticationMethod::ClientSecretPost => ClientCredentials::ClientSecretPost {
|
||||
client_id,
|
||||
client_secret: client_secret.ok_or(RouteError::MissingClientSecret)?,
|
||||
},
|
||||
OAuthClientAuthenticationMethod::ClientSecretBasic => {
|
||||
ClientCredentials::ClientSecretBasic {
|
||||
client_id,
|
||||
client_secret: client_secret.ok_or(RouteError::MissingClientSecret)?,
|
||||
}
|
||||
}
|
||||
OAuthClientAuthenticationMethod::ClientSecretJwt => ClientCredentials::ClientSecretJwt {
|
||||
client_id,
|
||||
client_secret: client_secret.ok_or(RouteError::MissingClientSecret)?,
|
||||
signing_algorithm: provider
|
||||
.token_endpoint_signing_alg
|
||||
.unwrap_or(mas_iana::jose::JsonWebSignatureAlg::Rs256),
|
||||
token_endpoint: token_endpoint.clone(),
|
||||
},
|
||||
OAuthClientAuthenticationMethod::PrivateKeyJwt => ClientCredentials::PrivateKeyJwt {
|
||||
client_id,
|
||||
jwt_signing_method:
|
||||
mas_oidc_client::types::client_credentials::JwtSigningMethod::Keystore(keystore),
|
||||
signing_algorithm: provider
|
||||
.token_endpoint_signing_alg
|
||||
.unwrap_or(mas_iana::jose::JsonWebSignatureAlg::Rs256),
|
||||
token_endpoint: token_endpoint.clone(),
|
||||
},
|
||||
// XXX: The database should never have an unsupported method in it
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
let redirect_uri = url_builder.upstream_oauth_callback(provider.id);
|
||||
|
||||
let validation_data = AuthorizationValidationData {
|
||||
state: session.state,
|
||||
nonce: session.nonce,
|
||||
code_challenge_verifier: session.code_challenge_verifier,
|
||||
redirect_uri,
|
||||
};
|
||||
|
||||
let id_token_verification_data = JwtVerificationData {
|
||||
issuer: &provider.issuer,
|
||||
jwks: &jwks,
|
||||
// TODO: make that configurable
|
||||
signing_algorithm: &mas_iana::jose::JsonWebSignatureAlg::Rs256,
|
||||
client_id: &provider.client_id,
|
||||
};
|
||||
|
||||
let (response, _id_token) =
|
||||
mas_oidc_client::requests::authorization_code::access_token_with_authorization_code(
|
||||
&http_service,
|
||||
client_credentials,
|
||||
token_endpoint,
|
||||
code,
|
||||
validation_data,
|
||||
Some(id_token_verification_data),
|
||||
clock.now(),
|
||||
&mut rng,
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(Json(response))
|
||||
}
|
||||
35
crates/handlers/src/upstream_oauth2/mod.rs
Normal file
35
crates/handlers/src/upstream_oauth2/mod.rs
Normal file
@@ -0,0 +1,35 @@
|
||||
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use axum::body::Full;
|
||||
use mas_http::{BodyToBytesResponseLayer, ClientInitError, ClientLayer, HttpService};
|
||||
use tower::{
|
||||
util::{MapErrLayer, MapRequestLayer},
|
||||
BoxError, Layer,
|
||||
};
|
||||
|
||||
pub(crate) mod authorize;
|
||||
pub(crate) mod callback;
|
||||
|
||||
async fn http_service(operation: &'static str) -> Result<HttpService, ClientInitError> {
|
||||
let client = (
|
||||
MapErrLayer::new(BoxError::from),
|
||||
MapRequestLayer::new(|req: hyper::Request<_>| req.map(Full::new)),
|
||||
BodyToBytesResponseLayer::default(),
|
||||
ClientLayer::new(operation),
|
||||
)
|
||||
.layer(mas_http::make_untraced_client().await?);
|
||||
|
||||
Ok(HttpService::new(client))
|
||||
}
|
||||
Reference in New Issue
Block a user