You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-08-09 04:22:45 +03:00
Handle cookies better by setting the right flags & expiration
This commit is contained in:
@@ -14,8 +14,18 @@
|
||||
|
||||
//! Private (encrypted) cookie jar, based on axum-extra's cookie jar
|
||||
|
||||
use std::convert::Infallible;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use axum::{
|
||||
extract::{FromRef, FromRequestParts},
|
||||
response::{IntoResponseParts, ResponseParts},
|
||||
};
|
||||
use axum_extra::extract::cookie::{Cookie, Key, PrivateCookieJar, SameSite};
|
||||
use http::request::Parts;
|
||||
use serde::{de::DeserializeOwned, Serialize};
|
||||
use thiserror::Error;
|
||||
use url::Url;
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
#[error("could not decode cookie")]
|
||||
@@ -23,32 +33,113 @@ pub enum CookieDecodeError {
|
||||
Deserialize(#[from] serde_json::Error),
|
||||
}
|
||||
|
||||
pub trait CookieExt {
|
||||
fn decode<T>(&self) -> Result<T, CookieDecodeError>
|
||||
where
|
||||
T: DeserializeOwned;
|
||||
/// Manages cookie options and encryption key
|
||||
///
|
||||
/// This is meant to be accessible through axum's state via the [`FromRef`]
|
||||
/// trait
|
||||
#[derive(Clone)]
|
||||
pub struct CookieManager {
|
||||
options: CookieOption,
|
||||
key: Key,
|
||||
}
|
||||
|
||||
impl CookieManager {
|
||||
#[must_use]
|
||||
pub const fn new(base_url: Url, key: Key) -> Self {
|
||||
let options = CookieOption::new(base_url);
|
||||
Self { options, key }
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
fn encode<T>(self, t: &T) -> Self
|
||||
where
|
||||
T: Serialize;
|
||||
pub fn derive_from(base_url: Url, key: &[u8]) -> Self {
|
||||
let key = Key::derive_from(key);
|
||||
Self::new(base_url, key)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> CookieExt for axum_extra::extract::cookie::Cookie<'a> {
|
||||
fn decode<T>(&self) -> Result<T, CookieDecodeError>
|
||||
where
|
||||
T: DeserializeOwned,
|
||||
{
|
||||
let decoded = serde_json::from_str(self.value())?;
|
||||
Ok(decoded)
|
||||
#[async_trait]
|
||||
impl<S> FromRequestParts<S> for CookieJar
|
||||
where
|
||||
CookieManager: FromRef<S>,
|
||||
S: Send + Sync,
|
||||
{
|
||||
type Rejection = Infallible;
|
||||
|
||||
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
|
||||
let cookie_manager = CookieManager::from_ref(state);
|
||||
let inner = PrivateCookieJar::from_headers(&parts.headers, cookie_manager.key.clone());
|
||||
let options = cookie_manager.options.clone();
|
||||
|
||||
Ok(CookieJar { inner, options })
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct CookieOption {
|
||||
base_url: Url,
|
||||
}
|
||||
|
||||
impl CookieOption {
|
||||
const fn new(base_url: Url) -> Self {
|
||||
Self { base_url }
|
||||
}
|
||||
|
||||
fn encode<T>(mut self, t: &T) -> Self
|
||||
where
|
||||
T: Serialize,
|
||||
{
|
||||
let encoded = serde_json::to_string(t).unwrap();
|
||||
self.set_value(encoded);
|
||||
fn secure(&self) -> bool {
|
||||
self.base_url.scheme() == "https"
|
||||
}
|
||||
|
||||
fn path(&self) -> &str {
|
||||
self.base_url.path()
|
||||
}
|
||||
|
||||
fn apply<'a>(&self, mut cookie: Cookie<'a>) -> Cookie<'a> {
|
||||
cookie.set_http_only(true);
|
||||
cookie.set_secure(self.secure());
|
||||
cookie.set_path(self.path().to_owned());
|
||||
cookie.set_same_site(SameSite::Lax);
|
||||
cookie
|
||||
}
|
||||
}
|
||||
|
||||
/// A cookie jar which encrypts cookies & sets secure options
|
||||
pub struct CookieJar {
|
||||
inner: PrivateCookieJar<Key>,
|
||||
options: CookieOption,
|
||||
}
|
||||
|
||||
impl CookieJar {
|
||||
#[must_use]
|
||||
pub fn save<T: Serialize>(mut self, key: &str, payload: &T, permanent: bool) -> Self {
|
||||
let serialized =
|
||||
serde_json::to_string(payload).expect("failed to serialize cookie payload");
|
||||
|
||||
let cookie = Cookie::new(key.to_owned(), serialized);
|
||||
let mut cookie = self.options.apply(cookie);
|
||||
|
||||
if permanent {
|
||||
// XXX: this should use a clock
|
||||
cookie.make_permanent();
|
||||
}
|
||||
|
||||
self.inner = self.inner.add(cookie);
|
||||
|
||||
self
|
||||
}
|
||||
|
||||
pub fn load<T: DeserializeOwned>(&self, key: &str) -> Result<Option<T>, CookieDecodeError> {
|
||||
let Some(cookie) = self.inner.get(key) else {
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
let decoded = serde_json::from_str(cookie.value())?;
|
||||
Ok(Some(decoded))
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoResponseParts for CookieJar {
|
||||
type Error = Infallible;
|
||||
|
||||
fn into_response_parts(self, res: ResponseParts) -> Result<ResponseParts, Self::Error> {
|
||||
self.inner.into_response_parts(res)
|
||||
}
|
||||
}
|
||||
|
@@ -12,7 +12,6 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use axum_extra::extract::cookie::{Cookie, PrivateCookieJar};
|
||||
use chrono::{DateTime, Duration, Utc};
|
||||
use data_encoding::{DecodeError, BASE64URL_NOPAD};
|
||||
use mas_storage::Clock;
|
||||
@@ -21,7 +20,7 @@ use serde::{Deserialize, Serialize};
|
||||
use serde_with::{serde_as, TimestampSeconds};
|
||||
use thiserror::Error;
|
||||
|
||||
use crate::{cookies::CookieDecodeError, CookieExt};
|
||||
use crate::cookies::{CookieDecodeError, CookieJar};
|
||||
|
||||
/// Failed to validate CSRF token
|
||||
#[derive(Debug, Error)]
|
||||
@@ -118,36 +117,41 @@ pub trait CsrfExt {
|
||||
C: Clock;
|
||||
}
|
||||
|
||||
impl<K> CsrfExt for PrivateCookieJar<K> {
|
||||
impl CsrfExt for CookieJar {
|
||||
fn csrf_token<C, R>(self, clock: &C, rng: R) -> (CsrfToken, Self)
|
||||
where
|
||||
R: RngCore,
|
||||
C: Clock,
|
||||
{
|
||||
let jar = self;
|
||||
let mut cookie = jar.get("csrf").unwrap_or_else(|| Cookie::new("csrf", ""));
|
||||
cookie.set_path("/");
|
||||
cookie.set_http_only(true);
|
||||
|
||||
let now = clock.now();
|
||||
let new_token = cookie
|
||||
.decode()
|
||||
.ok()
|
||||
.and_then(|token: CsrfToken| token.verify_expiration(now).ok())
|
||||
.unwrap_or_else(|| CsrfToken::generate(now, rng, Duration::hours(1)))
|
||||
.refresh(now, Duration::hours(1));
|
||||
let maybe_token = match self.load::<CsrfToken>("csrf") {
|
||||
Ok(Some(token)) => {
|
||||
let token = token.verify_expiration(now);
|
||||
|
||||
let cookie = cookie.encode(&new_token);
|
||||
let jar = jar.add(cookie);
|
||||
(new_token, jar)
|
||||
// If the token is expired, just ignore it
|
||||
token.ok()
|
||||
}
|
||||
Ok(None) => None,
|
||||
Err(e) => {
|
||||
tracing::warn!("Failed to decode CSRF cookie: {}", e);
|
||||
None
|
||||
}
|
||||
};
|
||||
|
||||
let token = maybe_token.map_or_else(
|
||||
|| CsrfToken::generate(now, rng, Duration::hours(1)),
|
||||
|token| token.refresh(now, Duration::hours(1)),
|
||||
);
|
||||
|
||||
let jar = self.save("csrf", &token, false);
|
||||
(token, jar)
|
||||
}
|
||||
|
||||
fn verify_form<C, T>(&self, clock: &C, form: ProtectedForm<T>) -> Result<T, CsrfError>
|
||||
where
|
||||
C: Clock,
|
||||
{
|
||||
let cookie = self.get("csrf").ok_or(CsrfError::Missing)?;
|
||||
let token: CsrfToken = cookie.decode()?;
|
||||
let token: CsrfToken = self.load("csrf")?.ok_or(CsrfError::Missing)?;
|
||||
let token = token.verify_expiration(clock.now())?;
|
||||
token.verify_form_value(&form.csrf)?;
|
||||
Ok(form.inner)
|
||||
|
@@ -34,7 +34,6 @@ pub mod user_authorization;
|
||||
pub use axum;
|
||||
|
||||
pub use self::{
|
||||
cookies::CookieExt,
|
||||
fancy_error::FancyError,
|
||||
session::{SessionInfo, SessionInfoExt},
|
||||
};
|
||||
|
@@ -12,13 +12,12 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use axum_extra::extract::cookie::{Cookie, PrivateCookieJar};
|
||||
use mas_data_model::BrowserSession;
|
||||
use mas_storage::{user::BrowserSessionRepository, RepositoryAccess};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use ulid::Ulid;
|
||||
|
||||
use crate::CookieExt;
|
||||
use crate::cookies::CookieJar;
|
||||
|
||||
/// An encrypted cookie to save the session ID
|
||||
#[derive(Serialize, Deserialize, Debug, Default, Clone)]
|
||||
@@ -79,26 +78,22 @@ pub trait SessionInfoExt {
|
||||
}
|
||||
}
|
||||
|
||||
impl<K> SessionInfoExt for PrivateCookieJar<K> {
|
||||
impl SessionInfoExt for CookieJar {
|
||||
fn session_info(self) -> (SessionInfo, Self) {
|
||||
let jar = self;
|
||||
let mut cookie = jar
|
||||
.get("session")
|
||||
.unwrap_or_else(|| Cookie::new("session", ""));
|
||||
cookie.set_path("/");
|
||||
cookie.set_http_only(true);
|
||||
let session_info = cookie.decode().unwrap_or_default();
|
||||
let info = match self.load("session") {
|
||||
Ok(Some(s)) => s,
|
||||
Ok(None) => SessionInfo::default(),
|
||||
Err(e) => {
|
||||
tracing::error!("failed to load session cookie: {}", e);
|
||||
SessionInfo::default()
|
||||
}
|
||||
};
|
||||
|
||||
let cookie = cookie.encode(&session_info);
|
||||
let jar = jar.add(cookie);
|
||||
(session_info, jar)
|
||||
let jar = self.update_session_info(&info);
|
||||
(info, jar)
|
||||
}
|
||||
|
||||
fn update_session_info(self, info: &SessionInfo) -> Self {
|
||||
let mut cookie = Cookie::new("session", "");
|
||||
cookie.set_path("/");
|
||||
cookie.set_http_only(true);
|
||||
let cookie = cookie.encode(&info);
|
||||
self.add(cookie)
|
||||
self.save("session", info, true)
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user