1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-29 22:01:14 +03:00

Move storage module to its own crate

This commit is contained in:
Quentin Gliech
2021-12-17 12:15:07 +01:00
parent 584294538b
commit ceb17d3646
46 changed files with 116 additions and 71 deletions

54
crates/storage/src/lib.rs Normal file
View File

@ -0,0 +1,54 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Interactions with the database
#![allow(clippy::used_underscore_binding)] // This is needed by sqlx macros
use chrono::{DateTime, Utc};
use mas_data_model::{StorageBackend, StorageBackendMarker};
use serde::Serialize;
use sqlx::migrate::Migrator;
use thiserror::Error;
#[derive(Debug, Error)]
#[error("database query returned an inconsistent state")]
pub struct DatabaseInconsistencyError;
#[derive(Serialize, Debug, Clone, PartialEq)]
pub struct PostgresqlBackend;
impl StorageBackend for PostgresqlBackend {
type AccessTokenData = i64;
type AuthenticationData = i64;
type AuthorizationGrantData = i64;
type BrowserSessionData = i64;
type ClientData = ();
type RefreshTokenData = i64;
type SessionData = i64;
type UserData = i64;
}
impl StorageBackendMarker for PostgresqlBackend {}
struct IdAndCreationTime {
id: i64,
created_at: DateTime<Utc>,
}
pub mod oauth2;
pub mod user;
/// Embedded migrations, allowing them to run on startup
pub static MIGRATOR: Migrator = sqlx::migrate!();

View File

