You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-11-20 12:02:22 +03:00
Save the post auth action during upstream OAuth login
This commit is contained in:
@@ -13,7 +13,7 @@
|
||||
// limitations under the License.
|
||||
|
||||
use axum::{
|
||||
extract::{Path, State},
|
||||
extract::{Path, Query, State},
|
||||
response::{IntoResponse, Redirect},
|
||||
};
|
||||
use axum_extra::extract::PrivateCookieJar;
|
||||
@@ -28,7 +28,7 @@ use thiserror::Error;
|
||||
use ulid::Ulid;
|
||||
|
||||
use super::UpstreamSessionsCookie;
|
||||
use crate::impl_from_error_for_route;
|
||||
use crate::{impl_from_error_for_route, views::shared::OptionalPostAuthAction};
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub(crate) enum RouteError {
|
||||
@@ -68,6 +68,7 @@ pub(crate) async fn get(
|
||||
State(url_builder): State<UrlBuilder>,
|
||||
cookie_jar: PrivateCookieJar<Encrypter>,
|
||||
Path(provider_id): Path<Ulid>,
|
||||
Query(query): Query<OptionalPostAuthAction>,
|
||||
) -> Result<impl IntoResponse, RouteError> {
|
||||
let (clock, mut rng) = crate::rng_and_clock()?;
|
||||
|
||||
@@ -115,7 +116,7 @@ pub(crate) async fn get(
|
||||
.await?;
|
||||
|
||||
let cookie_jar = UpstreamSessionsCookie::load(&cookie_jar)
|
||||
.add(session.id, provider.id, data.state)
|
||||
.add(session.id, provider.id, data.state, query.post_auth_action)
|
||||
.save(cookie_jar, clock.now());
|
||||
|
||||
txn.commit().await?;
|
||||
|
||||
@@ -137,7 +137,7 @@ pub(crate) async fn get(
|
||||
let mut txn = pool.begin().await?;
|
||||
|
||||
let sessions_cookie = UpstreamSessionsCookie::load(&cookie_jar);
|
||||
let session_id = sessions_cookie
|
||||
let (session_id, _post_auth_action) = sessions_cookie
|
||||
.find_session(provider_id, ¶ms.state)
|
||||
.map_err(|_| RouteError::MissingCookie)?;
|
||||
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
use axum_extra::extract::{cookie::Cookie, PrivateCookieJar};
|
||||
use chrono::{DateTime, Duration, NaiveDateTime, Utc};
|
||||
use mas_axum_utils::CookieExt;
|
||||
use mas_router::PostAuthAction;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use thiserror::Error;
|
||||
use time::OffsetDateTime;
|
||||
@@ -28,12 +29,13 @@ static COOKIE_NAME: &str = "upstream-oauth2-sessions";
|
||||
/// Sessions expire after 10 minutes
|
||||
static SESSION_MAX_TIME_SECS: i64 = 60 * 10;
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct Payload {
|
||||
session: Ulid,
|
||||
provider: Ulid,
|
||||
state: String,
|
||||
link: Option<Ulid>,
|
||||
post_auth_action: Option<PostAuthAction>,
|
||||
}
|
||||
|
||||
impl Payload {
|
||||
@@ -46,7 +48,7 @@ impl Payload {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Default)]
|
||||
#[derive(Serialize, Deserialize, Default, Debug)]
|
||||
pub struct UpstreamSessions(Vec<Payload>);
|
||||
|
||||
#[derive(Debug, Error, PartialEq, Eq)]
|
||||
@@ -87,12 +89,19 @@ impl UpstreamSessions {
|
||||
}
|
||||
|
||||
/// Add a new session, for a provider and a random state
|
||||
pub fn add(mut self, session: Ulid, provider: Ulid, state: String) -> Self {
|
||||
pub fn add(
|
||||
mut self,
|
||||
session: Ulid,
|
||||
provider: Ulid,
|
||||
state: String,
|
||||
post_auth_action: Option<PostAuthAction>,
|
||||
) -> Self {
|
||||
self.0.push(Payload {
|
||||
session,
|
||||
provider,
|
||||
state,
|
||||
link: None,
|
||||
post_auth_action,
|
||||
});
|
||||
self
|
||||
}
|
||||
@@ -102,11 +111,11 @@ impl UpstreamSessions {
|
||||
&self,
|
||||
provider: Ulid,
|
||||
state: &str,
|
||||
) -> Result<Ulid, UpstreamSessionNotFound> {
|
||||
) -> Result<(Ulid, Option<&PostAuthAction>), UpstreamSessionNotFound> {
|
||||
self.0
|
||||
.iter()
|
||||
.find(|p| p.provider == provider && p.state == state && p.link.is_none())
|
||||
.map(|p| p.session)
|
||||
.map(|p| (p.session, p.post_auth_action.as_ref()))
|
||||
.ok_or(UpstreamSessionNotFound)
|
||||
}
|
||||
|
||||
@@ -127,11 +136,14 @@ impl UpstreamSessions {
|
||||
}
|
||||
|
||||
/// Find a session from its link
|
||||
pub fn lookup_link(&self, link_id: Ulid) -> Result<Ulid, UpstreamSessionNotFound> {
|
||||
pub fn lookup_link(
|
||||
&self,
|
||||
link_id: Ulid,
|
||||
) -> Result<(Ulid, Option<&PostAuthAction>), UpstreamSessionNotFound> {
|
||||
self.0
|
||||
.iter()
|
||||
.find(|p| p.link == Some(link_id))
|
||||
.map(|p| p.session)
|
||||
.map(|p| (p.session, p.post_auth_action.as_ref()))
|
||||
.ok_or(UpstreamSessionNotFound)
|
||||
}
|
||||
|
||||
@@ -171,22 +183,22 @@ mod tests {
|
||||
|
||||
let first_session = Ulid::from_datetime_with_source(now.into(), &mut rng);
|
||||
let first_state = "first-state";
|
||||
let sessions = sessions.add(first_session, provider_a, first_state.into());
|
||||
let sessions = sessions.add(first_session, provider_a, first_state.into(), None);
|
||||
|
||||
let now = now + Duration::minutes(5);
|
||||
|
||||
let second_session = Ulid::from_datetime_with_source(now.into(), &mut rng);
|
||||
let second_state = "second-state";
|
||||
let sessions = sessions.add(second_session, provider_b, second_state.into());
|
||||
let sessions = sessions.add(second_session, provider_b, second_state.into(), None);
|
||||
|
||||
let sessions = sessions.expire(now);
|
||||
assert_eq!(
|
||||
sessions.find_session(provider_a, first_state),
|
||||
Ok(first_session)
|
||||
sessions.find_session(provider_a, first_state).unwrap().0,
|
||||
first_session,
|
||||
);
|
||||
assert_eq!(
|
||||
sessions.find_session(provider_b, second_state),
|
||||
Ok(second_session)
|
||||
sessions.find_session(provider_b, second_state).unwrap().0,
|
||||
second_session
|
||||
);
|
||||
assert!(sessions.find_session(provider_b, first_state).is_err());
|
||||
assert!(sessions.find_session(provider_a, second_state).is_err());
|
||||
@@ -196,8 +208,8 @@ mod tests {
|
||||
let sessions = sessions.expire(now);
|
||||
assert!(sessions.find_session(provider_a, first_state).is_err());
|
||||
assert_eq!(
|
||||
sessions.find_session(provider_b, second_state),
|
||||
Ok(second_session)
|
||||
sessions.find_session(provider_b, second_state).unwrap().0,
|
||||
second_session
|
||||
);
|
||||
|
||||
// Associate a link with the second
|
||||
@@ -210,7 +222,7 @@ mod tests {
|
||||
assert!(sessions.find_session(provider_b, second_state).is_err());
|
||||
|
||||
// But it can be looked up by its link
|
||||
assert_eq!(sessions.lookup_link(second_link), Ok(second_session));
|
||||
assert_eq!(sessions.lookup_link(second_link).unwrap().0, second_session);
|
||||
// And it can be consumed
|
||||
let sessions = sessions.consume_link(second_link).unwrap();
|
||||
// But only once
|
||||
|
||||
@@ -24,7 +24,6 @@ use mas_axum_utils::{
|
||||
SessionInfoExt,
|
||||
};
|
||||
use mas_keystore::Encrypter;
|
||||
use mas_router::Route;
|
||||
use mas_storage::{
|
||||
upstream_oauth2::{
|
||||
associate_link_to_user, consume_session, lookup_link, lookup_session_on_link,
|
||||
@@ -44,7 +43,7 @@ use thiserror::Error;
|
||||
use ulid::Ulid;
|
||||
|
||||
use super::UpstreamSessionsCookie;
|
||||
use crate::impl_from_error_for_route;
|
||||
use crate::{impl_from_error_for_route, views::shared::OptionalPostAuthAction};
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub(crate) enum RouteError {
|
||||
@@ -114,7 +113,7 @@ pub(crate) async fn get(
|
||||
let (clock, mut rng) = crate::rng_and_clock()?;
|
||||
|
||||
let sessions_cookie = UpstreamSessionsCookie::load(&cookie_jar);
|
||||
let session_id = sessions_cookie
|
||||
let (session_id, _post_auth_action) = sessions_cookie
|
||||
.lookup_link(link_id)
|
||||
.map_err(|_| RouteError::MissingCookie)?;
|
||||
|
||||
@@ -213,10 +212,14 @@ pub(crate) async fn post(
|
||||
let form = cookie_jar.verify_form(clock.now(), form)?;
|
||||
|
||||
let sessions_cookie = UpstreamSessionsCookie::load(&cookie_jar);
|
||||
let session_id = sessions_cookie
|
||||
let (session_id, post_auth_action) = sessions_cookie
|
||||
.lookup_link(link_id)
|
||||
.map_err(|_| RouteError::MissingCookie)?;
|
||||
|
||||
let post_auth_action = OptionalPostAuthAction {
|
||||
post_auth_action: post_auth_action.cloned(),
|
||||
};
|
||||
|
||||
let link = lookup_link(&mut txn, link_id)
|
||||
.await
|
||||
.to_option()?
|
||||
@@ -267,5 +270,5 @@ pub(crate) async fn post(
|
||||
|
||||
txn.commit().await?;
|
||||
|
||||
Ok((cookie_jar, mas_router::Index.go()))
|
||||
Ok((cookie_jar, post_auth_action.go_next()))
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user