diff --git a/crates/handlers/src/compat/login.rs b/crates/handlers/src/compat/login.rs new file mode 100644 index 00000000..5605f25b --- /dev/null +++ b/crates/handlers/src/compat/login.rs @@ -0,0 +1,149 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use axum::{response::IntoResponse, Extension, Json}; +use hyper::StatusCode; +use mas_config::MatrixConfig; +use mas_data_model::TokenType; +use mas_storage::compat::compat_login; +use rand::{distributions::Alphanumeric, thread_rng, Rng}; +use serde::{Deserialize, Serialize}; +use sqlx::PgPool; + +use super::MatrixError; + +#[derive(Debug, Serialize, Deserialize)] +#[serde(tag = "type")] +enum LoginType { + #[serde(rename = "m.login.password")] + Password, +} + +#[derive(Debug, Serialize, Deserialize)] +struct LoginTypes { + flows: Vec, +} + +pub(crate) async fn get() -> impl IntoResponse { + let res = LoginTypes { + flows: vec![LoginType::Password], + }; + + Json(res) +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(tag = "type")] +pub enum RequestBody { + #[serde(rename = "m.login.password")] + Password { + identifier: Identifier, + password: String, + }, + + #[serde(other)] + Unsupported, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(tag = "type")] +pub enum Identifier { + #[serde(rename = "m.id.user")] + User { user: String }, + + #[serde(other)] + Unsupported, +} + +#[derive(Debug, Serialize)] +pub struct ResponseBody { + access_token: String, + device_id: String, + user_id: String, +} + +pub enum RouteError { + Internal(Box), + Unsupported, + LoginFailed, +} + +impl From for RouteError { + fn from(e: sqlx::Error) -> Self { + Self::Internal(Box::new(e)) + } +} + +impl IntoResponse for RouteError { + fn into_response(self) -> axum::response::Response { + match self { + Self::Internal(_e) => MatrixError { + errcode: "M_UNKNOWN", + error: "Internal server error", + status: StatusCode::INTERNAL_SERVER_ERROR, + }, + Self::Unsupported => MatrixError { + errcode: "M_UNRECOGNIZED", + error: "Invalid login type", + status: StatusCode::BAD_REQUEST, + }, + Self::LoginFailed => MatrixError { + errcode: "M_UNAUTHORIZED", + error: "Invalid username/password", + status: StatusCode::FORBIDDEN, + }, + } + .into_response() + } +} + +pub(crate) async fn post( + Extension(pool): Extension, + Extension(config): Extension, + Json(input): Json, +) -> Result { + let mut conn = pool.acquire().await?; + let (username, password) = match input { + RequestBody::Password { + identifier: Identifier::User { user }, + password, + } => (user, password), + _ => { + return Err(RouteError::Unsupported); + } + }; + + let (token, device_id) = { + let mut rng = thread_rng(); + let token = TokenType::CompatAccessToken.generate(&mut rng); + let device_id: String = rng + .sample_iter(&Alphanumeric) + .take(10) + .map(char::from) + .collect(); + (token, device_id) + }; + + let (token, user) = compat_login(&mut conn, &username, &password, device_id, token) + .await + .map_err(|_| RouteError::LoginFailed)?; + + let user_id = format!("@{}:{}", user.username, config.homeserver); + + Ok(Json(ResponseBody { + access_token: token.token, + device_id: token.device_id, + user_id, + })) +} diff --git a/crates/handlers/src/compat/logout.rs b/crates/handlers/src/compat/logout.rs new file mode 100644 index 00000000..613389e8 --- /dev/null +++ b/crates/handlers/src/compat/logout.rs @@ -0,0 +1,86 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use axum::{response::IntoResponse, Extension, Json, TypedHeader}; +use headers::{authorization::Bearer, Authorization}; +use hyper::StatusCode; +use mas_data_model::{TokenFormatError, TokenType}; +use mas_storage::compat::compat_logout; +use sqlx::PgPool; + +use super::MatrixError; + +pub enum RouteError { + Internal(Box), + MissingAuthorization, + InvalidAuthorization, + LogoutFailed, +} + +impl IntoResponse for RouteError { + fn into_response(self) -> axum::response::Response { + match self { + Self::Internal(_) => MatrixError { + errcode: "M_UNKNOWN", + error: "Internal error", + status: StatusCode::INTERNAL_SERVER_ERROR, + }, + Self::MissingAuthorization => MatrixError { + errcode: "M_MISSING_TOKEN", + error: "Missing access token", + status: StatusCode::UNAUTHORIZED, + }, + Self::InvalidAuthorization | Self::LogoutFailed => MatrixError { + errcode: "M_UNKNOWN_TOKEN", + error: "Invalid access token", + status: StatusCode::UNAUTHORIZED, + }, + } + .into_response() + } +} + +impl From for RouteError { + fn from(e: sqlx::Error) -> Self { + Self::Internal(Box::new(e)) + } +} + +impl From for RouteError { + fn from(_e: TokenFormatError) -> Self { + Self::InvalidAuthorization + } +} + +pub(crate) async fn post( + Extension(pool): Extension, + maybe_authorization: Option>>, +) -> Result { + let mut conn = pool.acquire().await?; + + let TypedHeader(authorization) = maybe_authorization.ok_or(RouteError::MissingAuthorization)?; + + let token = authorization.token(); + let token_type = TokenType::check(token)?; + + if token_type != TokenType::CompatAccessToken { + return Err(RouteError::InvalidAuthorization); + } + + compat_logout(&mut conn, token) + .await + .map_err(|_| RouteError::LogoutFailed)?; + + Ok(Json(serde_json::json!({}))) +} diff --git a/crates/handlers/src/compat/mod.rs b/crates/handlers/src/compat/mod.rs index f73356a3..37a978eb 100644 --- a/crates/handlers/src/compat/mod.rs +++ b/crates/handlers/src/compat/mod.rs @@ -12,14 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -use axum::{response::IntoResponse, Extension, Json}; +use axum::{response::IntoResponse, Json}; use hyper::StatusCode; -use mas_config::MatrixConfig; -use mas_data_model::TokenType; -use mas_storage::compat::compat_login; -use rand::{distributions::Alphanumeric, thread_rng, Rng}; -use serde::{Deserialize, Serialize}; -use sqlx::PgPool; +use serde::Serialize; + +pub(crate) mod login; +pub(crate) mod logout; #[derive(Debug, Serialize)] struct MatrixError { @@ -35,127 +33,3 @@ impl IntoResponse for MatrixError { } } -#[derive(Debug, Serialize, Deserialize)] -#[serde(tag = "type")] -enum LoginType { - #[serde(rename = "m.login.password")] - Password, -} - -#[derive(Debug, Serialize, Deserialize)] -struct LoginTypes { - flows: Vec, -} - -pub(crate) async fn get() -> impl IntoResponse { - let res = LoginTypes { - flows: vec![LoginType::Password], - }; - - Json(res) -} - -#[derive(Debug, Serialize, Deserialize)] -#[serde(tag = "type")] -pub enum IncomingLogin { - #[serde(rename = "m.login.password")] - Password { - identifier: LoginIdentifier, - password: String, - }, - - #[serde(other)] - Unsupported, -} - -#[derive(Debug, Serialize, Deserialize)] -#[serde(tag = "type")] -pub enum LoginIdentifier { - #[serde(rename = "m.id.user")] - User { user: String }, - - #[serde(other)] - Unsupported, -} - -#[derive(Debug, Serialize)] -pub struct SuccessfulLogin { - access_token: String, - device_id: String, - user_id: String, -} - -pub enum RouteError { - Internal(Box), - Unsupported, - LoginFailed, -} - -impl From for RouteError { - fn from(e: sqlx::Error) -> Self { - Self::Internal(Box::new(e)) - } -} - -impl IntoResponse for RouteError { - fn into_response(self) -> axum::response::Response { - match self { - Self::Internal(_e) => MatrixError { - errcode: "M_UNKNOWN", - error: "Internal server error", - status: StatusCode::INTERNAL_SERVER_ERROR, - }, - Self::Unsupported => MatrixError { - errcode: "M_UNRECOGNIZED", - error: "Invalid login type", - status: StatusCode::BAD_REQUEST, - }, - Self::LoginFailed => MatrixError { - errcode: "M_UNAUTHORIZED", - error: "Invalid username/password", - status: StatusCode::FORBIDDEN, - }, - } - .into_response() - } -} - -pub(crate) async fn post( - Extension(pool): Extension, - Extension(config): Extension, - Json(input): Json, -) -> Result { - let mut conn = pool.acquire().await?; - let (username, password) = match input { - IncomingLogin::Password { - identifier: LoginIdentifier::User { user }, - password, - } => (user, password), - _ => { - return Err(RouteError::Unsupported); - } - }; - - let (token, device_id) = { - let mut rng = thread_rng(); - let token = TokenType::CompatAccessToken.generate(&mut rng); - let device_id: String = rng - .sample_iter(&Alphanumeric) - .take(10) - .map(char::from) - .collect(); - (token, device_id) - }; - - let (token, user) = compat_login(&mut conn, &username, &password, device_id, token) - .await - .map_err(|_| RouteError::LoginFailed)?; - - let user_id = format!("@{}:{}", user.username, config.homeserver); - - Ok(Json(SuccessfulLogin { - access_token: token.token, - device_id: token.device_id, - user_id, - })) -} diff --git a/crates/handlers/src/lib.rs b/crates/handlers/src/lib.rs index 191e61e9..82ffef62 100644 --- a/crates/handlers/src/lib.rs +++ b/crates/handlers/src/lib.rs @@ -99,7 +99,11 @@ where ) .route( mas_router::CompatLogin::route(), - get(self::compat::get).post(self::compat::post), + get(self::compat::login::get).post(self::compat::login::post), + ) + .route( + mas_router::CompatLogout::route(), + post(self::compat::logout::post), ) .layer( CorsLayer::new() diff --git a/crates/router/src/endpoints.rs b/crates/router/src/endpoints.rs index 07574eb7..52052d6b 100644 --- a/crates/router/src/endpoints.rs +++ b/crates/router/src/endpoints.rs @@ -373,3 +373,10 @@ pub struct CompatLogin; impl SimpleRoute for CompatLogin { const PATH: &'static str = "/_matrix/client/:version/login"; } + +/// `POST /_matrix/client/v3/logout` +pub struct CompatLogout; + +impl SimpleRoute for CompatLogout { + const PATH: &'static str = "/_matrix/client/:version/logout"; +} diff --git a/crates/storage/sqlx-data.json b/crates/storage/sqlx-data.json index b1177957..25fa0acb 100644 --- a/crates/storage/sqlx-data.json +++ b/crates/storage/sqlx-data.json @@ -1557,6 +1557,18 @@ }, "query": "\n UPDATE oauth2_refresh_tokens\n SET next_token_id = $2\n WHERE id = $1\n " }, + "c53f3f064920f8516a14b384760e2c30f18ab6f099e468cf019fb4eaa0547637": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Left": [ + "Text" + ] + } + }, + "query": "\n UPDATE compat_access_tokens\n SET deleted_at = NOW()\n WHERE token = $1 AND deleted_at IS NULL\n " + }, "d2f767218ec2489058db9a0382ca0eea20379c30aeae9f492da4ba35b66f4dc7": { "describe": { "columns": [], diff --git a/crates/storage/src/compat.rs b/crates/storage/src/compat.rs index a03b70cb..b7d5fee1 100644 --- a/crates/storage/src/compat.rs +++ b/crates/storage/src/compat.rs @@ -129,7 +129,7 @@ pub async fn lookup_active_compat_access_token( Ok((token, user)) } -#[tracing::instrument(skip(conn, password))] +#[tracing::instrument(skip(conn, password, token))] pub async fn compat_login( conn: impl Acquire<'_, Database = Postgres>, username: &str, @@ -200,3 +200,27 @@ pub async fn compat_login( txn.commit().await.context("could not commit transaction")?; Ok((token, user)) } + +#[tracing::instrument(skip_all)] +pub async fn compat_logout( + executor: impl PgExecutor<'_>, + token: &str, +) -> Result<(), anyhow::Error> { + let res = sqlx::query!( + r#" + UPDATE compat_access_tokens + SET deleted_at = NOW() + WHERE token = $1 AND deleted_at IS NULL + "#, + token, + ) + .execute(executor) + .await + .context("could not update compat access token")?; + + match res.rows_affected() { + 1 => Ok(()), + 0 => anyhow::bail!("no row affected"), + _ => anyhow::bail!("too many row affected"), + } +}