@ -0,0 +1,221 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use anyhow::Context;
use chrono::{DateTime, Duration, Utc};
use mas_data_model::{AccessToken, Authentication, BrowserSession, Client, Session, User};
use sqlx::PgExecutor;
use thiserror::Error;
use crate::{DatabaseInconsistencyError, IdAndCreationTime, PostgresqlBackend};
pub async fn add_access_token(
executor: impl PgExecutor<'_>,
session: &Session<PostgresqlBackend>,
token: &str,
expires_after: Duration,
) -> anyhow::Result<AccessToken<PostgresqlBackend>> {
// Checked convertion of duration to i32, maxing at i32::MAX
let expires_after_seconds = i32::try_from(expires_after.num_seconds()).unwrap_or(i32::MAX);
let res = sqlx::query_as!(
IdAndCreationTime,
r#"
INSERT INTO oauth2_access_tokens
(oauth2_session_id, token, expires_after)
VALUES
($1, $2, $3)
RETURNING
id, created_at
"#,
session.data,
token,
expires_after_seconds,
)
.fetch_one(executor)
.await
.context("could not insert oauth2 access token")?;
Ok(AccessToken {
data: res.id,
expires_after,
token: token.to_string(),
jti: format!("{}", res.id),
created_at: res.created_at,
})
}
#[derive(Debug)]
pub struct OAuth2AccessTokenLookup {
access_token_id: i64,
access_token: String,
access_token_expires_after: i32,
access_token_created_at: DateTime<Utc>,
session_id: i64,
client_id: String,
scope: String,
user_session_id: i64,
user_session_created_at: DateTime<Utc>,
user_id: i64,
user_username: String,
user_session_last_authentication_id: Option<i64>,
user_session_last_authentication_created_at: Option<DateTime<Utc>>,
}
#[derive(Debug, Error)]
#[error("failed to lookup access token")]
pub enum AccessTokenLookupError {
Database(#[from] sqlx::Error),
Inconsistency(#[from] DatabaseInconsistencyError),
}
impl AccessTokenLookupError {
#[must_use]
pub fn not_found(&self) -> bool {
matches!(
self,
&AccessTokenLookupError::Database(sqlx::Error::RowNotFound)
)
}
}
pub async fn lookup_active_access_token(
executor: impl PgExecutor<'_>,
token: &str,
) -> Result<(AccessToken<PostgresqlBackend>, Session<PostgresqlBackend>), AccessTokenLookupError> {
let res = sqlx::query_as!(
OAuth2AccessTokenLookup,
r#"
SELECT
at.id AS "access_token_id",
at.token AS "access_token",
at.expires_after AS "access_token_expires_after",
at.created_at AS "access_token_created_at",
os.id AS "session_id!",
os.client_id AS "client_id!",
os.scope AS "scope!",
us.id AS "user_session_id!",
us.created_at AS "user_session_created_at!",
u.id AS "user_id!",
u.username AS "user_username!",
usa.id AS "user_session_last_authentication_id?",
usa.created_at AS "user_session_last_authentication_created_at?"
FROM oauth2_access_tokens at
INNER JOIN oauth2_sessions os
ON os.id = at.oauth2_session_id
INNER JOIN user_sessions us
ON us.id = os.user_session_id
INNER JOIN users u
ON u.id = us.user_id
LEFT JOIN user_session_authentications usa
ON usa.session_id = us.id
WHERE at.token = $1
AND at.created_at + (at.expires_after * INTERVAL '1 second') >= now()
AND us.active
ORDER BY usa.created_at DESC
LIMIT 1
"#,
token,
)
.fetch_one(executor)
.await?;
let access_token = AccessToken {
data: res.access_token_id,
jti: format!("{}", res.access_token_id),
token: res.access_token,
created_at: res.access_token_created_at,
expires_after: Duration::seconds(res.access_token_expires_after.into()),
};
let client = Client {
data: (),
client_id: res.client_id,
};
let user = User {
data: res.user_id,
username: res.user_username,
sub: format!("fake-sub-{}", res.user_id),
};
let last_authentication = match (
res.user_session_last_authentication_id,
res.user_session_last_authentication_created_at,
) {
(None, None) => None,
(Some(id), Some(created_at)) => Some(Authentication {
data: id,
created_at,
}),
_ => return Err(DatabaseInconsistencyError.into()),
};
let browser_session = BrowserSession {
data: res.user_session_id,
created_at: res.user_session_created_at,
user,
last_authentication,
};
let scope = res.scope.parse().map_err(|_e| DatabaseInconsistencyError)?;
let session = Session {
data: res.session_id,
client,
browser_session,
scope,
};
Ok((access_token, session))
}
pub async fn revoke_access_token(
executor: impl PgExecutor<'_>,
access_token: &AccessToken<PostgresqlBackend>,
) -> anyhow::Result<()> {
let res = sqlx::query!(
r#"
DELETE FROM oauth2_access_tokens
WHERE id = $1
"#,
access_token.data,
)
.execute(executor)
.await
.context("could not revoke access tokens")?;
if res.rows_affected() == 1 {
Ok(())
} else {
Err(anyhow::anyhow!("no row were affected when revoking token"))
}
}
pub async fn cleanup_expired(executor: impl PgExecutor<'_>) -> anyhow::Result<u64> {
let res = sqlx::query!(
r#"
DELETE FROM oauth2_access_tokens
WHERE created_at + (expires_after * INTERVAL '1 second') + INTERVAL '15 minutes' < now()
"#,
)
.execute(executor)
.await
.context("could not cleanup expired access tokens")?;
Ok(res.rows_affected())
}

View File

@ -0,0 +1,510 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#![allow(clippy::unused_async)]
use std::num::NonZeroU32;
use anyhow::Context;
use chrono::{DateTime, Utc};
use mas_data_model::{
Authentication, AuthorizationCode, AuthorizationGrant, AuthorizationGrantStage, BrowserSession,
Client, Pkce, Session, User,
};
use oauth2_types::{pkce::CodeChallengeMethod, requests::ResponseMode, scope::Scope};
use sqlx::PgExecutor;
use url::Url;
use crate::{DatabaseInconsistencyError, IdAndCreationTime, PostgresqlBackend};
#[allow(clippy::too_many_arguments)]
pub async fn new_authorization_grant(
executor: impl PgExecutor<'_>,
client_id: String,
redirect_uri: Url,
scope: Scope,
code: Option<AuthorizationCode>,
state: Option<String>,
nonce: Option<String>,
max_age: Option<NonZeroU32>,
acr_values: Option<String>,
response_mode: ResponseMode,
response_type_token: bool,
response_type_id_token: bool,
) -> anyhow::Result<AuthorizationGrant<PostgresqlBackend>> {
let code_challenge = code
.as_ref()
.and_then(|c| c.pkce.as_ref())
.map(|p| &p.challenge);
let code_challenge_method = code
.as_ref()
.and_then(|c| c.pkce.as_ref())
.map(|p| p.challenge_method.to_string());
let code_str = code.as_ref().map(|c| &c.code);
let res = sqlx::query_as!(
IdAndCreationTime,
r#"
INSERT INTO oauth2_authorization_grants
(client_id, redirect_uri, scope, state, nonce, max_age,
acr_values, response_mode, code_challenge, code_challenge_method,
response_type_code, response_type_token, response_type_id_token,
code)
VALUES
($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14)
RETURNING id, created_at
"#,
&client_id,
redirect_uri.to_string(),
scope.to_string(),
state,
nonce,
// TODO: this conversion is a bit ugly
max_age.map(|x| i32::try_from(u32::from(x)).unwrap_or(i32::MAX)),
acr_values,
response_mode.to_string(),
code_challenge,
code_challenge_method,
code.is_some(),
response_type_token,
response_type_id_token,
code_str,
)
.fetch_one(executor)
.await
.context("could not insert oauth2 authorization grant")?;
let client = Client {
data: (),
client_id,
};
Ok(AuthorizationGrant {
data: res.id,
stage: AuthorizationGrantStage::Pending,
code,
redirect_uri,
client,
scope,
state,
nonce,
max_age,
acr_values,
response_mode,
created_at: res.created_at,
response_type_token,
response_type_id_token,
})
}
struct GrantLookup {
grant_id: i64,
grant_created_at: DateTime<Utc>,
grant_cancelled_at: Option<DateTime<Utc>>,
grant_fulfilled_at: Option<DateTime<Utc>>,
grant_exchanged_at: Option<DateTime<Utc>>,
grant_scope: String,
grant_state: Option<String>,
grant_redirect_uri: String,
grant_response_mode: String,
grant_nonce: Option<String>,
grant_max_age: Option<i32>,
grant_acr_values: Option<String>,
grant_response_type_code: bool,
grant_response_type_token: bool,
grant_response_type_id_token: bool,
grant_code: Option<String>,
grant_code_challenge: Option<String>,
grant_code_challenge_method: Option<String>,
client_id: String,
session_id: Option<i64>,
user_session_id: Option<i64>,
user_session_created_at: Option<DateTime<Utc>>,
user_id: Option<i64>,
user_username: Option<String>,
user_session_last_authentication_id: Option<i64>,
user_session_last_authentication_created_at: Option<DateTime<Utc>>,
}
impl TryInto<AuthorizationGrant<PostgresqlBackend>> for GrantLookup {
type Error = DatabaseInconsistencyError;
#[allow(clippy::too_many_lines)]
fn try_into(self) -> Result<AuthorizationGrant<PostgresqlBackend>, Self::Error> {
let scope: Scope = self
.grant_scope
.parse()
.map_err(|_e| DatabaseInconsistencyError)?;
let client = Client {
data: (),
client_id: self.client_id,
};
let last_authentication = match (
self.user_session_last_authentication_id,
self.user_session_last_authentication_created_at,
) {
(Some(id), Some(created_at)) => Some(Authentication {
data: id,
created_at,
}),
(None, None) => None,
_ => return Err(DatabaseInconsistencyError),
};
let session = match (
self.session_id,
self.user_session_id,
self.user_session_created_at,
self.user_id,
self.user_username,
last_authentication,
) {
(
Some(session_id),
Some(user_session_id),
Some(user_session_created_at),
Some(user_id),
Some(user_username),
last_authentication,
) => {
let user = User {
data: user_id,
username: user_username,
sub: format!("fake-sub-{}", user_id),
};
let browser_session = BrowserSession {
data: user_session_id,
user,
created_at: user_session_created_at,
last_authentication,
};
let client = client.clone();
let scope = scope.clone();
let session = Session {
data: session_id,
client,
browser_session,
scope,
};
Some(session)
}
(None, None, None, None, None, None) => None,
_ => return Err(DatabaseInconsistencyError),
};
let stage = match (
self.grant_fulfilled_at,
self.grant_exchanged_at,
self.grant_cancelled_at,
session,
) {
(None, None, None, None) => AuthorizationGrantStage::Pending,
(Some(fulfilled_at), None, None, Some(session)) => AuthorizationGrantStage::Fulfilled {
session,
fulfilled_at,
},
(Some(fulfilled_at), Some(exchanged_at), None, Some(session)) => {
AuthorizationGrantStage::Exchanged {
session,
fulfilled_at,
exchanged_at,
}
}
(None, None, Some(cancelled_at), None) => {
AuthorizationGrantStage::Cancelled { cancelled_at }
}
_ => {
return Err(DatabaseInconsistencyError);
}
};
let pkce = match (self.grant_code_challenge, self.grant_code_challenge_method) {
(Some(challenge), Some(challenge_method)) if challenge_method == "plain" => {
Some(Pkce {
challenge_method: CodeChallengeMethod::Plain,
challenge,
})
}
(Some(challenge), Some(challenge_method)) if challenge_method == "S256" => Some(Pkce {
challenge_method: CodeChallengeMethod::S256,
challenge,
}),
(None, None) => None,
_ => {
return Err(DatabaseInconsistencyError);
}
};
let code: Option<AuthorizationCode> =
match (self.grant_response_type_code, self.grant_code, pkce) {
(false, None, None) => None,
(true, Some(code), pkce) => Some(AuthorizationCode { code, pkce }),
_ => {
return Err(DatabaseInconsistencyError);
}
};
let redirect_uri = self
.grant_redirect_uri
.parse()
.map_err(|_e| DatabaseInconsistencyError)?;
let response_mode = self
.grant_response_mode
.parse()
.map_err(|_e| DatabaseInconsistencyError)?;
let max_age = self
.grant_max_age
.map(u32::try_from)
.transpose()
.map_err(|_e| DatabaseInconsistencyError)?
.map(NonZeroU32::try_from)
.transpose()
.map_err(|_e| DatabaseInconsistencyError)?;
Ok(AuthorizationGrant {
data: self.grant_id,
stage,
client,
code,
acr_values: self.grant_acr_values,
scope,
state: self.grant_state,
nonce: self.grant_nonce,
max_age, // TODO
response_mode,
redirect_uri,
created_at: self.grant_created_at,
response_type_token: self.grant_response_type_token,
response_type_id_token: self.grant_response_type_id_token,
})
}
}
pub async fn get_grant_by_id(
executor: impl PgExecutor<'_>,
id: i64,
) -> anyhow::Result<AuthorizationGrant<PostgresqlBackend>> {
// TODO: handle "not found" cases
let res = sqlx::query_as!(
GrantLookup,
r#"
SELECT
og.id AS grant_id,
og.created_at AS grant_created_at,
og.cancelled_at AS grant_cancelled_at,
og.fulfilled_at AS grant_fulfilled_at,
og.exchanged_at AS grant_exchanged_at,
og.scope AS grant_scope,
og.state AS grant_state,
og.redirect_uri AS grant_redirect_uri,
og.response_mode AS grant_response_mode,
og.nonce AS grant_nonce,
og.max_age AS grant_max_age,
og.acr_values AS grant_acr_values,
og.client_id AS client_id,
og.code AS grant_code,
og.response_type_code AS grant_response_type_code,
og.response_type_token AS grant_response_type_token,
og.response_type_id_token AS grant_response_type_id_token,
og.code_challenge AS grant_code_challenge,
og.code_challenge_method AS grant_code_challenge_method,
os.id AS "session_id?",
us.id AS "user_session_id?",
us.created_at AS "user_session_created_at?",
u.id AS "user_id?",
u.username AS "user_username?",
usa.id AS "user_session_last_authentication_id?",
usa.created_at AS "user_session_last_authentication_created_at?"
FROM
oauth2_authorization_grants og
LEFT JOIN oauth2_sessions os
ON os.id = og.oauth2_session_id
LEFT JOIN user_sessions us
ON us.id = os.user_session_id
LEFT JOIN users u
ON u.id = us.user_id
LEFT JOIN user_session_authentications usa
ON usa.session_id = us.id
WHERE
og.id = $1
ORDER BY usa.created_at DESC
LIMIT 1
"#,
id,
)
.fetch_one(executor)
.await
.context("failed to get grant by id")?;
let grant = res.try_into()?;
Ok(grant)
}
pub async fn lookup_grant_by_code(
executor: impl PgExecutor<'_>,
code: &str,
) -> anyhow::Result<AuthorizationGrant<PostgresqlBackend>> {
// TODO: handle "not found" cases
let res = sqlx::query_as!(
GrantLookup,
r#"
SELECT
og.id AS grant_id,
og.created_at AS grant_created_at,
og.cancelled_at AS grant_cancelled_at,
og.fulfilled_at AS grant_fulfilled_at,
og.exchanged_at AS grant_exchanged_at,
og.scope AS grant_scope,
og.state AS grant_state,
og.redirect_uri AS grant_redirect_uri,
og.response_mode AS grant_response_mode,
og.nonce AS grant_nonce,
og.max_age AS grant_max_age,
og.acr_values AS grant_acr_values,
og.client_id AS client_id,
og.code AS grant_code,
og.response_type_code AS grant_response_type_code,
og.response_type_token AS grant_response_type_token,
og.response_type_id_token AS grant_response_type_id_token,
og.code_challenge AS grant_code_challenge,
og.code_challenge_method AS grant_code_challenge_method,
os.id AS "session_id?",
us.id AS "user_session_id?",
us.created_at AS "user_session_created_at?",
u.id AS "user_id?",
u.username AS "user_username?",
usa.id AS "user_session_last_authentication_id?",
usa.created_at AS "user_session_last_authentication_created_at?"
FROM
oauth2_authorization_grants og
LEFT JOIN oauth2_sessions os
ON os.id = og.oauth2_session_id
LEFT JOIN user_sessions us
ON us.id = os.user_session_id
LEFT JOIN users u
ON u.id = us.user_id
LEFT JOIN user_session_authentications usa
ON usa.session_id = us.id
WHERE
og.code = $1
ORDER BY usa.created_at DESC
LIMIT 1
"#,
code,
)
.fetch_one(executor)
.await
.context("failed to lookup grant by code")?;
let grant = res.try_into()?;
Ok(grant)
}
pub async fn derive_session(
executor: impl PgExecutor<'_>,
grant: &AuthorizationGrant<PostgresqlBackend>,
browser_session: BrowserSession<PostgresqlBackend>,
) -> anyhow::Result<Session<PostgresqlBackend>> {
let res = sqlx::query_as!(
IdAndCreationTime,
r#"
INSERT INTO oauth2_sessions
(user_session_id, client_id, scope)
SELECT
$1,
og.client_id,
og.scope
FROM
oauth2_authorization_grants og
WHERE
og.id = $2
RETURNING id, created_at
"#,
browser_session.data,
grant.data,
)
.fetch_one(executor)
.await
.context("could not insert oauth2 session")?;
Ok(Session {
data: res.id,
browser_session,
client: grant.client.clone(),
scope: grant.scope.clone(),
})
}
pub async fn fulfill_grant(
executor: impl PgExecutor<'_>,
mut grant: AuthorizationGrant<PostgresqlBackend>,
session: Session<PostgresqlBackend>,
) -> anyhow::Result<AuthorizationGrant<PostgresqlBackend>> {
let fulfilled_at = sqlx::query_scalar!(
r#"
UPDATE oauth2_authorization_grants AS og
SET
oauth2_session_id = os.id,
fulfilled_at = os.created_at
FROM oauth2_sessions os
WHERE
og.id = $1 AND os.id = $2
RETURNING fulfilled_at AS "fulfilled_at!: DateTime<Utc>"
"#,
grant.data,
session.data,
)
.fetch_one(executor)
.await
.context("could not makr grant as fulfilled")?;
grant.stage = grant.stage.fulfill(fulfilled_at, session)?;
Ok(grant)
}
pub async fn exchange_grant(
executor: impl PgExecutor<'_>,
mut grant: AuthorizationGrant<PostgresqlBackend>,
) -> anyhow::Result<AuthorizationGrant<PostgresqlBackend>> {
let exchanged_at = sqlx::query_scalar!(
r#"
UPDATE oauth2_authorization_grants
SET
exchanged_at = NOW()
WHERE
id = $1
RETURNING exchanged_at AS "exchanged_at!: DateTime<Utc>"
"#,
grant.data,
)
.fetch_one(executor)
.await
.context("could not mark grant as exchanged")?;
grant.stage = grant.stage.exchange(exchanged_at)?;
Ok(grant)
}

View File

@ -0,0 +1,17 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
pub mod access_token;
pub mod authorization_grant;
pub mod refresh_token;

View File

@ -0,0 +1,214 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use anyhow::Context;
use chrono::{DateTime, Duration, Utc};
use mas_data_model::{
AccessToken, Authentication, BrowserSession, Client, RefreshToken, Session, User,
};
use sqlx::PgExecutor;
use crate::{DatabaseInconsistencyError, IdAndCreationTime, PostgresqlBackend};
pub async fn add_refresh_token(
executor: impl PgExecutor<'_>,
session: &Session<PostgresqlBackend>,
access_token: AccessToken<PostgresqlBackend>,
token: &str,
) -> anyhow::Result<RefreshToken<PostgresqlBackend>> {
let res = sqlx::query_as!(
IdAndCreationTime,
r#"
INSERT INTO oauth2_refresh_tokens
(oauth2_session_id, oauth2_access_token_id, token)
VALUES
($1, $2, $3)
RETURNING
id, created_at
"#,
session.data,
access_token.data,
token,
)
.fetch_one(executor)
.await
.context("could not insert oauth2 refresh token")?;
Ok(RefreshToken {
data: res.id,
token: token.to_string(),
access_token: Some(access_token),
created_at: res.created_at,
})
}
struct OAuth2RefreshTokenLookup {
refresh_token_id: i64,
refresh_token: String,
refresh_token_created_at: DateTime<Utc>,
access_token_id: Option<i64>,
access_token: Option<String>,
access_token_expires_after: Option<i32>,
access_token_created_at: Option<DateTime<Utc>>,
session_id: i64,
client_id: String,
scope: String,
user_session_id: i64,
user_session_created_at: DateTime<Utc>,
user_id: i64,
user_username: String,
user_session_last_authentication_id: Option<i64>,
user_session_last_authentication_created_at: Option<DateTime<Utc>>,
}
#[allow(clippy::too_many_lines)]
pub async fn lookup_active_refresh_token(
executor: impl PgExecutor<'_>,
token: &str,
) -> anyhow::Result<(RefreshToken<PostgresqlBackend>, Session<PostgresqlBackend>)> {
let res = sqlx::query_as!(
OAuth2RefreshTokenLookup,
r#"
SELECT
rt.id AS refresh_token_id,
rt.token AS refresh_token,
rt.created_at AS refresh_token_created_at,
at.id AS "access_token_id?",
at.token AS "access_token?",
at.expires_after AS "access_token_expires_after?",
at.created_at AS "access_token_created_at?",
os.id AS "session_id!",
os.client_id AS "client_id!",
os.scope AS "scope!",
us.id AS "user_session_id!",
us.created_at AS "user_session_created_at!",
u.id AS "user_id!",
u.username AS "user_username!",
usa.id AS "user_session_last_authentication_id?",
usa.created_at AS "user_session_last_authentication_created_at?"
FROM oauth2_refresh_tokens rt
LEFT JOIN oauth2_access_tokens at
ON at.id = rt.oauth2_access_token_id
INNER JOIN oauth2_sessions os
ON os.id = rt.oauth2_session_id
INNER JOIN user_sessions us
ON us.id = os.user_session_id
INNER JOIN users u
ON u.id = us.user_id
LEFT JOIN user_session_authentications usa
ON usa.session_id = us.id
WHERE rt.token = $1
AND rt.next_token_id IS NULL
AND us.active
ORDER BY usa.created_at DESC
LIMIT 1
"#,
token,
)
.fetch_one(executor)
.await
.context("failed to fetch oauth2 refresh token")?;
let access_token = match (
res.access_token_id,
res.access_token,
res.access_token_created_at,
res.access_token_expires_after,
) {
(None, None, None, None) => None,
(Some(id), Some(token), Some(created_at), Some(expires_after)) => Some(AccessToken {
data: id,
jti: format!("{}", id),
token,
created_at,
expires_after: Duration::seconds(expires_after.into()),
}),
_ => return Err(DatabaseInconsistencyError.into()),
};
let refresh_token = RefreshToken {
data: res.refresh_token_id,
token: res.refresh_token,
created_at: res.refresh_token_created_at,
access_token,
};
let client = Client {
data: (),
client_id: res.client_id,
};
let user = User {
data: res.user_id,
username: res.user_username,
sub: format!("fake-sub-{}", res.user_id),
};
let last_authentication = match (
res.user_session_last_authentication_id,
res.user_session_last_authentication_created_at,
) {
(None, None) => None,
(Some(id), Some(created_at)) => Some(Authentication {
data: id,
created_at,
}),
_ => return Err(DatabaseInconsistencyError.into()),
};
let browser_session = BrowserSession {
data: res.user_session_id,
created_at: res.user_session_created_at,
user,
last_authentication,
};
let session = Session {
data: res.session_id,
client,
browser_session,
scope: res.scope.parse().context("invalid scope in database")?,
};
Ok((refresh_token, session))
}
pub async fn replace_refresh_token(
executor: impl PgExecutor<'_>,
refresh_token: &RefreshToken<PostgresqlBackend>,
next_refresh_token: &RefreshToken<PostgresqlBackend>,
) -> anyhow::Result<()> {
let res = sqlx::query!(
r#"
UPDATE oauth2_refresh_tokens
SET next_token_id = $2
WHERE id = $1
"#,
refresh_token.data,
next_refresh_token.data
)
.execute(executor)
.await
.context("failed to update oauth2 refresh token")?;
if res.rows_affected() == 1 {
Ok(())
} else {
Err(anyhow::anyhow!(
"no row were affected when updating refresh token"
))
}
}

414
crates/storage/src/user.rs Normal file
View File

@ -0,0 +1,414 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::borrow::BorrowMut;
use anyhow::Context;
use argon2::Argon2;
use chrono::{DateTime, Utc};
use mas_data_model::{errors::HtmlError, Authentication, BrowserSession, User};
use password_hash::{PasswordHash, PasswordHasher, SaltString};
use rand::rngs::OsRng;
use sqlx::{Acquire, PgExecutor, Postgres, Transaction};
use thiserror::Error;
use tokio::task;
use tracing::{info_span, Instrument};
use warp::reject::Reject;
use super::{DatabaseInconsistencyError, PostgresqlBackend};
use crate::IdAndCreationTime;
#[derive(Debug, Clone)]
struct UserLookup {
pub id: i64,
pub username: String,
}
#[derive(Debug, Error)]
pub enum LoginError {
#[error("could not find user {username:?}")]
NotFound {
username: String,
#[source]
source: sqlx::Error,
},
#[error("authentication failed for {username:?}")]
Authentication {
username: String,
#[source]
source: AuthenticationError,
},
#[error("failed to login")]
Other(#[from] anyhow::Error),
}
impl HtmlError for LoginError {
fn html_display(&self) -> String {
match self {
LoginError::NotFound { .. } => "Could not find user".to_string(),
LoginError::Authentication { .. } => "Failed to authenticate user".to_string(),
LoginError::Other(e) => format!("Internal error: <pre>{}</pre>", e),
}
}
}
#[tracing::instrument(skip(conn, password))]
pub async fn login(
conn: impl Acquire<'_, Database = Postgres>,
username: &str,
password: String,
) -> Result<BrowserSession<PostgresqlBackend>, LoginError> {
let mut txn = conn.begin().await.context("could not start transaction")?;
let user = lookup_user_by_username(&mut txn, username)
.await
.map_err(|source| {
if matches!(source, sqlx::Error::RowNotFound) {
LoginError::NotFound {
username: username.to_string(),
source,
}
} else {
LoginError::Other(source.into())
}
})?;
let mut session = start_session(&mut txn, user).await?;
authenticate_session(&mut txn, &mut session, password)
.await
.map_err(|source| {
if matches!(source, AuthenticationError::Password { .. }) {
LoginError::Authentication {
username: username.to_string(),
source,
}
} else {
LoginError::Other(source.into())
}
})?;
txn.commit().await.context("could not commit transaction")?;
Ok(session)
}
#[derive(Debug, Error)]
#[error("could not fetch session")]
pub enum ActiveSessionLookupError {
Fetch(#[from] sqlx::Error),
Conversion(#[from] DatabaseInconsistencyError),
}
impl Reject for ActiveSessionLookupError {}
impl ActiveSessionLookupError {
#[must_use]
pub fn not_found(&self) -> bool {
matches!(
self,
ActiveSessionLookupError::Fetch(sqlx::Error::RowNotFound)
)
}
}
struct SessionLookup {
id: i64,
user_id: i64,
username: String,
created_at: DateTime<Utc>,
last_authentication_id: Option<i64>,
last_authd_at: Option<DateTime<Utc>>,
}
impl TryInto<BrowserSession<PostgresqlBackend>> for SessionLookup {
type Error = DatabaseInconsistencyError;
fn try_into(self) -> Result<BrowserSession<PostgresqlBackend>, Self::Error> {
let user = User {
data: self.user_id,
username: self.username,
sub: format!("fake-sub-{}", self.user_id),
};
let last_authentication = match (self.last_authentication_id, self.last_authd_at) {
(Some(id), Some(created_at)) => Some(Authentication {
data: id,
created_at,
}),
(None, None) => None,
_ => return Err(DatabaseInconsistencyError),
};
Ok(BrowserSession {
data: self.id,
user,
created_at: self.created_at,
last_authentication,
})
}
}
pub async fn lookup_active_session(
executor: impl PgExecutor<'_>,
id: i64,
) -> Result<BrowserSession<PostgresqlBackend>, ActiveSessionLookupError> {
let res = sqlx::query_as!(
SessionLookup,
r#"
SELECT
s.id,
u.id as user_id,
u.username,
s.created_at,
a.id as "last_authentication_id?",
a.created_at as "last_authd_at?"
FROM user_sessions s
INNER JOIN users u
ON s.user_id = u.id
LEFT JOIN user_session_authentications a
ON a.session_id = s.id
WHERE s.id = $1 AND s.active
ORDER BY a.created_at DESC
LIMIT 1
"#,
id,
)
.fetch_one(executor)
.await?
.try_into()?;
Ok(res)
}
pub async fn start_session(
executor: impl PgExecutor<'_>,
user: User<PostgresqlBackend>,
) -> anyhow::Result<BrowserSession<PostgresqlBackend>> {
let res = sqlx::query_as!(
IdAndCreationTime,
r#"
INSERT INTO user_sessions (user_id)
VALUES ($1)
RETURNING id, created_at
"#,
user.data,
)
.fetch_one(executor)
.await
.context("could not create session")?;
let session = BrowserSession {
data: res.id,
user,
created_at: res.created_at,
last_authentication: None,
};
Ok(session)
}
#[tracing::instrument(skip_all, fields(user.id = user.data))]
pub async fn count_active_sessions(
executor: impl PgExecutor<'_>,
user: &User<PostgresqlBackend>,
) -> Result<usize, anyhow::Error> {
let res = sqlx::query_scalar!(
r#"
SELECT COUNT(*) as "count!"
FROM user_sessions s
WHERE s.user_id = $1 AND s.active
"#,
user.data,
)
.fetch_one(executor)
.await?
.try_into()?;
Ok(res)
}
#[derive(Debug, Error)]
pub enum AuthenticationError {
#[error("could not verify password")]
Password(#[from] password_hash::Error),
#[error("could not fetch user password hash")]
Fetch(sqlx::Error),
#[error("could not save session auth")]
Save(sqlx::Error),
#[error("runtime error")]
Internal(#[from] tokio::task::JoinError),
}
#[tracing::instrument(skip_all, fields(session.id = session.data, user.id = session.user.data))]
pub async fn authenticate_session(
txn: &mut Transaction<'_, Postgres>,
session: &mut BrowserSession<PostgresqlBackend>,
password: String,
) -> Result<(), AuthenticationError> {
// First, fetch the hashed password from the user associated with that session
let hashed_password: String = sqlx::query_scalar!(
r#"
SELECT up.hashed_password
FROM user_passwords up
WHERE up.user_id = $1
ORDER BY up.created_at DESC
LIMIT 1
"#,
session.user.data,
)
.fetch_one(txn.borrow_mut())
.instrument(tracing::info_span!("Lookup hashed password"))
.await
.map_err(AuthenticationError::Fetch)?;
// TODO: pass verifiers list as parameter
// Verify the password in a blocking thread to avoid blocking the async executor
task::spawn_blocking(move || {
let context = Argon2::default();
let hasher = PasswordHash::new(&hashed_password).map_err(AuthenticationError::Password)?;
hasher
.verify_password(&[&context], &password)
.map_err(AuthenticationError::Password)
})
.instrument(tracing::info_span!("Verify hashed password"))
.await??;
// That went well, let's insert the auth info
let res = sqlx::query_as!(
IdAndCreationTime,
r#"
INSERT INTO user_session_authentications (session_id)
VALUES ($1)
RETURNING id, created_at
"#,
session.data,
)
.fetch_one(txn.borrow_mut())
.instrument(tracing::info_span!("Save authentication"))
.await
.map_err(AuthenticationError::Save)?;
session.last_authentication = Some(Authentication {
data: res.id,
created_at: res.created_at,
});
Ok(())
}
#[tracing::instrument(skip(txn, phf, password))]
pub async fn register_user(
txn: &mut Transaction<'_, Postgres>,
phf: impl PasswordHasher,
username: &str,
password: &str,
) -> anyhow::Result<User<PostgresqlBackend>> {
let id: i64 = sqlx::query_scalar!(
r#"
INSERT INTO users (username)
VALUES ($1)
RETURNING id
"#,
username,
)
.fetch_one(txn.borrow_mut())
.instrument(info_span!("Register user"))
.await
.context("could not insert user")?;
let user = User {
data: id,
username: username.to_string(),
sub: format!("fake-sub-{}", id),
};
set_password(txn.borrow_mut(), phf, &user, password).await?;
Ok(user)
}
#[tracing::instrument(skip_all, fields(user.id = user.data))]
pub async fn set_password(
executor: impl PgExecutor<'_>,
phf: impl PasswordHasher,
user: &User<PostgresqlBackend>,
password: &str,
) -> anyhow::Result<()> {
let salt = SaltString::generate(&mut OsRng);
let hashed_password = PasswordHash::generate(phf, password, salt.as_str())?;
sqlx::query_scalar!(
r#"
INSERT INTO user_passwords (user_id, hashed_password)
VALUES ($1, $2)
"#,
user.data,
hashed_password.to_string(),
)
.execute(executor)
.instrument(info_span!("Save user credentials"))
.await
.context("could not insert user password")?;
Ok(())
}
#[tracing::instrument(skip_all, fields(session.id = session.data))]
pub async fn end_session(
executor: impl PgExecutor<'_>,
session: &BrowserSession<PostgresqlBackend>,
) -> anyhow::Result<()> {
let res = sqlx::query!(
"UPDATE user_sessions SET active = FALSE WHERE id = $1",
session.data,
)
.execute(executor)
.instrument(info_span!("End session"))
.await
.context("could not end session")?;
match res.rows_affected() {
1 => Ok(()),
0 => Err(anyhow::anyhow!("no row affected")),
_ => Err(anyhow::anyhow!("too many row affected")),
}
}
#[tracing::instrument(skip(executor))]
pub async fn lookup_user_by_username(
executor: impl PgExecutor<'_>,
username: &str,
) -> Result<User<PostgresqlBackend>, sqlx::Error> {
let res = sqlx::query_as!(
UserLookup,
r#"
SELECT id, username
FROM users
WHERE username = $1
"#,
username,
)
.fetch_one(executor)
.instrument(info_span!("Fetch user"))
.await?;
Ok(User {
data: res.id,
username: res.username,
sub: format!("fake-sub-{}", res.id),
})
}