You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-11-20 12:02:22 +03:00
Refactor the upstream oauth session cookie
This commit is contained in:
@@ -16,7 +16,7 @@ use axum::{
|
||||
extract::{Path, State},
|
||||
response::{IntoResponse, Redirect},
|
||||
};
|
||||
use axum_extra::extract::{cookie::Cookie, PrivateCookieJar};
|
||||
use axum_extra::extract::PrivateCookieJar;
|
||||
use hyper::StatusCode;
|
||||
use mas_axum_utils::http_client_factory::HttpClientFactory;
|
||||
use mas_keystore::Encrypter;
|
||||
@@ -27,6 +27,7 @@ use sqlx::PgPool;
|
||||
use thiserror::Error;
|
||||
use ulid::Ulid;
|
||||
|
||||
use super::UpstreamSessionsCookie;
|
||||
use crate::impl_from_error_for_route;
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
@@ -107,17 +108,15 @@ pub(crate) async fn get(
|
||||
&mut rng,
|
||||
&clock,
|
||||
&provider,
|
||||
data.state,
|
||||
data.state.clone(),
|
||||
data.code_challenge_verifier,
|
||||
data.nonce,
|
||||
)
|
||||
.await?;
|
||||
|
||||
// TODO: handle that cookie somewhere else?
|
||||
let mut cookie = Cookie::new("upstream-oauth2-session-id", session.id.to_string());
|
||||
cookie.set_path("/");
|
||||
cookie.set_http_only(true);
|
||||
let cookie_jar = cookie_jar.add(cookie);
|
||||
let cookie_jar = UpstreamSessionsCookie::load(&cookie_jar)
|
||||
.add(session.id, provider.id, data.state)
|
||||
.save(cookie_jar, clock.now());
|
||||
|
||||
txn.commit().await?;
|
||||
|
||||
|
||||
@@ -35,7 +35,7 @@ use sqlx::PgPool;
|
||||
use thiserror::Error;
|
||||
use ulid::Ulid;
|
||||
|
||||
use super::client_credentials_for_provider;
|
||||
use super::{client_credentials_for_provider, UpstreamSessionsCookie};
|
||||
use crate::impl_from_error_for_route;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
@@ -89,9 +89,6 @@ pub(crate) enum RouteError {
|
||||
#[error("Missing session cookie")]
|
||||
MissingCookie,
|
||||
|
||||
#[error("Invalid session cookie")]
|
||||
InvalidCookie(#[source] ulid::DecodeError),
|
||||
|
||||
#[error(transparent)]
|
||||
InternalError(Box<dyn std::error::Error>),
|
||||
|
||||
@@ -107,6 +104,7 @@ impl_from_error_for_route!(mas_oidc_client::error::DiscoveryError);
|
||||
impl_from_error_for_route!(mas_oidc_client::error::JwksError);
|
||||
impl_from_error_for_route!(mas_oidc_client::error::TokenAuthorizationCodeError);
|
||||
impl_from_error_for_route!(super::ProviderCredentialsError);
|
||||
impl_from_error_for_route!(super::cookie::UpstreamSessionNotFound);
|
||||
|
||||
impl IntoResponse for RouteError {
|
||||
fn into_response(self) -> axum::response::Response {
|
||||
@@ -138,12 +136,10 @@ pub(crate) async fn get(
|
||||
|
||||
let mut txn = pool.begin().await?;
|
||||
|
||||
// XXX: that cookie should be managed elsewhere
|
||||
let cookie = cookie_jar
|
||||
.get("upstream-oauth2-session-id")
|
||||
.ok_or(RouteError::MissingCookie)?;
|
||||
|
||||
let session_id: Ulid = cookie.value().parse().map_err(RouteError::InvalidCookie)?;
|
||||
let sessions_cookie = UpstreamSessionsCookie::load(&cookie_jar);
|
||||
let session_id = sessions_cookie
|
||||
.find_session(provider_id, ¶ms.state)
|
||||
.map_err(|_| RouteError::MissingCookie)?;
|
||||
|
||||
let (provider, session) = lookup_session(&mut txn, session_id)
|
||||
.await
|
||||
@@ -256,9 +252,15 @@ pub(crate) async fn get(
|
||||
add_link(&mut txn, &mut rng, &clock, &provider, subject).await?
|
||||
};
|
||||
|
||||
let _session = complete_session(&mut txn, &clock, session, &link, response.id_token).await?;
|
||||
let session = complete_session(&mut txn, &clock, session, &link, response.id_token).await?;
|
||||
let cookie_jar = sessions_cookie
|
||||
.add_link_to_session(session.id, link.id)?
|
||||
.save(cookie_jar, clock.now());
|
||||
|
||||
txn.commit().await?;
|
||||
|
||||
Ok(mas_router::UpstreamOAuth2Link::new(link.id).go())
|
||||
Ok((
|
||||
cookie_jar,
|
||||
mas_router::UpstreamOAuth2Link::new(link.id).go(),
|
||||
))
|
||||
}
|
||||
|
||||
219
crates/handlers/src/upstream_oauth2/cookie.rs
Normal file
219
crates/handlers/src/upstream_oauth2/cookie.rs
Normal file
@@ -0,0 +1,219 @@
|
||||
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// TODO: move that to a standalone cookie manager
|
||||
|
||||
use axum_extra::extract::{cookie::Cookie, PrivateCookieJar};
|
||||
use chrono::{DateTime, Duration, NaiveDateTime, Utc};
|
||||
use mas_axum_utils::CookieExt;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use thiserror::Error;
|
||||
use time::OffsetDateTime;
|
||||
use ulid::Ulid;
|
||||
|
||||
/// Name of the cookie
|
||||
static COOKIE_NAME: &str = "upstream-oauth2-sessions";
|
||||
|
||||
/// Sessions expire after 10 minutes
|
||||
static SESSION_MAX_TIME_SECS: i64 = 60 * 10;
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct Payload {
|
||||
session: Ulid,
|
||||
provider: Ulid,
|
||||
state: String,
|
||||
link: Option<Ulid>,
|
||||
}
|
||||
|
||||
impl Payload {
|
||||
fn expired(&self, now: DateTime<Utc>) -> bool {
|
||||
let Ok(ts) = self.session.timestamp_ms().try_into() else { return true };
|
||||
let Some(when) = NaiveDateTime::from_timestamp_millis(ts) else { return true };
|
||||
let when = DateTime::from_utc(when, Utc);
|
||||
let max_age = Duration::seconds(SESSION_MAX_TIME_SECS);
|
||||
now - when > max_age
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Default)]
|
||||
pub struct UpstreamSessions(Vec<Payload>);
|
||||
|
||||
#[derive(Debug, Error, PartialEq, Eq)]
|
||||
#[error("upstream session not found")]
|
||||
pub struct UpstreamSessionNotFound;
|
||||
|
||||
impl UpstreamSessions {
|
||||
/// Load the upstreams sessions cookie
|
||||
pub fn load<K>(cookie_jar: &PrivateCookieJar<K>) -> Self {
|
||||
cookie_jar
|
||||
.get(COOKIE_NAME)
|
||||
.and_then(|c| c.decode().ok())
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
/// Save the upstreams sessions to the cookie jar
|
||||
pub fn save<K>(
|
||||
self,
|
||||
cookie_jar: PrivateCookieJar<K>,
|
||||
now: DateTime<Utc>,
|
||||
) -> PrivateCookieJar<K> {
|
||||
let this = self.expire(now);
|
||||
let mut cookie = Cookie::named(COOKIE_NAME).encode(&this);
|
||||
cookie.set_path("/");
|
||||
cookie.set_http_only(true);
|
||||
|
||||
let expiration = now + Duration::seconds(SESSION_MAX_TIME_SECS);
|
||||
let expiration = OffsetDateTime::from_unix_timestamp(expiration.timestamp())
|
||||
.expect("invalid unix timestamp");
|
||||
cookie.set_expires(expiration);
|
||||
|
||||
cookie_jar.add(cookie)
|
||||
}
|
||||
|
||||
fn expire(mut self, now: DateTime<Utc>) -> Self {
|
||||
self.0.retain(|p| !p.expired(now));
|
||||
self
|
||||
}
|
||||
|
||||
/// Add a new session, for a provider and a random state
|
||||
pub fn add(mut self, session: Ulid, provider: Ulid, state: String) -> Self {
|
||||
self.0.push(Payload {
|
||||
session,
|
||||
provider,
|
||||
state,
|
||||
link: None,
|
||||
});
|
||||
self
|
||||
}
|
||||
|
||||
// Find a session ID from the provider and the state
|
||||
pub fn find_session(
|
||||
&self,
|
||||
provider: Ulid,
|
||||
state: &str,
|
||||
) -> Result<Ulid, UpstreamSessionNotFound> {
|
||||
self.0
|
||||
.iter()
|
||||
.find(|p| p.provider == provider && p.state == state && p.link.is_none())
|
||||
.map(|p| p.session)
|
||||
.ok_or(UpstreamSessionNotFound)
|
||||
}
|
||||
|
||||
/// Save the link generated by a session
|
||||
pub fn add_link_to_session(
|
||||
mut self,
|
||||
session: Ulid,
|
||||
link: Ulid,
|
||||
) -> Result<Self, UpstreamSessionNotFound> {
|
||||
let payload = self
|
||||
.0
|
||||
.iter_mut()
|
||||
.find(|p| p.session == session && p.link.is_none())
|
||||
.ok_or(UpstreamSessionNotFound)?;
|
||||
|
||||
payload.link = Some(link);
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
/// Find a session from its link
|
||||
pub fn lookup_link(&self, link_id: Ulid) -> Result<Ulid, UpstreamSessionNotFound> {
|
||||
self.0
|
||||
.iter()
|
||||
.find(|p| p.link == Some(link_id))
|
||||
.map(|p| p.session)
|
||||
.ok_or(UpstreamSessionNotFound)
|
||||
}
|
||||
|
||||
/// Mark a link as consumed to avoid replay
|
||||
pub fn consume_link(mut self, link_id: Ulid) -> Result<Self, UpstreamSessionNotFound> {
|
||||
let pos = self
|
||||
.0
|
||||
.iter()
|
||||
.position(|p| p.link == Some(link_id))
|
||||
.ok_or(UpstreamSessionNotFound)?;
|
||||
|
||||
self.0.remove(pos);
|
||||
|
||||
Ok(self)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use chrono::TimeZone;
|
||||
use rand::SeedableRng;
|
||||
use rand_chacha::ChaChaRng;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_session_cookie() {
|
||||
let now = chrono::Utc
|
||||
.with_ymd_and_hms(2018, 1, 18, 1, 30, 22)
|
||||
.unwrap();
|
||||
let mut rng = ChaChaRng::seed_from_u64(42);
|
||||
|
||||
let sessions = UpstreamSessions::default();
|
||||
|
||||
let provider_a = Ulid::from_datetime_with_source(now.into(), &mut rng);
|
||||
let provider_b = Ulid::from_datetime_with_source(now.into(), &mut rng);
|
||||
|
||||
let first_session = Ulid::from_datetime_with_source(now.into(), &mut rng);
|
||||
let first_state = "first-state";
|
||||
let sessions = sessions.add(first_session, provider_a, first_state.into());
|
||||
|
||||
let now = now + Duration::minutes(5);
|
||||
|
||||
let second_session = Ulid::from_datetime_with_source(now.into(), &mut rng);
|
||||
let second_state = "second-state";
|
||||
let sessions = sessions.add(second_session, provider_b, second_state.into());
|
||||
|
||||
let sessions = sessions.expire(now);
|
||||
assert_eq!(
|
||||
sessions.find_session(provider_a, first_state),
|
||||
Ok(first_session)
|
||||
);
|
||||
assert_eq!(
|
||||
sessions.find_session(provider_b, second_state),
|
||||
Ok(second_session)
|
||||
);
|
||||
assert!(sessions.find_session(provider_b, first_state).is_err());
|
||||
assert!(sessions.find_session(provider_a, second_state).is_err());
|
||||
|
||||
// Make the first session expire
|
||||
let now = now + Duration::minutes(6);
|
||||
let sessions = sessions.expire(now);
|
||||
assert!(sessions.find_session(provider_a, first_state).is_err());
|
||||
assert_eq!(
|
||||
sessions.find_session(provider_b, second_state),
|
||||
Ok(second_session)
|
||||
);
|
||||
|
||||
// Associate a link with the second
|
||||
let second_link = Ulid::from_datetime_with_source(now.into(), &mut rng);
|
||||
let sessions = sessions
|
||||
.add_link_to_session(second_session, second_link)
|
||||
.unwrap();
|
||||
|
||||
// Now the session can't be found with its state
|
||||
assert!(sessions.find_session(provider_b, second_state).is_err());
|
||||
|
||||
// But it can be looked up by its link
|
||||
assert_eq!(sessions.lookup_link(second_link), Ok(second_session));
|
||||
// And it can be consumed
|
||||
let sessions = sessions.consume_link(second_link).unwrap();
|
||||
// But only once
|
||||
assert!(sessions.consume_link(second_link).is_err());
|
||||
}
|
||||
}
|
||||
@@ -43,6 +43,7 @@ use sqlx::PgPool;
|
||||
use thiserror::Error;
|
||||
use ulid::Ulid;
|
||||
|
||||
use super::UpstreamSessionsCookie;
|
||||
use crate::impl_from_error_for_route;
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
@@ -62,9 +63,6 @@ pub(crate) enum RouteError {
|
||||
#[error("Missing session cookie")]
|
||||
MissingCookie,
|
||||
|
||||
#[error("Invalid session cookie")]
|
||||
InvalidCookie(#[source] ulid::DecodeError),
|
||||
|
||||
#[error("Invalid form action")]
|
||||
InvalidFormAction,
|
||||
|
||||
@@ -81,6 +79,7 @@ impl_from_error_for_route!(mas_storage::GenericLookupError);
|
||||
impl_from_error_for_route!(mas_storage::user::ActiveSessionLookupError);
|
||||
impl_from_error_for_route!(mas_storage::user::UserLookupError);
|
||||
impl_from_error_for_route!(mas_axum_utils::csrf::CsrfError);
|
||||
impl_from_error_for_route!(super::cookie::UpstreamSessionNotFound);
|
||||
|
||||
impl IntoResponse for RouteError {
|
||||
fn into_response(self) -> axum::response::Response {
|
||||
@@ -114,18 +113,16 @@ pub(crate) async fn get(
|
||||
let mut txn = pool.begin().await?;
|
||||
let (clock, mut rng) = crate::rng_and_clock()?;
|
||||
|
||||
let sessions_cookie = UpstreamSessionsCookie::load(&cookie_jar);
|
||||
let session_id = sessions_cookie
|
||||
.lookup_link(link_id)
|
||||
.map_err(|_| RouteError::MissingCookie)?;
|
||||
|
||||
let link = lookup_link(&mut txn, link_id)
|
||||
.await
|
||||
.to_option()?
|
||||
.ok_or(RouteError::LinkNotFound)?;
|
||||
|
||||
// XXX: that cookie should be managed elsewhere
|
||||
let cookie = cookie_jar
|
||||
.get("upstream-oauth2-session-id")
|
||||
.ok_or(RouteError::MissingCookie)?;
|
||||
|
||||
let session_id: Ulid = cookie.value().parse().map_err(RouteError::InvalidCookie)?;
|
||||
|
||||
// This checks that we're in a browser session which is allowed to consume this
|
||||
// link: the upstream auth session should have been started in this browser.
|
||||
let upstream_session = lookup_session_on_link(&mut txn, &link, session_id)
|
||||
@@ -215,18 +212,16 @@ pub(crate) async fn post(
|
||||
let (clock, mut rng) = crate::rng_and_clock()?;
|
||||
let form = cookie_jar.verify_form(clock.now(), form)?;
|
||||
|
||||
let sessions_cookie = UpstreamSessionsCookie::load(&cookie_jar);
|
||||
let session_id = sessions_cookie
|
||||
.lookup_link(link_id)
|
||||
.map_err(|_| RouteError::MissingCookie)?;
|
||||
|
||||
let link = lookup_link(&mut txn, link_id)
|
||||
.await
|
||||
.to_option()?
|
||||
.ok_or(RouteError::LinkNotFound)?;
|
||||
|
||||
// XXX: that cookie should be managed elsewhere
|
||||
let cookie = cookie_jar
|
||||
.get("upstream-oauth2-session-id")
|
||||
.ok_or(RouteError::MissingCookie)?;
|
||||
|
||||
let session_id: Ulid = cookie.value().parse().map_err(RouteError::InvalidCookie)?;
|
||||
|
||||
// This checks that we're in a browser session which is allowed to consume this
|
||||
// link: the upstream auth session should have been started in this browser.
|
||||
let upstream_session = lookup_session_on_link(&mut txn, &link, session_id)
|
||||
@@ -265,6 +260,9 @@ pub(crate) async fn post(
|
||||
consume_session(&mut txn, &clock, upstream_session).await?;
|
||||
authenticate_session_with_upstream(&mut txn, &mut rng, &clock, &mut session, &link).await?;
|
||||
|
||||
let cookie_jar = sessions_cookie
|
||||
.consume_link(link_id)?
|
||||
.save(cookie_jar, clock.now());
|
||||
let cookie_jar = cookie_jar.set_session(&session);
|
||||
|
||||
txn.commit().await?;
|
||||
|
||||
@@ -22,8 +22,11 @@ use url::Url;
|
||||
|
||||
pub(crate) mod authorize;
|
||||
pub(crate) mod callback;
|
||||
mod cookie;
|
||||
pub(crate) mod link;
|
||||
|
||||
use self::cookie::UpstreamSessions as UpstreamSessionsCookie;
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
enum ProviderCredentialsError {
|
||||
#[error("Provider doesn't have a client secret")]
|
||||
|
||||
Reference in New Issue
Block a user