1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-11-26 10:44:51 +03:00
Files
authentication-service/crates/storage/src/oauth2/access_token.rs
Quentin Gliech 2d2127dcdb More cleanups
2022-11-02 18:59:00 +01:00

277 lines
8.4 KiB
Rust

// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use anyhow::Context;
use chrono::{DateTime, Duration, Utc};
use mas_data_model::{AccessToken, Authentication, BrowserSession, Session, User, UserEmail};
use rand::Rng;
use sqlx::{PgConnection, PgExecutor};
use thiserror::Error;
use ulid::Ulid;
use uuid::Uuid;
use super::client::{lookup_client, ClientFetchError};
use crate::{Clock, DatabaseInconsistencyError, PostgresqlBackend};
#[tracing::instrument(
skip_all,
fields(
session.id = %session.data,
client.id = %session.client.data,
user.id = %session.browser_session.user.data,
access_token.id,
),
err(Debug),
)]
pub async fn add_access_token(
executor: impl PgExecutor<'_>,
mut rng: impl Rng + Send,
clock: &Clock,
session: &Session<PostgresqlBackend>,
access_token: String,
expires_after: Duration,
) -> Result<AccessToken<PostgresqlBackend>, anyhow::Error> {
let created_at = clock.now();
let expires_at = created_at + expires_after;
let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng);
tracing::Span::current().record("access_token.id", tracing::field::display(id));
sqlx::query!(
r#"
INSERT INTO oauth2_access_tokens
(oauth2_access_token_id, oauth2_session_id, access_token, created_at, expires_at)
VALUES
($1, $2, $3, $4, $5)
"#,
Uuid::from(id),
Uuid::from(session.data),
&access_token,
created_at,
expires_at,
)
.execute(executor)
.await
.context("could not insert oauth2 access token")?;
Ok(AccessToken {
data: id,
access_token,
jti: id.to_string(),
created_at,
expires_at,
})
}
#[derive(Debug)]
pub struct OAuth2AccessTokenLookup {
oauth2_access_token_id: Uuid,
oauth2_access_token: String,
oauth2_access_token_created_at: DateTime<Utc>,
oauth2_access_token_expires_at: DateTime<Utc>,
oauth2_session_id: Uuid,
oauth2_client_id: Uuid,
scope: String,
user_session_id: Uuid,
user_session_created_at: DateTime<Utc>,
user_id: Uuid,
user_username: String,
user_session_last_authentication_id: Option<Uuid>,
user_session_last_authentication_created_at: Option<DateTime<Utc>>,
user_email_id: Option<Uuid>,
user_email: Option<String>,
user_email_created_at: Option<DateTime<Utc>>,
user_email_confirmed_at: Option<DateTime<Utc>>,
}
#[derive(Debug, Error)]
#[error("failed to lookup access token")]
pub enum AccessTokenLookupError {
Database(#[from] sqlx::Error),
ClientFetch(#[from] ClientFetchError),
Inconsistency(#[from] DatabaseInconsistencyError),
}
impl AccessTokenLookupError {
#[must_use]
pub fn not_found(&self) -> bool {
matches!(self, Self::Database(sqlx::Error::RowNotFound))
}
}
#[allow(clippy::too_many_lines)]
pub async fn lookup_active_access_token(
conn: &mut PgConnection,
token: &str,
) -> Result<(AccessToken<PostgresqlBackend>, Session<PostgresqlBackend>), AccessTokenLookupError> {
let res = sqlx::query_as!(
OAuth2AccessTokenLookup,
r#"
SELECT
at.oauth2_access_token_id,
at.access_token AS "oauth2_access_token",
at.created_at AS "oauth2_access_token_created_at",
at.expires_at AS "oauth2_access_token_expires_at",
os.oauth2_session_id AS "oauth2_session_id!",
os.oauth2_client_id AS "oauth2_client_id!",
os.scope AS "scope!",
us.user_session_id AS "user_session_id!",
us.created_at AS "user_session_created_at!",
u.user_id AS "user_id!",
u.username AS "user_username!",
usa.user_session_authentication_id AS "user_session_last_authentication_id?",
usa.created_at AS "user_session_last_authentication_created_at?",
ue.user_email_id AS "user_email_id?",
ue.email AS "user_email?",
ue.created_at AS "user_email_created_at?",
ue.confirmed_at AS "user_email_confirmed_at?"
FROM oauth2_access_tokens at
INNER JOIN oauth2_sessions os
USING (oauth2_session_id)
INNER JOIN user_sessions us
USING (user_session_id)
INNER JOIN users u
USING (user_id)
LEFT JOIN user_session_authentications usa
USING (user_session_id)
LEFT JOIN user_emails ue
ON ue.user_email_id = u.primary_user_email_id
WHERE at.access_token = $1
AND at.revoked_at IS NULL
AND os.finished_at IS NULL
ORDER BY usa.created_at DESC
LIMIT 1
"#,
token,
)
.fetch_one(&mut *conn)
.await?;
let id = Ulid::from(res.oauth2_access_token_id);
let access_token = AccessToken {
data: id,
jti: id.to_string(),
access_token: res.oauth2_access_token,
created_at: res.oauth2_access_token_created_at,
expires_at: res.oauth2_access_token_expires_at,
};
let client = lookup_client(&mut *conn, res.oauth2_client_id.into()).await?;
let primary_email = match (
res.user_email_id,
res.user_email,
res.user_email_created_at,
res.user_email_confirmed_at,
) {
(Some(id), Some(email), Some(created_at), confirmed_at) => Some(UserEmail {
data: id.into(),
email,
created_at,
confirmed_at,
}),
(None, None, None, None) => None,
_ => return Err(DatabaseInconsistencyError.into()),
};
let id = Ulid::from(res.user_id);
let user = User {
data: id,
username: res.user_username,
sub: id.to_string(),
primary_email,
};
let last_authentication = match (
res.user_session_last_authentication_id,
res.user_session_last_authentication_created_at,
) {
(None, None) => None,
(Some(id), Some(created_at)) => Some(Authentication {
data: id.into(),
created_at,
}),
_ => return Err(DatabaseInconsistencyError.into()),
};
let browser_session = BrowserSession {
data: res.user_session_id.into(),
created_at: res.user_session_created_at,
user,
last_authentication,
};
let scope = res.scope.parse().map_err(|_e| DatabaseInconsistencyError)?;
let session = Session {
data: res.oauth2_session_id.into(),
client,
browser_session,
scope,
};
Ok((access_token, session))
}
#[tracing::instrument(
skip_all,
fields(access_token.id = %access_token.data),
err(Debug),
)]
pub async fn revoke_access_token(
executor: impl PgExecutor<'_>,
clock: &Clock,
access_token: AccessToken<PostgresqlBackend>,
) -> anyhow::Result<()> {
let revoked_at = clock.now();
let res = sqlx::query!(
r#"
UPDATE oauth2_access_tokens
SET revoked_at = $2
WHERE oauth2_access_token_id = $1
"#,
Uuid::from(access_token.data),
revoked_at,
)
.execute(executor)
.await
.context("could not revoke access tokens")?;
if res.rows_affected() == 1 {
Ok(())
} else {
Err(anyhow::anyhow!("no row were affected when revoking token"))
}
}
pub async fn cleanup_expired(executor: impl PgExecutor<'_>, clock: &Clock) -> anyhow::Result<u64> {
// Cleanup token which expired more than 15 minutes ago
let threshold = clock.now() - Duration::minutes(15);
let res = sqlx::query!(
r#"
DELETE FROM oauth2_access_tokens
WHERE expires_at < $1
"#,
threshold,
)
.execute(executor)
.await
.context("could not cleanup expired access tokens")?;
Ok(res.rows_affected())
}