1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-31 09:24:31 +03:00

Have a unified URL builder/router

This commit is contained in:
Quentin Gliech
2022-05-10 09:52:27 +02:00
parent 0ac4fddee4
commit f4353b660e
28 changed files with 684 additions and 371 deletions

14
Cargo.lock generated
View File

@ -2014,11 +2014,11 @@ dependencies = [
"futures 0.3.21", "futures 0.3.21",
"hyper", "hyper",
"indoc", "indoc",
"mas-axum-utils",
"mas-config", "mas-config",
"mas-email", "mas-email",
"mas-handlers", "mas-handlers",
"mas-http", "mas-http",
"mas-router",
"mas-static-files", "mas-static-files",
"mas-storage", "mas-storage",
"mas-tasks", "mas-tasks",
@ -2127,6 +2127,7 @@ dependencies = [
"mas-http", "mas-http",
"mas-iana", "mas-iana",
"mas-jose", "mas-jose",
"mas-router",
"mas-storage", "mas-storage",
"mas-templates", "mas-templates",
"mime", "mime",
@ -2236,6 +2237,17 @@ dependencies = [
"url", "url",
] ]
[[package]]
name = "mas-router"
version = "0.1.0"
dependencies = [
"axum",
"serde",
"serde_urlencoded",
"serde_with",
"url",
]
[[package]] [[package]]
name = "mas-static-files" name = "mas-static-files"
version = "0.1.0" version = "0.1.0"

View File

@ -17,12 +17,10 @@ pub mod cookies;
pub mod csrf; pub mod csrf;
pub mod fancy_error; pub mod fancy_error;
pub mod session; pub mod session;
pub mod url_builder;
pub mod user_authorization; pub mod user_authorization;
pub use self::{ pub use self::{
cookies::CookieExt, cookies::CookieExt,
fancy_error::{fancy_error, internal_error, FancyError}, fancy_error::{fancy_error, internal_error, FancyError},
session::{SessionInfo, SessionInfoExt}, session::{SessionInfo, SessionInfoExt},
url_builder::UrlBuilder,
}; };

View File

@ -32,11 +32,11 @@ opentelemetry-jaeger = { version = "0.16.0", features = ["rt-tokio", "reqwest_co
opentelemetry-otlp = { version = "0.10.0", features = ["trace", "metrics"], optional = true } opentelemetry-otlp = { version = "0.10.0", features = ["trace", "metrics"], optional = true }
opentelemetry-zipkin = { version = "0.15.0", features = ["reqwest-client", "reqwest-rustls"], default-features = false, optional = true } opentelemetry-zipkin = { version = "0.15.0", features = ["reqwest-client", "reqwest-rustls"], default-features = false, optional = true }
mas-axum-utils = { path = "../axum-utils" }
mas-config = { path = "../config" } mas-config = { path = "../config" }
mas-email = { path = "../email" } mas-email = { path = "../email" }
mas-handlers = { path = "../handlers" } mas-handlers = { path = "../handlers" }
mas-http = { path = "../http" } mas-http = { path = "../http" }
mas-router = { path = "../router" }
mas-static-files = { path = "../static-files" } mas-static-files = { path = "../static-files" }
mas-storage = { path = "../storage" } mas-storage = { path = "../storage" }
mas-tasks = { path = "../tasks" } mas-tasks = { path = "../tasks" }

View File

@ -22,10 +22,10 @@ use anyhow::Context;
use clap::Parser; use clap::Parser;
use futures::{future::TryFutureExt, stream::TryStreamExt}; use futures::{future::TryFutureExt, stream::TryStreamExt};
use hyper::Server; use hyper::Server;
use mas_axum_utils::UrlBuilder;
use mas_config::RootConfig; use mas_config::RootConfig;
use mas_email::{MailTransport, Mailer}; use mas_email::{MailTransport, Mailer};
use mas_http::ServerLayer; use mas_http::ServerLayer;
use mas_router::UrlBuilder;
use mas_storage::MIGRATOR; use mas_storage::MIGRATOR;
use mas_tasks::TaskQueue; use mas_tasks::TaskQueue;
use mas_templates::Templates; use mas_templates::Templates;

View File

@ -64,6 +64,7 @@ mas-iana = { path = "../iana" }
mas-jose = { path = "../jose" } mas-jose = { path = "../jose" }
mas-storage = { path = "../storage" } mas-storage = { path = "../storage" }
mas-templates = { path = "../templates" } mas-templates = { path = "../templates" }
mas-router = { path = "../router" }
[dev-dependencies] [dev-dependencies]
indoc = "1.0.6" indoc = "1.0.6"

View File

@ -28,11 +28,11 @@ use axum::{
Router, Router,
}; };
use hyper::header::{ACCEPT, ACCEPT_LANGUAGE, AUTHORIZATION, CONTENT_LANGUAGE, CONTENT_TYPE}; use hyper::header::{ACCEPT, ACCEPT_LANGUAGE, AUTHORIZATION, CONTENT_LANGUAGE, CONTENT_TYPE};
use mas_axum_utils::UrlBuilder;
use mas_config::Encrypter; use mas_config::Encrypter;
use mas_email::Mailer; use mas_email::Mailer;
use mas_http::CorsLayerExt; use mas_http::CorsLayerExt;
use mas_jose::StaticKeystore; use mas_jose::StaticKeystore;
use mas_router::{Route, UrlBuilder};
use mas_templates::Templates; use mas_templates::Templates;
use sqlx::PgPool; use sqlx::PgPool;
use tower_http::cors::{Any, CorsLayer}; use tower_http::cors::{Any, CorsLayer};
@ -58,25 +58,34 @@ where
// All those routes are API-like, with a common CORS layer // All those routes are API-like, with a common CORS layer
let api_router = Router::new() let api_router = Router::new()
.route( .route(
"/.well-known/openid-configuration", mas_router::OidcConfiguration::route(),
get(self::oauth2::discovery::get), get(self::oauth2::discovery::get),
) )
.route("/.well-known/webfinger", get(self::oauth2::webfinger::get))
.route("/oauth2/keys.json", get(self::oauth2::keys::get))
.route( .route(
"/oauth2/userinfo", mas_router::Webfinger::route(),
get(self::oauth2::webfinger::get),
)
.route(
mas_router::OAuth2Keys::route(),
get(self::oauth2::keys::get),
)
.route(
mas_router::OidcUserinfo::route(),
on( on(
MethodFilter::POST | MethodFilter::GET, MethodFilter::POST | MethodFilter::GET,
self::oauth2::userinfo::get, self::oauth2::userinfo::get,
), ),
) )
.route( .route(
"/oauth2/introspect", mas_router::OAuth2Introspection::route(),
post(self::oauth2::introspection::post), post(self::oauth2::introspection::post),
) )
.route("/oauth2/token", post(self::oauth2::token::post))
.route( .route(
"/oauth2/registration", mas_router::OAuth2TokenEndpoint::route(),
post(self::oauth2::token::post),
)
.route(
mas_router::OAuth2RegistrationEndpoint::route(),
post(self::oauth2::registration::post), post(self::oauth2::registration::post),
) )
.layer( .layer(
@ -94,38 +103,44 @@ where
); );
Router::new() Router::new()
.route("/", get(self::views::index::get)) .route(mas_router::Index::route(), get(self::views::index::get))
.route("/health", get(self::health::get)) .route(mas_router::Healthcheck::route(), get(self::health::get))
.route( .route(
"/login", mas_router::Login::route(),
get(self::views::login::get).post(self::views::login::post), get(self::views::login::get).post(self::views::login::post),
) )
.route("/logout", post(self::views::logout::post)) .route(mas_router::Logout::route(), post(self::views::logout::post))
.route( .route(
"/reauth", mas_router::Reauth::route(),
get(self::views::reauth::get).post(self::views::reauth::post), get(self::views::reauth::get).post(self::views::reauth::post),
) )
.route( .route(
"/register", mas_router::Register::route(),
get(self::views::register::get).post(self::views::register::post), get(self::views::register::get).post(self::views::register::post),
) )
.route("/verify/:code", get(self::views::verify::get))
.route("/account", get(self::views::account::get))
.route( .route(
"/account/password", mas_router::VerifyEmail::route(),
get(self::views::verify::get),
)
.route(mas_router::Account::route(), get(self::views::account::get))
.route(
mas_router::AccountPassword::route(),
get(self::views::account::password::get).post(self::views::account::password::post), get(self::views::account::password::get).post(self::views::account::password::post),
) )
.route( .route(
"/account/emails", mas_router::AccountEmails::route(),
get(self::views::account::emails::get).post(self::views::account::emails::post), get(self::views::account::emails::get).post(self::views::account::emails::post),
) )
.route("/authorize", get(self::oauth2::authorization::get))
.route( .route(
"/authorize/:grant_id", mas_router::OAuth2AuthorizationEndpoint::route(),
get(self::oauth2::authorization::get),
)
.route(
mas_router::ContinueAuthorizationGrant::route(),
get(self::oauth2::authorization::complete::get), get(self::oauth2::authorization::complete::get),
) )
.route( .route(
"/consent/:grant_id", mas_router::Consent::route(),
get(self::oauth2::consent::get).post(self::oauth2::consent::post), get(self::oauth2::consent::get).post(self::oauth2::consent::post),
) )
.merge(api_router) .merge(api_router)

View File

@ -24,6 +24,7 @@ use hyper::StatusCode;
use mas_axum_utils::SessionInfoExt; use mas_axum_utils::SessionInfoExt;
use mas_config::Encrypter; use mas_config::Encrypter;
use mas_data_model::{AuthorizationGrant, BrowserSession, TokenType}; use mas_data_model::{AuthorizationGrant, BrowserSession, TokenType};
use mas_router::{PostAuthAction, Route};
use mas_storage::{ use mas_storage::{
oauth2::{ oauth2::{
access_token::add_access_token, access_token::add_access_token,
@ -41,10 +42,6 @@ use sqlx::{PgPool, Postgres, Transaction};
use thiserror::Error; use thiserror::Error;
use super::callback::{CallbackDestination, CallbackDestinationError, InvalidRedirectUriError}; use super::callback::{CallbackDestination, CallbackDestinationError, InvalidRedirectUriError};
use crate::{
oauth2::consent::ConsentRequest,
views::{LoginRequest, PostAuthAction, ReauthRequest},
};
#[derive(Debug, Error)] #[derive(Debug, Error)]
pub enum RouteError { pub enum RouteError {
@ -122,15 +119,14 @@ pub(crate) async fn get(
let grant = get_grant_by_id(&mut txn, grant_id).await?; let grant = get_grant_by_id(&mut txn, grant_id).await?;
let callback_destination = CallbackDestination::try_from(&grant)?; let callback_destination = CallbackDestination::try_from(&grant)?;
let continue_grant = PostAuthAction::continue_grant(&grant); let continue_grant = PostAuthAction::continue_grant(grant_id);
let consent_request = ConsentRequest::for_grant(&grant);
let session = if let Some(session) = maybe_session { let session = if let Some(session) = maybe_session {
session session
} else { } else {
// If there is no session, redirect to the login screen, redirecting here after // If there is no session, redirect to the login screen, redirecting here after
// logout // logout
return Ok((cookie_jar, LoginRequest::from(continue_grant).go()).into_response()); return Ok((cookie_jar, mas_router::Login::and_then(continue_grant).go()).into_response());
}; };
match complete(grant, session, txn).await { match complete(grant, session, txn).await {
@ -138,11 +134,14 @@ pub(crate) async fn get(
let res = callback_destination.go(&templates, params).await?; let res = callback_destination.go(&templates, params).await?;
Ok((cookie_jar, res).into_response()) Ok((cookie_jar, res).into_response())
} }
Err(GrantCompletionError::RequiresReauth) => { Err(GrantCompletionError::RequiresReauth) => Ok((
Ok((cookie_jar, ReauthRequest::from(continue_grant).go()).into_response()) cookie_jar,
} mas_router::Reauth::and_then(continue_grant).go(),
)
.into_response()),
Err(GrantCompletionError::RequiresConsent) => { Err(GrantCompletionError::RequiresConsent) => {
Ok((cookie_jar, consent_request.go()).into_response()) let next = mas_router::Consent(grant_id);
Ok((cookie_jar, next.go()).into_response())
} }
Err(GrantCompletionError::NotPending) => Err(RouteError::NotPending), Err(GrantCompletionError::NotPending) => Err(RouteError::NotPending),
Err(GrantCompletionError::Internal(e)) => Err(RouteError::Internal(e)), Err(GrantCompletionError::Internal(e)) => Err(RouteError::Internal(e)),

View File

@ -23,6 +23,7 @@ use mas_axum_utils::SessionInfoExt;
use mas_config::Encrypter; use mas_config::Encrypter;
use mas_data_model::{AuthorizationCode, Pkce}; use mas_data_model::{AuthorizationCode, Pkce};
use mas_iana::oauth::OAuthAuthorizationEndpointResponseType; use mas_iana::oauth::OAuthAuthorizationEndpointResponseType;
use mas_router::{PostAuthAction, Route};
use mas_storage::oauth2::{ use mas_storage::oauth2::{
authorization_grant::new_authorization_grant, authorization_grant::new_authorization_grant,
client::{lookup_client_by_client_id, ClientFetchError}, client::{lookup_client_by_client_id, ClientFetchError},
@ -45,8 +46,6 @@ use sqlx::PgPool;
use thiserror::Error; use thiserror::Error;
use self::{callback::CallbackDestination, complete::GrantCompletionError}; use self::{callback::CallbackDestination, complete::GrantCompletionError};
use super::consent::ConsentRequest;
use crate::views::{LoginRequest, PostAuthAction, ReauthRequest, RegisterRequest};
mod callback; mod callback;
pub mod complete; pub mod complete;
@ -287,8 +286,7 @@ pub(crate) async fn get(
requires_consent, requires_consent,
) )
.await?; .await?;
let continue_grant = PostAuthAction::continue_grant(&grant); let continue_grant = PostAuthAction::continue_grant(grant.data);
let consent_request = ConsentRequest::for_grant(&grant);
let res = match (maybe_session, params.auth.prompt) { let res = match (maybe_session, params.auth.prompt) {
// Cases where there is no active session, redirect to the relevant page // Cases where there is no active session, redirect to the relevant page
@ -300,13 +298,17 @@ pub(crate) async fn get(
// Client asked for a registration, show the registration prompt // Client asked for a registration, show the registration prompt
txn.commit().await?; txn.commit().await?;
RegisterRequest::from(continue_grant).go().into_response() mas_router::Register::and_then(continue_grant)
.go()
.into_response()
} }
(None, _) => { (None, _) => {
// Other cases where we don't have a session, ask for a login // Other cases where we don't have a session, ask for a login
txn.commit().await?; txn.commit().await?;
LoginRequest::from(continue_grant).go().into_response() mas_router::Login::and_then(continue_grant)
.go()
.into_response()
} }
// Special case when we already have a sesion but prompt=login|select_account // Special case when we already have a sesion but prompt=login|select_account
@ -314,7 +316,9 @@ pub(crate) async fn get(
// TODO: better pages here // TODO: better pages here
txn.commit().await?; txn.commit().await?;
ReauthRequest::from(continue_grant).go().into_response() mas_router::Reauth::and_then(continue_grant)
.go()
.into_response()
} }
// Else, we immediately try to complete the authorization grant // Else, we immediately try to complete the authorization grant
@ -343,14 +347,17 @@ pub(crate) async fn get(
} }
} }
(Some(user_session), _) => { (Some(user_session), _) => {
let grant_id = grant.data;
// Else, we show the relevant reauth/consent page if necessary // Else, we show the relevant reauth/consent page if necessary
match self::complete::complete(grant, user_session, txn).await { match self::complete::complete(grant, user_session, txn).await {
Ok(params) => callback_destination.go(&templates, params).await?, Ok(params) => callback_destination.go(&templates, params).await?,
Err(GrantCompletionError::RequiresConsent) => { Err(GrantCompletionError::RequiresConsent) => {
consent_request.go().into_response() mas_router::Consent(grant_id).go().into_response()
} }
Err(GrantCompletionError::RequiresReauth) => { Err(GrantCompletionError::RequiresReauth) => {
ReauthRequest::from(continue_grant).go().into_response() mas_router::Reauth::and_then(continue_grant)
.go()
.into_response()
} }
Err(GrantCompletionError::Anyhow(a)) => return Err(RouteError::Anyhow(a)), Err(GrantCompletionError::Anyhow(a)) => return Err(RouteError::Anyhow(a)),
Err(GrantCompletionError::Internal(e)) => { Err(GrantCompletionError::Internal(e)) => {

View File

@ -15,7 +15,7 @@
use anyhow::Context; use anyhow::Context;
use axum::{ use axum::{
extract::{Extension, Form, Path}, extract::{Extension, Form, Path},
response::{Html, IntoResponse, Redirect, Response}, response::{Html, IntoResponse, Response},
}; };
use axum_extra::extract::PrivateCookieJar; use axum_extra::extract::PrivateCookieJar;
use hyper::StatusCode; use hyper::StatusCode;
@ -24,20 +24,16 @@ use mas_axum_utils::{
SessionInfoExt, SessionInfoExt,
}; };
use mas_config::Encrypter; use mas_config::Encrypter;
use mas_data_model::{AuthorizationGrant, AuthorizationGrantStage}; use mas_data_model::AuthorizationGrantStage;
use mas_storage::{ use mas_router::{PostAuthAction, Route};
oauth2::{ use mas_storage::oauth2::{
authorization_grant::{get_grant_by_id, give_consent_to_grant}, authorization_grant::{get_grant_by_id, give_consent_to_grant},
consent::insert_client_consent, consent::insert_client_consent,
},
PostgresqlBackend,
}; };
use mas_templates::{ConsentContext, TemplateContext, Templates}; use mas_templates::{ConsentContext, TemplateContext, Templates};
use sqlx::PgPool; use sqlx::PgPool;
use thiserror::Error; use thiserror::Error;
use crate::views::{LoginRequest, PostAuthAction};
#[derive(Debug, Error)] #[derive(Debug, Error)]
pub enum RouteError { pub enum RouteError {
#[error(transparent)] #[error(transparent)]
@ -50,23 +46,6 @@ impl IntoResponse for RouteError {
} }
} }
pub(crate) struct ConsentRequest {
grant_id: i64,
}
impl ConsentRequest {
pub fn for_grant(grant: &AuthorizationGrant<PostgresqlBackend>) -> Self {
Self {
grant_id: grant.data,
}
}
pub fn go(&self) -> Redirect {
let uri = format!("/consent/{}", self.grant_id);
Redirect::to(&uri)
}
}
pub(crate) async fn get( pub(crate) async fn get(
Extension(templates): Extension<Templates>, Extension(templates): Extension<Templates>,
Extension(pool): Extension<PgPool>, Extension(pool): Extension<PgPool>,
@ -105,7 +84,7 @@ pub(crate) async fn get(
Ok((cookie_jar, Html(content)).into_response()) Ok((cookie_jar, Html(content)).into_response())
} else { } else {
let login = LoginRequest::from(PostAuthAction::continue_grant(&grant)); let login = mas_router::Login::and_continue_grant(grant_id);
Ok((cookie_jar, login.go()).into_response()) Ok((cookie_jar, login.go()).into_response())
} }
} }
@ -133,12 +112,12 @@ pub(crate) async fn post(
.context("could not load session")?; .context("could not load session")?;
let grant = get_grant_by_id(&mut txn, grant_id).await?; let grant = get_grant_by_id(&mut txn, grant_id).await?;
let next = PostAuthAction::continue_grant(&grant); let next = PostAuthAction::continue_grant(grant_id);
let session = if let Some(session) = maybe_session { let session = if let Some(session) = maybe_session {
session session
} else { } else {
let login = LoginRequest::from(next); let login = mas_router::Login::and_then(next);
return Ok((cookie_jar, login.go()).into_response()); return Ok((cookie_jar, login.go()).into_response());
}; };
@ -163,5 +142,5 @@ pub(crate) async fn post(
txn.commit().await.context("could not commit txn")?; txn.commit().await.context("could not commit txn")?;
Ok((cookie_jar, next.redirect()).into_response()) Ok((cookie_jar, next.go_next()).into_response())
} }

View File

@ -15,7 +15,6 @@
use std::sync::Arc; use std::sync::Arc;
use axum::{extract::Extension, response::IntoResponse, Json}; use axum::{extract::Extension, response::IntoResponse, Json};
use mas_axum_utils::UrlBuilder;
use mas_iana::{ use mas_iana::{
jose::JsonWebSignatureAlg, jose::JsonWebSignatureAlg,
oauth::{ oauth::{
@ -24,6 +23,7 @@ use mas_iana::{
}, },
}; };
use mas_jose::{SigningKeystore, StaticKeystore}; use mas_jose::{SigningKeystore, StaticKeystore};
use mas_router::UrlBuilder;
use oauth2_types::{ use oauth2_types::{
oidc::{ClaimType, Metadata, SubjectType}, oidc::{ClaimType, Metadata, SubjectType},
requests::{Display, GrantType, Prompt, ResponseMode}, requests::{Display, GrantType, Prompt, ResponseMode},

View File

@ -20,10 +20,7 @@ use chrono::{DateTime, Duration, Utc};
use data_encoding::BASE64URL_NOPAD; use data_encoding::BASE64URL_NOPAD;
use headers::{CacheControl, HeaderMap, HeaderMapExt, Pragma}; use headers::{CacheControl, HeaderMap, HeaderMapExt, Pragma};
use hyper::StatusCode; use hyper::StatusCode;
use mas_axum_utils::{ use mas_axum_utils::client_authorization::{ClientAuthorization, CredentialsVerificationError};
client_authorization::{ClientAuthorization, CredentialsVerificationError},
UrlBuilder,
};
use mas_config::Encrypter; use mas_config::Encrypter;
use mas_data_model::{AuthorizationGrantStage, Client, TokenType}; use mas_data_model::{AuthorizationGrantStage, Client, TokenType};
use mas_iana::jose::JsonWebSignatureAlg; use mas_iana::jose::JsonWebSignatureAlg;
@ -31,6 +28,7 @@ use mas_jose::{
claims::{self, ClaimError}, claims::{self, ClaimError},
DecodedJsonWebToken, SigningKeystore, StaticKeystore, DecodedJsonWebToken, SigningKeystore, StaticKeystore,
}; };
use mas_router::UrlBuilder;
use mas_storage::{ use mas_storage::{
oauth2::{ oauth2::{
access_token::{add_access_token, revoke_access_token}, access_token::{add_access_token, revoke_access_token},

View File

@ -21,8 +21,9 @@ use axum::{
}; };
use headers::ContentType; use headers::ContentType;
use hyper::StatusCode; use hyper::StatusCode;
use mas_axum_utils::{internal_error, user_authorization::UserAuthorization, UrlBuilder}; use mas_axum_utils::{internal_error, user_authorization::UserAuthorization};
use mas_jose::{DecodedJsonWebToken, SigningKeystore, StaticKeystore}; use mas_jose::{DecodedJsonWebToken, SigningKeystore, StaticKeystore};
use mas_router::UrlBuilder;
use mime::Mime; use mime::Mime;
use oauth2_types::scope; use oauth2_types::scope;
use serde::Serialize; use serde::Serialize;

View File

@ -14,7 +14,7 @@
use axum::{extract::Query, response::IntoResponse, Extension, Json, TypedHeader}; use axum::{extract::Query, response::IntoResponse, Extension, Json, TypedHeader};
use headers::ContentType; use headers::ContentType;
use mas_axum_utils::UrlBuilder; use mas_router::UrlBuilder;
use oauth2_types::webfinger::WebFingerResponse; use oauth2_types::webfinger::WebFingerResponse;
use serde::Deserialize; use serde::Deserialize;

View File

@ -20,11 +20,12 @@ use axum_extra::extract::PrivateCookieJar;
use lettre::{message::Mailbox, Address}; use lettre::{message::Mailbox, Address};
use mas_axum_utils::{ use mas_axum_utils::{
csrf::{CsrfExt, ProtectedForm}, csrf::{CsrfExt, ProtectedForm},
fancy_error, FancyError, SessionInfoExt, UrlBuilder, fancy_error, FancyError, SessionInfoExt,
}; };
use mas_config::Encrypter; use mas_config::Encrypter;
use mas_data_model::{BrowserSession, User, UserEmail}; use mas_data_model::{BrowserSession, User, UserEmail};
use mas_email::Mailer; use mas_email::Mailer;
use mas_router::{Route, UrlBuilder};
use mas_storage::{ use mas_storage::{
user::{ user::{
add_user_email, add_user_email_verification_code, get_user_email, get_user_emails, add_user_email, add_user_email_verification_code, get_user_email, get_user_emails,
@ -38,8 +39,6 @@ use serde::Deserialize;
use sqlx::{PgExecutor, PgPool}; use sqlx::{PgExecutor, PgPool};
use tracing::info; use tracing::info;
use crate::views::LoginRequest;
#[derive(Deserialize, Debug)] #[derive(Deserialize, Debug)]
#[serde(tag = "action", rename_all = "snake_case")] #[serde(tag = "action", rename_all = "snake_case")]
pub enum ManagementForm { pub enum ManagementForm {
@ -69,7 +68,7 @@ pub(crate) async fn get(
if let Some(session) = maybe_session { if let Some(session) = maybe_session {
render(templates, session, cookie_jar, &mut conn).await render(templates, session, cookie_jar, &mut conn).await
} else { } else {
let login = LoginRequest::default(); let login = mas_router::Login::default();
Ok((cookie_jar, login.go()).into_response()) Ok((cookie_jar, login.go()).into_response())
} }
} }
@ -119,7 +118,7 @@ async fn start_email_verification(
let mailbox = Mailbox::new(Some(user.username.clone()), address); let mailbox = Mailbox::new(Some(user.username.clone()), address);
let link = url_builder.email_verification(&code); let link = url_builder.email_verification(code);
let context = EmailVerificationContext::new(user.clone().into(), link); let context = EmailVerificationContext::new(user.clone().into(), link);
@ -149,7 +148,7 @@ pub(crate) async fn post(
let mut session = if let Some(session) = maybe_session { let mut session = if let Some(session) = maybe_session {
session session
} else { } else {
let login = LoginRequest::default(); let login = mas_router::Login::default();
return Ok((cookie_jar, login.go()).into_response()); return Ok((cookie_jar, login.go()).into_response());
}; };

View File

@ -22,12 +22,11 @@ use axum::{
use axum_extra::extract::PrivateCookieJar; use axum_extra::extract::PrivateCookieJar;
use mas_axum_utils::{csrf::CsrfExt, fancy_error, FancyError, SessionInfoExt}; use mas_axum_utils::{csrf::CsrfExt, fancy_error, FancyError, SessionInfoExt};
use mas_config::Encrypter; use mas_config::Encrypter;
use mas_router::Route;
use mas_storage::user::{count_active_sessions, get_user_emails}; use mas_storage::user::{count_active_sessions, get_user_emails};
use mas_templates::{AccountContext, TemplateContext, Templates}; use mas_templates::{AccountContext, TemplateContext, Templates};
use sqlx::PgPool; use sqlx::PgPool;
use super::LoginRequest;
pub(crate) async fn get( pub(crate) async fn get(
Extension(templates): Extension<Templates>, Extension(templates): Extension<Templates>,
Extension(pool): Extension<PgPool>, Extension(pool): Extension<PgPool>,
@ -49,7 +48,7 @@ pub(crate) async fn get(
let session = if let Some(session) = maybe_session { let session = if let Some(session) = maybe_session {
session session
} else { } else {
let login = LoginRequest::default(); let login = mas_router::Login::default();
return Ok((cookie_jar, login.go()).into_response()); return Ok((cookie_jar, login.go()).into_response());
}; };

View File

@ -24,6 +24,7 @@ use mas_axum_utils::{
}; };
use mas_config::Encrypter; use mas_config::Encrypter;
use mas_data_model::BrowserSession; use mas_data_model::BrowserSession;
use mas_router::Route;
use mas_storage::{ use mas_storage::{
user::{authenticate_session, set_password}, user::{authenticate_session, set_password},
PostgresqlBackend, PostgresqlBackend,
@ -32,8 +33,6 @@ use mas_templates::{EmptyContext, TemplateContext, Templates};
use serde::Deserialize; use serde::Deserialize;
use sqlx::PgPool; use sqlx::PgPool;
use crate::views::LoginRequest;
#[derive(Deserialize)] #[derive(Deserialize)]
pub struct ChangeForm { pub struct ChangeForm {
current_password: String, current_password: String,
@ -61,7 +60,7 @@ pub(crate) async fn get(
if let Some(session) = maybe_session { if let Some(session) = maybe_session {
render(templates, session, cookie_jar).await render(templates, session, cookie_jar).await
} else { } else {
let login = LoginRequest::default(); let login = mas_router::Login::default();
Ok((cookie_jar, login.go()).into_response()) Ok((cookie_jar, login.go()).into_response())
} }
} }
@ -107,7 +106,7 @@ pub(crate) async fn post(
let mut session = if let Some(session) = maybe_session { let mut session = if let Some(session) = maybe_session {
session session
} else { } else {
let login = LoginRequest::default(); let login = mas_router::Login::default();
return Ok((cookie_jar, login.go()).into_response()); return Ok((cookie_jar, login.go()).into_response());
}; };

View File

@ -17,8 +17,9 @@ use axum::{
response::{Html, IntoResponse}, response::{Html, IntoResponse},
}; };
use axum_extra::extract::PrivateCookieJar; use axum_extra::extract::PrivateCookieJar;
use mas_axum_utils::{csrf::CsrfExt, fancy_error, FancyError, SessionInfoExt, UrlBuilder}; use mas_axum_utils::{csrf::CsrfExt, fancy_error, FancyError, SessionInfoExt};
use mas_config::Encrypter; use mas_config::Encrypter;
use mas_router::UrlBuilder;
use mas_templates::{IndexContext, TemplateContext, Templates}; use mas_templates::{IndexContext, TemplateContext, Templates};
use sqlx::PgPool; use sqlx::PgPool;

View File

@ -12,11 +12,9 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use std::borrow::Cow;
use axum::{ use axum::{
extract::{Extension, Form, Query}, extract::{Extension, Form, Query},
response::{Html, IntoResponse, Redirect, Response}, response::{Html, IntoResponse, Response},
}; };
use axum_extra::extract::PrivateCookieJar; use axum_extra::extract::PrivateCookieJar;
use mas_axum_utils::{ use mas_axum_utils::{
@ -25,53 +23,13 @@ use mas_axum_utils::{
}; };
use mas_config::Encrypter; use mas_config::Encrypter;
use mas_data_model::errors::WrapFormError; use mas_data_model::errors::WrapFormError;
use mas_router::Route;
use mas_storage::user::login; use mas_storage::user::login;
use mas_templates::{LoginContext, LoginFormField, TemplateContext, Templates}; use mas_templates::{LoginContext, LoginFormField, TemplateContext, Templates};
use serde::Deserialize; use serde::Deserialize;
use sqlx::PgPool; use sqlx::PgPool;
use super::{shared::PostAuthAction, RegisterRequest}; use super::shared::OptionalPostAuthAction;
#[derive(Deserialize, Default, Debug)]
pub(crate) struct LoginRequest {
#[serde(flatten)]
post_auth_action: Option<PostAuthAction>,
}
impl From<PostAuthAction> for LoginRequest {
fn from(post_auth_action: PostAuthAction) -> Self {
Some(post_auth_action).into()
}
}
impl From<Option<PostAuthAction>> for LoginRequest {
fn from(post_auth_action: Option<PostAuthAction>) -> Self {
Self { post_auth_action }
}
}
impl LoginRequest {
pub fn as_link(&self) -> Cow<'static, str> {
if let Some(next) = &self.post_auth_action {
let qs = serde_urlencoded::to_string(next).unwrap();
Cow::Owned(format!("/login?{}", qs))
} else {
Cow::Borrowed("/login")
}
}
pub fn go(&self) -> Redirect {
Redirect::to(&self.as_link())
}
fn redirect(self) -> Redirect {
if let Some(action) = self.post_auth_action {
action.redirect()
} else {
Redirect::to("/")
}
}
}
#[derive(Deserialize)] #[derive(Deserialize)]
pub(crate) struct LoginForm { pub(crate) struct LoginForm {
@ -83,7 +41,7 @@ pub(crate) struct LoginForm {
pub(crate) async fn get( pub(crate) async fn get(
Extension(templates): Extension<Templates>, Extension(templates): Extension<Templates>,
Extension(pool): Extension<PgPool>, Extension(pool): Extension<PgPool>,
Query(query): Query<LoginRequest>, Query(query): Query<OptionalPostAuthAction>,
cookie_jar: PrivateCookieJar<Encrypter>, cookie_jar: PrivateCookieJar<Encrypter>,
) -> Result<Response, FancyError> { ) -> Result<Response, FancyError> {
let mut conn = pool let mut conn = pool
@ -100,23 +58,23 @@ pub(crate) async fn get(
.map_err(fancy_error(templates.clone()))?; .map_err(fancy_error(templates.clone()))?;
if maybe_session.is_some() { if maybe_session.is_some() {
let response = query.redirect().into_response(); let reply = query.go_next();
Ok(response) Ok((cookie_jar, reply).into_response())
} else { } else {
let ctx = LoginContext::default(); let ctx = LoginContext::default();
let ctx = match query.post_auth_action { let next = query
Some(next) => { .load_context(&mut conn)
let register_link = RegisterRequest::from(next.clone()).as_link(); .await
let next = next .map_err(fancy_error(templates.clone()))?;
.load_context(&mut conn) let ctx = if let Some(next) = next {
.await ctx.with_post_action(next)
.map_err(fancy_error(templates.clone()))?; } else {
ctx.with_post_action(next) ctx
.with_register_link(register_link.to_string())
}
None => ctx,
}; };
let ctx = ctx.with_csrf(csrf_token.form_value()); let register_link = mas_router::Register::from(query.post_auth_action).relative_url();
let ctx = ctx
.with_register_link(register_link.to_string())
.with_csrf(csrf_token.form_value());
let content = templates let content = templates
.render_login(&ctx) .render_login(&ctx)
@ -130,7 +88,7 @@ pub(crate) async fn get(
pub(crate) async fn post( pub(crate) async fn post(
Extension(templates): Extension<Templates>, Extension(templates): Extension<Templates>,
Extension(pool): Extension<PgPool>, Extension(pool): Extension<PgPool>,
Query(query): Query<LoginRequest>, Query(query): Query<OptionalPostAuthAction>,
cookie_jar: PrivateCookieJar<Encrypter>, cookie_jar: PrivateCookieJar<Encrypter>,
Form(form): Form<ProtectedForm<LoginForm>>, Form(form): Form<ProtectedForm<LoginForm>>,
) -> Result<Response, FancyError> { ) -> Result<Response, FancyError> {
@ -150,7 +108,7 @@ pub(crate) async fn post(
match login(&mut conn, &form.username, form.password).await { match login(&mut conn, &form.username, form.password).await {
Ok(session_info) => { Ok(session_info) => {
let cookie_jar = cookie_jar.set_session(&session_info); let cookie_jar = cookie_jar.set_session(&session_info);
let reply = query.redirect(); let reply = query.go_next();
Ok((cookie_jar, reply).into_response()) Ok((cookie_jar, reply).into_response())
} }
Err(e) => { Err(e) => {
@ -172,15 +130,3 @@ pub(crate) async fn post(
} }
} }
} }
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn deserialize_login_request() {
let res: Result<LoginRequest, _> =
serde_urlencoded::from_str("next=continue_authorization_grant&data=13");
res.unwrap().post_auth_action.unwrap();
}
}

View File

@ -20,7 +20,3 @@ pub mod reauth;
pub mod register; pub mod register;
pub mod shared; pub mod shared;
pub mod verify; pub mod verify;
pub(crate) use self::{
login::LoginRequest, reauth::ReauthRequest, register::RegisterRequest, shared::PostAuthAction,
};

View File

@ -12,11 +12,9 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use std::borrow::Cow;
use axum::{ use axum::{
extract::{Extension, Form, Query}, extract::{Extension, Form, Query},
response::{Html, IntoResponse, Redirect, Response}, response::{Html, IntoResponse, Response},
}; };
use axum_extra::extract::PrivateCookieJar; use axum_extra::extract::PrivateCookieJar;
use mas_axum_utils::{ use mas_axum_utils::{
@ -24,49 +22,13 @@ use mas_axum_utils::{
fancy_error, FancyError, SessionInfoExt, fancy_error, FancyError, SessionInfoExt,
}; };
use mas_config::Encrypter; use mas_config::Encrypter;
use mas_router::Route;
use mas_storage::user::authenticate_session; use mas_storage::user::authenticate_session;
use mas_templates::{ReauthContext, TemplateContext, Templates}; use mas_templates::{ReauthContext, TemplateContext, Templates};
use serde::Deserialize; use serde::Deserialize;
use sqlx::PgPool; use sqlx::PgPool;
use super::{LoginRequest, PostAuthAction}; use super::shared::OptionalPostAuthAction;
#[derive(Deserialize)]
pub(crate) struct ReauthRequest {
#[serde(flatten)]
post_auth_action: Option<PostAuthAction>,
}
impl From<PostAuthAction> for ReauthRequest {
fn from(post_auth_action: PostAuthAction) -> Self {
Self {
post_auth_action: Some(post_auth_action),
}
}
}
impl ReauthRequest {
pub fn as_link(&self) -> Cow<'static, str> {
if let Some(next) = &self.post_auth_action {
let qs = serde_urlencoded::to_string(next).unwrap();
Cow::Owned(format!("/reauth?{}", qs))
} else {
Cow::Borrowed("/reauth")
}
}
pub fn go(&self) -> Redirect {
Redirect::to(&self.as_link())
}
fn redirect(self) -> Redirect {
if let Some(action) = self.post_auth_action {
action.redirect()
} else {
Redirect::to("/")
}
}
}
#[derive(Deserialize, Debug)] #[derive(Deserialize, Debug)]
pub(crate) struct ReauthForm { pub(crate) struct ReauthForm {
@ -76,7 +38,7 @@ pub(crate) struct ReauthForm {
pub(crate) async fn get( pub(crate) async fn get(
Extension(templates): Extension<Templates>, Extension(templates): Extension<Templates>,
Extension(pool): Extension<PgPool>, Extension(pool): Extension<PgPool>,
Query(query): Query<ReauthRequest>, Query(query): Query<OptionalPostAuthAction>,
cookie_jar: PrivateCookieJar<Encrypter>, cookie_jar: PrivateCookieJar<Encrypter>,
) -> Result<Response, FancyError> { ) -> Result<Response, FancyError> {
let mut conn = pool let mut conn = pool
@ -97,20 +59,19 @@ pub(crate) async fn get(
} else { } else {
// If there is no session, redirect to the login screen, keeping the // If there is no session, redirect to the login screen, keeping the
// PostAuthAction // PostAuthAction
let login: LoginRequest = query.post_auth_action.into(); let login = mas_router::Login::from(query.post_auth_action);
return Ok((cookie_jar, login.go()).into_response()); return Ok((cookie_jar, login.go()).into_response());
}; };
let ctx = ReauthContext::default(); let ctx = ReauthContext::default();
let ctx = match query.post_auth_action { let next = query
Some(next) => { .load_context(&mut conn)
let next = next .await
.load_context(&mut conn) .map_err(fancy_error(templates.clone()))?;
.await let ctx = if let Some(next) = next {
.map_err(fancy_error(templates.clone()))?; ctx.with_post_action(next)
ctx.with_post_action(next) } else {
} ctx
None => ctx,
}; };
let ctx = ctx.with_session(session).with_csrf(csrf_token.form_value()); let ctx = ctx.with_session(session).with_csrf(csrf_token.form_value());
@ -125,7 +86,7 @@ pub(crate) async fn get(
pub(crate) async fn post( pub(crate) async fn post(
Extension(templates): Extension<Templates>, Extension(templates): Extension<Templates>,
Extension(pool): Extension<PgPool>, Extension(pool): Extension<PgPool>,
Query(query): Query<ReauthRequest>, Query(query): Query<OptionalPostAuthAction>,
cookie_jar: PrivateCookieJar<Encrypter>, cookie_jar: PrivateCookieJar<Encrypter>,
Form(form): Form<ProtectedForm<ReauthForm>>, Form(form): Form<ProtectedForm<ReauthForm>>,
) -> Result<Response, FancyError> { ) -> Result<Response, FancyError> {
@ -147,7 +108,7 @@ pub(crate) async fn post(
} else { } else {
// If there is no session, redirect to the login screen, keeping the // If there is no session, redirect to the login screen, keeping the
// PostAuthAction // PostAuthAction
let login: LoginRequest = query.post_auth_action.into(); let login = mas_router::Login::from(query.post_auth_action);
return Ok((cookie_jar, login.go()).into_response()); return Ok((cookie_jar, login.go()).into_response());
}; };
@ -158,6 +119,6 @@ pub(crate) async fn post(
let cookie_jar = cookie_jar.set_session(&session); let cookie_jar = cookie_jar.set_session(&session);
txn.commit().await.map_err(fancy_error(templates.clone()))?; txn.commit().await.map_err(fancy_error(templates.clone()))?;
let redirection = query.redirect(); let reply = query.go_next();
Ok((cookie_jar, redirection).into_response()) Ok((cookie_jar, reply).into_response())
} }

View File

@ -14,12 +14,10 @@
#![allow(clippy::trait_duplication_in_bounds)] #![allow(clippy::trait_duplication_in_bounds)]
use std::borrow::Cow;
use argon2::Argon2; use argon2::Argon2;
use axum::{ use axum::{
extract::{Extension, Form, Query}, extract::{Extension, Form, Query},
response::{Html, IntoResponse, Redirect, Response}, response::{Html, IntoResponse, Response},
}; };
use axum_extra::extract::PrivateCookieJar; use axum_extra::extract::PrivateCookieJar;
use mas_axum_utils::{ use mas_axum_utils::{
@ -27,49 +25,13 @@ use mas_axum_utils::{
fancy_error, FancyError, SessionInfoExt, fancy_error, FancyError, SessionInfoExt,
}; };
use mas_config::Encrypter; use mas_config::Encrypter;
use mas_router::Route;
use mas_storage::user::{register_user, start_session}; use mas_storage::user::{register_user, start_session};
use mas_templates::{RegisterContext, TemplateContext, Templates}; use mas_templates::{RegisterContext, TemplateContext, Templates};
use serde::Deserialize; use serde::Deserialize;
use sqlx::PgPool; use sqlx::PgPool;
use super::{LoginRequest, PostAuthAction}; use super::shared::OptionalPostAuthAction;
#[derive(Deserialize)]
pub(crate) struct RegisterRequest {
#[serde(flatten)]
post_auth_action: Option<PostAuthAction>,
}
impl From<PostAuthAction> for RegisterRequest {
fn from(post_auth_action: PostAuthAction) -> Self {
Self {
post_auth_action: Some(post_auth_action),
}
}
}
impl RegisterRequest {
pub fn as_link(&self) -> Cow<'static, str> {
if let Some(next) = &self.post_auth_action {
let qs = serde_urlencoded::to_string(next).unwrap();
Cow::Owned(format!("/register?{}", qs))
} else {
Cow::Borrowed("/register")
}
}
pub fn go(&self) -> Redirect {
Redirect::to(&self.as_link())
}
fn redirect(self) -> Redirect {
if let Some(action) = self.post_auth_action {
action.redirect()
} else {
Redirect::to("/")
}
}
}
#[derive(Deserialize)] #[derive(Deserialize)]
pub(crate) struct RegisterForm { pub(crate) struct RegisterForm {
@ -81,7 +43,7 @@ pub(crate) struct RegisterForm {
pub(crate) async fn get( pub(crate) async fn get(
Extension(templates): Extension<Templates>, Extension(templates): Extension<Templates>,
Extension(pool): Extension<PgPool>, Extension(pool): Extension<PgPool>,
Query(query): Query<RegisterRequest>, Query(query): Query<OptionalPostAuthAction>,
cookie_jar: PrivateCookieJar<Encrypter>, cookie_jar: PrivateCookieJar<Encrypter>,
) -> Result<Response, FancyError> { ) -> Result<Response, FancyError> {
let mut conn = pool let mut conn = pool
@ -98,21 +60,20 @@ pub(crate) async fn get(
.map_err(fancy_error(templates.clone()))?; .map_err(fancy_error(templates.clone()))?;
if maybe_session.is_some() { if maybe_session.is_some() {
let response = query.redirect().into_response(); let reply = query.go_next();
Ok(response) Ok((cookie_jar, reply).into_response())
} else { } else {
let ctx = RegisterContext::default(); let ctx = RegisterContext::default();
let ctx = match &query.post_auth_action { let next = query
Some(next) => { .load_context(&mut conn)
let next = next .await
.load_context(&mut conn) .map_err(fancy_error(templates.clone()))?;
.await let ctx = if let Some(next) = next {
.map_err(fancy_error(templates.clone()))?; ctx.with_post_action(next)
ctx.with_post_action(next) } else {
} ctx
None => ctx,
}; };
let login_link = LoginRequest::from(query.post_auth_action).as_link(); let login_link = mas_router::Login::from(query.post_auth_action).relative_url();
let ctx = ctx.with_login_link(login_link.to_string()); let ctx = ctx.with_login_link(login_link.to_string());
let ctx = ctx.with_csrf(csrf_token.form_value()); let ctx = ctx.with_csrf(csrf_token.form_value());
@ -128,7 +89,7 @@ pub(crate) async fn get(
pub(crate) async fn post( pub(crate) async fn post(
Extension(templates): Extension<Templates>, Extension(templates): Extension<Templates>,
Extension(pool): Extension<PgPool>, Extension(pool): Extension<PgPool>,
Query(query): Query<RegisterRequest>, Query(query): Query<OptionalPostAuthAction>,
cookie_jar: PrivateCookieJar<Encrypter>, cookie_jar: PrivateCookieJar<Encrypter>,
Form(form): Form<ProtectedForm<RegisterForm>>, Form(form): Form<ProtectedForm<RegisterForm>>,
) -> Result<Response, FancyError> { ) -> Result<Response, FancyError> {
@ -155,6 +116,6 @@ pub(crate) async fn post(
txn.commit().await.map_err(fancy_error(templates.clone()))?; txn.commit().await.map_err(fancy_error(templates.clone()))?;
let cookie_jar = cookie_jar.set_session(&session); let cookie_jar = cookie_jar.set_session(&session);
let reply = query.redirect(); let reply = query.go_next();
Ok((cookie_jar, reply).into_response()) Ok((cookie_jar, reply).into_response())
} }

View File

@ -12,62 +12,36 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use axum::response::Redirect; use mas_router::{PostAuthAction, Route};
use mas_data_model::AuthorizationGrant; use mas_storage::oauth2::authorization_grant::get_grant_by_id;
use mas_storage::{oauth2::authorization_grant::get_grant_by_id, PostgresqlBackend};
use mas_templates::PostAuthContext; use mas_templates::PostAuthContext;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use sqlx::PgConnection; use sqlx::PgConnection;
#[derive(Deserialize, Serialize, Clone, Debug)] #[derive(Serialize, Deserialize, Default, Debug, Clone)]
#[serde(rename_all = "snake_case", tag = "next")] pub(crate) struct OptionalPostAuthAction {
pub(crate) enum PostAuthAction { #[serde(flatten)]
ContinueAuthorizationGrant { pub post_auth_action: Option<PostAuthAction>,
#[serde(deserialize_with = "serde_with::rust::display_fromstr::deserialize")]
data: i64,
},
} }
impl PostAuthAction { impl OptionalPostAuthAction {
pub fn continue_grant(grant: &AuthorizationGrant<PostgresqlBackend>) -> Self { pub fn go_next(&self) -> axum::response::Redirect {
Self::ContinueAuthorizationGrant { data: grant.data } self.post_auth_action.as_ref().map_or_else(
|| mas_router::Index.go(),
mas_router::PostAuthAction::go_next,
)
} }
pub fn redirect(&self) -> Redirect {
match self {
PostAuthAction::ContinueAuthorizationGrant { data } => {
let url = format!("/authorize/{}", data);
Redirect::to(&url)
}
}
}
pub async fn load_context<'e>( pub async fn load_context<'e>(
&self, &self,
conn: &mut PgConnection, conn: &mut PgConnection,
) -> anyhow::Result<PostAuthContext> { ) -> anyhow::Result<Option<PostAuthContext>> {
match self { match &self.post_auth_action {
Self::ContinueAuthorizationGrant { data } => { Some(PostAuthAction::ContinueAuthorizationGrant { data }) => {
let grant = get_grant_by_id(conn, *data).await?; let grant = get_grant_by_id(conn, *data).await?;
let grant = grant.into(); let grant = grant.into();
Ok(PostAuthContext::ContinueAuthorizationGrant { grant }) Ok(Some(PostAuthContext::ContinueAuthorizationGrant { grant }))
} }
None => Ok(None),
} }
} }
} }
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_post_auth_action() {
let action: PostAuthAction =
serde_urlencoded::from_str("next=continue_authorization_grant&data=123").unwrap();
assert!(matches!(
action,
PostAuthAction::ContinueAuthorizationGrant { data: 123 }
));
}
}

13
crates/router/Cargo.toml Normal file
View File

@ -0,0 +1,13 @@
[package]
name = "mas-router"
version = "0.1.0"
authors = ["Quentin Gliech <quenting@element.io>"]
edition = "2021"
license = "Apache-2.0"
[dependencies]
axum = { version = "0.5.4", default-features = false }
serde = { version = "1.0.137", features = ["derive"] }
serde_urlencoded = "0.7.1"
serde_with = "1.13.0"
url = "2.2.2"

View File

@ -0,0 +1,359 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use serde::{Deserialize, Serialize};
pub use crate::traits::*;
#[derive(Deserialize, Serialize, Clone, Debug)]
#[serde(rename_all = "snake_case", tag = "next")]
pub enum PostAuthAction {
ContinueAuthorizationGrant {
#[serde(deserialize_with = "serde_with::rust::display_fromstr::deserialize")]
data: i64,
},
}
impl PostAuthAction {
#[must_use]
pub fn continue_grant(data: i64) -> Self {
PostAuthAction::ContinueAuthorizationGrant { data }
}
#[must_use]
pub fn go_next(&self) -> axum::response::Redirect {
match self {
Self::ContinueAuthorizationGrant { data } => ContinueAuthorizationGrant(*data).go(),
}
}
}
/// `GET /.well-known/openid-configuration`
#[derive(Debug, Clone)]
pub struct OidcConfiguration;
impl SimpleRoute for OidcConfiguration {
const PATH: &'static str = "/.well-known/openid-configuration";
}
/// `GET /.well-known/webfinger`
#[derive(Debug, Clone)]
pub struct Webfinger;
impl SimpleRoute for Webfinger {
const PATH: &'static str = "/.well-known/webfinger";
}
/// `GET /oauth2/keys.json`
#[derive(Debug, Clone)]
pub struct OAuth2Keys;
impl SimpleRoute for OAuth2Keys {
const PATH: &'static str = "/oauth2/keys.json";
}
/// `GET /oauth2/userinfo`
#[derive(Debug, Clone)]
pub struct OidcUserinfo;
impl SimpleRoute for OidcUserinfo {
const PATH: &'static str = "/oauth2/userinfo";
}
/// `POST /oauth2/userinfo`
#[derive(Debug, Clone)]
pub struct OAuth2Introspection;
impl SimpleRoute for OAuth2Introspection {
const PATH: &'static str = "/oauth2/introspect";
}
/// `POST /oauth2/token`
#[derive(Debug, Clone)]
pub struct OAuth2TokenEndpoint;
impl SimpleRoute for OAuth2TokenEndpoint {
const PATH: &'static str = "/oauth2/token";
}
/// `POST /oauth2/registration`
#[derive(Debug, Clone)]
pub struct OAuth2RegistrationEndpoint;
impl SimpleRoute for OAuth2RegistrationEndpoint {
const PATH: &'static str = "/oauth2/registration";
}
/// `GET /authorize`
#[derive(Debug, Clone)]
pub struct OAuth2AuthorizationEndpoint;
impl SimpleRoute for OAuth2AuthorizationEndpoint {
const PATH: &'static str = "/authorize";
}
/// `GET /`
#[derive(Debug, Clone)]
pub struct Index;
impl SimpleRoute for Index {
const PATH: &'static str = "/";
}
/// `GET /health`
#[derive(Debug, Clone)]
pub struct Healthcheck;
impl SimpleRoute for Healthcheck {
const PATH: &'static str = "/health";
}
/// `GET|POST /login`
#[derive(Default, Debug, Clone)]
pub struct Login {
post_auth_action: Option<PostAuthAction>,
}
impl Route for Login {
type Query = PostAuthAction;
fn route() -> &'static str {
"/login"
}
fn query(&self) -> Option<&Self::Query> {
self.post_auth_action.as_ref()
}
}
impl Login {
#[must_use]
pub fn and_then(action: PostAuthAction) -> Self {
Self {
post_auth_action: Some(action),
}
}
#[must_use]
pub fn and_continue_grant(data: i64) -> Self {
Self {
post_auth_action: Some(PostAuthAction::continue_grant(data)),
}
}
/// Get a reference to the login's post auth action.
#[must_use]
pub fn post_auth_action(&self) -> Option<&PostAuthAction> {
self.post_auth_action.as_ref()
}
#[must_use]
pub fn go_next(&self) -> axum::response::Redirect {
match &self.post_auth_action {
Some(action) => action.go_next(),
None => Index.go(),
}
}
}
impl From<Option<PostAuthAction>> for Login {
fn from(post_auth_action: Option<PostAuthAction>) -> Self {
Self { post_auth_action }
}
}
/// `POST /logout`
#[derive(Debug, Clone)]
pub struct Logout;
impl SimpleRoute for Logout {
const PATH: &'static str = "/logout";
}
/// `GET|POST /reauth`
#[derive(Default, Debug, Clone)]
pub struct Reauth {
post_auth_action: Option<PostAuthAction>,
}
impl Reauth {
#[must_use]
pub fn and_then(action: PostAuthAction) -> Self {
Self {
post_auth_action: Some(action),
}
}
#[must_use]
pub fn and_continue_grant(data: i64) -> Self {
Self {
post_auth_action: Some(PostAuthAction::continue_grant(data)),
}
}
/// Get a reference to the reauth's post auth action.
#[must_use]
pub fn post_auth_action(&self) -> Option<&PostAuthAction> {
self.post_auth_action.as_ref()
}
#[must_use]
pub fn go_next(&self) -> axum::response::Redirect {
match &self.post_auth_action {
Some(action) => action.go_next(),
None => Index.go(),
}
}
}
impl Route for Reauth {
type Query = PostAuthAction;
fn route() -> &'static str {
"/reauth"
}
fn query(&self) -> Option<&Self::Query> {
self.post_auth_action.as_ref()
}
}
impl From<Option<PostAuthAction>> for Reauth {
fn from(post_auth_action: Option<PostAuthAction>) -> Self {
Self { post_auth_action }
}
}
/// `GET|POST /register`
#[derive(Default, Debug, Clone)]
pub struct Register {
post_auth_action: Option<PostAuthAction>,
}
impl Register {
#[must_use]
pub fn and_then(action: PostAuthAction) -> Self {
Self {
post_auth_action: Some(action),
}
}
#[must_use]
pub fn and_continue_grant(data: i64) -> Self {
Self {
post_auth_action: Some(PostAuthAction::continue_grant(data)),
}
}
/// Get a reference to the reauth's post auth action.
#[must_use]
pub fn post_auth_action(&self) -> Option<&PostAuthAction> {
self.post_auth_action.as_ref()
}
#[must_use]
pub fn go_next(&self) -> axum::response::Redirect {
match &self.post_auth_action {
Some(action) => action.go_next(),
None => Index.go(),
}
}
}
impl Route for Register {
type Query = PostAuthAction;
fn route() -> &'static str {
"/register"
}
fn query(&self) -> Option<&Self::Query> {
self.post_auth_action.as_ref()
}
}
impl From<Option<PostAuthAction>> for Register {
fn from(post_auth_action: Option<PostAuthAction>) -> Self {
Self { post_auth_action }
}
}
/// `GET /verify/:code`
#[derive(Debug, Clone)]
pub struct VerifyEmail(pub String);
impl Route for VerifyEmail {
type Query = ();
fn route() -> &'static str {
"/verify/:code"
}
fn path(&self) -> std::borrow::Cow<'static, str> {
format!("/verify/{}", self.0).into()
}
}
/// `GET /account`
#[derive(Debug, Clone)]
pub struct Account;
impl SimpleRoute for Account {
const PATH: &'static str = "/account";
}
/// `GET|POST /account/password`
#[derive(Debug, Clone)]
pub struct AccountPassword;
impl SimpleRoute for AccountPassword {
const PATH: &'static str = "/account/password";
}
/// `GET|POST /account/emails`
#[derive(Debug, Clone)]
pub struct AccountEmails;
impl SimpleRoute for AccountEmails {
const PATH: &'static str = "/account/emails";
}
/// `GET /authorize/:grant_id`
#[derive(Debug, Clone)]
pub struct ContinueAuthorizationGrant(pub i64);
impl Route for ContinueAuthorizationGrant {
type Query = ();
fn route() -> &'static str {
"/authorize/:grant_id"
}
fn path(&self) -> std::borrow::Cow<'static, str> {
format!("/authorize/{}", self.0).into()
}
}
/// `GET /consent/:grant_id`
#[derive(Debug, Clone)]
pub struct Consent(pub i64);
impl Route for Consent {
type Query = ();
fn route() -> &'static str {
"/consent/:grant_id"
}
fn path(&self) -> std::borrow::Cow<'static, str> {
format!("/consent/{}", self.0).into()
}
}

53
crates/router/src/lib.rs Normal file
View File

@ -0,0 +1,53 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#![deny(clippy::pedantic)]
pub(crate) mod endpoints;
pub(crate) mod traits;
mod url_builder;
pub use self::{endpoints::*, traits::Route, url_builder::UrlBuilder};
#[cfg(test)]
mod tests {
use std::borrow::Cow;
use url::Url;
use super::*;
#[test]
fn test_relative_urls() {
assert_eq!(
OidcConfiguration.relative_url(),
Cow::Borrowed("/.well-known/openid-configuration")
);
assert_eq!(Index.relative_url(), Cow::Borrowed("/"));
assert_eq!(
Login::and_continue_grant(42).relative_url(),
Cow::Borrowed("/login?next=continue_authorization_grant&data=42")
);
}
#[test]
fn test_absolute_urls() {
let base = Url::try_from("https://example.com/").unwrap();
assert_eq!(Index.absolute_url(&base).as_str(), "https://example.com/");
assert_eq!(
OidcConfiguration.absolute_url(&base).as_str(),
"https://example.com/.well-known/openid-configuration"
);
}
}

View File

@ -0,0 +1,60 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::borrow::{Borrow, Cow};
use serde::Serialize;
use url::Url;
pub trait Route {
type Query: Serialize;
fn route() -> &'static str;
fn query(&self) -> Option<&Self::Query> {
None
}
fn path(&self) -> Cow<'static, str> {
Cow::Borrowed(Self::route())
}
fn relative_url(&self) -> Cow<'static, str> {
let path = self.path();
if let Some(query) = self.query() {
let query = serde_urlencoded::to_string(query).unwrap();
format!("{}?{}", path, query).into()
} else {
path
}
}
fn absolute_url(&self, base: &Url) -> Url {
let relative = self.relative_url();
base.join(relative.borrow()).unwrap()
}
fn go(&self) -> axum::response::Redirect {
axum::response::Redirect::to(&self.relative_url())
}
}
pub trait SimpleRoute {
const PATH: &'static str;
}
impl<T: SimpleRoute> Route for T {
type Query = ();
fn route() -> &'static str {
Self::PATH
}
}

View File

@ -16,13 +16,21 @@
use url::Url; use url::Url;
/// Helps building absolute URLs use crate::traits::Route;
#[derive(Clone, Debug, PartialEq, Eq)] #[derive(Clone, Debug, PartialEq, Eq)]
pub struct UrlBuilder { pub struct UrlBuilder {
base: Url, base: Url,
} }
impl UrlBuilder { impl UrlBuilder {
fn url_for<U>(&self, destination: &U) -> Url
where
U: Route,
{
destination.absolute_url(&self.base)
}
/// Create a new [`UrlBuilder`] from a base URL /// Create a new [`UrlBuilder`] from a base URL
#[must_use] #[must_use]
pub fn new(base: Url) -> Self { pub fn new(base: Url) -> Self {
@ -38,55 +46,49 @@ impl UrlBuilder {
/// OIDC dicovery document URL /// OIDC dicovery document URL
#[must_use] #[must_use]
pub fn oidc_discovery(&self) -> Url { pub fn oidc_discovery(&self) -> Url {
self.base self.url_for(&crate::endpoints::OidcConfiguration)
.join(".well-known/openid-configuration")
.expect("build URL")
} }
/// OAuth 2.0 authorization endpoint /// OAuth 2.0 authorization endpoint
#[must_use] #[must_use]
pub fn oauth_authorization_endpoint(&self) -> Url { pub fn oauth_authorization_endpoint(&self) -> Url {
self.base.join("authorize").expect("build URL") self.url_for(&crate::endpoints::OAuth2AuthorizationEndpoint)
} }
/// OAuth 2.0 token endpoint /// OAuth 2.0 token endpoint
#[must_use] #[must_use]
pub fn oauth_token_endpoint(&self) -> Url { pub fn oauth_token_endpoint(&self) -> Url {
self.base.join("oauth2/token").expect("build URL") self.url_for(&crate::endpoints::OAuth2TokenEndpoint)
} }
/// OAuth 2.0 introspection endpoint /// OAuth 2.0 introspection endpoint
#[must_use] #[must_use]
pub fn oauth_introspection_endpoint(&self) -> Url { pub fn oauth_introspection_endpoint(&self) -> Url {
self.base.join("oauth2/introspect").expect("build URL") self.url_for(&crate::endpoints::OAuth2Introspection)
} }
/// OAuth 2.0 client registration endpoint /// OAuth 2.0 client registration endpoint
#[must_use] #[must_use]
pub fn oauth_registration_endpoint(&self) -> Url { pub fn oauth_registration_endpoint(&self) -> Url {
self.base.join("oauth2/registration").expect("build URL") self.url_for(&crate::endpoints::OAuth2RegistrationEndpoint)
} }
/// OpenID Connect userinfo endpoint // OIDC userinfo endpoint
#[must_use] #[must_use]
pub fn oidc_userinfo_endpoint(&self) -> Url { pub fn oidc_userinfo_endpoint(&self) -> Url {
self.base.join("oauth2/userinfo").expect("build URL") self.url_for(&crate::endpoints::OidcUserinfo)
} }
/// JWKS URI /// JWKS URI
#[must_use] #[must_use]
pub fn jwks_uri(&self) -> Url { pub fn jwks_uri(&self) -> Url {
self.base.join("oauth2/keys.json").expect("build URL") self.url_for(&crate::endpoints::OAuth2Keys)
} }
/// Email verification URL /// Email verification URL
#[must_use] #[must_use]
pub fn email_verification(&self, code: &str) -> Url { pub fn email_verification(&self, code: String) -> Url {
self.base self.url_for(&crate::endpoints::VerifyEmail(code))
.join("verify/")
.expect("build URL")
.join(code)
.expect("build URL")
} }
} }
@ -99,7 +101,7 @@ mod tests {
let base = Url::parse("https://example.com/").unwrap(); let base = Url::parse("https://example.com/").unwrap();
let builder = UrlBuilder::new(base); let builder = UrlBuilder::new(base);
assert_eq!( assert_eq!(
builder.email_verification("123456abcdef").as_str(), builder.email_verification("123456abcdef".into()).as_str(),
"https://example.com/verify/123456abcdef" "https://example.com/verify/123456abcdef"
); );
} }

View File

@ -241,7 +241,7 @@ pub enum PostAuthContext {
} }
/// Context used by the `login.html` template /// Context used by the `login.html` template
#[derive(Serialize)] #[derive(Serialize, Default)]
pub struct LoginContext { pub struct LoginContext {
form: ErroredForm<LoginFormField>, form: ErroredForm<LoginFormField>,
next: Option<PostAuthContext>, next: Option<PostAuthContext>,
@ -288,16 +288,6 @@ impl LoginContext {
} }
} }
impl Default for LoginContext {
fn default() -> Self {
Self {
form: ErroredForm::new(),
next: None,
register_link: "/register".to_string(),
}
}
}
/// Fields of the registration form /// Fields of the registration form
#[derive(Serialize, Debug, Clone, Copy, Hash, PartialEq, Eq)] #[derive(Serialize, Debug, Clone, Copy, Hash, PartialEq, Eq)]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
@ -313,7 +303,7 @@ pub enum RegisterFormField {
} }
/// Context used by the `register.html` template /// Context used by the `register.html` template
#[derive(Serialize)] #[derive(Serialize, Default)]
pub struct RegisterContext { pub struct RegisterContext {
form: ErroredForm<LoginFormField>, form: ErroredForm<LoginFormField>,
next: Option<PostAuthContext>, next: Option<PostAuthContext>,
@ -357,16 +347,6 @@ impl RegisterContext {
} }
} }
impl Default for RegisterContext {
fn default() -> Self {
Self {
form: ErroredForm::new(),
next: None,
login_link: "/login".to_string(),
}
}
}
/// Context used by the `consent.html` template /// Context used by the `consent.html` template
#[derive(Serialize)] #[derive(Serialize)]
pub struct ConsentContext { pub struct ConsentContext {