1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-11-20 12:02:22 +03:00

Save the post auth action during upstream OAuth login

This commit is contained in:
Quentin Gliech
2022-12-05 18:27:56 +01:00
parent 4d93f4d4f0
commit 23fd833d45
18 changed files with 142 additions and 100 deletions

View File

@@ -13,7 +13,7 @@
// limitations under the License.
use axum::{
extract::{Path, State},
extract::{Path, Query, State},
response::{IntoResponse, Redirect},
};
use axum_extra::extract::PrivateCookieJar;
@@ -28,7 +28,7 @@ use thiserror::Error;
use ulid::Ulid;
use super::UpstreamSessionsCookie;
use crate::impl_from_error_for_route;
use crate::{impl_from_error_for_route, views::shared::OptionalPostAuthAction};
#[derive(Debug, Error)]
pub(crate) enum RouteError {
@@ -68,6 +68,7 @@ pub(crate) async fn get(
State(url_builder): State<UrlBuilder>,
cookie_jar: PrivateCookieJar<Encrypter>,
Path(provider_id): Path<Ulid>,
Query(query): Query<OptionalPostAuthAction>,
) -> Result<impl IntoResponse, RouteError> {
let (clock, mut rng) = crate::rng_and_clock()?;
@@ -115,7 +116,7 @@ pub(crate) async fn get(
.await?;
let cookie_jar = UpstreamSessionsCookie::load(&cookie_jar)
.add(session.id, provider.id, data.state)
.add(session.id, provider.id, data.state, query.post_auth_action)
.save(cookie_jar, clock.now());
txn.commit().await?;

View File

@@ -137,7 +137,7 @@ pub(crate) async fn get(
let mut txn = pool.begin().await?;
let sessions_cookie = UpstreamSessionsCookie::load(&cookie_jar);
let session_id = sessions_cookie
let (session_id, _post_auth_action) = sessions_cookie
.find_session(provider_id, &params.state)
.map_err(|_| RouteError::MissingCookie)?;

View File

@@ -17,6 +17,7 @@
use axum_extra::extract::{cookie::Cookie, PrivateCookieJar};
use chrono::{DateTime, Duration, NaiveDateTime, Utc};
use mas_axum_utils::CookieExt;
use mas_router::PostAuthAction;
use serde::{Deserialize, Serialize};
use thiserror::Error;
use time::OffsetDateTime;
@@ -28,12 +29,13 @@ static COOKIE_NAME: &str = "upstream-oauth2-sessions";
/// Sessions expire after 10 minutes
static SESSION_MAX_TIME_SECS: i64 = 60 * 10;
#[derive(Serialize, Deserialize)]
#[derive(Serialize, Deserialize, Debug)]
pub struct Payload {
session: Ulid,
provider: Ulid,
state: String,
link: Option<Ulid>,
post_auth_action: Option<PostAuthAction>,
}
impl Payload {
@@ -46,7 +48,7 @@ impl Payload {
}
}
#[derive(Serialize, Deserialize, Default)]
#[derive(Serialize, Deserialize, Default, Debug)]
pub struct UpstreamSessions(Vec<Payload>);
#[derive(Debug, Error, PartialEq, Eq)]
@@ -87,12 +89,19 @@ impl UpstreamSessions {
}
/// Add a new session, for a provider and a random state
pub fn add(mut self, session: Ulid, provider: Ulid, state: String) -> Self {
pub fn add(
mut self,
session: Ulid,
provider: Ulid,
state: String,
post_auth_action: Option<PostAuthAction>,
) -> Self {
self.0.push(Payload {
session,
provider,
state,
link: None,
post_auth_action,
});
self
}
@@ -102,11 +111,11 @@ impl UpstreamSessions {
&self,
provider: Ulid,
state: &str,
) -> Result<Ulid, UpstreamSessionNotFound> {
) -> Result<(Ulid, Option<&PostAuthAction>), UpstreamSessionNotFound> {
self.0
.iter()
.find(|p| p.provider == provider && p.state == state && p.link.is_none())
.map(|p| p.session)
.map(|p| (p.session, p.post_auth_action.as_ref()))
.ok_or(UpstreamSessionNotFound)
}
@@ -127,11 +136,14 @@ impl UpstreamSessions {
}
/// Find a session from its link
pub fn lookup_link(&self, link_id: Ulid) -> Result<Ulid, UpstreamSessionNotFound> {
pub fn lookup_link(
&self,
link_id: Ulid,
) -> Result<(Ulid, Option<&PostAuthAction>), UpstreamSessionNotFound> {
self.0
.iter()
.find(|p| p.link == Some(link_id))
.map(|p| p.session)
.map(|p| (p.session, p.post_auth_action.as_ref()))
.ok_or(UpstreamSessionNotFound)
}
@@ -171,22 +183,22 @@ mod tests {
let first_session = Ulid::from_datetime_with_source(now.into(), &mut rng);
let first_state = "first-state";
let sessions = sessions.add(first_session, provider_a, first_state.into());
let sessions = sessions.add(first_session, provider_a, first_state.into(), None);
let now = now + Duration::minutes(5);
let second_session = Ulid::from_datetime_with_source(now.into(), &mut rng);
let second_state = "second-state";
let sessions = sessions.add(second_session, provider_b, second_state.into());
let sessions = sessions.add(second_session, provider_b, second_state.into(), None);
let sessions = sessions.expire(now);
assert_eq!(
sessions.find_session(provider_a, first_state),
Ok(first_session)
sessions.find_session(provider_a, first_state).unwrap().0,
first_session,
);
assert_eq!(
sessions.find_session(provider_b, second_state),
Ok(second_session)
sessions.find_session(provider_b, second_state).unwrap().0,
second_session
);
assert!(sessions.find_session(provider_b, first_state).is_err());
assert!(sessions.find_session(provider_a, second_state).is_err());
@@ -196,8 +208,8 @@ mod tests {
let sessions = sessions.expire(now);
assert!(sessions.find_session(provider_a, first_state).is_err());
assert_eq!(
sessions.find_session(provider_b, second_state),
Ok(second_session)
sessions.find_session(provider_b, second_state).unwrap().0,
second_session
);
// Associate a link with the second
@@ -210,7 +222,7 @@ mod tests {
assert!(sessions.find_session(provider_b, second_state).is_err());
// But it can be looked up by its link
assert_eq!(sessions.lookup_link(second_link), Ok(second_session));
assert_eq!(sessions.lookup_link(second_link).unwrap().0, second_session);
// And it can be consumed
let sessions = sessions.consume_link(second_link).unwrap();
// But only once

View File

@@ -24,7 +24,6 @@ use mas_axum_utils::{
SessionInfoExt,
};
use mas_keystore::Encrypter;
use mas_router::Route;
use mas_storage::{
upstream_oauth2::{
associate_link_to_user, consume_session, lookup_link, lookup_session_on_link,
@@ -44,7 +43,7 @@ use thiserror::Error;
use ulid::Ulid;
use super::UpstreamSessionsCookie;
use crate::impl_from_error_for_route;
use crate::{impl_from_error_for_route, views::shared::OptionalPostAuthAction};
#[derive(Debug, Error)]
pub(crate) enum RouteError {
@@ -114,7 +113,7 @@ pub(crate) async fn get(
let (clock, mut rng) = crate::rng_and_clock()?;
let sessions_cookie = UpstreamSessionsCookie::load(&cookie_jar);
let session_id = sessions_cookie
let (session_id, _post_auth_action) = sessions_cookie
.lookup_link(link_id)
.map_err(|_| RouteError::MissingCookie)?;
@@ -213,10 +212,14 @@ pub(crate) async fn post(
let form = cookie_jar.verify_form(clock.now(), form)?;
let sessions_cookie = UpstreamSessionsCookie::load(&cookie_jar);
let session_id = sessions_cookie
let (session_id, post_auth_action) = sessions_cookie
.lookup_link(link_id)
.map_err(|_| RouteError::MissingCookie)?;
let post_auth_action = OptionalPostAuthAction {
post_auth_action: post_auth_action.cloned(),
};
let link = lookup_link(&mut txn, link_id)
.await
.to_option()?
@@ -267,5 +270,5 @@ pub(crate) async fn post(
txn.commit().await?;
Ok((cookie_jar, mas_router::Index.go()))
Ok((cookie_jar, post_auth_action.go_next()))
}

View File

@@ -22,7 +22,6 @@ use mas_axum_utils::{
FancyError, SessionInfoExt,
};
use mas_keystore::Encrypter;
use mas_router::Route;
use mas_storage::user::{login, LoginError};
use mas_templates::{
FieldError, FormError, LoginContext, LoginFormField, TemplateContext, Templates, ToFormState,
@@ -160,10 +159,7 @@ async fn render(
} else {
ctx
};
let register_link = mas_router::Register::from(action.post_auth_action).relative_url();
let ctx = ctx
.with_register_link(register_link.to_string())
.with_csrf(csrf_token.form_value());
let ctx = ctx.with_csrf(csrf_token.form_value());
let content = templates.render_login(&ctx).await?;
Ok(content)

View File

@@ -243,10 +243,7 @@ async fn render(
} else {
ctx
};
let login_link = mas_router::Login::from(action.post_auth_action).relative_url();
let ctx = ctx
.with_login_link(login_link.to_string())
.with_csrf(csrf_token.form_value());
let ctx = ctx.with_csrf(csrf_token.form_value());
let content = templates.render_register(&ctx).await?;
Ok(content)

View File

@@ -16,7 +16,7 @@ use mas_router::{PostAuthAction, Route};
use mas_storage::{
compat::get_compat_sso_login_by_id, oauth2::authorization_grant::get_grant_by_id,
};
use mas_templates::PostAuthContext;
use mas_templates::{PostAuthContext, PostAuthContextInner};
use serde::{Deserialize, Serialize};
use sqlx::PgConnection;
@@ -41,23 +41,24 @@ impl OptionalPostAuthAction {
&self,
conn: &mut PgConnection,
) -> anyhow::Result<Option<PostAuthContext>> {
match &self.post_auth_action {
Some(PostAuthAction::ContinueAuthorizationGrant { data }) => {
let grant = get_grant_by_id(conn, *data).await?;
let Some(action) = self.post_auth_action.clone() else { return Ok(None) };
let ctx = match action {
PostAuthAction::ContinueAuthorizationGrant { data } => {
let grant = get_grant_by_id(conn, data).await?;
let grant = Box::new(grant.into());
Ok(Some(PostAuthContext::ContinueAuthorizationGrant { grant }))
PostAuthContextInner::ContinueAuthorizationGrant { grant }
}
Some(PostAuthAction::ContinueCompatSsoLogin { data }) => {
let login = get_compat_sso_login_by_id(conn, *data).await?;
PostAuthAction::ContinueCompatSsoLogin { data } => {
let login = get_compat_sso_login_by_id(conn, data).await?;
let login = Box::new(login.into());
Ok(Some(PostAuthContext::ContinueCompatSsoLogin { login }))
PostAuthContextInner::ContinueCompatSsoLogin { login }
}
Some(PostAuthAction::ChangePassword) => Ok(Some(PostAuthContext::ChangePassword)),
PostAuthAction::ChangePassword => PostAuthContextInner::ChangePassword,
Some(PostAuthAction::LinkUpstream { id }) => {
let link = mas_storage::upstream_oauth2::lookup_link(&mut *conn, *id).await?;
PostAuthAction::LinkUpstream { id } => {
let link = mas_storage::upstream_oauth2::lookup_link(&mut *conn, id).await?;
let provider =
mas_storage::upstream_oauth2::lookup_provider(&mut *conn, link.provider_id)
@@ -65,10 +66,13 @@ impl OptionalPostAuthAction {
let provider = Box::new(provider);
let link = Box::new(link);
Ok(Some(PostAuthContext::LinkUpstream { provider, link }))
PostAuthContextInner::LinkUpstream { provider, link }
}
};
None => Ok(None),
}
Ok(Some(PostAuthContext {
params: action.clone(),
ctx,
}))
}
}