diff --git a/Cargo.lock b/Cargo.lock index 05aa924a..01edf785 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -497,6 +497,7 @@ dependencies = [ "bitflags", "bytes 1.1.0", "futures-util", + "headers", "http", "http-body", "hyper", diff --git a/crates/axum-utils/Cargo.toml b/crates/axum-utils/Cargo.toml index 3054b720..f72f9b18 100644 --- a/crates/axum-utils/Cargo.toml +++ b/crates/axum-utils/Cargo.toml @@ -7,7 +7,7 @@ license = "Apache-2.0" [dependencies] async-trait = "0.1.52" -axum = "0.4.8" +axum = { version = "0.4.8", features = ["headers"] } bincode = "1.3.3" chrono = "0.4.19" cookie = { version = "0.16.0", features = ["signed", "private", "percent-encode"] } diff --git a/crates/axum-utils/src/lib.rs b/crates/axum-utils/src/lib.rs index 6b4d16a6..bcc64107 100644 --- a/crates/axum-utils/src/lib.rs +++ b/crates/axum-utils/src/lib.rs @@ -17,6 +17,7 @@ pub mod csrf; pub mod fancy_error; pub mod session; pub mod url_builder; +pub mod user_authorization; pub use self::{ cookies::{Cookie, CookieExt, PrivateCookieJar}, diff --git a/crates/axum-utils/src/user_authorization.rs b/crates/axum-utils/src/user_authorization.rs new file mode 100644 index 00000000..33480c6b --- /dev/null +++ b/crates/axum-utils/src/user_authorization.rs @@ -0,0 +1,311 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::{collections::HashMap, error::Error}; + +use async_trait::async_trait; +use axum::{ + body::HttpBody, + extract::{ + rejection::{FailedToDeserializeQueryString, FormRejection, TypedHeaderRejectionReason}, + Form, FromRequest, TypedHeader, + }, + response::{IntoResponse, Response}, +}; +use headers::{authorization::Bearer, Authorization, Header, HeaderMapExt, HeaderName}; +use http::{header::WWW_AUTHENTICATE, HeaderMap, HeaderValue, StatusCode}; +use mas_data_model::Session; +use mas_storage::{ + oauth2::access_token::{lookup_active_access_token, AccessTokenLookupError}, + PostgresqlBackend, +}; +use serde::{de::DeserializeOwned, Deserialize}; +use sqlx::{Acquire, Postgres}; + +#[derive(Debug, Deserialize)] +struct AuthorizedForm { + #[serde(default)] + access_token: Option, + + #[serde(flatten)] + inner: F, +} + +#[derive(Debug)] +enum AccessToken { + Form(String), + Header(String), + None, +} + +impl AccessToken { + pub async fn fetch( + &self, + conn: impl Acquire<'_, Database = Postgres> + Send, + ) -> Result< + ( + mas_data_model::AccessToken, + Session, + ), + AuthorizationVerificationError, + > { + let token = match &self { + AccessToken::Form(t) | AccessToken::Header(t) => t, + AccessToken::None => return Err(AuthorizationVerificationError::MissingToken), + }; + + let (token, session) = lookup_active_access_token(conn, token).await?; + + Ok((token, session)) + } +} + +#[derive(Debug)] +pub struct UserAuthorization { + access_token: AccessToken, + form: Option, +} + +impl UserAuthorization { + // TODO: take scopes to validate as parameter + pub async fn protected_form( + self, + conn: impl Acquire<'_, Database = Postgres> + Send, + ) -> Result<(Session, F), AuthorizationVerificationError> { + let form = match self.form { + Some(f) => f, + None => return Err(AuthorizationVerificationError::MissingForm), + }; + + let (_token, session) = self.access_token.fetch(conn).await?; + + Ok((session, form)) + } + + // TODO: take scopes to validate as parameter + pub async fn protected( + self, + conn: impl Acquire<'_, Database = Postgres> + Send, + ) -> Result, AuthorizationVerificationError> { + let (_token, session) = self.access_token.fetch(conn).await?; + + Ok(session) + } +} + +pub enum UserAuthorizationError { + InvalidHeader, + TokenInFormAndHeader, + BadForm(FailedToDeserializeQueryString), + InternalError(Box), +} + +pub enum AuthorizationVerificationError { + MissingToken, + InvalidToken, + MissingForm, + InternalError(Box), +} + +impl From for AuthorizationVerificationError { + fn from(e: AccessTokenLookupError) -> Self { + if e.not_found() { + Self::InvalidToken + } else { + Self::InternalError(Box::new(e)) + } + } +} + +enum BearerError { + InvalidRequest, + InvalidToken, + #[allow(dead_code)] + InsufficientScope { + scope: Option, + }, +} + +impl BearerError { + fn error(&self) -> HeaderValue { + match self { + BearerError::InvalidRequest => HeaderValue::from_static("invalid_request"), + BearerError::InvalidToken => HeaderValue::from_static("invalid_token"), + BearerError::InsufficientScope { .. } => HeaderValue::from_static("insufficient_scope"), + } + } + + fn params(&self) -> HashMap<&'static str, HeaderValue> { + match self { + BearerError::InsufficientScope { scope: Some(scope) } => { + let mut m = HashMap::new(); + m.insert("scope", scope.clone()); + m + } + _ => HashMap::new(), + } + } +} + +enum WwwAuthenticate { + #[allow(dead_code)] + Basic { realm: HeaderValue }, + Bearer { + realm: Option, + error: BearerError, + error_description: Option, + }, +} + +impl Header for WwwAuthenticate { + fn name() -> &'static HeaderName { + &WWW_AUTHENTICATE + } + + fn decode<'i, I>(_values: &mut I) -> Result + where + Self: Sized, + I: Iterator, + { + Err(headers::Error::invalid()) + } + + fn encode>(&self, values: &mut E) { + let (scheme, params) = match self { + WwwAuthenticate::Basic { realm } => { + let mut params = HashMap::new(); + params.insert("realm", realm.clone()); + ("Basic", params) + } + WwwAuthenticate::Bearer { + realm, + error, + error_description, + } => { + let mut params = error.params(); + params.insert("error", error.error()); + + if let Some(realm) = realm { + params.insert("realm", realm.clone()); + } + + if let Some(error_description) = error_description { + params.insert("error_description", error_description.clone()); + } + + ("Bearer", params) + } + }; + + let params = params.into_iter().map(|(k, v)| format!(" {}={:?}", k, v)); + let value: String = std::iter::once(scheme.to_string()).chain(params).collect(); + let value = HeaderValue::from_str(&value).unwrap(); + values.extend(std::iter::once(value)); + } +} + +impl IntoResponse for UserAuthorizationError { + fn into_response(self) -> Response { + match self { + Self::BadForm(_) | Self::InvalidHeader | Self::TokenInFormAndHeader => { + let mut headers = HeaderMap::new(); + + headers.typed_insert(WwwAuthenticate::Bearer { + realm: None, + error: BearerError::InvalidRequest, + error_description: None, + }); + (StatusCode::BAD_REQUEST, headers).into_response() + } + Self::InternalError(e) => { + (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response() + } + } + } +} + +impl IntoResponse for AuthorizationVerificationError { + fn into_response(self) -> Response { + match self { + Self::MissingForm | Self::MissingToken => { + let mut headers = HeaderMap::new(); + + headers.typed_insert(WwwAuthenticate::Bearer { + realm: None, + error: BearerError::InvalidRequest, + error_description: None, + }); + (StatusCode::BAD_REQUEST, headers).into_response() + } + Self::InvalidToken => { + let mut headers = HeaderMap::new(); + + headers.typed_insert(WwwAuthenticate::Bearer { + realm: None, + error: BearerError::InvalidToken, + error_description: None, + }); + (StatusCode::BAD_REQUEST, headers).into_response() + } + Self::InternalError(e) => { + (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response() + } + } + } +} + +#[async_trait] +impl FromRequest for UserAuthorization +where + B: Send + HttpBody, + B::Data: Send, + B::Error: Error + Send + Sync + 'static, + F: DeserializeOwned, +{ + type Rejection = UserAuthorizationError; + + async fn from_request( + req: &mut axum::extract::RequestParts, + ) -> Result { + let header = TypedHeader::>::from_request(req).await; + + let token_from_header = match header { + Ok(header) => Some(header.token().to_string()), + Err(err) => match err.reason() { + TypedHeaderRejectionReason::Missing => None, + TypedHeaderRejectionReason::Error(_) => { + return Err(UserAuthorizationError::InvalidHeader) + } + }, + }; + + let (token_from_form, form) = match Form::>::from_request(req).await { + Ok(Form(form)) => (form.access_token, Some(form.inner)), + Err(FormRejection::InvalidFormContentType(_err)) => (None, None), + Err(FormRejection::FailedToDeserializeQueryString(err)) => { + return Err(UserAuthorizationError::BadForm(err)) + } + Err(e) => return Err(UserAuthorizationError::InternalError(Box::new(e))), + }; + + let access_token = match (token_from_header, token_from_form) { + (Some(_), Some(_)) => return Err(UserAuthorizationError::TokenInFormAndHeader), + (Some(t), None) => AccessToken::Header(t), + (None, Some(t)) => AccessToken::Form(t), + (None, None) => AccessToken::None, + }; + + Ok(UserAuthorization { access_token, form }) + } +} diff --git a/crates/handlers/src/lib.rs b/crates/handlers/src/lib.rs index eb8e04c1..3fe7fe56 100644 --- a/crates/handlers/src/lib.rs +++ b/crates/handlers/src/lib.rs @@ -24,7 +24,7 @@ use std::sync::Arc; use axum::{ body::HttpBody, extract::Extension, - routing::{get, post}, + routing::{get, on, post, MethodFilter}, Router, }; use mas_axum_utils::UrlBuilder; @@ -83,6 +83,13 @@ where get(self::oauth2::discovery::get), ) .route("/oauth2/keys.json", get(self::oauth2::keys::get)) + .route( + "/oauth2/userinfo", + on( + MethodFilter::POST | MethodFilter::GET, + self::oauth2::userinfo::get, + ), + ) .fallback(mas_static_files::Assets) .layer(Extension(pool.clone())) .layer(Extension(templates.clone())) diff --git a/crates/handlers/src/oauth2/mod.rs b/crates/handlers/src/oauth2/mod.rs index 36b35cab..bfc68ffb 100644 --- a/crates/handlers/src/oauth2/mod.rs +++ b/crates/handlers/src/oauth2/mod.rs @@ -17,7 +17,7 @@ pub mod discovery; // pub mod introspection; pub mod keys; // pub mod token; -// pub mod userinfo; +pub mod userinfo; use hyper::{ http::uri::{Parts, PathAndQuery}, diff --git a/crates/handlers/src/oauth2/userinfo.rs b/crates/handlers/src/oauth2/userinfo.rs index 9612731f..d36d9af2 100644 --- a/crates/handlers/src/oauth2/userinfo.rs +++ b/crates/handlers/src/oauth2/userinfo.rs @@ -1,4 +1,4 @@ -// Copyright 2021 The Matrix.org Foundation C.I.C. +// Copyright 2021, 2022 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,17 +12,16 @@ // See the License for the specific language governing permissions and // limitations under the License. -use mas_data_model::{AccessToken, Session}; -use mas_storage::PostgresqlBackend; -use mas_warp_utils::filters::{ - self, - authenticate::{authentication, recover_unauthorized}, +use axum::{ + extract::Extension, + response::{IntoResponse, Response}, + Json, }; +use mas_axum_utils::{internal_error, user_authorization::UserAuthorization}; use oauth2_types::scope; use serde::Serialize; use serde_with::skip_serializing_none; use sqlx::PgPool; -use warp::{filters::BoxedFilter, Filter, Rejection, Reply}; #[skip_serializing_none] #[derive(Serialize)] @@ -33,25 +32,21 @@ struct UserInfo { email_verified: Option, } -pub(super) fn filter(pool: &PgPool) -> BoxedFilter<(Box,)> { - warp::path!("oauth2" / "userinfo") - .and(filters::trace::name("GET /oauth2/userinfo")) - .and( - warp::get() - .or(warp::post()) - .unify() - .and(authentication(pool)) - .and_then(userinfo) - .recover(recover_unauthorized) - .unify(), - ) - .boxed() -} +pub async fn get( + Extension(pool): Extension, + user_authorization: UserAuthorization, +) -> Result { + let mut conn = pool + .acquire() + .await + .map_err(internal_error) + .map_err(IntoResponse::into_response)?; + + let session = user_authorization + .protected(&mut conn) + .await + .map_err(IntoResponse::into_response)?; -async fn userinfo( - _token: AccessToken, - session: Session, -) -> Result, Rejection> { let user = session.browser_session.user; let mut res = UserInfo { sub: user.sub, @@ -67,5 +62,5 @@ async fn userinfo( } } - Ok(Box::new(warp::reply::json(&res))) + Ok(Json(res)) } diff --git a/crates/storage/src/oauth2/access_token.rs b/crates/storage/src/oauth2/access_token.rs index cfa271b1..109c0a08 100644 --- a/crates/storage/src/oauth2/access_token.rs +++ b/crates/storage/src/oauth2/access_token.rs @@ -15,7 +15,7 @@ use anyhow::Context; use chrono::{DateTime, Duration, Utc}; use mas_data_model::{AccessToken, Authentication, BrowserSession, Session, User, UserEmail}; -use sqlx::{PgConnection, PgExecutor}; +use sqlx::{Acquire, PgExecutor, Postgres}; use thiserror::Error; use super::client::{lookup_client_by_client_id, ClientFetchError}; @@ -93,14 +93,26 @@ impl AccessTokenLookupError { } } -#[allow(clippy::too_many_lines)] -pub async fn lookup_active_access_token( - conn: &mut PgConnection, - token: &str, -) -> Result<(AccessToken, Session), AccessTokenLookupError> { - let res = sqlx::query_as!( - OAuth2AccessTokenLookup, - r#" +// TODO: remove that manual async +#[allow(clippy::too_many_lines, clippy::manual_async_fn)] +pub fn lookup_active_access_token<'a, 'c, A>( + conn: A, + token: &'a str, +) -> impl std::future::Future< + Output = Result< + (AccessToken, Session), + AccessTokenLookupError, + >, +> + Send + + 'a +where + A: Acquire<'c, Database = Postgres> + Send + 'a, +{ + async move { + let mut conn = conn.acquire().await?; + let res = sqlx::query_as!( + OAuth2AccessTokenLookup, + r#" SELECT at.id AS "access_token_id", at.token AS "access_token", @@ -140,73 +152,74 @@ pub async fn lookup_active_access_token( ORDER BY usa.created_at DESC LIMIT 1 "#, - token, - ) - .fetch_one(&mut *conn) - .await?; + token, + ) + .fetch_one(&mut *conn) + .await?; - let access_token = AccessToken { - data: res.access_token_id, - jti: format!("{}", res.access_token_id), - token: res.access_token, - created_at: res.access_token_created_at, - expires_after: Duration::seconds(res.access_token_expires_after.into()), - }; + let access_token = AccessToken { + data: res.access_token_id, + jti: format!("{}", res.access_token_id), + token: res.access_token, + created_at: res.access_token_created_at, + expires_after: Duration::seconds(res.access_token_expires_after.into()), + }; - let client = lookup_client_by_client_id(&mut *conn, &res.client_id).await?; + let client = lookup_client_by_client_id(&mut *conn, &res.client_id).await?; - let primary_email = match ( - res.user_email_id, - res.user_email, - res.user_email_created_at, - res.user_email_confirmed_at, - ) { - (Some(id), Some(email), Some(created_at), confirmed_at) => Some(UserEmail { - data: id, - email, - created_at, - confirmed_at, - }), - (None, None, None, None) => None, - _ => return Err(DatabaseInconsistencyError.into()), - }; + let primary_email = match ( + res.user_email_id, + res.user_email, + res.user_email_created_at, + res.user_email_confirmed_at, + ) { + (Some(id), Some(email), Some(created_at), confirmed_at) => Some(UserEmail { + data: id, + email, + created_at, + confirmed_at, + }), + (None, None, None, None) => None, + _ => return Err(DatabaseInconsistencyError.into()), + }; - let user = User { - data: res.user_id, - username: res.user_username, - sub: format!("fake-sub-{}", res.user_id), - primary_email, - }; + let user = User { + data: res.user_id, + username: res.user_username, + sub: format!("fake-sub-{}", res.user_id), + primary_email, + }; - let last_authentication = match ( - res.user_session_last_authentication_id, - res.user_session_last_authentication_created_at, - ) { - (None, None) => None, - (Some(id), Some(created_at)) => Some(Authentication { - data: id, - created_at, - }), - _ => return Err(DatabaseInconsistencyError.into()), - }; + let last_authentication = match ( + res.user_session_last_authentication_id, + res.user_session_last_authentication_created_at, + ) { + (None, None) => None, + (Some(id), Some(created_at)) => Some(Authentication { + data: id, + created_at, + }), + _ => return Err(DatabaseInconsistencyError.into()), + }; - let browser_session = BrowserSession { - data: res.user_session_id, - created_at: res.user_session_created_at, - user, - last_authentication, - }; + let browser_session = BrowserSession { + data: res.user_session_id, + created_at: res.user_session_created_at, + user, + last_authentication, + }; - let scope = res.scope.parse().map_err(|_e| DatabaseInconsistencyError)?; + let scope = res.scope.parse().map_err(|_e| DatabaseInconsistencyError)?; - let session = Session { - data: res.session_id, - client, - browser_session, - scope, - }; + let session = Session { + data: res.session_id, + client, + browser_session, + scope, + }; - Ok((access_token, session)) + Ok((access_token, session)) + } } pub async fn revoke_access_token(