1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-08-09 04:22:45 +03:00

Better post-login/auth redirects

This commit is contained in:
Quentin Gliech
2021-11-16 15:09:14 +01:00
parent 0a2fda35fd
commit 04f8c5fe97
6 changed files with 188 additions and 58 deletions

View File

@@ -14,7 +14,8 @@
use std::{ use std::{
collections::{HashMap, HashSet}, collections::{HashMap, HashSet},
convert::TryFrom, convert::{TryFrom, TryInto},
num::NonZeroU32,
}; };
use chrono::Duration; use chrono::Duration;
@@ -24,7 +25,8 @@ use hyper::{
StatusCode, StatusCode,
}; };
use mas_data_model::{ use mas_data_model::{
Authentication, AuthorizationCode, AuthorizationGrantStage, BrowserSession, Pkce, Authentication, AuthorizationCode, AuthorizationGrant, AuthorizationGrantStage, BrowserSession,
Pkce, StorageBackend,
}; };
use mas_templates::{FormPostContext, Templates}; use mas_templates::{FormPostContext, Templates};
use oauth2_types::{ use oauth2_types::{
@@ -55,7 +57,7 @@ use crate::{
session::{optional_session, session}, session::{optional_session, session},
with_templates, with_templates,
}, },
handlers::views::LoginRequest, handlers::views::{LoginRequest, PostAuthAction, ReauthRequest},
storage::{ storage::{
oauth2::{ oauth2::{
access_token::add_access_token, access_token::add_access_token,
@@ -227,7 +229,7 @@ pub fn filter(
let step = warp::path!("oauth2" / "authorize" / "step") let step = warp::path!("oauth2" / "authorize" / "step")
.and(warp::get()) .and(warp::get())
.and(warp::query().map(|s: StepRequest| s.id)) .and(warp::query())
.and(session(pool, cookies_config)) .and(session(pool, cookies_config))
.and(transaction(pool)) .and(transaction(pool))
.and_then(step); .and_then(step);
@@ -352,6 +354,14 @@ async fn get(
None None
}; };
let max_age: Option<NonZeroU32> = params
.auth
.max_age
.as_ref()
.map(|d| d.num_seconds().try_into().and_then(|d: u32| d.try_into()))
.transpose()
.wrap_error()?;
let grant = new_authorization_grant( let grant = new_authorization_grant(
&mut txn, &mut txn,
client.client_id.clone(), client.client_id.clone(),
@@ -360,8 +370,7 @@ async fn get(
code, code,
params.auth.state, params.auth.state,
params.auth.nonce, params.auth.nonce,
// TODO: support max_age and acr_values max_age,
None,
None, None,
response_mode, response_mode,
response_type.contains(&ResponseType::Token), response_type.contains(&ResponseType::Token),
@@ -370,33 +379,36 @@ async fn get(
.await .await
.wrap_error()?; .wrap_error()?;
let next = ContinueAuthorizationGrant::from_authorization_grant(grant);
if let Some(user_session) = maybe_session { if let Some(user_session) = maybe_session {
step(grant.data, user_session, txn).await step(next, user_session, txn).await
} else { } else {
// If not, redirect the user to the login page // If not, redirect the user to the login page
txn.commit().await.wrap_error()?; txn.commit().await.wrap_error()?;
let next = StepRequest::new(grant.data) let next: PostAuthAction<_> = next.into();
.build_uri() let next: LoginRequest<_> = next.into();
.wrap_error()? let next = next.build_uri().wrap_error()?;
.to_string();
let destination = LoginRequest::new(Some(next)).build_uri().wrap_error()?; Ok(ReplyOrBackToClient::Reply(Box::new(see_other(next))))
Ok(ReplyOrBackToClient::Reply(Box::new(see_other(destination))))
} }
} }
#[derive(Deserialize, Serialize)] #[derive(Serialize, Deserialize)]
struct StepRequest { pub(crate) struct ContinueAuthorizationGrant<S: StorageBackend> {
id: i64, data: S::AuthorizationGrantData,
} }
impl StepRequest { impl<S: StorageBackend> ContinueAuthorizationGrant<S> {
fn new(id: i64) -> Self { pub fn from_authorization_grant(grant: AuthorizationGrant<S>) -> Self {
Self { id } Self { data: grant.data }
} }
fn build_uri(&self) -> anyhow::Result<Uri> { pub fn build_uri(&self) -> anyhow::Result<Uri>
where
S::AuthorizationGrantData: Serialize,
{
let qs = serde_urlencoded::to_string(self)?; let qs = serde_urlencoded::to_string(self)?;
let path_and_query = PathAndQuery::try_from(format!("/oauth2/authorize/step?{}", qs))?; let path_and_query = PathAndQuery::try_from(format!("/oauth2/authorize/step?{}", qs))?;
let uri = Uri::from_parts({ let uri = Uri::from_parts({
@@ -408,20 +420,14 @@ impl StepRequest {
} }
} }
fn reauth() -> ReplyOrBackToClient {
// Ask for a reauth
// TODO: have the OAuth2 session ID in there
ReplyOrBackToClient::Reply(Box::new(see_other(Uri::from_static("/reauth"))))
}
async fn step( async fn step(
grant_id: i64, next: ContinueAuthorizationGrant<PostgresqlBackend>,
browser_session: BrowserSession<PostgresqlBackend>, browser_session: BrowserSession<PostgresqlBackend>,
mut txn: Transaction<'_, Postgres>, mut txn: Transaction<'_, Postgres>,
) -> Result<ReplyOrBackToClient, Rejection> { ) -> Result<ReplyOrBackToClient, Rejection> {
// TODO: we should check if the grant here was started by the browser doing that // TODO: we should check if the grant here was started by the browser doing that
// request using a signed cookie // request using a signed cookie
let grant = get_grant_by_id(&mut txn, grant_id).await.wrap_error()?; let grant = get_grant_by_id(&mut txn, next.data).await.wrap_error()?;
if !matches!(grant.stage, AuthorizationGrantStage::Pending) { if !matches!(grant.stage, AuthorizationGrantStage::Pending) {
return Err(anyhow::anyhow!("authorization grant not pending")).wrap_error(); return Err(anyhow::anyhow!("authorization grant not pending")).wrap_error();
@@ -485,7 +491,13 @@ async fn step(
params, params,
} }
} }
_ => reauth(), _ => {
let next: PostAuthAction<_> = next.into();
let next: ReauthRequest<_> = next.into();
let next = next.build_uri().wrap_error()?;
ReplyOrBackToClient::Reply(Box::new(see_other(next)))
}
}; };
txn.commit().await.wrap_error()?; txn.commit().await.wrap_error()?;

View File

@@ -25,6 +25,7 @@ mod keys;
mod token; mod token;
mod userinfo; mod userinfo;
pub(crate) use self::authorization::ContinueAuthorizationGrant;
use self::{ use self::{
authorization::filter as authorization, discovery::filter as discovery, authorization::filter as authorization, discovery::filter as discovery,
introspection::filter as introspection, keys::filter as keys, token::filter as token, introspection::filter as introspection, keys::filter as keys, token::filter as token,

View File

@@ -15,12 +15,13 @@
use std::convert::TryFrom; use std::convert::TryFrom;
use hyper::http::uri::{Parts, PathAndQuery, Uri}; use hyper::http::uri::{Parts, PathAndQuery, Uri};
use mas_data_model::{errors::WrapFormError, BrowserSession}; use mas_data_model::{errors::WrapFormError, BrowserSession, StorageBackend};
use mas_templates::{LoginContext, LoginFormField, TemplateContext, Templates}; use mas_templates::{LoginContext, LoginFormField, TemplateContext, Templates};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use sqlx::{pool::PoolConnection, PgPool, Postgres}; use sqlx::{pool::PoolConnection, PgPool, Postgres};
use warp::{reply::html, Filter, Rejection, Reply}; use warp::{reply::html, Filter, Rejection, Reply};
use super::shared::PostAuthAction;
use crate::{ use crate::{
config::{CookiesConfig, CsrfConfig}, config::{CookiesConfig, CsrfConfig},
errors::WrapError, errors::WrapError,
@@ -34,19 +35,33 @@ use crate::{
storage::{login, PostgresqlBackend}, storage::{login, PostgresqlBackend},
}; };
#[derive(Serialize, Deserialize)] #[derive(Deserialize)]
pub struct LoginRequest { #[serde(
next: Option<String>, rename_all = "snake_case",
bound = "<S as StorageBackend>::AuthorizationGrantData: Deserialize<'de>"
)]
pub(crate) struct LoginRequest<S: StorageBackend> {
#[serde(flatten, skip_serializing_if = "Option::is_none")]
next: Option<PostAuthAction<S>>,
} }
impl LoginRequest { impl<S: StorageBackend> From<PostAuthAction<S>> for LoginRequest<S> {
pub fn new(next: Option<String>) -> Self { fn from(next: PostAuthAction<S>) -> Self {
Self { next } Self { next: Some(next) }
}
} }
pub fn build_uri(&self) -> anyhow::Result<Uri> { impl<S: StorageBackend> LoginRequest<S> {
let qs = serde_urlencoded::to_string(self)?; pub fn build_uri(&self) -> anyhow::Result<Uri>
let path_and_query = PathAndQuery::try_from(format!("/login?{}", qs))?; where
S::AuthorizationGrantData: Serialize,
{
let path_and_query = if let Some(next) = &self.next {
let qs = serde_urlencoded::to_string(next)?;
PathAndQuery::try_from(format!("/login?{}", qs))?
} else {
PathAndQuery::from_static("/login")
};
let uri = Uri::from_parts({ let uri = Uri::from_parts({
let mut parts = Parts::default(); let mut parts = Parts::default();
parts.path_and_query = Some(path_and_query); parts.path_and_query = Some(path_and_query);
@@ -55,19 +70,17 @@ impl LoginRequest {
Ok(uri) Ok(uri)
} }
fn redirect(self) -> Result<impl Reply, Rejection> { fn redirect(self) -> Result<impl Reply, Rejection>
let uri: Uri = Uri::from_parts({ where
let mut parts = Parts::default(); S::AuthorizationGrantData: Serialize,
parts.path_and_query = Some( {
self.next let uri = self
.map(warp::http::uri::PathAndQuery::try_from) .next
.as_ref()
.map(PostAuthAction::build_uri)
.transpose() .transpose()
.wrap_error()? .wrap_error()?
.unwrap_or_else(|| PathAndQuery::from_static("/")), .unwrap_or_else(|| Uri::from_static("/"));
);
parts
})
.wrap_error()?;
Ok(warp::redirect::see_other(uri)) Ok(warp::redirect::see_other(uri))
} }
} }
@@ -108,7 +121,7 @@ async fn get(
templates: Templates, templates: Templates,
cookie_saver: EncryptedCookieSaver, cookie_saver: EncryptedCookieSaver,
csrf_token: CsrfToken, csrf_token: CsrfToken,
query: LoginRequest, query: LoginRequest<PostgresqlBackend>,
maybe_session: Option<BrowserSession<PostgresqlBackend>>, maybe_session: Option<BrowserSession<PostgresqlBackend>>,
) -> Result<Box<dyn Reply>, Rejection> { ) -> Result<Box<dyn Reply>, Rejection> {
if maybe_session.is_some() { if maybe_session.is_some() {
@@ -128,7 +141,7 @@ async fn post(
cookie_saver: EncryptedCookieSaver, cookie_saver: EncryptedCookieSaver,
csrf_token: CsrfToken, csrf_token: CsrfToken,
form: LoginForm, form: LoginForm,
query: LoginRequest, query: LoginRequest<PostgresqlBackend>,
) -> Result<Box<dyn Reply>, Rejection> { ) -> Result<Box<dyn Reply>, Rejection> {
use crate::storage::user::LoginError; use crate::storage::user::LoginError;
// TODO: recover // TODO: recover

View File

@@ -23,12 +23,13 @@ mod login;
mod logout; mod logout;
mod reauth; mod reauth;
mod register; mod register;
mod shared;
pub use self::login::LoginRequest;
use self::{ use self::{
index::filter as index, login::filter as login, logout::filter as logout, index::filter as index, login::filter as login, logout::filter as logout,
reauth::filter as reauth, register::filter as register, reauth::filter as reauth, register::filter as register,
}; };
pub(crate) use self::{login::LoginRequest, reauth::ReauthRequest, shared::PostAuthAction};
pub(super) fn filter( pub(super) fn filter(
pool: &PgPool, pool: &PgPool,

View File

@@ -12,12 +12,16 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use mas_data_model::BrowserSession; use std::convert::TryFrom;
use hyper::http::uri::{Parts, PathAndQuery};
use mas_data_model::{BrowserSession, StorageBackend};
use mas_templates::{EmptyContext, TemplateContext, Templates}; use mas_templates::{EmptyContext, TemplateContext, Templates};
use serde::Deserialize; use serde::{Deserialize, Serialize};
use sqlx::{PgPool, Postgres, Transaction}; use sqlx::{PgPool, Postgres, Transaction};
use warp::{hyper::Uri, reply::html, Filter, Rejection, Reply}; use warp::{hyper::Uri, reply::html, Filter, Rejection, Reply};
use super::PostAuthAction;
use crate::{ use crate::{
config::{CookiesConfig, CsrfConfig}, config::{CookiesConfig, CsrfConfig},
errors::WrapError, errors::WrapError,
@@ -30,6 +34,55 @@ use crate::{
}, },
storage::{user::authenticate_session, PostgresqlBackend}, storage::{user::authenticate_session, PostgresqlBackend},
}; };
#[derive(Deserialize)]
#[serde(
rename_all = "snake_case",
bound = "<S as StorageBackend>::AuthorizationGrantData: Deserialize<'de>"
)]
pub(crate) struct ReauthRequest<S: StorageBackend> {
#[serde(flatten, skip_serializing_if = "Option::is_none")]
next: Option<PostAuthAction<S>>,
}
impl<S: StorageBackend> From<PostAuthAction<S>> for ReauthRequest<S> {
fn from(next: PostAuthAction<S>) -> Self {
Self { next: Some(next) }
}
}
impl<S: StorageBackend> ReauthRequest<S> {
pub fn build_uri(&self) -> anyhow::Result<Uri>
where
S::AuthorizationGrantData: Serialize,
{
let path_and_query = if let Some(next) = &self.next {
let qs = serde_urlencoded::to_string(next)?;
PathAndQuery::try_from(format!("/reauth?{}", qs))?
} else {
PathAndQuery::from_static("/reauth")
};
let uri = Uri::from_parts({
let mut parts = Parts::default();
parts.path_and_query = Some(path_and_query);
parts
})?;
Ok(uri)
}
fn redirect(self) -> Result<impl Reply, Rejection>
where
S::AuthorizationGrantData: Serialize,
{
let uri = self
.next
.as_ref()
.map(PostAuthAction::build_uri)
.transpose()
.wrap_error()?
.unwrap_or_else(|| Uri::from_static("/"));
Ok(warp::redirect::see_other(uri))
}
}
#[derive(Deserialize, Debug)] #[derive(Deserialize, Debug)]
struct ReauthForm { struct ReauthForm {
@@ -47,12 +100,14 @@ pub(super) fn filter(
.and(encrypted_cookie_saver(cookies_config)) .and(encrypted_cookie_saver(cookies_config))
.and(updated_csrf_token(cookies_config, csrf_config)) .and(updated_csrf_token(cookies_config, csrf_config))
.and(session(pool, cookies_config)) .and(session(pool, cookies_config))
.and(warp::query())
.and_then(get); .and_then(get);
let post = warp::post() let post = warp::post()
.and(session(pool, cookies_config)) .and(session(pool, cookies_config))
.and(transaction(pool)) .and(transaction(pool))
.and(protected_form(cookies_config)) .and(protected_form(cookies_config))
.and(warp::query())
.and_then(post); .and_then(post);
warp::path!("reauth").and(get.or(post)) warp::path!("reauth").and(get.or(post))
@@ -63,6 +118,7 @@ async fn get(
cookie_saver: EncryptedCookieSaver, cookie_saver: EncryptedCookieSaver,
csrf_token: CsrfToken, csrf_token: CsrfToken,
session: BrowserSession<PostgresqlBackend>, session: BrowserSession<PostgresqlBackend>,
_query: ReauthRequest<PostgresqlBackend>,
) -> Result<impl Reply, Rejection> { ) -> Result<impl Reply, Rejection> {
let ctx = EmptyContext let ctx = EmptyContext
.with_session(session) .with_session(session)
@@ -78,11 +134,12 @@ async fn post(
session: BrowserSession<PostgresqlBackend>, session: BrowserSession<PostgresqlBackend>,
mut txn: Transaction<'_, Postgres>, mut txn: Transaction<'_, Postgres>,
form: ReauthForm, form: ReauthForm,
query: ReauthRequest<PostgresqlBackend>,
) -> Result<impl Reply, Rejection> { ) -> Result<impl Reply, Rejection> {
authenticate_session(&mut txn, &session, form.password) authenticate_session(&mut txn, &session, form.password)
.await .await
.wrap_error()?; .wrap_error()?;
txn.commit().await.wrap_error()?; txn.commit().await.wrap_error()?;
Ok(warp::redirect(Uri::from_static("/"))) Ok(query.redirect()?)
} }

View File

@@ -0,0 +1,46 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use hyper::Uri;
use mas_data_model::StorageBackend;
use serde::{Deserialize, Serialize};
use super::super::oauth2::ContinueAuthorizationGrant;
#[derive(Deserialize, Serialize)]
#[serde(rename_all = "snake_case", tag = "next")]
pub(crate) enum PostAuthAction<S: StorageBackend> {
#[serde(bound(
deserialize = "S::AuthorizationGrantData: Deserialize<'de>",
serialize = "S::AuthorizationGrantData: Serialize"
))]
ContinueAuthorizationGrant(ContinueAuthorizationGrant<S>),
}
impl<S: StorageBackend> PostAuthAction<S> {
pub fn build_uri(&self) -> anyhow::Result<Uri>
where
S::AuthorizationGrantData: Serialize,
{
match self {
PostAuthAction::ContinueAuthorizationGrant(c) => c.build_uri(),
}
}
}
impl<S: StorageBackend> From<ContinueAuthorizationGrant<S>> for PostAuthAction<S> {
fn from(g: ContinueAuthorizationGrant<S>) -> Self {
Self::ContinueAuthorizationGrant(g)
}
}