1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-31 09:24:31 +03:00

Have a consent screen before continuing the SSO login

This commit is contained in:
Quentin Gliech
2022-05-20 15:03:38 +02:00
parent 033d60eb73
commit 1d61a94da4
10 changed files with 214 additions and 18 deletions

1
Cargo.lock generated
View File

@ -2294,6 +2294,7 @@ name = "mas-templates"
version = "0.1.0" version = "0.1.0"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"chrono",
"mas-config", "mas-config",
"mas-data-model", "mas-data-model",
"oauth2-types", "oauth2-types",

View File

@ -90,6 +90,18 @@ pub struct CompatSession<T: StorageBackend> {
pub deleted_at: Option<DateTime<Utc>>, pub deleted_at: Option<DateTime<Utc>>,
} }
impl<S: StorageBackendMarker> From<CompatSession<S>> for CompatSession<()> {
fn from(t: CompatSession<S>) -> Self {
Self {
data: (),
user: t.user.into(),
device: t.device,
created_at: t.created_at,
deleted_at: t.deleted_at,
}
}
}
#[derive(Debug, Clone, PartialEq)] #[derive(Debug, Clone, PartialEq)]
pub struct CompatAccessToken<T: StorageBackend> { pub struct CompatAccessToken<T: StorageBackend> {
pub data: T::CompatAccessTokenData, pub data: T::CompatAccessTokenData,
@ -98,6 +110,17 @@ pub struct CompatAccessToken<T: StorageBackend> {
pub expires_at: Option<DateTime<Utc>>, pub expires_at: Option<DateTime<Utc>>,
} }
impl<S: StorageBackendMarker> From<CompatAccessToken<S>> for CompatAccessToken<()> {
fn from(t: CompatAccessToken<S>) -> Self {
Self {
data: (),
token: t.token,
created_at: t.created_at,
expires_at: t.expires_at,
}
}
}
#[derive(Debug, Clone, PartialEq)] #[derive(Debug, Clone, PartialEq)]
pub struct CompatRefreshToken<T: StorageBackend> { pub struct CompatRefreshToken<T: StorageBackend> {
pub data: T::RefreshTokenData, pub data: T::RefreshTokenData,
@ -105,13 +128,12 @@ pub struct CompatRefreshToken<T: StorageBackend> {
pub created_at: DateTime<Utc>, pub created_at: DateTime<Utc>,
} }
impl<S: StorageBackendMarker> From<CompatAccessToken<S>> for CompatAccessToken<()> { impl<S: StorageBackendMarker> From<CompatRefreshToken<S>> for CompatRefreshToken<()> {
fn from(t: CompatAccessToken<S>) -> Self { fn from(t: CompatRefreshToken<S>) -> Self {
CompatAccessToken { Self {
data: (), data: (),
token: t.token, token: t.token,
created_at: t.created_at, created_at: t.created_at,
expires_at: t.expires_at,
} }
} }
} }
@ -131,6 +153,30 @@ pub enum CompatSsoLoginState<T: StorageBackend> {
}, },
} }
impl<S: StorageBackendMarker> From<CompatSsoLoginState<S>> for CompatSsoLoginState<()> {
fn from(t: CompatSsoLoginState<S>) -> Self {
match t {
CompatSsoLoginState::Pending => Self::Pending,
CompatSsoLoginState::Fullfilled {
fullfilled_at,
session,
} => Self::Fullfilled {
fullfilled_at,
session: session.into(),
},
CompatSsoLoginState::Exchanged {
fullfilled_at,
exchanged_at,
session,
} => Self::Exchanged {
fullfilled_at,
exchanged_at,
session: session.into(),
},
}
}
}
#[derive(Debug, Clone, PartialEq, Serialize)] #[derive(Debug, Clone, PartialEq, Serialize)]
#[serde(bound = "T: StorageBackend")] #[serde(bound = "T: StorageBackend")]
pub struct CompatSsoLogin<T: StorageBackend> { pub struct CompatSsoLogin<T: StorageBackend> {
@ -141,3 +187,15 @@ pub struct CompatSsoLogin<T: StorageBackend> {
pub created_at: DateTime<Utc>, pub created_at: DateTime<Utc>,
pub state: CompatSsoLoginState<T>, pub state: CompatSsoLoginState<T>,
} }
impl<S: StorageBackendMarker> From<CompatSsoLogin<S>> for CompatSsoLogin<()> {
fn from(t: CompatSsoLogin<S>) -> Self {
Self {
data: (),
redirect_uri: t.redirect_uri,
token: t.token,
created_at: t.created_at,
state: t.state.into(),
}
}
}

View File

@ -40,6 +40,7 @@ enum LoginType {
#[serde(rename = "m.login.sso")] #[serde(rename = "m.login.sso")]
Sso { Sso {
#[serde(skip_serializing_if = "Vec::is_empty")]
identity_providers: Vec<SsoIdentityProvider>, identity_providers: Vec<SsoIdentityProvider>,
}, },
} }

