1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-08-09 04:22:45 +03:00

Handle cookies better by setting the right flags & expiration

This commit is contained in:
Quentin Gliech
2023-08-24 17:38:33 +02:00
parent 2405a3c061
commit a39f71c181
31 changed files with 242 additions and 167 deletions

View File

@@ -14,8 +14,18 @@
//! Private (encrypted) cookie jar, based on axum-extra's cookie jar
use std::convert::Infallible;
use async_trait::async_trait;
use axum::{
extract::{FromRef, FromRequestParts},
response::{IntoResponseParts, ResponseParts},
};
use axum_extra::extract::cookie::{Cookie, Key, PrivateCookieJar, SameSite};
use http::request::Parts;
use serde::{de::DeserializeOwned, Serialize};
use thiserror::Error;
use url::Url;
#[derive(Debug, Error)]
#[error("could not decode cookie")]
@@ -23,32 +33,113 @@ pub enum CookieDecodeError {
Deserialize(#[from] serde_json::Error),
}
pub trait CookieExt {
fn decode<T>(&self) -> Result<T, CookieDecodeError>
where
T: DeserializeOwned;
/// Manages cookie options and encryption key
///
/// This is meant to be accessible through axum's state via the [`FromRef`]
/// trait
#[derive(Clone)]
pub struct CookieManager {
options: CookieOption,
key: Key,
}
impl CookieManager {
#[must_use]
pub const fn new(base_url: Url, key: Key) -> Self {
let options = CookieOption::new(base_url);
Self { options, key }
}
#[must_use]
fn encode<T>(self, t: &T) -> Self
where
T: Serialize;
pub fn derive_from(base_url: Url, key: &[u8]) -> Self {
let key = Key::derive_from(key);
Self::new(base_url, key)
}
}
impl<'a> CookieExt for axum_extra::extract::cookie::Cookie<'a> {
fn decode<T>(&self) -> Result<T, CookieDecodeError>
where
T: DeserializeOwned,
{
let decoded = serde_json::from_str(self.value())?;
Ok(decoded)
#[async_trait]
impl<S> FromRequestParts<S> for CookieJar
where
CookieManager: FromRef<S>,
S: Send + Sync,
{
type Rejection = Infallible;
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
let cookie_manager = CookieManager::from_ref(state);
let inner = PrivateCookieJar::from_headers(&parts.headers, cookie_manager.key.clone());
let options = cookie_manager.options.clone();
Ok(CookieJar { inner, options })
}
}
#[derive(Debug, Clone)]
struct CookieOption {
base_url: Url,
}
impl CookieOption {
const fn new(base_url: Url) -> Self {
Self { base_url }
}
fn encode<T>(mut self, t: &T) -> Self
where
T: Serialize,
{
let encoded = serde_json::to_string(t).unwrap();
self.set_value(encoded);
fn secure(&self) -> bool {
self.base_url.scheme() == "https"
}
fn path(&self) -> &str {
self.base_url.path()
}
fn apply<'a>(&self, mut cookie: Cookie<'a>) -> Cookie<'a> {
cookie.set_http_only(true);
cookie.set_secure(self.secure());
cookie.set_path(self.path().to_owned());
cookie.set_same_site(SameSite::Lax);
cookie
}
}
/// A cookie jar which encrypts cookies & sets secure options
pub struct CookieJar {
inner: PrivateCookieJar<Key>,
options: CookieOption,
}
impl CookieJar {
#[must_use]
pub fn save<T: Serialize>(mut self, key: &str, payload: &T, permanent: bool) -> Self {
let serialized =
serde_json::to_string(payload).expect("failed to serialize cookie payload");
let cookie = Cookie::new(key.to_owned(), serialized);
let mut cookie = self.options.apply(cookie);
if permanent {
// XXX: this should use a clock
cookie.make_permanent();
}
self.inner = self.inner.add(cookie);
self
}
pub fn load<T: DeserializeOwned>(&self, key: &str) -> Result<Option<T>, CookieDecodeError> {
let Some(cookie) = self.inner.get(key) else {
return Ok(None);
};
let decoded = serde_json::from_str(cookie.value())?;
Ok(Some(decoded))
}
}
impl IntoResponseParts for CookieJar {
type Error = Infallible;
fn into_response_parts(self, res: ResponseParts) -> Result<ResponseParts, Self::Error> {
self.inner.into_response_parts(res)
}
}

View File

@@ -12,7 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use axum_extra::extract::cookie::{Cookie, PrivateCookieJar};
use chrono::{DateTime, Duration, Utc};
use data_encoding::{DecodeError, BASE64URL_NOPAD};
use mas_storage::Clock;
@@ -21,7 +20,7 @@ use serde::{Deserialize, Serialize};
use serde_with::{serde_as, TimestampSeconds};
use thiserror::Error;
use crate::{cookies::CookieDecodeError, CookieExt};
use crate::cookies::{CookieDecodeError, CookieJar};
/// Failed to validate CSRF token
#[derive(Debug, Error)]
@@ -118,36 +117,41 @@ pub trait CsrfExt {
C: Clock;
}
impl<K> CsrfExt for PrivateCookieJar<K> {
impl CsrfExt for CookieJar {
fn csrf_token<C, R>(self, clock: &C, rng: R) -> (CsrfToken, Self)
where
R: RngCore,
C: Clock,
{
let jar = self;
let mut cookie = jar.get("csrf").unwrap_or_else(|| Cookie::new("csrf", ""));
cookie.set_path("/");
cookie.set_http_only(true);
let now = clock.now();
let new_token = cookie
.decode()
.ok()
.and_then(|token: CsrfToken| token.verify_expiration(now).ok())
.unwrap_or_else(|| CsrfToken::generate(now, rng, Duration::hours(1)))
.refresh(now, Duration::hours(1));
let maybe_token = match self.load::<CsrfToken>("csrf") {
Ok(Some(token)) => {
let token = token.verify_expiration(now);
let cookie = cookie.encode(&new_token);
let jar = jar.add(cookie);
(new_token, jar)
// If the token is expired, just ignore it
token.ok()
}
Ok(None) => None,
Err(e) => {
tracing::warn!("Failed to decode CSRF cookie: {}", e);
None
}
};
let token = maybe_token.map_or_else(
|| CsrfToken::generate(now, rng, Duration::hours(1)),
|token| token.refresh(now, Duration::hours(1)),
);
let jar = self.save("csrf", &token, false);
(token, jar)
}
fn verify_form<C, T>(&self, clock: &C, form: ProtectedForm<T>) -> Result<T, CsrfError>
where
C: Clock,
{
let cookie = self.get("csrf").ok_or(CsrfError::Missing)?;
let token: CsrfToken = cookie.decode()?;
let token: CsrfToken = self.load("csrf")?.ok_or(CsrfError::Missing)?;
let token = token.verify_expiration(clock.now())?;
token.verify_form_value(&form.csrf)?;
Ok(form.inner)

View File

@@ -34,7 +34,6 @@ pub mod user_authorization;
pub use axum;
pub use self::{
cookies::CookieExt,
fancy_error::FancyError,
session::{SessionInfo, SessionInfoExt},
};

View File

@@ -12,13 +12,12 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use axum_extra::extract::cookie::{Cookie, PrivateCookieJar};
use mas_data_model::BrowserSession;
use mas_storage::{user::BrowserSessionRepository, RepositoryAccess};
use serde::{Deserialize, Serialize};
use ulid::Ulid;
use crate::CookieExt;
use crate::cookies::CookieJar;
/// An encrypted cookie to save the session ID
#[derive(Serialize, Deserialize, Debug, Default, Clone)]
@@ -79,26 +78,22 @@ pub trait SessionInfoExt {
}
}
impl<K> SessionInfoExt for PrivateCookieJar<K> {
impl SessionInfoExt for CookieJar {
fn session_info(self) -> (SessionInfo, Self) {
let jar = self;
let mut cookie = jar
.get("session")
.unwrap_or_else(|| Cookie::new("session", ""));
cookie.set_path("/");
cookie.set_http_only(true);
let session_info = cookie.decode().unwrap_or_default();
let info = match self.load("session") {
Ok(Some(s)) => s,
Ok(None) => SessionInfo::default(),
Err(e) => {
tracing::error!("failed to load session cookie: {}", e);
SessionInfo::default()
}
};
let cookie = cookie.encode(&session_info);
let jar = jar.add(cookie);
(session_info, jar)
let jar = self.update_session_info(&info);
(info, jar)
}
fn update_session_info(self, info: &SessionInfo) -> Self {
let mut cookie = Cookie::new("session", "");
cookie.set_path("/");
cookie.set_http_only(true);
let cookie = cookie.encode(&info);
self.add(cookie)
self.save("session", info, true)
}
}