You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-07-31 09:24:31 +03:00
Save the post auth action during upstream OAuth login
This commit is contained in:
2
Cargo.lock
generated
2
Cargo.lock
generated
@ -2638,7 +2638,6 @@ dependencies = [
|
|||||||
"async-trait",
|
"async-trait",
|
||||||
"axum 0.6.1",
|
"axum 0.6.1",
|
||||||
"axum-extra",
|
"axum-extra",
|
||||||
"bincode",
|
|
||||||
"chrono",
|
"chrono",
|
||||||
"data-encoding",
|
"data-encoding",
|
||||||
"futures-util",
|
"futures-util",
|
||||||
@ -3125,6 +3124,7 @@ dependencies = [
|
|||||||
"anyhow",
|
"anyhow",
|
||||||
"camino",
|
"camino",
|
||||||
"chrono",
|
"chrono",
|
||||||
|
"http",
|
||||||
"mas-data-model",
|
"mas-data-model",
|
||||||
"mas-router",
|
"mas-router",
|
||||||
"oauth2-types",
|
"oauth2-types",
|
||||||
|
@ -9,7 +9,6 @@ license = "Apache-2.0"
|
|||||||
async-trait = "0.1.59"
|
async-trait = "0.1.59"
|
||||||
axum = { version = "0.6.1", features = ["headers"] }
|
axum = { version = "0.6.1", features = ["headers"] }
|
||||||
axum-extra = { version = "0.4.2", features = ["cookie-private"] }
|
axum-extra = { version = "0.4.2", features = ["cookie-private"] }
|
||||||
bincode = "1.3.3"
|
|
||||||
chrono = "0.4.23"
|
chrono = "0.4.23"
|
||||||
data-encoding = "2.3.2"
|
data-encoding = "2.3.2"
|
||||||
futures-util = "0.3.25"
|
futures-util = "0.3.25"
|
||||||
|
@ -14,15 +14,13 @@
|
|||||||
|
|
||||||
//! Private (encrypted) cookie jar, based on axum-extra's cookie jar
|
//! Private (encrypted) cookie jar, based on axum-extra's cookie jar
|
||||||
|
|
||||||
use data_encoding::BASE64URL_NOPAD;
|
|
||||||
use serde::{de::DeserializeOwned, Serialize};
|
use serde::{de::DeserializeOwned, Serialize};
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
|
|
||||||
#[derive(Debug, Error)]
|
#[derive(Debug, Error)]
|
||||||
#[error("could not decode cookie")]
|
#[error("could not decode cookie")]
|
||||||
pub enum CookieDecodeError {
|
pub enum CookieDecodeError {
|
||||||
Deserialize(#[from] bincode::Error),
|
Deserialize(#[from] serde_json::Error),
|
||||||
Decode(#[from] data_encoding::DecodeError),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub trait CookieExt {
|
pub trait CookieExt {
|
||||||
@ -41,10 +39,7 @@ impl<'a> CookieExt for axum_extra::extract::cookie::Cookie<'a> {
|
|||||||
where
|
where
|
||||||
T: DeserializeOwned,
|
T: DeserializeOwned,
|
||||||
{
|
{
|
||||||
let bytes = BASE64URL_NOPAD.decode(self.value().as_bytes())?;
|
let decoded = serde_json::from_str(self.value())?;
|
||||||
|
|
||||||
let decoded = bincode::deserialize(&bytes)?;
|
|
||||||
|
|
||||||
Ok(decoded)
|
Ok(decoded)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -52,8 +47,7 @@ impl<'a> CookieExt for axum_extra::extract::cookie::Cookie<'a> {
|
|||||||
where
|
where
|
||||||
T: Serialize,
|
T: Serialize,
|
||||||
{
|
{
|
||||||
let bytes = bincode::serialize(t).unwrap();
|
let encoded = serde_json::to_string(t).unwrap();
|
||||||
let encoded = BASE64URL_NOPAD.encode(&bytes);
|
|
||||||
self.set_value(encoded);
|
self.set_value(encoded);
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
@ -13,7 +13,7 @@
|
|||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
use axum::{
|
use axum::{
|
||||||
extract::{Path, State},
|
extract::{Path, Query, State},
|
||||||
response::{IntoResponse, Redirect},
|
response::{IntoResponse, Redirect},
|
||||||
};
|
};
|
||||||
use axum_extra::extract::PrivateCookieJar;
|
use axum_extra::extract::PrivateCookieJar;
|
||||||
@ -28,7 +28,7 @@ use thiserror::Error;
|
|||||||
use ulid::Ulid;
|
use ulid::Ulid;
|
||||||
|
|
||||||
use super::UpstreamSessionsCookie;
|
use super::UpstreamSessionsCookie;
|
||||||
use crate::impl_from_error_for_route;
|
use crate::{impl_from_error_for_route, views::shared::OptionalPostAuthAction};
|
||||||
|
|
||||||
#[derive(Debug, Error)]
|
#[derive(Debug, Error)]
|
||||||
pub(crate) enum RouteError {
|
pub(crate) enum RouteError {
|
||||||
@ -68,6 +68,7 @@ pub(crate) async fn get(
|
|||||||
State(url_builder): State<UrlBuilder>,
|
State(url_builder): State<UrlBuilder>,
|
||||||
cookie_jar: PrivateCookieJar<Encrypter>,
|
cookie_jar: PrivateCookieJar<Encrypter>,
|
||||||
Path(provider_id): Path<Ulid>,
|
Path(provider_id): Path<Ulid>,
|
||||||
|
Query(query): Query<OptionalPostAuthAction>,
|
||||||
) -> Result<impl IntoResponse, RouteError> {
|
) -> Result<impl IntoResponse, RouteError> {
|
||||||
let (clock, mut rng) = crate::rng_and_clock()?;
|
let (clock, mut rng) = crate::rng_and_clock()?;
|
||||||
|
|
||||||
@ -115,7 +116,7 @@ pub(crate) async fn get(
|
|||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
let cookie_jar = UpstreamSessionsCookie::load(&cookie_jar)
|
let cookie_jar = UpstreamSessionsCookie::load(&cookie_jar)
|
||||||
.add(session.id, provider.id, data.state)
|
.add(session.id, provider.id, data.state, query.post_auth_action)
|
||||||
.save(cookie_jar, clock.now());
|
.save(cookie_jar, clock.now());
|
||||||
|
|
||||||
txn.commit().await?;
|
txn.commit().await?;
|
||||||
|
@ -137,7 +137,7 @@ pub(crate) async fn get(
|
|||||||
let mut txn = pool.begin().await?;
|
let mut txn = pool.begin().await?;
|
||||||
|
|
||||||
let sessions_cookie = UpstreamSessionsCookie::load(&cookie_jar);
|
let sessions_cookie = UpstreamSessionsCookie::load(&cookie_jar);
|
||||||
let session_id = sessions_cookie
|
let (session_id, _post_auth_action) = sessions_cookie
|
||||||
.find_session(provider_id, ¶ms.state)
|
.find_session(provider_id, ¶ms.state)
|
||||||
.map_err(|_| RouteError::MissingCookie)?;
|
.map_err(|_| RouteError::MissingCookie)?;
|
||||||
|
|
||||||
|
@ -17,6 +17,7 @@
|
|||||||
use axum_extra::extract::{cookie::Cookie, PrivateCookieJar};
|
use axum_extra::extract::{cookie::Cookie, PrivateCookieJar};
|
||||||
use chrono::{DateTime, Duration, NaiveDateTime, Utc};
|
use chrono::{DateTime, Duration, NaiveDateTime, Utc};
|
||||||
use mas_axum_utils::CookieExt;
|
use mas_axum_utils::CookieExt;
|
||||||
|
use mas_router::PostAuthAction;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use time::OffsetDateTime;
|
use time::OffsetDateTime;
|
||||||
@ -28,12 +29,13 @@ static COOKIE_NAME: &str = "upstream-oauth2-sessions";
|
|||||||
/// Sessions expire after 10 minutes
|
/// Sessions expire after 10 minutes
|
||||||
static SESSION_MAX_TIME_SECS: i64 = 60 * 10;
|
static SESSION_MAX_TIME_SECS: i64 = 60 * 10;
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize)]
|
#[derive(Serialize, Deserialize, Debug)]
|
||||||
pub struct Payload {
|
pub struct Payload {
|
||||||
session: Ulid,
|
session: Ulid,
|
||||||
provider: Ulid,
|
provider: Ulid,
|
||||||
state: String,
|
state: String,
|
||||||
link: Option<Ulid>,
|
link: Option<Ulid>,
|
||||||
|
post_auth_action: Option<PostAuthAction>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Payload {
|
impl Payload {
|
||||||
@ -46,7 +48,7 @@ impl Payload {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Default)]
|
#[derive(Serialize, Deserialize, Default, Debug)]
|
||||||
pub struct UpstreamSessions(Vec<Payload>);
|
pub struct UpstreamSessions(Vec<Payload>);
|
||||||
|
|
||||||
#[derive(Debug, Error, PartialEq, Eq)]
|
#[derive(Debug, Error, PartialEq, Eq)]
|
||||||
@ -87,12 +89,19 @@ impl UpstreamSessions {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Add a new session, for a provider and a random state
|
/// Add a new session, for a provider and a random state
|
||||||
pub fn add(mut self, session: Ulid, provider: Ulid, state: String) -> Self {
|
pub fn add(
|
||||||
|
mut self,
|
||||||
|
session: Ulid,
|
||||||
|
provider: Ulid,
|
||||||
|
state: String,
|
||||||
|
post_auth_action: Option<PostAuthAction>,
|
||||||
|
) -> Self {
|
||||||
self.0.push(Payload {
|
self.0.push(Payload {
|
||||||
session,
|
session,
|
||||||
provider,
|
provider,
|
||||||
state,
|
state,
|
||||||
link: None,
|
link: None,
|
||||||
|
post_auth_action,
|
||||||
});
|
});
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
@ -102,11 +111,11 @@ impl UpstreamSessions {
|
|||||||
&self,
|
&self,
|
||||||
provider: Ulid,
|
provider: Ulid,
|
||||||
state: &str,
|
state: &str,
|
||||||
) -> Result<Ulid, UpstreamSessionNotFound> {
|
) -> Result<(Ulid, Option<&PostAuthAction>), UpstreamSessionNotFound> {
|
||||||
self.0
|
self.0
|
||||||
.iter()
|
.iter()
|
||||||
.find(|p| p.provider == provider && p.state == state && p.link.is_none())
|
.find(|p| p.provider == provider && p.state == state && p.link.is_none())
|
||||||
.map(|p| p.session)
|
.map(|p| (p.session, p.post_auth_action.as_ref()))
|
||||||
.ok_or(UpstreamSessionNotFound)
|
.ok_or(UpstreamSessionNotFound)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -127,11 +136,14 @@ impl UpstreamSessions {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Find a session from its link
|
/// Find a session from its link
|
||||||
pub fn lookup_link(&self, link_id: Ulid) -> Result<Ulid, UpstreamSessionNotFound> {
|
pub fn lookup_link(
|
||||||
|
&self,
|
||||||
|
link_id: Ulid,
|
||||||
|
) -> Result<(Ulid, Option<&PostAuthAction>), UpstreamSessionNotFound> {
|
||||||
self.0
|
self.0
|
||||||
.iter()
|
.iter()
|
||||||
.find(|p| p.link == Some(link_id))
|
.find(|p| p.link == Some(link_id))
|
||||||
.map(|p| p.session)
|
.map(|p| (p.session, p.post_auth_action.as_ref()))
|
||||||
.ok_or(UpstreamSessionNotFound)
|
.ok_or(UpstreamSessionNotFound)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -171,22 +183,22 @@ mod tests {
|
|||||||
|
|
||||||
let first_session = Ulid::from_datetime_with_source(now.into(), &mut rng);
|
let first_session = Ulid::from_datetime_with_source(now.into(), &mut rng);
|
||||||
let first_state = "first-state";
|
let first_state = "first-state";
|
||||||
let sessions = sessions.add(first_session, provider_a, first_state.into());
|
let sessions = sessions.add(first_session, provider_a, first_state.into(), None);
|
||||||
|
|
||||||
let now = now + Duration::minutes(5);
|
let now = now + Duration::minutes(5);
|
||||||
|
|
||||||
let second_session = Ulid::from_datetime_with_source(now.into(), &mut rng);
|
let second_session = Ulid::from_datetime_with_source(now.into(), &mut rng);
|
||||||
let second_state = "second-state";
|
let second_state = "second-state";
|
||||||
let sessions = sessions.add(second_session, provider_b, second_state.into());
|
let sessions = sessions.add(second_session, provider_b, second_state.into(), None);
|
||||||
|
|
||||||
let sessions = sessions.expire(now);
|
let sessions = sessions.expire(now);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
sessions.find_session(provider_a, first_state),
|
sessions.find_session(provider_a, first_state).unwrap().0,
|
||||||
Ok(first_session)
|
first_session,
|
||||||
);
|
);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
sessions.find_session(provider_b, second_state),
|
sessions.find_session(provider_b, second_state).unwrap().0,
|
||||||
Ok(second_session)
|
second_session
|
||||||
);
|
);
|
||||||
assert!(sessions.find_session(provider_b, first_state).is_err());
|
assert!(sessions.find_session(provider_b, first_state).is_err());
|
||||||
assert!(sessions.find_session(provider_a, second_state).is_err());
|
assert!(sessions.find_session(provider_a, second_state).is_err());
|
||||||
@ -196,8 +208,8 @@ mod tests {
|
|||||||
let sessions = sessions.expire(now);
|
let sessions = sessions.expire(now);
|
||||||
assert!(sessions.find_session(provider_a, first_state).is_err());
|
assert!(sessions.find_session(provider_a, first_state).is_err());
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
sessions.find_session(provider_b, second_state),
|
sessions.find_session(provider_b, second_state).unwrap().0,
|
||||||
Ok(second_session)
|
second_session
|
||||||
);
|
);
|
||||||
|
|
||||||
// Associate a link with the second
|
// Associate a link with the second
|
||||||
@ -210,7 +222,7 @@ mod tests {
|
|||||||
assert!(sessions.find_session(provider_b, second_state).is_err());
|
assert!(sessions.find_session(provider_b, second_state).is_err());
|
||||||
|
|
||||||
// But it can be looked up by its link
|
// But it can be looked up by its link
|
||||||
assert_eq!(sessions.lookup_link(second_link), Ok(second_session));
|
assert_eq!(sessions.lookup_link(second_link).unwrap().0, second_session);
|
||||||
// And it can be consumed
|
// And it can be consumed
|
||||||
let sessions = sessions.consume_link(second_link).unwrap();
|
let sessions = sessions.consume_link(second_link).unwrap();
|
||||||
// But only once
|
// But only once
|
||||||
|
@ -24,7 +24,6 @@ use mas_axum_utils::{
|
|||||||
SessionInfoExt,
|
SessionInfoExt,
|
||||||
};
|
};
|
||||||
use mas_keystore::Encrypter;
|
use mas_keystore::Encrypter;
|
||||||
use mas_router::Route;
|
|
||||||
use mas_storage::{
|
use mas_storage::{
|
||||||
upstream_oauth2::{
|
upstream_oauth2::{
|
||||||
associate_link_to_user, consume_session, lookup_link, lookup_session_on_link,
|
associate_link_to_user, consume_session, lookup_link, lookup_session_on_link,
|
||||||
@ -44,7 +43,7 @@ use thiserror::Error;
|
|||||||
use ulid::Ulid;
|
use ulid::Ulid;
|
||||||
|
|
||||||
use super::UpstreamSessionsCookie;
|
use super::UpstreamSessionsCookie;
|
||||||
use crate::impl_from_error_for_route;
|
use crate::{impl_from_error_for_route, views::shared::OptionalPostAuthAction};
|
||||||
|
|
||||||
#[derive(Debug, Error)]
|
#[derive(Debug, Error)]
|
||||||
pub(crate) enum RouteError {
|
pub(crate) enum RouteError {
|
||||||
@ -114,7 +113,7 @@ pub(crate) async fn get(
|
|||||||
let (clock, mut rng) = crate::rng_and_clock()?;
|
let (clock, mut rng) = crate::rng_and_clock()?;
|
||||||
|
|
||||||
let sessions_cookie = UpstreamSessionsCookie::load(&cookie_jar);
|
let sessions_cookie = UpstreamSessionsCookie::load(&cookie_jar);
|
||||||
let session_id = sessions_cookie
|
let (session_id, _post_auth_action) = sessions_cookie
|
||||||
.lookup_link(link_id)
|
.lookup_link(link_id)
|
||||||
.map_err(|_| RouteError::MissingCookie)?;
|
.map_err(|_| RouteError::MissingCookie)?;
|
||||||
|
|
||||||
@ -213,10 +212,14 @@ pub(crate) async fn post(
|
|||||||
let form = cookie_jar.verify_form(clock.now(), form)?;
|
let form = cookie_jar.verify_form(clock.now(), form)?;
|
||||||
|
|
||||||
let sessions_cookie = UpstreamSessionsCookie::load(&cookie_jar);
|
let sessions_cookie = UpstreamSessionsCookie::load(&cookie_jar);
|
||||||
let session_id = sessions_cookie
|
let (session_id, post_auth_action) = sessions_cookie
|
||||||
.lookup_link(link_id)
|
.lookup_link(link_id)
|
||||||
.map_err(|_| RouteError::MissingCookie)?;
|
.map_err(|_| RouteError::MissingCookie)?;
|
||||||
|
|
||||||
|
let post_auth_action = OptionalPostAuthAction {
|
||||||
|
post_auth_action: post_auth_action.cloned(),
|
||||||
|
};
|
||||||
|
|
||||||
let link = lookup_link(&mut txn, link_id)
|
let link = lookup_link(&mut txn, link_id)
|
||||||
.await
|
.await
|
||||||
.to_option()?
|
.to_option()?
|
||||||
@ -267,5 +270,5 @@ pub(crate) async fn post(
|
|||||||
|
|
||||||
txn.commit().await?;
|
txn.commit().await?;
|
||||||
|
|
||||||
Ok((cookie_jar, mas_router::Index.go()))
|
Ok((cookie_jar, post_auth_action.go_next()))
|
||||||
}
|
}
|
||||||
|
@ -22,7 +22,6 @@ use mas_axum_utils::{
|
|||||||
FancyError, SessionInfoExt,
|
FancyError, SessionInfoExt,
|
||||||
};
|
};
|
||||||
use mas_keystore::Encrypter;
|
use mas_keystore::Encrypter;
|
||||||
use mas_router::Route;
|
|
||||||
use mas_storage::user::{login, LoginError};
|
use mas_storage::user::{login, LoginError};
|
||||||
use mas_templates::{
|
use mas_templates::{
|
||||||
FieldError, FormError, LoginContext, LoginFormField, TemplateContext, Templates, ToFormState,
|
FieldError, FormError, LoginContext, LoginFormField, TemplateContext, Templates, ToFormState,
|
||||||
@ -160,10 +159,7 @@ async fn render(
|
|||||||
} else {
|
} else {
|
||||||
ctx
|
ctx
|
||||||
};
|
};
|
||||||
let register_link = mas_router::Register::from(action.post_auth_action).relative_url();
|
let ctx = ctx.with_csrf(csrf_token.form_value());
|
||||||
let ctx = ctx
|
|
||||||
.with_register_link(register_link.to_string())
|
|
||||||
.with_csrf(csrf_token.form_value());
|
|
||||||
|
|
||||||
let content = templates.render_login(&ctx).await?;
|
let content = templates.render_login(&ctx).await?;
|
||||||
Ok(content)
|
Ok(content)
|
||||||
|
@ -243,10 +243,7 @@ async fn render(
|
|||||||
} else {
|
} else {
|
||||||
ctx
|
ctx
|
||||||
};
|
};
|
||||||
let login_link = mas_router::Login::from(action.post_auth_action).relative_url();
|
let ctx = ctx.with_csrf(csrf_token.form_value());
|
||||||
let ctx = ctx
|
|
||||||
.with_login_link(login_link.to_string())
|
|
||||||
.with_csrf(csrf_token.form_value());
|
|
||||||
|
|
||||||
let content = templates.render_register(&ctx).await?;
|
let content = templates.render_register(&ctx).await?;
|
||||||
Ok(content)
|
Ok(content)
|
||||||
|
@ -16,7 +16,7 @@ use mas_router::{PostAuthAction, Route};
|
|||||||
use mas_storage::{
|
use mas_storage::{
|
||||||
compat::get_compat_sso_login_by_id, oauth2::authorization_grant::get_grant_by_id,
|
compat::get_compat_sso_login_by_id, oauth2::authorization_grant::get_grant_by_id,
|
||||||
};
|
};
|
||||||
use mas_templates::PostAuthContext;
|
use mas_templates::{PostAuthContext, PostAuthContextInner};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use sqlx::PgConnection;
|
use sqlx::PgConnection;
|
||||||
|
|
||||||
@ -41,23 +41,24 @@ impl OptionalPostAuthAction {
|
|||||||
&self,
|
&self,
|
||||||
conn: &mut PgConnection,
|
conn: &mut PgConnection,
|
||||||
) -> anyhow::Result<Option<PostAuthContext>> {
|
) -> anyhow::Result<Option<PostAuthContext>> {
|
||||||
match &self.post_auth_action {
|
let Some(action) = self.post_auth_action.clone() else { return Ok(None) };
|
||||||
Some(PostAuthAction::ContinueAuthorizationGrant { data }) => {
|
let ctx = match action {
|
||||||
let grant = get_grant_by_id(conn, *data).await?;
|
PostAuthAction::ContinueAuthorizationGrant { data } => {
|
||||||
|
let grant = get_grant_by_id(conn, data).await?;
|
||||||
let grant = Box::new(grant.into());
|
let grant = Box::new(grant.into());
|
||||||
Ok(Some(PostAuthContext::ContinueAuthorizationGrant { grant }))
|
PostAuthContextInner::ContinueAuthorizationGrant { grant }
|
||||||
}
|
}
|
||||||
|
|
||||||
Some(PostAuthAction::ContinueCompatSsoLogin { data }) => {
|
PostAuthAction::ContinueCompatSsoLogin { data } => {
|
||||||
let login = get_compat_sso_login_by_id(conn, *data).await?;
|
let login = get_compat_sso_login_by_id(conn, data).await?;
|
||||||
let login = Box::new(login.into());
|
let login = Box::new(login.into());
|
||||||
Ok(Some(PostAuthContext::ContinueCompatSsoLogin { login }))
|
PostAuthContextInner::ContinueCompatSsoLogin { login }
|
||||||
}
|
}
|
||||||
|
|
||||||
Some(PostAuthAction::ChangePassword) => Ok(Some(PostAuthContext::ChangePassword)),
|
PostAuthAction::ChangePassword => PostAuthContextInner::ChangePassword,
|
||||||
|
|
||||||
Some(PostAuthAction::LinkUpstream { id }) => {
|
PostAuthAction::LinkUpstream { id } => {
|
||||||
let link = mas_storage::upstream_oauth2::lookup_link(&mut *conn, *id).await?;
|
let link = mas_storage::upstream_oauth2::lookup_link(&mut *conn, id).await?;
|
||||||
|
|
||||||
let provider =
|
let provider =
|
||||||
mas_storage::upstream_oauth2::lookup_provider(&mut *conn, link.provider_id)
|
mas_storage::upstream_oauth2::lookup_provider(&mut *conn, link.provider_id)
|
||||||
@ -65,10 +66,13 @@ impl OptionalPostAuthAction {
|
|||||||
|
|
||||||
let provider = Box::new(provider);
|
let provider = Box::new(provider);
|
||||||
let link = Box::new(link);
|
let link = Box::new(link);
|
||||||
Ok(Some(PostAuthContext::LinkUpstream { provider, link }))
|
PostAuthContextInner::LinkUpstream { provider, link }
|
||||||
}
|
}
|
||||||
|
};
|
||||||
|
|
||||||
None => Ok(None),
|
Ok(Some(PostAuthContext {
|
||||||
}
|
params: action.clone(),
|
||||||
|
ctx,
|
||||||
|
}))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -532,17 +532,27 @@ impl Route for CompatLoginSsoComplete {
|
|||||||
/// `GET /upstream/authorize/:id`
|
/// `GET /upstream/authorize/:id`
|
||||||
pub struct UpstreamOAuth2Authorize {
|
pub struct UpstreamOAuth2Authorize {
|
||||||
id: Ulid,
|
id: Ulid,
|
||||||
|
post_auth_action: Option<PostAuthAction>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl UpstreamOAuth2Authorize {
|
impl UpstreamOAuth2Authorize {
|
||||||
#[must_use]
|
#[must_use]
|
||||||
pub const fn new(id: Ulid) -> Self {
|
pub const fn new(id: Ulid) -> Self {
|
||||||
Self { id }
|
Self {
|
||||||
|
id,
|
||||||
|
post_auth_action: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn and_then(mut self, action: PostAuthAction) -> Self {
|
||||||
|
self.post_auth_action = Some(action);
|
||||||
|
self
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Route for UpstreamOAuth2Authorize {
|
impl Route for UpstreamOAuth2Authorize {
|
||||||
type Query = ();
|
type Query = PostAuthAction;
|
||||||
fn route() -> &'static str {
|
fn route() -> &'static str {
|
||||||
"/upstream/authorize/:provider_id"
|
"/upstream/authorize/:provider_id"
|
||||||
}
|
}
|
||||||
@ -550,6 +560,10 @@ impl Route for UpstreamOAuth2Authorize {
|
|||||||
fn path(&self) -> std::borrow::Cow<'static, str> {
|
fn path(&self) -> std::borrow::Cow<'static, str> {
|
||||||
format!("/upstream/authorize/{}", self.id).into()
|
format!("/upstream/authorize/{}", self.id).into()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn query(&self) -> Option<&Self::Query> {
|
||||||
|
self.post_auth_action.as_ref()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// `GET /upstream/callback/:id`
|
/// `GET /upstream/callback/:id`
|
||||||
|
@ -20,6 +20,7 @@ serde_urlencoded = "0.7.1"
|
|||||||
camino = "1.1.1"
|
camino = "1.1.1"
|
||||||
chrono = "0.4.23"
|
chrono = "0.4.23"
|
||||||
url = "2.3.1"
|
url = "2.3.1"
|
||||||
|
http = "0.2.8"
|
||||||
ulid = { version = "1.0.0", features = ["serde"] }
|
ulid = { version = "1.0.0", features = ["serde"] }
|
||||||
|
|
||||||
oauth2-types = { path = "../oauth2-types" }
|
oauth2-types = { path = "../oauth2-types" }
|
||||||
|
@ -244,10 +244,10 @@ impl FormField for LoginFormField {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Context used in login and reauth screens, for the post-auth action to do
|
/// Inner context used in login and reauth screens. See [`PostAuthContext`].
|
||||||
#[derive(Serialize)]
|
#[derive(Serialize)]
|
||||||
#[serde(tag = "kind", rename_all = "snake_case")]
|
#[serde(tag = "kind", rename_all = "snake_case")]
|
||||||
pub enum PostAuthContext {
|
pub enum PostAuthContextInner {
|
||||||
/// Continue an authorization grant
|
/// Continue an authorization grant
|
||||||
ContinueAuthorizationGrant {
|
ContinueAuthorizationGrant {
|
||||||
/// The authorization grant that will be continued after authentication
|
/// The authorization grant that will be continued after authentication
|
||||||
@ -274,13 +274,23 @@ pub enum PostAuthContext {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Context used in login and reauth screens, for the post-auth action to do
|
||||||
|
#[derive(Serialize)]
|
||||||
|
pub struct PostAuthContext {
|
||||||
|
/// The post auth action params from the URL
|
||||||
|
pub params: PostAuthAction,
|
||||||
|
|
||||||
|
/// The loaded post auth context
|
||||||
|
#[serde(flatten)]
|
||||||
|
pub ctx: PostAuthContextInner,
|
||||||
|
}
|
||||||
|
|
||||||
/// Context used by the `login.html` template
|
/// Context used by the `login.html` template
|
||||||
#[derive(Serialize, Default)]
|
#[derive(Serialize, Default)]
|
||||||
pub struct LoginContext {
|
pub struct LoginContext {
|
||||||
form: FormState<LoginFormField>,
|
form: FormState<LoginFormField>,
|
||||||
next: Option<PostAuthContext>,
|
next: Option<PostAuthContext>,
|
||||||
providers: Vec<UpstreamOAuthProvider>,
|
providers: Vec<UpstreamOAuthProvider>,
|
||||||
register_link: String,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl TemplateContext for LoginContext {
|
impl TemplateContext for LoginContext {
|
||||||
@ -293,7 +303,6 @@ impl TemplateContext for LoginContext {
|
|||||||
form: FormState::default(),
|
form: FormState::default(),
|
||||||
next: None,
|
next: None,
|
||||||
providers: Vec::new(),
|
providers: Vec::new(),
|
||||||
register_link: "/register".to_owned(),
|
|
||||||
}]
|
}]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -313,18 +322,9 @@ impl LoginContext {
|
|||||||
|
|
||||||
/// Add a post authentication action to the context
|
/// Add a post authentication action to the context
|
||||||
#[must_use]
|
#[must_use]
|
||||||
pub fn with_post_action(self, next: PostAuthContext) -> Self {
|
pub fn with_post_action(self, context: PostAuthContext) -> Self {
|
||||||
Self {
|
Self {
|
||||||
next: Some(next),
|
next: Some(context),
|
||||||
..self
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Add a registration link to the context
|
|
||||||
#[must_use]
|
|
||||||
pub fn with_register_link(self, register_link: String) -> Self {
|
|
||||||
Self {
|
|
||||||
register_link,
|
|
||||||
..self
|
..self
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -361,7 +361,6 @@ impl FormField for RegisterFormField {
|
|||||||
pub struct RegisterContext {
|
pub struct RegisterContext {
|
||||||
form: FormState<RegisterFormField>,
|
form: FormState<RegisterFormField>,
|
||||||
next: Option<PostAuthContext>,
|
next: Option<PostAuthContext>,
|
||||||
login_link: String,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl TemplateContext for RegisterContext {
|
impl TemplateContext for RegisterContext {
|
||||||
@ -373,7 +372,6 @@ impl TemplateContext for RegisterContext {
|
|||||||
vec![RegisterContext {
|
vec![RegisterContext {
|
||||||
form: FormState::default(),
|
form: FormState::default(),
|
||||||
next: None,
|
next: None,
|
||||||
login_link: "/login".to_owned(),
|
|
||||||
}]
|
}]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -393,12 +391,6 @@ impl RegisterContext {
|
|||||||
..self
|
..self
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Add a login link to the context
|
|
||||||
#[must_use]
|
|
||||||
pub fn with_login_link(self, login_link: String) -> Self {
|
|
||||||
Self { login_link, ..self }
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Context used by the `consent.html` template
|
/// Context used by the `consent.html` template
|
||||||
|
@ -22,7 +22,9 @@ use url::Url;
|
|||||||
|
|
||||||
pub fn register(tera: &mut Tera, url_builder: UrlBuilder) {
|
pub fn register(tera: &mut Tera, url_builder: UrlBuilder) {
|
||||||
tera.register_tester("empty", self::tester_empty);
|
tera.register_tester("empty", self::tester_empty);
|
||||||
tera.register_function("add_params_to_uri", function_add_params_to_uri);
|
tera.register_filter("to_params", filter_to_params);
|
||||||
|
tera.register_filter("safe_get", filter_safe_get);
|
||||||
|
tera.register_function("add_params_to_url", function_add_params_to_url);
|
||||||
tera.register_function("merge", function_merge);
|
tera.register_function("merge", function_merge);
|
||||||
tera.register_function("dict", function_dict);
|
tera.register_function("dict", function_dict);
|
||||||
tera.register_function("static_asset", make_static_asset(url_builder));
|
tera.register_function("static_asset", make_static_asset(url_builder));
|
||||||
@ -37,12 +39,42 @@ fn tester_empty(value: Option<&Value>, params: &[Value]) -> Result<bool, tera::E
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn filter_to_params(params: &Value, kv: &HashMap<String, Value>) -> Result<Value, tera::Error> {
|
||||||
|
let prefix = kv.get("prefix").and_then(Value::as_str).unwrap_or("");
|
||||||
|
let params = serde_urlencoded::to_string(params)
|
||||||
|
.map_err(|e| tera::Error::chain(e, "Could not serialize parameters"))?;
|
||||||
|
|
||||||
|
if params.is_empty() {
|
||||||
|
Ok(Value::String(String::new()))
|
||||||
|
} else {
|
||||||
|
Ok(Value::String(format!("{prefix}{params}")))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Alternative to `get` which does not crash on `None` and defaults to `None`
|
||||||
|
pub fn filter_safe_get(value: &Value, args: &HashMap<String, Value>) -> Result<Value, tera::Error> {
|
||||||
|
let default = args.get("default").unwrap_or(&Value::Null);
|
||||||
|
let key = args
|
||||||
|
.get("key")
|
||||||
|
.and_then(Value::as_str)
|
||||||
|
.ok_or_else(|| tera::Error::msg("Invalid parameter `uri`"))?;
|
||||||
|
|
||||||
|
match value.as_object() {
|
||||||
|
Some(o) => match o.get(key) {
|
||||||
|
Some(val) => Ok(val.clone()),
|
||||||
|
// If the value is not present, allow for an optional default value
|
||||||
|
None => Ok(default.clone()),
|
||||||
|
},
|
||||||
|
None => Ok(default.clone()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
enum ParamsWhere {
|
enum ParamsWhere {
|
||||||
Fragment,
|
Fragment,
|
||||||
Query,
|
Query,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn function_add_params_to_uri(params: &HashMap<String, Value>) -> Result<Value, tera::Error> {
|
fn function_add_params_to_url(params: &HashMap<String, Value>) -> Result<Value, tera::Error> {
|
||||||
use ParamsWhere::{Fragment, Query};
|
use ParamsWhere::{Fragment, Query};
|
||||||
|
|
||||||
// First, get the `uri`, `mode` and `params` parameters
|
// First, get the `uri`, `mode` and `params` parameters
|
||||||
@ -77,12 +109,7 @@ fn function_add_params_to_uri(params: &HashMap<String, Value>) -> Result<Value,
|
|||||||
.unwrap_or_default();
|
.unwrap_or_default();
|
||||||
|
|
||||||
// Merge the exising and the additional parameters together
|
// Merge the exising and the additional parameters together
|
||||||
let params: HashMap<&String, &Value> = params
|
let params: HashMap<&String, &Value> = params.iter().chain(existing.iter()).collect();
|
||||||
.iter()
|
|
||||||
// Filter out the `uri` and `mode` params
|
|
||||||
.filter(|(k, _v)| k != &"uri" && k != &"mode")
|
|
||||||
.chain(existing.iter())
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
// Transform them back to urlencoded
|
// Transform them back to urlencoded
|
||||||
let params = serde_urlencoded::to_string(params)
|
let params = serde_urlencoded::to_string(params)
|
||||||
|
@ -48,9 +48,9 @@ pub use self::{
|
|||||||
AccountContext, AccountEmailsContext, CompatSsoContext, ConsentContext, EmailAddContext,
|
AccountContext, AccountEmailsContext, CompatSsoContext, ConsentContext, EmailAddContext,
|
||||||
EmailVerificationContext, EmailVerificationPageContext, EmptyContext, ErrorContext,
|
EmailVerificationContext, EmailVerificationPageContext, EmptyContext, ErrorContext,
|
||||||
FormPostContext, IndexContext, LoginContext, LoginFormField, PolicyViolationContext,
|
FormPostContext, IndexContext, LoginContext, LoginFormField, PolicyViolationContext,
|
||||||
PostAuthContext, ReauthContext, ReauthFormField, RegisterContext, RegisterFormField,
|
PostAuthContext, PostAuthContextInner, ReauthContext, ReauthFormField, RegisterContext,
|
||||||
TemplateContext, UpstreamExistingLinkContext, UpstreamRegister, UpstreamSuggestLink,
|
RegisterFormField, TemplateContext, UpstreamExistingLinkContext, UpstreamRegister,
|
||||||
WithCsrf, WithOptionalSession, WithSession,
|
UpstreamSuggestLink, WithCsrf, WithOptionalSession, WithSession,
|
||||||
},
|
},
|
||||||
forms::{FieldError, FormError, FormField, FormState, ToFormState},
|
forms::{FieldError, FormError, FormField, FormState, ToFormState},
|
||||||
};
|
};
|
||||||
|
@ -23,7 +23,7 @@ limitations under the License.
|
|||||||
<button class="{{ class }}" type="submit">{{ text }}</button>
|
<button class="{{ class }}" type="submit">{{ text }}</button>
|
||||||
</form>
|
</form>
|
||||||
{% elif mode == "fragment" or mode == "query" %}
|
{% elif mode == "fragment" or mode == "query" %}
|
||||||
<a class="{{ class }}" href="{{ add_params_to_uri(uri=uri, mode=mode, params=params) }}">{{ text }}</a>
|
<a class="{{ class }}" href="{{ add_params_to_url(uri=uri, mode=mode, params=params) }}">{{ text }}</a>
|
||||||
{% else %}
|
{% else %}
|
||||||
{{ throw(message="Invalid mode") }}
|
{{ throw(message="Invalid mode") }}
|
||||||
{% endif %}
|
{% endif %}
|
||||||
|
@ -62,7 +62,8 @@ limitations under the License.
|
|||||||
{% if not next or next.kind != "link_upstream" %}
|
{% if not next or next.kind != "link_upstream" %}
|
||||||
<div class="text-center mt-4">
|
<div class="text-center mt-4">
|
||||||
Don't have an account yet?
|
Don't have an account yet?
|
||||||
{{ button::link_text(text="Create an account", href=register_link) }}
|
{% set params = next | safe_get(key="params") | to_params(prefix="?") %}
|
||||||
|
{{ button::link_text(text="Create an account", href="/register" ~ params) }}
|
||||||
</div>
|
</div>
|
||||||
{% endif %}
|
{% endif %}
|
||||||
|
|
||||||
@ -74,7 +75,8 @@ limitations under the License.
|
|||||||
</div>
|
</div>
|
||||||
|
|
||||||
{% for provider in providers %}
|
{% for provider in providers %}
|
||||||
{{ button::link(text="Continue with " ~ provider.issuer, href="/upstream/authorize/" ~ provider.id) }}
|
{% set params = next | safe_get(key="params") | to_params(prefix="?") %}
|
||||||
|
{{ button::link(text="Continue with " ~ provider.issuer, href="/upstream/authorize/" ~ provider.id ~ params) }}
|
||||||
{% endfor %}
|
{% endfor %}
|
||||||
{% endif %}
|
{% endif %}
|
||||||
</form>
|
</form>
|
||||||
|
@ -55,8 +55,8 @@ limitations under the License.
|
|||||||
{% endif %}
|
{% endif %}
|
||||||
<div class="text-center mt-4">
|
<div class="text-center mt-4">
|
||||||
Already have an account?
|
Already have an account?
|
||||||
{# TODO: proper link #}
|
{% set params = next | safe_get(key="params") | to_params(prefix="?") %}
|
||||||
{{ button::link_text(text="Sign in instead", href=login_link) }}
|
{{ button::link_text(text="Sign in instead", href="/login" ~ params) }}
|
||||||
</div>
|
</div>
|
||||||
</form>
|
</form>
|
||||||
</section>
|
</section>
|
||||||
|
Reference in New Issue
Block a user