View File

@ -16,16 +16,20 @@
use std::collections::HashMap; use std::collections::HashMap;
use axum::{ use axum::{
extract::Path, extract::{Form, Path},
response::{IntoResponse, Redirect, Response}, response::{Html, IntoResponse, Redirect, Response},
Extension, Extension,
}; };
use axum_extra::extract::PrivateCookieJar; use axum_extra::extract::PrivateCookieJar;
use mas_axum_utils::{FancyError, SessionInfoExt}; use mas_axum_utils::{
csrf::{CsrfExt, ProtectedForm},
FancyError, SessionInfoExt,
};
use mas_config::Encrypter; use mas_config::Encrypter;
use mas_data_model::Device; use mas_data_model::Device;
use mas_router::Route; use mas_router::Route;
use mas_storage::compat::{fullfill_compat_sso_login, get_compat_sso_login_by_id}; use mas_storage::compat::{fullfill_compat_sso_login, get_compat_sso_login_by_id};
use mas_templates::{CompatSsoContext, TemplateContext, Templates};
use rand::thread_rng; use rand::thread_rng;
use serde::Serialize; use serde::Serialize;
use sqlx::PgPool; use sqlx::PgPool;
@ -41,12 +45,46 @@ struct AllParams<'s> {
pub async fn get( pub async fn get(
Extension(pool): Extension<PgPool>, Extension(pool): Extension<PgPool>,
Extension(templates): Extension<Templates>,
cookie_jar: PrivateCookieJar<Encrypter>, cookie_jar: PrivateCookieJar<Encrypter>,
Path(id): Path<i64>, Path(id): Path<i64>,
) -> Result<Response, FancyError> {
let mut conn = pool.acquire().await?;
let (session_info, cookie_jar) = cookie_jar.session_info();
let (csrf_token, cookie_jar) = cookie_jar.csrf_token();
let maybe_session = session_info.load_session(&mut conn).await?;
let session = if let Some(session) = maybe_session {
session
} else {
// If there is no session, redirect to the login screen
let login = mas_router::Login::and_continue_compat_sso_login(id);
return Ok((cookie_jar, login.go()).into_response());
};
let login = get_compat_sso_login_by_id(&mut conn, id).await?;
let ctx = CompatSsoContext::new(login)
.with_session(session)
.with_csrf(csrf_token.form_value());
let content = templates.render_sso_login(&ctx).await?;
Ok((cookie_jar, Html(content)).into_response())
}
pub async fn post(
Extension(pool): Extension<PgPool>,
cookie_jar: PrivateCookieJar<Encrypter>,
Path(id): Path<i64>,
Form(form): Form<ProtectedForm<()>>,
) -> Result<Response, FancyError> { ) -> Result<Response, FancyError> {
let mut txn = pool.begin().await?; let mut txn = pool.begin().await?;
let (session_info, cookie_jar) = cookie_jar.session_info(); let (session_info, cookie_jar) = cookie_jar.session_info();
cookie_jar.verify_form(form)?;
let maybe_session = session_info.load_session(&mut txn).await?; let maybe_session = session_info.load_session(&mut txn).await?;

View File

@ -193,7 +193,8 @@ where
) )
.route( .route(
mas_router::CompatLoginSsoComplete::route(), mas_router::CompatLoginSsoComplete::route(),
get(self::compat::login_sso_complete::get), get(self::compat::login_sso_complete::get)
.post(self::compat::login_sso_complete::post),
) )
.layer(ThenLayer::new( .layer(ThenLayer::new(
move |result: Result<axum::response::Response, Infallible>| async move { move |result: Result<axum::response::Response, Infallible>| async move {

View File

@ -13,7 +13,9 @@
// limitations under the License. // limitations under the License.
use mas_router::{PostAuthAction, Route}; use mas_router::{PostAuthAction, Route};
use mas_storage::oauth2::authorization_grant::get_grant_by_id; use mas_storage::{
compat::get_compat_sso_login_by_id, oauth2::authorization_grant::get_grant_by_id,
};
use mas_templates::PostAuthContext; use mas_templates::PostAuthContext;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use sqlx::PgConnection; use sqlx::PgConnection;
@ -41,8 +43,10 @@ impl OptionalPostAuthAction {
let grant = Box::new(grant.into()); let grant = Box::new(grant.into());
Ok(Some(PostAuthContext::ContinueAuthorizationGrant { grant })) Ok(Some(PostAuthContext::ContinueAuthorizationGrant { grant }))
} }
Some(PostAuthAction::ContinueCompatSsoLogin { .. }) => { Some(PostAuthAction::ContinueCompatSsoLogin { data }) => {
Ok(Some(PostAuthContext::ContinueCompatSsoLogin)) let login = get_compat_sso_login_by_id(conn, *data).await?;
let login = Box::new(login.into());
Ok(Some(PostAuthContext::ContinueCompatSsoLogin { login }))
} }
Some(PostAuthAction::ChangePassword) => Ok(Some(PostAuthContext::ChangePassword)), Some(PostAuthAction::ChangePassword) => Ok(Some(PostAuthContext::ChangePassword)),
None => Ok(None), None => Ok(None),

View File

@ -20,6 +20,7 @@ serde = { version = "1.0.137", features = ["derive"] }
serde_json = "1.0.81" serde_json = "1.0.81"
serde_urlencoded = "0.7.1" serde_urlencoded = "0.7.1"
chrono = "0.4.19"
url = "2.2.2" url = "2.2.2"
oauth2-types = { path = "../oauth2-types" } oauth2-types = { path = "../oauth2-types" }

View File

@ -16,7 +16,11 @@
#![allow(clippy::trait_duplication_in_bounds)] #![allow(clippy::trait_duplication_in_bounds)]
use mas_data_model::{AuthorizationGrant, BrowserSession, StorageBackend, User, UserEmail}; use chrono::Utc;
use mas_data_model::{
AuthorizationGrant, BrowserSession, CompatSsoLogin, CompatSsoLoginState, StorageBackend, User,
UserEmail,
};
use serde::{ser::SerializeStruct, Deserialize, Serialize}; use serde::{ser::SerializeStruct, Deserialize, Serialize};
use url::Url; use url::Url;
@ -250,7 +254,10 @@ pub enum PostAuthContext {
/// Continue legacy login /// Continue legacy login
/// TODO: add the login context in there /// TODO: add the login context in there
ContinueCompatSsoLogin, ContinueCompatSsoLogin {
/// The compat SSO login request
login: Box<CompatSsoLogin<()>>,
},
/// Change the account password /// Change the account password
ChangePassword, ChangePassword,
@ -454,6 +461,42 @@ impl ReauthContext {
} }
} }
/// Context used by the `sso.html` template
#[derive(Serialize)]
pub struct CompatSsoContext {
login: CompatSsoLogin<()>,
}
impl TemplateContext for CompatSsoContext {
fn sample() -> Vec<Self>
where
Self: Sized,
{
vec![CompatSsoContext {
login: CompatSsoLogin {
data: (),
redirect_uri: Url::parse("https://app.element.io/").unwrap(),
token: "abcdefghijklmnopqrstuvwxyz012345".into(),
created_at: Utc::now(),
state: CompatSsoLoginState::Pending,
},
}]
}
}
impl CompatSsoContext {
/// Constructs a context for the legacy SSO login page
#[must_use]
pub fn new<T>(login: T) -> Self
where
T: Into<CompatSsoLogin<()>>,
{
Self {
login: login.into(),
}
}
}
/// Context used by the `account/index.html` template /// Context used by the `account/index.html` template
#[derive(Serialize)] #[derive(Serialize)]
pub struct AccountContext { pub struct AccountContext {

View File

@ -45,10 +45,11 @@ mod macros;
pub use self::{ pub use self::{
context::{ context::{
AccountContext, AccountEmailsContext, ConsentContext, EmailVerificationContext, AccountContext, AccountEmailsContext, CompatSsoContext, ConsentContext,
EmptyContext, ErrorContext, FormPostContext, IndexContext, LoginContext, LoginFormField, EmailVerificationContext, EmptyContext, ErrorContext, FormPostContext, IndexContext,
PostAuthContext, ReauthContext, ReauthFormField, RegisterContext, RegisterFormField, LoginContext, LoginFormField, PostAuthContext, ReauthContext, ReauthFormField,
TemplateContext, WithCsrf, WithOptionalSession, WithSession, RegisterContext, RegisterFormField, TemplateContext, WithCsrf, WithOptionalSession,
WithSession,
}, },
forms::{FieldError, FormError, FormField, FormState, ToFormState}, forms::{FieldError, FormError, FormField, FormState, ToFormState},
}; };
@ -294,9 +295,12 @@ register_templates! {
/// Render the registration page /// Render the registration page
pub fn render_register(WithCsrf<RegisterContext>) { "pages/register.html" } pub fn render_register(WithCsrf<RegisterContext>) { "pages/register.html" }
/// Render the registration page /// Render the client consent page
pub fn render_consent(WithCsrf<WithSession<ConsentContext>>) { "pages/consent.html" } pub fn render_consent(WithCsrf<WithSession<ConsentContext>>) { "pages/consent.html" }
/// Render the client consent page
pub fn render_sso_login(WithCsrf<WithSession<CompatSsoContext>>) { "pages/sso.html" }
/// Render the home page /// Render the home page
pub fn render_index(WithCsrf<WithOptionalSession<IndexContext>>) { "pages/index.html" } pub fn render_index(WithCsrf<WithOptionalSession<IndexContext>>) { "pages/index.html" }

View File

@ -0,0 +1,45 @@
{#
Copyright 2022 The Matrix.org Foundation C.I.C.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
#}
{% extends "base.html" %}
{% block content %}
<section class="flex items-center justify-center flex-1">
<div class="w-96 m-2">
<form method="POST" class="grid grid-cols-1 gap-6">
<div class="rounded-lg bg-grey-25 dark:bg-grey-450 p-2 flex flex-col">
<div class="text-center">
<h1 class="text-lg text-center font-medium">{{ login.redirect_uri }}</h1>
<h1>wants to access your Matrix account</h1>
</div>
</div>
<input type="hidden" name="csrf" value="{{ csrf_token }}" />
{{ button::button(text="Allow") }}
</form>
<div class="text-center mt-4">
<form method="POST" action="/logout">
<input type="hidden" name="csrf" value="{{ csrf_token }}" />
<div>
Not {{ current_session.user.username }}?
{{ button::button_text(text="Sign out", name="logout", type="submit") }}
</div>
</form>
</div>
</div>
</section>
{% endblock content %}