1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-11-20 12:02:22 +03:00

Have a unified URL builder/router

This commit is contained in:
Quentin Gliech
2022-05-10 09:52:27 +02:00
parent 0ac4fddee4
commit f4353b660e
28 changed files with 684 additions and 371 deletions

View File

@@ -28,11 +28,11 @@ use axum::{
Router,
};
use hyper::header::{ACCEPT, ACCEPT_LANGUAGE, AUTHORIZATION, CONTENT_LANGUAGE, CONTENT_TYPE};
use mas_axum_utils::UrlBuilder;
use mas_config::Encrypter;
use mas_email::Mailer;
use mas_http::CorsLayerExt;
use mas_jose::StaticKeystore;
use mas_router::{Route, UrlBuilder};
use mas_templates::Templates;
use sqlx::PgPool;
use tower_http::cors::{Any, CorsLayer};
@@ -58,25 +58,34 @@ where
// All those routes are API-like, with a common CORS layer
let api_router = Router::new()
.route(
"/.well-known/openid-configuration",
mas_router::OidcConfiguration::route(),
get(self::oauth2::discovery::get),
)
.route("/.well-known/webfinger", get(self::oauth2::webfinger::get))
.route("/oauth2/keys.json", get(self::oauth2::keys::get))
.route(
"/oauth2/userinfo",
mas_router::Webfinger::route(),
get(self::oauth2::webfinger::get),
)
.route(
mas_router::OAuth2Keys::route(),
get(self::oauth2::keys::get),
)
.route(
mas_router::OidcUserinfo::route(),
on(
MethodFilter::POST | MethodFilter::GET,
self::oauth2::userinfo::get,
),
)
.route(
"/oauth2/introspect",
mas_router::OAuth2Introspection::route(),
post(self::oauth2::introspection::post),
)
.route("/oauth2/token", post(self::oauth2::token::post))
.route(
"/oauth2/registration",
mas_router::OAuth2TokenEndpoint::route(),
post(self::oauth2::token::post),
)
.route(
mas_router::OAuth2RegistrationEndpoint::route(),
post(self::oauth2::registration::post),
)
.layer(
@@ -94,38 +103,44 @@ where
);
Router::new()
.route("/", get(self::views::index::get))
.route("/health", get(self::health::get))
.route(mas_router::Index::route(), get(self::views::index::get))
.route(mas_router::Healthcheck::route(), get(self::health::get))
.route(
"/login",
mas_router::Login::route(),
get(self::views::login::get).post(self::views::login::post),
)
.route("/logout", post(self::views::logout::post))
.route(mas_router::Logout::route(), post(self::views::logout::post))
.route(
"/reauth",
mas_router::Reauth::route(),
get(self::views::reauth::get).post(self::views::reauth::post),
)
.route(
"/register",
mas_router::Register::route(),
get(self::views::register::get).post(self::views::register::post),
)
.route("/verify/:code", get(self::views::verify::get))
.route("/account", get(self::views::account::get))
.route(
"/account/password",
mas_router::VerifyEmail::route(),
get(self::views::verify::get),
)
.route(mas_router::Account::route(), get(self::views::account::get))
.route(
mas_router::AccountPassword::route(),
get(self::views::account::password::get).post(self::views::account::password::post),
)
.route(
"/account/emails",
mas_router::AccountEmails::route(),
get(self::views::account::emails::get).post(self::views::account::emails::post),
)
.route("/authorize", get(self::oauth2::authorization::get))
.route(
"/authorize/:grant_id",
mas_router::OAuth2AuthorizationEndpoint::route(),
get(self::oauth2::authorization::get),
)
.route(
mas_router::ContinueAuthorizationGrant::route(),
get(self::oauth2::authorization::complete::get),
)
.route(
"/consent/:grant_id",
mas_router::Consent::route(),
get(self::oauth2::consent::get).post(self::oauth2::consent::post),
)
.merge(api_router)

View File

