You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-07-29 22:01:14 +03:00
Lookup and save upstream links
This commit is contained in:
@ -46,3 +46,10 @@ pub struct UpstreamOAuthAuthorizationSession {
|
|||||||
pub created_at: DateTime<Utc>,
|
pub created_at: DateTime<Utc>,
|
||||||
pub completed_at: Option<DateTime<Utc>>,
|
pub completed_at: Option<DateTime<Utc>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl UpstreamOAuthAuthorizationSession {
|
||||||
|
#[must_use]
|
||||||
|
pub const fn completed(&self) -> bool {
|
||||||
|
self.completed_at.is_some()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -21,13 +21,17 @@ use axum_extra::extract::PrivateCookieJar;
|
|||||||
use hyper::StatusCode;
|
use hyper::StatusCode;
|
||||||
use mas_axum_utils::http_client_factory::HttpClientFactory;
|
use mas_axum_utils::http_client_factory::HttpClientFactory;
|
||||||
use mas_http::ClientInitError;
|
use mas_http::ClientInitError;
|
||||||
|
use mas_jose::claims::ClaimError;
|
||||||
use mas_keystore::{Encrypter, Keystore};
|
use mas_keystore::{Encrypter, Keystore};
|
||||||
use mas_oidc_client::{
|
use mas_oidc_client::{
|
||||||
error::{DiscoveryError, JwksError, TokenAuthorizationCodeError},
|
error::{DiscoveryError, JwksError, TokenAuthorizationCodeError},
|
||||||
requests::{authorization_code::AuthorizationValidationData, jose::JwtVerificationData},
|
requests::{authorization_code::AuthorizationValidationData, jose::JwtVerificationData},
|
||||||
};
|
};
|
||||||
use mas_router::UrlBuilder;
|
use mas_router::UrlBuilder;
|
||||||
use mas_storage::{upstream_oauth2::lookup_session, LookupResultExt};
|
use mas_storage::{
|
||||||
|
upstream_oauth2::{add_link, complete_session, lookup_link_by_subject, lookup_session},
|
||||||
|
GenericLookupError, LookupResultExt,
|
||||||
|
};
|
||||||
use oauth2_types::errors::ClientErrorCode;
|
use oauth2_types::errors::ClientErrorCode;
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
use sqlx::PgPool;
|
use sqlx::PgPool;
|
||||||
@ -66,9 +70,21 @@ pub(crate) enum RouteError {
|
|||||||
#[error("Provider mismatch")]
|
#[error("Provider mismatch")]
|
||||||
ProviderMismatch,
|
ProviderMismatch,
|
||||||
|
|
||||||
|
#[error("Session already completed")]
|
||||||
|
AlreadyCompleted,
|
||||||
|
|
||||||
#[error("State parameter mismatch")]
|
#[error("State parameter mismatch")]
|
||||||
StateMismatch,
|
StateMismatch,
|
||||||
|
|
||||||
|
#[error("Missing ID token")]
|
||||||
|
MissingIDToken,
|
||||||
|
|
||||||
|
#[error("Invalid ID token")]
|
||||||
|
InvalidIdToken(#[from] ClaimError),
|
||||||
|
|
||||||
|
#[error("User already linked")]
|
||||||
|
UserAlreadyLinked,
|
||||||
|
|
||||||
#[error("Error from the provider: {error}")]
|
#[error("Error from the provider: {error}")]
|
||||||
ClientError {
|
ClientError {
|
||||||
error: ClientErrorCode,
|
error: ClientErrorCode,
|
||||||
@ -88,6 +104,12 @@ pub(crate) enum RouteError {
|
|||||||
Anyhow(#[from] anyhow::Error),
|
Anyhow(#[from] anyhow::Error),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl From<GenericLookupError> for RouteError {
|
||||||
|
fn from(e: GenericLookupError) -> Self {
|
||||||
|
Self::InternalError(Box::new(e))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl From<sqlx::Error> for RouteError {
|
impl From<sqlx::Error> for RouteError {
|
||||||
fn from(e: sqlx::Error) -> Self {
|
fn from(e: sqlx::Error) -> Self {
|
||||||
Self::InternalError(Box::new(e))
|
Self::InternalError(Box::new(e))
|
||||||
@ -182,6 +204,11 @@ pub(crate) async fn get(
|
|||||||
return Err(RouteError::StateMismatch);
|
return Err(RouteError::StateMismatch);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if session.completed() {
|
||||||
|
// The session was already completed
|
||||||
|
return Err(RouteError::AlreadyCompleted);
|
||||||
|
}
|
||||||
|
|
||||||
// Let's extract the code from the params, and return if there was an error
|
// Let's extract the code from the params, and return if there was an error
|
||||||
let code = match params.code_or_error {
|
let code = match params.code_or_error {
|
||||||
CodeOrError::Error {
|
CodeOrError::Error {
|
||||||
@ -224,10 +251,11 @@ pub(crate) async fn get(
|
|||||||
|
|
||||||
let redirect_uri = url_builder.upstream_oauth_callback(provider.id);
|
let redirect_uri = url_builder.upstream_oauth_callback(provider.id);
|
||||||
|
|
||||||
|
// TODO: all that should be borrowed
|
||||||
let validation_data = AuthorizationValidationData {
|
let validation_data = AuthorizationValidationData {
|
||||||
state: session.state,
|
state: session.state.clone(),
|
||||||
nonce: session.nonce,
|
nonce: session.nonce.clone(),
|
||||||
code_challenge_verifier: session.code_challenge_verifier,
|
code_challenge_verifier: session.code_challenge_verifier.clone(),
|
||||||
redirect_uri,
|
redirect_uri,
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -243,7 +271,7 @@ pub(crate) async fn get(
|
|||||||
.http_service("upstream-exchange-code")
|
.http_service("upstream-exchange-code")
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
let (response, _id_token) =
|
let (response, id_token) =
|
||||||
mas_oidc_client::requests::authorization_code::access_token_with_authorization_code(
|
mas_oidc_client::requests::authorization_code::access_token_with_authorization_code(
|
||||||
&http_service,
|
&http_service,
|
||||||
client_credentials,
|
client_credentials,
|
||||||
@ -256,5 +284,30 @@ pub(crate) async fn get(
|
|||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
|
let (_header, mut id_token) = id_token.ok_or(RouteError::MissingIDToken)?.into_parts();
|
||||||
|
|
||||||
|
// Extract the subject from the id_token
|
||||||
|
let subject = mas_jose::claims::SUB.extract_required(&mut id_token)?;
|
||||||
|
|
||||||
|
// Look for an existing link
|
||||||
|
let maybe_link = lookup_link_by_subject(&mut txn, &provider, &subject)
|
||||||
|
.await
|
||||||
|
.to_option()?;
|
||||||
|
|
||||||
|
let link = if let Some((link, maybe_user_id)) = maybe_link {
|
||||||
|
if let Some(_user_id) = maybe_user_id {
|
||||||
|
// TODO: Here we should login if the user is linked
|
||||||
|
return Err(RouteError::UserAlreadyLinked);
|
||||||
|
}
|
||||||
|
|
||||||
|
link
|
||||||
|
} else {
|
||||||
|
add_link(&mut txn, &mut rng, &clock, &provider, subject).await?
|
||||||
|
};
|
||||||
|
|
||||||
|
let _session = complete_session(&mut txn, &clock, session, &link).await?;
|
||||||
|
|
||||||
|
txn.commit().await?;
|
||||||
|
|
||||||
Ok(Json(response))
|
Ok(Json(response))
|
||||||
}
|
}
|
||||||
|
@ -689,6 +689,45 @@
|
|||||||
},
|
},
|
||||||
"query": "\n SELECT\n ue.user_email_id,\n ue.email AS \"user_email\",\n ue.created_at AS \"user_email_created_at\",\n ue.confirmed_at AS \"user_email_confirmed_at\"\n FROM user_emails ue\n\n WHERE ue.user_id = $1\n AND ue.email = $2\n "
|
"query": "\n SELECT\n ue.user_email_id,\n ue.email AS \"user_email\",\n ue.created_at AS \"user_email_created_at\",\n ue.confirmed_at AS \"user_email_confirmed_at\"\n FROM user_emails ue\n\n WHERE ue.user_id = $1\n AND ue.email = $2\n "
|
||||||
},
|
},
|
||||||
|
"3a6de39a88ef93a91f3cc0465785bafd58ef7dbd4aae924a8bcfcefaf2f1a0d7": {
|
||||||
|
"describe": {
|
||||||
|
"columns": [
|
||||||
|
{
|
||||||
|
"name": "upstream_oauth_link_id",
|
||||||
|
"ordinal": 0,
|
||||||
|
"type_info": "Uuid"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "user_id",
|
||||||
|
"ordinal": 1,
|
||||||
|
"type_info": "Uuid"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "subject",
|
||||||
|
"ordinal": 2,
|
||||||
|
"type_info": "Text"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "created_at",
|
||||||
|
"ordinal": 3,
|
||||||
|
"type_info": "Timestamptz"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"nullable": [
|
||||||
|
false,
|
||||||
|
true,
|
||||||
|
false,
|
||||||
|
false
|
||||||
|
],
|
||||||
|
"parameters": {
|
||||||
|
"Left": [
|
||||||
|
"Uuid",
|
||||||
|
"Text"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"query": "\n SELECT\n upstream_oauth_link_id,\n user_id,\n subject,\n created_at\n FROM upstream_oauth_links\n WHERE upstream_oauth_provider_id = $1\n AND subject = $2\n "
|
||||||
|
},
|
||||||
"3df0838b660466f69ee681337fe6753133748defb715e53c8381badcc3e8bca9": {
|
"3df0838b660466f69ee681337fe6753133748defb715e53c8381badcc3e8bca9": {
|
||||||
"describe": {
|
"describe": {
|
||||||
"columns": [
|
"columns": [
|
||||||
@ -954,6 +993,19 @@
|
|||||||
},
|
},
|
||||||
"query": "\n INSERT INTO user_passwords (user_password_id, user_id, hashed_password, created_at)\n VALUES ($1, $2, $3, $4)\n "
|
"query": "\n INSERT INTO user_passwords (user_password_id, user_id, hashed_password, created_at)\n VALUES ($1, $2, $3, $4)\n "
|
||||||
},
|
},
|
||||||
|
"4b6a44d040a0dc849bb4e04abb11a181995b5847917605ef4c160389686a54f5": {
|
||||||
|
"describe": {
|
||||||
|
"columns": [],
|
||||||
|
"nullable": [],
|
||||||
|
"parameters": {
|
||||||
|
"Left": [
|
||||||
|
"Uuid",
|
||||||
|
"Timestamptz"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"query": "\n UPDATE upstream_oauth_authorization_sessions\n SET upstream_oauth_link_id = $1,\n completed_at = $2\n "
|
||||||
|
},
|
||||||
"4f8ec19f3f1bfe0268fe102a24e5a9fa542e77eccbebdce65e6deb1c197adf36": {
|
"4f8ec19f3f1bfe0268fe102a24e5a9fa542e77eccbebdce65e6deb1c197adf36": {
|
||||||
"describe": {
|
"describe": {
|
||||||
"columns": [
|
"columns": [
|
||||||
@ -2417,6 +2469,21 @@
|
|||||||
},
|
},
|
||||||
"query": "\n DELETE FROM user_emails\n WHERE user_emails.user_email_id = $1\n "
|
"query": "\n DELETE FROM user_emails\n WHERE user_emails.user_email_id = $1\n "
|
||||||
},
|
},
|
||||||
|
"e1dc9dd2bf26a341050a53151bf51f7638448ccc2bd458bbdfe87cc22f086313": {
|
||||||
|
"describe": {
|
||||||
|
"columns": [],
|
||||||
|
"nullable": [],
|
||||||
|
"parameters": {
|
||||||
|
"Left": [
|
||||||
|
"Uuid",
|
||||||
|
"Uuid",
|
||||||
|
"Text",
|
||||||
|
"Timestamptz"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"query": "\n INSERT INTO upstream_oauth_links (\n upstream_oauth_link_id,\n upstream_oauth_provider_id,\n user_id,\n subject,\n created_at\n ) VALUES ($1, $2, NULL, $3, $4)\n "
|
||||||
|
},
|
||||||
"e446e37d48c8838ef2e0d0fd82f8f7b04893c84ad46747cdf193ebd83755ceb2": {
|
"e446e37d48c8838ef2e0d0fd82f8f7b04893c84ad46747cdf193ebd83755ceb2": {
|
||||||
"describe": {
|
"describe": {
|
||||||
"columns": [],
|
"columns": [],
|
||||||
|
120
crates/storage/src/upstream_oauth2/link.rs
Normal file
120
crates/storage/src/upstream_oauth2/link.rs
Normal file
@ -0,0 +1,120 @@
|
|||||||
|
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
use chrono::{DateTime, Utc};
|
||||||
|
use mas_data_model::{UpstreamOAuthLink, UpstreamOAuthProvider};
|
||||||
|
use rand::Rng;
|
||||||
|
use sqlx::PgExecutor;
|
||||||
|
use ulid::Ulid;
|
||||||
|
use uuid::Uuid;
|
||||||
|
|
||||||
|
use crate::{Clock, GenericLookupError};
|
||||||
|
|
||||||
|
struct LinkLookup {
|
||||||
|
upstream_oauth_link_id: Uuid,
|
||||||
|
user_id: Option<Uuid>,
|
||||||
|
subject: String,
|
||||||
|
created_at: DateTime<Utc>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tracing::instrument(
|
||||||
|
skip_all,
|
||||||
|
fields(
|
||||||
|
upstream_oauth_link.subject = subject,
|
||||||
|
%upstream_oauth_provider.id,
|
||||||
|
%upstream_oauth_provider.issuer,
|
||||||
|
%upstream_oauth_provider.client_id,
|
||||||
|
),
|
||||||
|
err,
|
||||||
|
)]
|
||||||
|
pub async fn lookup_link_by_subject(
|
||||||
|
executor: impl PgExecutor<'_>,
|
||||||
|
upstream_oauth_provider: &UpstreamOAuthProvider,
|
||||||
|
subject: &str,
|
||||||
|
) -> Result<(UpstreamOAuthLink, Option<Ulid>), GenericLookupError> {
|
||||||
|
let res = sqlx::query_as!(
|
||||||
|
LinkLookup,
|
||||||
|
r#"
|
||||||
|
SELECT
|
||||||
|
upstream_oauth_link_id,
|
||||||
|
user_id,
|
||||||
|
subject,
|
||||||
|
created_at
|
||||||
|
FROM upstream_oauth_links
|
||||||
|
WHERE upstream_oauth_provider_id = $1
|
||||||
|
AND subject = $2
|
||||||
|
"#,
|
||||||
|
Uuid::from(upstream_oauth_provider.id),
|
||||||
|
subject,
|
||||||
|
)
|
||||||
|
.fetch_one(executor)
|
||||||
|
.await
|
||||||
|
.map_err(GenericLookupError::what("Upstream OAuth 2.0 link"))?;
|
||||||
|
|
||||||
|
Ok((
|
||||||
|
UpstreamOAuthLink {
|
||||||
|
id: Ulid::from(res.upstream_oauth_link_id),
|
||||||
|
subject: res.subject,
|
||||||
|
created_at: res.created_at,
|
||||||
|
},
|
||||||
|
res.user_id.map(Ulid::from),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tracing::instrument(
|
||||||
|
skip_all,
|
||||||
|
fields(
|
||||||
|
upstream_oauth_link.id,
|
||||||
|
upstream_oauth_link.subject = subject,
|
||||||
|
%upstream_oauth_provider.id,
|
||||||
|
%upstream_oauth_provider.issuer,
|
||||||
|
%upstream_oauth_provider.client_id,
|
||||||
|
),
|
||||||
|
err,
|
||||||
|
)]
|
||||||
|
pub async fn add_link(
|
||||||
|
executor: impl PgExecutor<'_>,
|
||||||
|
mut rng: impl Rng + Send,
|
||||||
|
clock: &Clock,
|
||||||
|
upstream_oauth_provider: &UpstreamOAuthProvider,
|
||||||
|
subject: String,
|
||||||
|
) -> Result<UpstreamOAuthLink, sqlx::Error> {
|
||||||
|
let created_at = clock.now();
|
||||||
|
let id = Ulid::from_datetime_with_source(created_at.into(), &mut rng);
|
||||||
|
tracing::Span::current().record("upstream_oauth_link.id", tracing::field::display(id));
|
||||||
|
|
||||||
|
sqlx::query!(
|
||||||
|
r#"
|
||||||
|
INSERT INTO upstream_oauth_links (
|
||||||
|
upstream_oauth_link_id,
|
||||||
|
upstream_oauth_provider_id,
|
||||||
|
user_id,
|
||||||
|
subject,
|
||||||
|
created_at
|
||||||
|
) VALUES ($1, $2, NULL, $3, $4)
|
||||||
|
"#,
|
||||||
|
Uuid::from(id),
|
||||||
|
Uuid::from(upstream_oauth_provider.id),
|
||||||
|
&subject,
|
||||||
|
created_at,
|
||||||
|
)
|
||||||
|
.execute(executor)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
Ok(UpstreamOAuthLink {
|
||||||
|
id,
|
||||||
|
subject,
|
||||||
|
created_at,
|
||||||
|
})
|
||||||
|
}
|
@ -12,10 +12,12 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
|
mod link;
|
||||||
mod provider;
|
mod provider;
|
||||||
mod session;
|
mod session;
|
||||||
|
|
||||||
pub use self::{
|
pub use self::{
|
||||||
|
link::{add_link, lookup_link_by_subject},
|
||||||
provider::{add_provider, lookup_provider, ProviderLookupError},
|
provider::{add_provider, lookup_provider, ProviderLookupError},
|
||||||
session::{add_session, lookup_session, SessionLookupError},
|
session::{add_session, complete_session, lookup_session, SessionLookupError},
|
||||||
};
|
};
|
||||||
|
@ -13,7 +13,7 @@
|
|||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
use chrono::{DateTime, Utc};
|
use chrono::{DateTime, Utc};
|
||||||
use mas_data_model::{UpstreamOAuthAuthorizationSession, UpstreamOAuthProvider};
|
use mas_data_model::{UpstreamOAuthAuthorizationSession, UpstreamOAuthLink, UpstreamOAuthProvider};
|
||||||
use rand::Rng;
|
use rand::Rng;
|
||||||
use sqlx::PgExecutor;
|
use sqlx::PgExecutor;
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
@ -128,9 +128,9 @@ pub async fn lookup_session(
|
|||||||
#[tracing::instrument(
|
#[tracing::instrument(
|
||||||
skip_all,
|
skip_all,
|
||||||
fields(
|
fields(
|
||||||
upstream_oauth_provider.id = %provider.id,
|
%upstream_oauth_provider.id,
|
||||||
upstream_oauth_provider.issuer = %provider.issuer,
|
%upstream_oauth_provider.issuer,
|
||||||
upstream_oauth_provider.client_id = %provider.client_id,
|
%upstream_oauth_provider.client_id,
|
||||||
upstream_oauth_authorization_session.id,
|
upstream_oauth_authorization_session.id,
|
||||||
),
|
),
|
||||||
err,
|
err,
|
||||||
@ -139,7 +139,7 @@ pub async fn add_session(
|
|||||||
executor: impl PgExecutor<'_>,
|
executor: impl PgExecutor<'_>,
|
||||||
mut rng: impl Rng + Send,
|
mut rng: impl Rng + Send,
|
||||||
clock: &Clock,
|
clock: &Clock,
|
||||||
provider: &UpstreamOAuthProvider,
|
upstream_oauth_provider: &UpstreamOAuthProvider,
|
||||||
state: String,
|
state: String,
|
||||||
code_challenge_verifier: Option<String>,
|
code_challenge_verifier: Option<String>,
|
||||||
nonce: String,
|
nonce: String,
|
||||||
@ -164,7 +164,7 @@ pub async fn add_session(
|
|||||||
) VALUES ($1, $2, $3, $4, $5, $6, NULL)
|
) VALUES ($1, $2, $3, $4, $5, $6, NULL)
|
||||||
"#,
|
"#,
|
||||||
Uuid::from(id),
|
Uuid::from(id),
|
||||||
Uuid::from(provider.id),
|
Uuid::from(upstream_oauth_provider.id),
|
||||||
&state,
|
&state,
|
||||||
code_challenge_verifier.as_deref(),
|
code_challenge_verifier.as_deref(),
|
||||||
nonce,
|
nonce,
|
||||||
@ -182,3 +182,35 @@ pub async fn add_session(
|
|||||||
completed_at: None,
|
completed_at: None,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tracing::instrument(
|
||||||
|
skip_all,
|
||||||
|
fields(
|
||||||
|
%upstream_oauth_authorization_session.id,
|
||||||
|
%upstream_oauth_link.id,
|
||||||
|
),
|
||||||
|
err,
|
||||||
|
)]
|
||||||
|
pub async fn complete_session(
|
||||||
|
executor: impl PgExecutor<'_>,
|
||||||
|
clock: &Clock,
|
||||||
|
mut upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession,
|
||||||
|
upstream_oauth_link: &UpstreamOAuthLink,
|
||||||
|
) -> Result<UpstreamOAuthAuthorizationSession, sqlx::Error> {
|
||||||
|
let completed_at = clock.now();
|
||||||
|
sqlx::query!(
|
||||||
|
r#"
|
||||||
|
UPDATE upstream_oauth_authorization_sessions
|
||||||
|
SET upstream_oauth_link_id = $1,
|
||||||
|
completed_at = $2
|
||||||
|
"#,
|
||||||
|
Uuid::from(upstream_oauth_link.id),
|
||||||
|
completed_at,
|
||||||
|
)
|
||||||
|
.execute(executor)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
upstream_oauth_authorization_session.completed_at = Some(completed_at);
|
||||||
|
|
||||||
|
Ok(upstream_oauth_authorization_session)
|
||||||
|
}
|
||||||
|
Reference in New Issue
Block a user