diff --git a/crates/handlers/src/compat/mod.rs b/crates/handlers/src/compat/mod.rs index 172838b7..59b4d25a 100644 --- a/crates/handlers/src/compat/mod.rs +++ b/crates/handlers/src/compat/mod.rs @@ -18,6 +18,7 @@ use serde::Serialize; pub(crate) mod login; pub(crate) mod logout; +pub(crate) mod refresh; #[derive(Debug, Serialize)] struct MatrixError { diff --git a/crates/handlers/src/compat/refresh.rs b/crates/handlers/src/compat/refresh.rs new file mode 100644 index 00000000..085a2a50 --- /dev/null +++ b/crates/handlers/src/compat/refresh.rs @@ -0,0 +1,138 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use axum::{response::IntoResponse, Extension, Json}; +use chrono::Duration; +use hyper::StatusCode; +use mas_data_model::{TokenFormatError, TokenType}; +use mas_storage::compat::{ + add_compat_access_token, add_compat_refresh_token, expire_compat_access_token, + lookup_active_compat_refresh_token, replace_compat_refresh_token, + CompatRefreshTokenLookupError, +}; +use rand::thread_rng; +use serde::{Deserialize, Serialize}; +use serde_with::{serde_as, DurationMilliSeconds}; +use sqlx::PgPool; +use thiserror::Error; + +use super::MatrixError; + +#[derive(Debug, Deserialize)] +pub struct RequestBody { + refresh_token: String, +} + +#[derive(Debug, Error)] +pub enum RouteError { + #[error(transparent)] + Internal(Box), + + #[error(transparent)] + Anyhow(#[from] anyhow::Error), + + #[error("invalid token")] + InvalidToken, +} + +impl IntoResponse for RouteError { + fn into_response(self) -> axum::response::Response { + match self { + Self::Internal(_) | Self::Anyhow(_) => MatrixError { + errcode: "M_UNKNOWN", + error: "Internal error", + status: StatusCode::INTERNAL_SERVER_ERROR, + }, + Self::InvalidToken => MatrixError { + errcode: "M_UNKNOWN_TOKEN", + error: "Invalid refresh token", + status: StatusCode::UNAUTHORIZED, + }, + } + .into_response() + } +} + +impl From for RouteError { + fn from(e: sqlx::Error) -> Self { + Self::Internal(Box::new(e)) + } +} + +impl From for RouteError { + fn from(_e: TokenFormatError) -> Self { + Self::InvalidToken + } +} + +impl From for RouteError { + fn from(e: CompatRefreshTokenLookupError) -> Self { + if e.not_found() { + Self::InvalidToken + } else { + Self::Internal(Box::new(e)) + } + } +} + +#[serde_as] +#[derive(Debug, Serialize)] +pub struct ResponseBody { + access_token: String, + refresh_token: String, + #[serde_as(as = "DurationMilliSeconds")] + expires_in_ms: Duration, +} + +pub(crate) async fn post( + Extension(pool): Extension, + Json(input): Json, +) -> Result { + let mut txn = pool.begin().await?; + + let token_type = TokenType::check(&input.refresh_token)?; + + if token_type != TokenType::CompatRefreshToken { + return Err(RouteError::InvalidToken); + } + + let (refresh_token, access_token, session) = + lookup_active_compat_refresh_token(&mut txn, &input.refresh_token).await?; + + let (new_refresh_token_str, new_access_token_str) = { + let mut rng = thread_rng(); + ( + TokenType::CompatRefreshToken.generate(&mut rng), + TokenType::CompatAccessToken.generate(&mut rng), + ) + }; + + let expires_in = Duration::minutes(5); + let new_access_token = + add_compat_access_token(&mut txn, &session, new_access_token_str, Some(expires_in)).await?; + let new_refresh_token = + add_compat_refresh_token(&mut txn, &session, &new_access_token, new_refresh_token_str) + .await?; + + replace_compat_refresh_token(&mut txn, &refresh_token, &new_refresh_token).await?; + expire_compat_access_token(&mut txn, access_token).await?; + + txn.commit().await?; + + Ok(Json(ResponseBody { + access_token: new_access_token.token, + refresh_token: new_refresh_token.token, + expires_in_ms: expires_in, + })) +} diff --git a/crates/handlers/src/lib.rs b/crates/handlers/src/lib.rs index 5f9efa11..a0c91c54 100644 --- a/crates/handlers/src/lib.rs +++ b/crates/handlers/src/lib.rs @@ -121,6 +121,10 @@ where mas_router::CompatLogout::route(), post(self::compat::logout::post), ) + .route( + mas_router::CompatRefresh::route(), + post(self::compat::refresh::post), + ) .layer( CorsLayer::new() .allow_origin(Any) diff --git a/crates/handlers/src/oauth2/introspection.rs b/crates/handlers/src/oauth2/introspection.rs index 4e7c537f..7a4bb9c4 100644 --- a/crates/handlers/src/oauth2/introspection.rs +++ b/crates/handlers/src/oauth2/introspection.rs @@ -248,7 +248,8 @@ pub(crate) async fn post( } } TokenType::CompatRefreshToken => { - let (token, session) = lookup_active_compat_refresh_token(&mut conn, token).await?; + let (refresh_token, _access_token, session) = + lookup_active_compat_refresh_token(&mut conn, token).await?; let device_scope = session.device.to_scope_token(); let scope = [device_scope].into_iter().collect(); @@ -260,8 +261,8 @@ pub(crate) async fn post( username: Some(session.user.username), token_type: Some(OAuthTokenTypeHint::RefreshToken), exp: None, - iat: Some(token.created_at), - nbf: Some(token.created_at), + iat: Some(refresh_token.created_at), + nbf: Some(refresh_token.created_at), sub: Some(session.user.sub), aud: None, iss: None, diff --git a/crates/router/src/endpoints.rs b/crates/router/src/endpoints.rs index 52052d6b..b964f7df 100644 --- a/crates/router/src/endpoints.rs +++ b/crates/router/src/endpoints.rs @@ -380,3 +380,10 @@ pub struct CompatLogout; impl SimpleRoute for CompatLogout { const PATH: &'static str = "/_matrix/client/:version/logout"; } + +/// `POST /_matrix/client/v3/refresh` +pub struct CompatRefresh; + +impl SimpleRoute for CompatRefresh { + const PATH: &'static str = "/_matrix/client/:version/refresh"; +} diff --git a/crates/storage/sqlx-data.json b/crates/storage/sqlx-data.json index b2bba71c..3d3156e6 100644 --- a/crates/storage/sqlx-data.json +++ b/crates/storage/sqlx-data.json @@ -1,5 +1,18 @@ { "db": "PostgreSQL", + "02be1a7451e890cb0cc07b32c937881ac9bd1707eb498e20a3cf27737c95a949": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "Int8", + "Int8" + ] + } + }, + "query": "\n UPDATE compat_refresh_tokens\n SET next_token_id = $2\n WHERE id = $1\n " + }, "08896e50738af687ac53dc5ac5ae0b19bcac7503230ba90e11de799978d7a026": { "describe": { "columns": [ @@ -333,6 +346,18 @@ }, "query": "\n INSERT INTO user_sessions (user_id)\n VALUES ($1)\n RETURNING id, created_at\n " }, + "366ea127c7b220960f17fd1b651600826ac10b8baf92f0e936fd07f34a7dc0fc": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "Int8" + ] + } + }, + "query": "\n UPDATE compat_access_tokens\n SET expires_at = NOW()\n WHERE id = $1\n " + }, "41b5ecd6860791ac6f90417ac51eb977b8c69a3dd81af4672b2592efb65963eb": { "describe": { "columns": [ @@ -1497,6 +1522,122 @@ }, "query": "\n INSERT INTO oauth2_client_redirect_uris (oauth2_client_id, redirect_uri)\n SELECT $1, uri FROM UNNEST($2::text[]) uri\n " }, + "ab800ea65b9c703a56b6c3b7dd47402dbbe0c9900f6d965c908b84332b2aa148": { + "describe": { + "columns": [ + { + "name": "compat_refresh_token_id", + "ordinal": 0, + "type_info": "Int8" + }, + { + "name": "compat_refresh_token", + "ordinal": 1, + "type_info": "Text" + }, + { + "name": "compat_refresh_token_created_at", + "ordinal": 2, + "type_info": "Timestamptz" + }, + { + "name": "compat_access_token_id", + "ordinal": 3, + "type_info": "Int8" + }, + { + "name": "compat_access_token", + "ordinal": 4, + "type_info": "Text" + }, + { + "name": "compat_access_token_created_at", + "ordinal": 5, + "type_info": "Timestamptz" + }, + { + "name": "compat_access_token_expires_at", + "ordinal": 6, + "type_info": "Timestamptz" + }, + { + "name": "compat_session_id", + "ordinal": 7, + "type_info": "Int8" + }, + { + "name": "compat_session_created_at", + "ordinal": 8, + "type_info": "Timestamptz" + }, + { + "name": "compat_session_deleted_at", + "ordinal": 9, + "type_info": "Timestamptz" + }, + { + "name": "compat_session_device_id", + "ordinal": 10, + "type_info": "Text" + }, + { + "name": "user_id!", + "ordinal": 11, + "type_info": "Int8" + }, + { + "name": "user_username!", + "ordinal": 12, + "type_info": "Text" + }, + { + "name": "user_email_id?", + "ordinal": 13, + "type_info": "Int8" + }, + { + "name": "user_email?", + "ordinal": 14, + "type_info": "Text" + }, + { + "name": "user_email_created_at?", + "ordinal": 15, + "type_info": "Timestamptz" + }, + { + "name": "user_email_confirmed_at?", + "ordinal": 16, + "type_info": "Timestamptz" + } + ], + "nullable": [ + false, + false, + false, + false, + false, + false, + true, + false, + false, + true, + false, + false, + false, + false, + false, + false, + true + ], + "parameters": { + "Left": [ + "Text" + ] + } + }, + "query": "\n SELECT\n cr.id AS \"compat_refresh_token_id\",\n cr.token AS \"compat_refresh_token\",\n cr.created_at AS \"compat_refresh_token_created_at\",\n ct.id AS \"compat_access_token_id\",\n ct.token AS \"compat_access_token\",\n ct.created_at AS \"compat_access_token_created_at\",\n ct.expires_at AS \"compat_access_token_expires_at\",\n cs.id AS \"compat_session_id\",\n cs.created_at AS \"compat_session_created_at\",\n cs.deleted_at AS \"compat_session_deleted_at\",\n cs.device_id AS \"compat_session_device_id\",\n u.id AS \"user_id!\",\n u.username AS \"user_username!\",\n ue.id AS \"user_email_id?\",\n ue.email AS \"user_email?\",\n ue.created_at AS \"user_email_created_at?\",\n ue.confirmed_at AS \"user_email_confirmed_at?\"\n\n FROM compat_refresh_tokens cr\n INNER JOIN compat_access_tokens ct\n ON ct.id = cr.compat_access_token_id\n INNER JOIN compat_sessions cs\n ON cs.id = cr.compat_session_id\n INNER JOIN users u\n ON u.id = cs.user_id\n LEFT JOIN user_emails ue\n ON ue.id = u.primary_email_id\n\n WHERE cr.token = $1\n AND cr.next_token_id IS NULL\n AND cs.deleted_at IS NULL\n " + }, "aea289a04e151da235825305a5085bc6aa100fce139dbf10a2c1bed4867fc52a": { "describe": { "columns": [ @@ -2038,97 +2179,5 @@ } }, "query": "TRUNCATE oauth2_client_redirect_uris, oauth2_clients RESTART IDENTITY CASCADE" - }, - "fc5d32bab9999ad383f906dbf20a45dafba1149e809155eccb4d94506ff6cf6f": { - "describe": { - "columns": [ - { - "name": "compat_refresh_token_id", - "ordinal": 0, - "type_info": "Int8" - }, - { - "name": "compat_refresh_token", - "ordinal": 1, - "type_info": "Text" - }, - { - "name": "compat_refresh_token_created_at", - "ordinal": 2, - "type_info": "Timestamptz" - }, - { - "name": "compat_session_id", - "ordinal": 3, - "type_info": "Int8" - }, - { - "name": "compat_session_created_at", - "ordinal": 4, - "type_info": "Timestamptz" - }, - { - "name": "compat_session_deleted_at", - "ordinal": 5, - "type_info": "Timestamptz" - }, - { - "name": "compat_session_device_id", - "ordinal": 6, - "type_info": "Text" - }, - { - "name": "user_id!", - "ordinal": 7, - "type_info": "Int8" - }, - { - "name": "user_username!", - "ordinal": 8, - "type_info": "Text" - }, - { - "name": "user_email_id?", - "ordinal": 9, - "type_info": "Int8" - }, - { - "name": "user_email?", - "ordinal": 10, - "type_info": "Text" - }, - { - "name": "user_email_created_at?", - "ordinal": 11, - "type_info": "Timestamptz" - }, - { - "name": "user_email_confirmed_at?", - "ordinal": 12, - "type_info": "Timestamptz" - } - ], - "nullable": [ - false, - false, - false, - false, - false, - true, - false, - false, - false, - false, - false, - false, - true - ], - "parameters": { - "Left": [ - "Text" - ] - } - }, - "query": "\n SELECT\n cr.id AS \"compat_refresh_token_id\",\n cr.token AS \"compat_refresh_token\",\n cr.created_at AS \"compat_refresh_token_created_at\",\n cs.id AS \"compat_session_id\",\n cs.created_at AS \"compat_session_created_at\",\n cs.deleted_at AS \"compat_session_deleted_at\",\n cs.device_id AS \"compat_session_device_id\",\n u.id AS \"user_id!\",\n u.username AS \"user_username!\",\n ue.id AS \"user_email_id?\",\n ue.email AS \"user_email?\",\n ue.created_at AS \"user_email_created_at?\",\n ue.confirmed_at AS \"user_email_confirmed_at?\"\n\n FROM compat_refresh_tokens cr\n INNER JOIN compat_sessions cs\n ON cs.id = cr.compat_session_id\n INNER JOIN users u\n ON u.id = cs.user_id\n LEFT JOIN user_emails ue\n ON ue.id = u.primary_email_id\n\n WHERE cr.token = $1\n AND cs.deleted_at IS NULL\n " } } \ No newline at end of file diff --git a/crates/storage/src/compat.rs b/crates/storage/src/compat.rs index 43b9abf8..32d31482 100644 --- a/crates/storage/src/compat.rs +++ b/crates/storage/src/compat.rs @@ -153,6 +153,10 @@ pub struct CompatRefreshTokenLookup { compat_refresh_token_id: i64, compat_refresh_token: String, compat_refresh_token_created_at: DateTime, + compat_access_token_id: i64, + compat_access_token: String, + compat_access_token_created_at: DateTime, + compat_access_token_expires_at: Option>, compat_session_id: i64, compat_session_created_at: DateTime, compat_session_deleted_at: Option>, @@ -186,6 +190,7 @@ pub async fn lookup_active_compat_refresh_token( ) -> Result< ( CompatRefreshToken, + CompatAccessToken, CompatSession, ), CompatRefreshTokenLookupError, @@ -197,6 +202,10 @@ pub async fn lookup_active_compat_refresh_token( cr.id AS "compat_refresh_token_id", cr.token AS "compat_refresh_token", cr.created_at AS "compat_refresh_token_created_at", + ct.id AS "compat_access_token_id", + ct.token AS "compat_access_token", + ct.created_at AS "compat_access_token_created_at", + ct.expires_at AS "compat_access_token_expires_at", cs.id AS "compat_session_id", cs.created_at AS "compat_session_created_at", cs.deleted_at AS "compat_session_deleted_at", @@ -209,6 +218,8 @@ pub async fn lookup_active_compat_refresh_token( ue.confirmed_at AS "user_email_confirmed_at?" FROM compat_refresh_tokens cr + INNER JOIN compat_access_tokens ct + ON ct.id = cr.compat_access_token_id INNER JOIN compat_sessions cs ON cs.id = cr.compat_session_id INNER JOIN users u @@ -217,6 +228,7 @@ pub async fn lookup_active_compat_refresh_token( ON ue.id = u.primary_email_id WHERE cr.token = $1 + AND cr.next_token_id IS NULL AND cs.deleted_at IS NULL "#, token, @@ -225,12 +237,19 @@ pub async fn lookup_active_compat_refresh_token( .instrument(info_span!("Fetch compat refresh token")) .await?; - let token = CompatRefreshToken { + let refresh_token = CompatRefreshToken { data: res.compat_refresh_token_id, token: res.compat_refresh_token, created_at: res.compat_refresh_token_created_at, }; + let access_token = CompatAccessToken { + data: res.compat_access_token_id, + token: res.compat_access_token, + created_at: res.compat_access_token_created_at, + expires_at: res.compat_access_token_expires_at, + }; + let primary_email = match ( res.user_email_id, res.user_email, @@ -264,7 +283,7 @@ pub async fn lookup_active_compat_refresh_token( deleted_at: res.compat_session_deleted_at, }; - Ok((token, session)) + Ok((refresh_token, access_token, session)) } #[tracing::instrument(skip(conn, password), err)] @@ -392,6 +411,31 @@ pub async fn add_compat_access_token( } } +pub async fn expire_compat_access_token( + executor: impl PgExecutor<'_>, + access_token: CompatAccessToken, +) -> anyhow::Result<()> { + let res = sqlx::query!( + r#" + UPDATE compat_access_tokens + SET expires_at = NOW() + WHERE id = $1 + "#, + access_token.data, + ) + .execute(executor) + .await + .context("failed to update compat access token")?; + + if res.rows_affected() == 1 { + Ok(()) + } else { + Err(anyhow::anyhow!( + "no row were affected when updating access token" + )) + } +} + pub async fn add_compat_refresh_token( executor: impl PgExecutor<'_>, session: &CompatSession, @@ -447,3 +491,30 @@ pub async fn compat_logout( _ => anyhow::bail!("too many row affected"), } } + +pub async fn replace_compat_refresh_token( + executor: impl PgExecutor<'_>, + refresh_token: &CompatRefreshToken, + next_refresh_token: &CompatRefreshToken, +) -> anyhow::Result<()> { + let res = sqlx::query!( + r#" + UPDATE compat_refresh_tokens + SET next_token_id = $2 + WHERE id = $1 + "#, + refresh_token.data, + next_refresh_token.data + ) + .execute(executor) + .await + .context("failed to update compat refresh token")?; + + if res.rows_affected() == 1 { + Ok(()) + } else { + Err(anyhow::anyhow!( + "no row were affected when updating refresh token" + )) + } +}