@@ -24,6 +24,7 @@ use hyper::StatusCode;
use mas_axum_utils::SessionInfoExt;
use mas_config::Encrypter;
use mas_data_model::{AuthorizationGrant, BrowserSession, TokenType};
use mas_router::{PostAuthAction, Route};
use mas_storage::{
oauth2::{
access_token::add_access_token,
@@ -41,10 +42,6 @@ use sqlx::{PgPool, Postgres, Transaction};
use thiserror::Error;
use super::callback::{CallbackDestination, CallbackDestinationError, InvalidRedirectUriError};
use crate::{
oauth2::consent::ConsentRequest,
views::{LoginRequest, PostAuthAction, ReauthRequest},
};
#[derive(Debug, Error)]
pub enum RouteError {
@@ -122,15 +119,14 @@ pub(crate) async fn get(
let grant = get_grant_by_id(&mut txn, grant_id).await?;
let callback_destination = CallbackDestination::try_from(&grant)?;
let continue_grant = PostAuthAction::continue_grant(&grant);
let consent_request = ConsentRequest::for_grant(&grant);
let continue_grant = PostAuthAction::continue_grant(grant_id);
let session = if let Some(session) = maybe_session {
session
} else {
// If there is no session, redirect to the login screen, redirecting here after
// logout
return Ok((cookie_jar, LoginRequest::from(continue_grant).go()).into_response());
return Ok((cookie_jar, mas_router::Login::and_then(continue_grant).go()).into_response());
};
match complete(grant, session, txn).await {
@@ -138,11 +134,14 @@ pub(crate) async fn get(
let res = callback_destination.go(&templates, params).await?;
Ok((cookie_jar, res).into_response())
}
Err(GrantCompletionError::RequiresReauth) => {
Ok((cookie_jar, ReauthRequest::from(continue_grant).go()).into_response())
}
Err(GrantCompletionError::RequiresReauth) => Ok((
cookie_jar,
mas_router::Reauth::and_then(continue_grant).go(),
)
.into_response()),
Err(GrantCompletionError::RequiresConsent) => {
Ok((cookie_jar, consent_request.go()).into_response())
let next = mas_router::Consent(grant_id);
Ok((cookie_jar, next.go()).into_response())
}
Err(GrantCompletionError::NotPending) => Err(RouteError::NotPending),
Err(GrantCompletionError::Internal(e)) => Err(RouteError::Internal(e)),

View File

@@ -23,6 +23,7 @@ use mas_axum_utils::SessionInfoExt;
use mas_config::Encrypter;
use mas_data_model::{AuthorizationCode, Pkce};
use mas_iana::oauth::OAuthAuthorizationEndpointResponseType;
use mas_router::{PostAuthAction, Route};
use mas_storage::oauth2::{
authorization_grant::new_authorization_grant,
client::{lookup_client_by_client_id, ClientFetchError},
@@ -45,8 +46,6 @@ use sqlx::PgPool;
use thiserror::Error;
use self::{callback::CallbackDestination, complete::GrantCompletionError};
use super::consent::ConsentRequest;
use crate::views::{LoginRequest, PostAuthAction, ReauthRequest, RegisterRequest};
mod callback;
pub mod complete;
@@ -287,8 +286,7 @@ pub(crate) async fn get(
requires_consent,
)
.await?;
let continue_grant = PostAuthAction::continue_grant(&grant);
let consent_request = ConsentRequest::for_grant(&grant);
let continue_grant = PostAuthAction::continue_grant(grant.data);
let res = match (maybe_session, params.auth.prompt) {
// Cases where there is no active session, redirect to the relevant page
@@ -300,13 +298,17 @@ pub(crate) async fn get(
// Client asked for a registration, show the registration prompt
txn.commit().await?;
RegisterRequest::from(continue_grant).go().into_response()
mas_router::Register::and_then(continue_grant)
.go()
.into_response()
}
(None, _) => {
// Other cases where we don't have a session, ask for a login
txn.commit().await?;
LoginRequest::from(continue_grant).go().into_response()
mas_router::Login::and_then(continue_grant)
.go()
.into_response()
}
// Special case when we already have a sesion but prompt=login|select_account
@@ -314,7 +316,9 @@ pub(crate) async fn get(
// TODO: better pages here
txn.commit().await?;
ReauthRequest::from(continue_grant).go().into_response()
mas_router::Reauth::and_then(continue_grant)
.go()
.into_response()
}
// Else, we immediately try to complete the authorization grant
@@ -343,14 +347,17 @@ pub(crate) async fn get(
}
}
(Some(user_session), _) => {
let grant_id = grant.data;
// Else, we show the relevant reauth/consent page if necessary
match self::complete::complete(grant, user_session, txn).await {
Ok(params) => callback_destination.go(&templates, params).await?,
Err(GrantCompletionError::RequiresConsent) => {
consent_request.go().into_response()
mas_router::Consent(grant_id).go().into_response()
}
Err(GrantCompletionError::RequiresReauth) => {
ReauthRequest::from(continue_grant).go().into_response()
mas_router::Reauth::and_then(continue_grant)
.go()
.into_response()
}
Err(GrantCompletionError::Anyhow(a)) => return Err(RouteError::Anyhow(a)),
Err(GrantCompletionError::Internal(e)) => {

View File

@@ -15,7 +15,7 @@
use anyhow::Context;
use axum::{
extract::{Extension, Form, Path},
response::{Html, IntoResponse, Redirect, Response},
response::{Html, IntoResponse, Response},
};
use axum_extra::extract::PrivateCookieJar;
use hyper::StatusCode;
@@ -24,20 +24,16 @@ use mas_axum_utils::{
SessionInfoExt,
};
use mas_config::Encrypter;
use mas_data_model::{AuthorizationGrant, AuthorizationGrantStage};
use mas_storage::{
oauth2::{
authorization_grant::{get_grant_by_id, give_consent_to_grant},
consent::insert_client_consent,
},
PostgresqlBackend,
use mas_data_model::AuthorizationGrantStage;
use mas_router::{PostAuthAction, Route};
use mas_storage::oauth2::{
authorization_grant::{get_grant_by_id, give_consent_to_grant},
consent::insert_client_consent,
};
use mas_templates::{ConsentContext, TemplateContext, Templates};
use sqlx::PgPool;
use thiserror::Error;
use crate::views::{LoginRequest, PostAuthAction};
#[derive(Debug, Error)]
pub enum RouteError {
#[error(transparent)]
@@ -50,23 +46,6 @@ impl IntoResponse for RouteError {
}
}
pub(crate) struct ConsentRequest {
grant_id: i64,
}
impl ConsentRequest {
pub fn for_grant(grant: &AuthorizationGrant<PostgresqlBackend>) -> Self {
Self {
grant_id: grant.data,
}
}
pub fn go(&self) -> Redirect {
let uri = format!("/consent/{}", self.grant_id);
Redirect::to(&uri)
}
}
pub(crate) async fn get(
Extension(templates): Extension<Templates>,
Extension(pool): Extension<PgPool>,
@@ -105,7 +84,7 @@ pub(crate) async fn get(
Ok((cookie_jar, Html(content)).into_response())
} else {
let login = LoginRequest::from(PostAuthAction::continue_grant(&grant));
let login = mas_router::Login::and_continue_grant(grant_id);
Ok((cookie_jar, login.go()).into_response())
}
}
@@ -133,12 +112,12 @@ pub(crate) async fn post(
.context("could not load session")?;
let grant = get_grant_by_id(&mut txn, grant_id).await?;
let next = PostAuthAction::continue_grant(&grant);
let next = PostAuthAction::continue_grant(grant_id);
let session = if let Some(session) = maybe_session {
session
} else {
let login = LoginRequest::from(next);
let login = mas_router::Login::and_then(next);
return Ok((cookie_jar, login.go()).into_response());
};
@@ -163,5 +142,5 @@ pub(crate) async fn post(
txn.commit().await.context("could not commit txn")?;
Ok((cookie_jar, next.redirect()).into_response())
Ok((cookie_jar, next.go_next()).into_response())
}

View File

@@ -15,7 +15,6 @@
use std::sync::Arc;
use axum::{extract::Extension, response::IntoResponse, Json};
use mas_axum_utils::UrlBuilder;
use mas_iana::{
jose::JsonWebSignatureAlg,
oauth::{
@@ -24,6 +23,7 @@ use mas_iana::{
},
};
use mas_jose::{SigningKeystore, StaticKeystore};
use mas_router::UrlBuilder;
use oauth2_types::{
oidc::{ClaimType, Metadata, SubjectType},
requests::{Display, GrantType, Prompt, ResponseMode},

View File

@@ -20,10 +20,7 @@ use chrono::{DateTime, Duration, Utc};
use data_encoding::BASE64URL_NOPAD;
use headers::{CacheControl, HeaderMap, HeaderMapExt, Pragma};
use hyper::StatusCode;
use mas_axum_utils::{
client_authorization::{ClientAuthorization, CredentialsVerificationError},
UrlBuilder,
};
use mas_axum_utils::client_authorization::{ClientAuthorization, CredentialsVerificationError};
use mas_config::Encrypter;
use mas_data_model::{AuthorizationGrantStage, Client, TokenType};
use mas_iana::jose::JsonWebSignatureAlg;
@@ -31,6 +28,7 @@ use mas_jose::{
claims::{self, ClaimError},
DecodedJsonWebToken, SigningKeystore, StaticKeystore,
};
use mas_router::UrlBuilder;
use mas_storage::{
oauth2::{
access_token::{add_access_token, revoke_access_token},

View File

@@ -21,8 +21,9 @@ use axum::{
};
use headers::ContentType;
use hyper::StatusCode;
use mas_axum_utils::{internal_error, user_authorization::UserAuthorization, UrlBuilder};
use mas_axum_utils::{internal_error, user_authorization::UserAuthorization};
use mas_jose::{DecodedJsonWebToken, SigningKeystore, StaticKeystore};
use mas_router::UrlBuilder;
use mime::Mime;
use oauth2_types::scope;
use serde::Serialize;

View File

@@ -14,7 +14,7 @@
use axum::{extract::Query, response::IntoResponse, Extension, Json, TypedHeader};
use headers::ContentType;
use mas_axum_utils::UrlBuilder;
use mas_router::UrlBuilder;
use oauth2_types::webfinger::WebFingerResponse;
use serde::Deserialize;

View File

@@ -20,11 +20,12 @@ use axum_extra::extract::PrivateCookieJar;
use lettre::{message::Mailbox, Address};
use mas_axum_utils::{
csrf::{CsrfExt, ProtectedForm},
fancy_error, FancyError, SessionInfoExt, UrlBuilder,
fancy_error, FancyError, SessionInfoExt,
};
use mas_config::Encrypter;
use mas_data_model::{BrowserSession, User, UserEmail};
use mas_email::Mailer;
use mas_router::{Route, UrlBuilder};
use mas_storage::{
user::{
add_user_email, add_user_email_verification_code, get_user_email, get_user_emails,
@@ -38,8 +39,6 @@ use serde::Deserialize;
use sqlx::{PgExecutor, PgPool};
use tracing::info;
use crate::views::LoginRequest;
#[derive(Deserialize, Debug)]
#[serde(tag = "action", rename_all = "snake_case")]
pub enum ManagementForm {
@@ -69,7 +68,7 @@ pub(crate) async fn get(
if let Some(session) = maybe_session {
render(templates, session, cookie_jar, &mut conn).await
} else {
let login = LoginRequest::default();
let login = mas_router::Login::default();
Ok((cookie_jar, login.go()).into_response())
}
}
@@ -119,7 +118,7 @@ async fn start_email_verification(
let mailbox = Mailbox::new(Some(user.username.clone()), address);
let link = url_builder.email_verification(&code);
let link = url_builder.email_verification(code);
let context = EmailVerificationContext::new(user.clone().into(), link);
@@ -149,7 +148,7 @@ pub(crate) async fn post(
let mut session = if let Some(session) = maybe_session {
session
} else {
let login = LoginRequest::default();
let login = mas_router::Login::default();
return Ok((cookie_jar, login.go()).into_response());
};

View File

@@ -22,12 +22,11 @@ use axum::{
use axum_extra::extract::PrivateCookieJar;
use mas_axum_utils::{csrf::CsrfExt, fancy_error, FancyError, SessionInfoExt};
use mas_config::Encrypter;
use mas_router::Route;
use mas_storage::user::{count_active_sessions, get_user_emails};
use mas_templates::{AccountContext, TemplateContext, Templates};
use sqlx::PgPool;
use super::LoginRequest;
pub(crate) async fn get(
Extension(templates): Extension<Templates>,
Extension(pool): Extension<PgPool>,
@@ -49,7 +48,7 @@ pub(crate) async fn get(
let session = if let Some(session) = maybe_session {
session
} else {
let login = LoginRequest::default();
let login = mas_router::Login::default();
return Ok((cookie_jar, login.go()).into_response());
};

View File

@@ -24,6 +24,7 @@ use mas_axum_utils::{
};
use mas_config::Encrypter;
use mas_data_model::BrowserSession;
use mas_router::Route;
use mas_storage::{
user::{authenticate_session, set_password},
PostgresqlBackend,
@@ -32,8 +33,6 @@ use mas_templates::{EmptyContext, TemplateContext, Templates};
use serde::Deserialize;
use sqlx::PgPool;
use crate::views::LoginRequest;
#[derive(Deserialize)]
pub struct ChangeForm {
current_password: String,
@@ -61,7 +60,7 @@ pub(crate) async fn get(
if let Some(session) = maybe_session {
render(templates, session, cookie_jar).await
} else {
let login = LoginRequest::default();
let login = mas_router::Login::default();
Ok((cookie_jar, login.go()).into_response())
}
}
@@ -107,7 +106,7 @@ pub(crate) async fn post(
let mut session = if let Some(session) = maybe_session {
session
} else {
let login = LoginRequest::default();
let login = mas_router::Login::default();
return Ok((cookie_jar, login.go()).into_response());
};

View File

@@ -17,8 +17,9 @@ use axum::{
response::{Html, IntoResponse},
};
use axum_extra::extract::PrivateCookieJar;
use mas_axum_utils::{csrf::CsrfExt, fancy_error, FancyError, SessionInfoExt, UrlBuilder};
use mas_axum_utils::{csrf::CsrfExt, fancy_error, FancyError, SessionInfoExt};
use mas_config::Encrypter;
use mas_router::UrlBuilder;
use mas_templates::{IndexContext, TemplateContext, Templates};
use sqlx::PgPool;

View File

@@ -12,11 +12,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::borrow::Cow;
use axum::{
extract::{Extension, Form, Query},
response::{Html, IntoResponse, Redirect, Response},
response::{Html, IntoResponse, Response},
};
use axum_extra::extract::PrivateCookieJar;
use mas_axum_utils::{
@@ -25,53 +23,13 @@ use mas_axum_utils::{
};
use mas_config::Encrypter;
use mas_data_model::errors::WrapFormError;
use mas_router::Route;
use mas_storage::user::login;
use mas_templates::{LoginContext, LoginFormField, TemplateContext, Templates};
use serde::Deserialize;
use sqlx::PgPool;
use super::{shared::PostAuthAction, RegisterRequest};
#[derive(Deserialize, Default, Debug)]
pub(crate) struct LoginRequest {
#[serde(flatten)]
post_auth_action: Option<PostAuthAction>,
}
impl From<PostAuthAction> for LoginRequest {
fn from(post_auth_action: PostAuthAction) -> Self {
Some(post_auth_action).into()
}
}
impl From<Option<PostAuthAction>> for LoginRequest {
fn from(post_auth_action: Option<PostAuthAction>) -> Self {
Self { post_auth_action }
}
}
impl LoginRequest {
pub fn as_link(&self) -> Cow<'static, str> {
if let Some(next) = &self.post_auth_action {
let qs = serde_urlencoded::to_string(next).unwrap();
Cow::Owned(format!("/login?{}", qs))
} else {
Cow::Borrowed("/login")
}
}
pub fn go(&self) -> Redirect {
Redirect::to(&self.as_link())
}
fn redirect(self) -> Redirect {
if let Some(action) = self.post_auth_action {
action.redirect()
} else {
Redirect::to("/")
}
}
}
use super::shared::OptionalPostAuthAction;
#[derive(Deserialize)]
pub(crate) struct LoginForm {
@@ -83,7 +41,7 @@ pub(crate) struct LoginForm {
pub(crate) async fn get(
Extension(templates): Extension<Templates>,
Extension(pool): Extension<PgPool>,
Query(query): Query<LoginRequest>,
Query(query): Query<OptionalPostAuthAction>,
cookie_jar: PrivateCookieJar<Encrypter>,
) -> Result<Response, FancyError> {
let mut conn = pool
@@ -100,23 +58,23 @@ pub(crate) async fn get(
.map_err(fancy_error(templates.clone()))?;
if maybe_session.is_some() {
let response = query.redirect().into_response();
Ok(response)
let reply = query.go_next();
Ok((cookie_jar, reply).into_response())
} else {
let ctx = LoginContext::default();
let ctx = match query.post_auth_action {
Some(next) => {
let register_link = RegisterRequest::from(next.clone()).as_link();
let next = next
.load_context(&mut conn)
.await
.map_err(fancy_error(templates.clone()))?;
ctx.with_post_action(next)
.with_register_link(register_link.to_string())
}
None => ctx,
let next = query
.load_context(&mut conn)
.await
.map_err(fancy_error(templates.clone()))?;
let ctx = if let Some(next) = next {
ctx.with_post_action(next)
} else {
ctx
};
let ctx = ctx.with_csrf(csrf_token.form_value());
let register_link = mas_router::Register::from(query.post_auth_action).relative_url();
let ctx = ctx
.with_register_link(register_link.to_string())
.with_csrf(csrf_token.form_value());
let content = templates
.render_login(&ctx)
@@ -130,7 +88,7 @@ pub(crate) async fn get(
pub(crate) async fn post(
Extension(templates): Extension<Templates>,
Extension(pool): Extension<PgPool>,
Query(query): Query<LoginRequest>,
Query(query): Query<OptionalPostAuthAction>,
cookie_jar: PrivateCookieJar<Encrypter>,
Form(form): Form<ProtectedForm<LoginForm>>,
) -> Result<Response, FancyError> {
@@ -150,7 +108,7 @@ pub(crate) async fn post(
match login(&mut conn, &form.username, form.password).await {
Ok(session_info) => {
let cookie_jar = cookie_jar.set_session(&session_info);
let reply = query.redirect();
let reply = query.go_next();
Ok((cookie_jar, reply).into_response())
}
Err(e) => {
@@ -172,15 +130,3 @@ pub(crate) async fn post(
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn deserialize_login_request() {
let res: Result<LoginRequest, _> =
serde_urlencoded::from_str("next=continue_authorization_grant&data=13");
res.unwrap().post_auth_action.unwrap();
}
}

View File

@@ -20,7 +20,3 @@ pub mod reauth;
pub mod register;
pub mod shared;
pub mod verify;
pub(crate) use self::{
login::LoginRequest, reauth::ReauthRequest, register::RegisterRequest, shared::PostAuthAction,
};

View File

@@ -12,11 +12,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::borrow::Cow;
use axum::{
extract::{Extension, Form, Query},
response::{Html, IntoResponse, Redirect, Response},
response::{Html, IntoResponse, Response},
};
use axum_extra::extract::PrivateCookieJar;
use mas_axum_utils::{
@@ -24,49 +22,13 @@ use mas_axum_utils::{
fancy_error, FancyError, SessionInfoExt,
};
use mas_config::Encrypter;
use mas_router::Route;
use mas_storage::user::authenticate_session;
use mas_templates::{ReauthContext, TemplateContext, Templates};
use serde::Deserialize;
use sqlx::PgPool;
use super::{LoginRequest, PostAuthAction};
#[derive(Deserialize)]
pub(crate) struct ReauthRequest {
#[serde(flatten)]
post_auth_action: Option<PostAuthAction>,
}
impl From<PostAuthAction> for ReauthRequest {
fn from(post_auth_action: PostAuthAction) -> Self {
Self {
post_auth_action: Some(post_auth_action),
}
}
}
impl ReauthRequest {
pub fn as_link(&self) -> Cow<'static, str> {
if let Some(next) = &self.post_auth_action {
let qs = serde_urlencoded::to_string(next).unwrap();
Cow::Owned(format!("/reauth?{}", qs))
} else {
Cow::Borrowed("/reauth")
}
}
pub fn go(&self) -> Redirect {
Redirect::to(&self.as_link())
}
fn redirect(self) -> Redirect {
if let Some(action) = self.post_auth_action {
action.redirect()
} else {
Redirect::to("/")
}
}
}
use super::shared::OptionalPostAuthAction;
#[derive(Deserialize, Debug)]
pub(crate) struct ReauthForm {
@@ -76,7 +38,7 @@ pub(crate) struct ReauthForm {
pub(crate) async fn get(
Extension(templates): Extension<Templates>,
Extension(pool): Extension<PgPool>,
Query(query): Query<ReauthRequest>,
Query(query): Query<OptionalPostAuthAction>,
cookie_jar: PrivateCookieJar<Encrypter>,
) -> Result<Response, FancyError> {
let mut conn = pool
@@ -97,20 +59,19 @@ pub(crate) async fn get(
} else {
// If there is no session, redirect to the login screen, keeping the
// PostAuthAction
let login: LoginRequest = query.post_auth_action.into();
let login = mas_router::Login::from(query.post_auth_action);
return Ok((cookie_jar, login.go()).into_response());
};
let ctx = ReauthContext::default();
let ctx = match query.post_auth_action {
Some(next) => {
let next = next
.load_context(&mut conn)
.await
.map_err(fancy_error(templates.clone()))?;
ctx.with_post_action(next)
}
None => ctx,
let next = query
.load_context(&mut conn)
.await
.map_err(fancy_error(templates.clone()))?;
let ctx = if let Some(next) = next {
ctx.with_post_action(next)
} else {
ctx
};
let ctx = ctx.with_session(session).with_csrf(csrf_token.form_value());
@@ -125,7 +86,7 @@ pub(crate) async fn get(
pub(crate) async fn post(
Extension(templates): Extension<Templates>,
Extension(pool): Extension<PgPool>,
Query(query): Query<ReauthRequest>,
Query(query): Query<OptionalPostAuthAction>,
cookie_jar: PrivateCookieJar<Encrypter>,
Form(form): Form<ProtectedForm<ReauthForm>>,
) -> Result<Response, FancyError> {
@@ -147,7 +108,7 @@ pub(crate) async fn post(
} else {
// If there is no session, redirect to the login screen, keeping the
// PostAuthAction
let login: LoginRequest = query.post_auth_action.into();
let login = mas_router::Login::from(query.post_auth_action);
return Ok((cookie_jar, login.go()).into_response());
};
@@ -158,6 +119,6 @@ pub(crate) async fn post(
let cookie_jar = cookie_jar.set_session(&session);
txn.commit().await.map_err(fancy_error(templates.clone()))?;
let redirection = query.redirect();
Ok((cookie_jar, redirection).into_response())
let reply = query.go_next();
Ok((cookie_jar, reply).into_response())
}

View File

@@ -14,12 +14,10 @@
#![allow(clippy::trait_duplication_in_bounds)]
use std::borrow::Cow;
use argon2::Argon2;
use axum::{
extract::{Extension, Form, Query},
response::{Html, IntoResponse, Redirect, Response},
response::{Html, IntoResponse, Response},
};
use axum_extra::extract::PrivateCookieJar;
use mas_axum_utils::{
@@ -27,49 +25,13 @@ use mas_axum_utils::{
fancy_error, FancyError, SessionInfoExt,
};
use mas_config::Encrypter;
use mas_router::Route;
use mas_storage::user::{register_user, start_session};
use mas_templates::{RegisterContext, TemplateContext, Templates};
use serde::Deserialize;
use sqlx::PgPool;
use super::{LoginRequest, PostAuthAction};
#[derive(Deserialize)]
pub(crate) struct RegisterRequest {
#[serde(flatten)]
post_auth_action: Option<PostAuthAction>,
}
impl From<PostAuthAction> for RegisterRequest {
fn from(post_auth_action: PostAuthAction) -> Self {
Self {
post_auth_action: Some(post_auth_action),
}
}
}
impl RegisterRequest {
pub fn as_link(&self) -> Cow<'static, str> {
if let Some(next) = &self.post_auth_action {
let qs = serde_urlencoded::to_string(next).unwrap();
Cow::Owned(format!("/register?{}", qs))
} else {
Cow::Borrowed("/register")
}
}
pub fn go(&self) -> Redirect {
Redirect::to(&self.as_link())
}
fn redirect(self) -> Redirect {
if let Some(action) = self.post_auth_action {
action.redirect()
} else {
Redirect::to("/")
}
}
}
use super::shared::OptionalPostAuthAction;
#[derive(Deserialize)]
pub(crate) struct RegisterForm {
@@ -81,7 +43,7 @@ pub(crate) struct RegisterForm {
pub(crate) async fn get(
Extension(templates): Extension<Templates>,
Extension(pool): Extension<PgPool>,
Query(query): Query<RegisterRequest>,
Query(query): Query<OptionalPostAuthAction>,
cookie_jar: PrivateCookieJar<Encrypter>,
) -> Result<Response, FancyError> {
let mut conn = pool
@@ -98,21 +60,20 @@ pub(crate) async fn get(
.map_err(fancy_error(templates.clone()))?;
if maybe_session.is_some() {
let response = query.redirect().into_response();
Ok(response)
let reply = query.go_next();
Ok((cookie_jar, reply).into_response())
} else {
let ctx = RegisterContext::default();
let ctx = match &query.post_auth_action {
Some(next) => {
let next = next
.load_context(&mut conn)
.await
.map_err(fancy_error(templates.clone()))?;
ctx.with_post_action(next)
}
None => ctx,
let next = query
.load_context(&mut conn)
.await
.map_err(fancy_error(templates.clone()))?;
let ctx = if let Some(next) = next {
ctx.with_post_action(next)
} else {
ctx
};
let login_link = LoginRequest::from(query.post_auth_action).as_link();
let login_link = mas_router::Login::from(query.post_auth_action).relative_url();
let ctx = ctx.with_login_link(login_link.to_string());
let ctx = ctx.with_csrf(csrf_token.form_value());
@@ -128,7 +89,7 @@ pub(crate) async fn get(
pub(crate) async fn post(
Extension(templates): Extension<Templates>,
Extension(pool): Extension<PgPool>,
Query(query): Query<RegisterRequest>,
Query(query): Query<OptionalPostAuthAction>,
cookie_jar: PrivateCookieJar<Encrypter>,
Form(form): Form<ProtectedForm<RegisterForm>>,
) -> Result<Response, FancyError> {
@@ -155,6 +116,6 @@ pub(crate) async fn post(
txn.commit().await.map_err(fancy_error(templates.clone()))?;
let cookie_jar = cookie_jar.set_session(&session);
let reply = query.redirect();
let reply = query.go_next();
Ok((cookie_jar, reply).into_response())
}

View File

@@ -12,62 +12,36 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use axum::response::Redirect;
use mas_data_model::AuthorizationGrant;
use mas_storage::{oauth2::authorization_grant::get_grant_by_id, PostgresqlBackend};
use mas_router::{PostAuthAction, Route};
use mas_storage::oauth2::authorization_grant::get_grant_by_id;
use mas_templates::PostAuthContext;
use serde::{Deserialize, Serialize};
use sqlx::PgConnection;
#[derive(Deserialize, Serialize, Clone, Debug)]
#[serde(rename_all = "snake_case", tag = "next")]
pub(crate) enum PostAuthAction {
ContinueAuthorizationGrant {
#[serde(deserialize_with = "serde_with::rust::display_fromstr::deserialize")]
data: i64,
},
#[derive(Serialize, Deserialize, Default, Debug, Clone)]
pub(crate) struct OptionalPostAuthAction {
#[serde(flatten)]
pub post_auth_action: Option<PostAuthAction>,
}
impl PostAuthAction {
pub fn continue_grant(grant: &AuthorizationGrant<PostgresqlBackend>) -> Self {
Self::ContinueAuthorizationGrant { data: grant.data }
impl OptionalPostAuthAction {
pub fn go_next(&self) -> axum::response::Redirect {
self.post_auth_action.as_ref().map_or_else(
|| mas_router::Index.go(),
mas_router::PostAuthAction::go_next,
)
}
pub fn redirect(&self) -> Redirect {
match self {
PostAuthAction::ContinueAuthorizationGrant { data } => {
let url = format!("/authorize/{}", data);
Redirect::to(&url)
}
}
}
pub async fn load_context<'e>(
&self,
conn: &mut PgConnection,
) -> anyhow::Result<PostAuthContext> {
match self {
Self::ContinueAuthorizationGrant { data } => {
) -> anyhow::Result<Option<PostAuthContext>> {
match &self.post_auth_action {
Some(PostAuthAction::ContinueAuthorizationGrant { data }) => {
let grant = get_grant_by_id(conn, *data).await?;
let grant = grant.into();
Ok(PostAuthContext::ContinueAuthorizationGrant { grant })
Ok(Some(PostAuthContext::ContinueAuthorizationGrant { grant }))
}
None => Ok(None),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_post_auth_action() {
let action: PostAuthAction =
serde_urlencoded::from_str("next=continue_authorization_grant&data=123").unwrap();
assert!(matches!(
action,
PostAuthAction::ContinueAuthorizationGrant { data: 123 }
));
}
}