You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-07-29 22:01:14 +03:00
storage: OAuth2 client repository
This commit is contained in:
@ -19,8 +19,8 @@ use sqlx::{PgConnection, PgExecutor};
|
||||
use ulid::Ulid;
|
||||
use uuid::Uuid;
|
||||
|
||||
use super::client::lookup_client;
|
||||
use crate::{Clock, DatabaseError, DatabaseInconsistencyError};
|
||||
use super::client::OAuth2ClientRepository;
|
||||
use crate::{Clock, DatabaseError, DatabaseInconsistencyError, Repository};
|
||||
|
||||
#[tracing::instrument(
|
||||
skip_all,
|
||||
@ -144,7 +144,9 @@ pub async fn lookup_active_access_token(
|
||||
};
|
||||
|
||||
let session_id = res.oauth2_session_id.into();
|
||||
let client = lookup_client(&mut *conn, res.oauth2_client_id.into())
|
||||
let client = conn
|
||||
.oauth2_client()
|
||||
.lookup(res.oauth2_client_id.into())
|
||||
.await?
|
||||
.ok_or_else(|| {
|
||||
DatabaseInconsistencyError::on("oauth2_sessions")
|
||||
|
@ -27,8 +27,8 @@ use ulid::Ulid;
|
||||
use url::Url;
|
||||
use uuid::Uuid;
|
||||
|
||||
use super::client::lookup_client;
|
||||
use crate::{Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt};
|
||||
use super::client::OAuth2ClientRepository;
|
||||
use crate::{Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt, Repository};
|
||||
|
||||
#[tracing::instrument(
|
||||
skip_all,
|
||||
@ -163,7 +163,7 @@ impl GrantLookup {
|
||||
#[allow(clippy::too_many_lines)]
|
||||
async fn into_authorization_grant(
|
||||
self,
|
||||
executor: impl PgExecutor<'_>,
|
||||
conn: &mut PgConnection,
|
||||
) -> Result<AuthorizationGrant, DatabaseError> {
|
||||
let id = self.oauth2_authorization_grant_id.into();
|
||||
let scope: Scope = self.oauth2_authorization_grant_scope.parse().map_err(|e| {
|
||||
@ -173,8 +173,9 @@ impl GrantLookup {
|
||||
.source(e)
|
||||
})?;
|
||||
|
||||
// TODO: don't unwrap
|
||||
let client = lookup_client(executor, self.oauth2_client_id.into())
|
||||
let client = conn
|
||||
.oauth2_client()
|
||||
.lookup(self.oauth2_client_id.into())
|
||||
.await?
|
||||
.ok_or_else(|| {
|
||||
DatabaseInconsistencyError::on("oauth2_authorization_grants")
|
||||
|
@ -12,8 +12,12 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::{collections::HashMap, string::ToString};
|
||||
use std::{
|
||||
collections::{BTreeMap, BTreeSet},
|
||||
string::ToString,
|
||||
};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use mas_data_model::{Client, JwksOrJwksUri};
|
||||
use mas_iana::{
|
||||
jose::JsonWebSignatureAlg,
|
||||
@ -21,17 +25,83 @@ use mas_iana::{
|
||||
};
|
||||
use mas_jose::jwk::PublicJsonWebKeySet;
|
||||
use oauth2_types::requests::GrantType;
|
||||
use rand::Rng;
|
||||
use sqlx::{PgConnection, PgExecutor};
|
||||
use rand::{Rng, RngCore};
|
||||
use sqlx::PgConnection;
|
||||
use tracing::{info_span, Instrument};
|
||||
use ulid::Ulid;
|
||||
use url::Url;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt};
|
||||
use crate::{
|
||||
tracing::ExecuteExt, Clock, DatabaseError, DatabaseInconsistencyError, LookupResultExt,
|
||||
};
|
||||
|
||||
#[async_trait]
|
||||
pub trait OAuth2ClientRepository: Send + Sync {
|
||||
type Error;
|
||||
|
||||
async fn lookup(&mut self, id: Ulid) -> Result<Option<Client>, Self::Error>;
|
||||
|
||||
async fn find_by_client_id(&mut self, client_id: &str) -> Result<Option<Client>, Self::Error> {
|
||||
let Ok(id) = client_id.parse() else { return Ok(None) };
|
||||
self.lookup(id).await
|
||||
}
|
||||
|
||||
async fn load_batch(
|
||||
&mut self,
|
||||
ids: BTreeSet<Ulid>,
|
||||
) -> Result<BTreeMap<Ulid, Client>, Self::Error>;
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
async fn add(
|
||||
&mut self,
|
||||
rng: &mut (dyn RngCore + Send),
|
||||
clock: &Clock,
|
||||
redirect_uris: Vec<Url>,
|
||||
encrypted_client_secret: Option<String>,
|
||||
grant_types: Vec<GrantType>,
|
||||
contacts: Vec<String>,
|
||||
client_name: Option<String>,
|
||||
logo_uri: Option<Url>,
|
||||
client_uri: Option<Url>,
|
||||
policy_uri: Option<Url>,
|
||||
tos_uri: Option<Url>,
|
||||
jwks_uri: Option<Url>,
|
||||
jwks: Option<PublicJsonWebKeySet>,
|
||||
id_token_signed_response_alg: Option<JsonWebSignatureAlg>,
|
||||
userinfo_signed_response_alg: Option<JsonWebSignatureAlg>,
|
||||
token_endpoint_auth_method: Option<OAuthClientAuthenticationMethod>,
|
||||
token_endpoint_auth_signing_alg: Option<JsonWebSignatureAlg>,
|
||||
initiate_login_uri: Option<Url>,
|
||||
) -> Result<Client, Self::Error>;
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
async fn add_from_config(
|
||||
&mut self,
|
||||
mut rng: impl Rng + Send,
|
||||
clock: &Clock,
|
||||
client_id: Ulid,
|
||||
client_auth_method: OAuthClientAuthenticationMethod,
|
||||
encrypted_client_secret: Option<String>,
|
||||
jwks: Option<PublicJsonWebKeySet>,
|
||||
jwks_uri: Option<Url>,
|
||||
redirect_uris: Vec<Url>,
|
||||
) -> Result<Client, Self::Error>;
|
||||
}
|
||||
|
||||
pub struct PgOAuth2ClientRepository<'c> {
|
||||
conn: &'c mut PgConnection,
|
||||
}
|
||||
|
||||
impl<'c> PgOAuth2ClientRepository<'c> {
|
||||
pub fn new(conn: &'c mut PgConnection) -> Self {
|
||||
Self { conn }
|
||||
}
|
||||
}
|
||||
|
||||
// XXX: response_types & contacts
|
||||
#[derive(Debug)]
|
||||
pub struct OAuth2ClientLookup {
|
||||
struct OAuth2ClientLookup {
|
||||
oauth2_client_id: Uuid,
|
||||
encrypted_client_secret: Option<String>,
|
||||
redirect_uris: Vec<String>,
|
||||
@ -234,252 +304,305 @@ impl TryInto<Client> for OAuth2ClientLookup {
|
||||
}
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all, err)]
|
||||
pub async fn lookup_clients(
|
||||
executor: impl PgExecutor<'_>,
|
||||
ids: impl IntoIterator<Item = Ulid> + Send,
|
||||
) -> Result<HashMap<Ulid, Client>, DatabaseError> {
|
||||
let ids: Vec<Uuid> = ids.into_iter().map(Uuid::from).collect();
|
||||
let res = sqlx::query_as!(
|
||||
OAuth2ClientLookup,
|
||||
r#"
|
||||
SELECT
|
||||
c.oauth2_client_id,
|
||||
c.encrypted_client_secret,
|
||||
ARRAY(
|
||||
SELECT redirect_uri
|
||||
FROM oauth2_client_redirect_uris r
|
||||
WHERE r.oauth2_client_id = c.oauth2_client_id
|
||||
) AS "redirect_uris!",
|
||||
c.grant_type_authorization_code,
|
||||
c.grant_type_refresh_token,
|
||||
c.client_name,
|
||||
c.logo_uri,
|
||||
c.client_uri,
|
||||
c.policy_uri,
|
||||
c.tos_uri,
|
||||
c.jwks_uri,
|
||||
c.jwks,
|
||||
c.id_token_signed_response_alg,
|
||||
c.userinfo_signed_response_alg,
|
||||
c.token_endpoint_auth_method,
|
||||
c.token_endpoint_auth_signing_alg,
|
||||
c.initiate_login_uri
|
||||
FROM oauth2_clients c
|
||||
#[async_trait]
|
||||
impl<'c> OAuth2ClientRepository for PgOAuth2ClientRepository<'c> {
|
||||
type Error = DatabaseError;
|
||||
|
||||
WHERE c.oauth2_client_id = ANY($1::uuid[])
|
||||
"#,
|
||||
&ids,
|
||||
)
|
||||
.fetch_all(executor)
|
||||
.await?;
|
||||
#[tracing::instrument(
|
||||
name = "db.oauth2_client.lookup",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
oauth2_client.id = %id,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn lookup(&mut self, id: Ulid) -> Result<Option<Client>, Self::Error> {
|
||||
let res = sqlx::query_as!(
|
||||
OAuth2ClientLookup,
|
||||
r#"
|
||||
SELECT oauth2_client_id
|
||||
, encrypted_client_secret
|
||||
, ARRAY(
|
||||
SELECT redirect_uri
|
||||
FROM oauth2_client_redirect_uris r
|
||||
WHERE r.oauth2_client_id = c.oauth2_client_id
|
||||
) AS "redirect_uris!"
|
||||
, grant_type_authorization_code
|
||||
, grant_type_refresh_token
|
||||
, client_name
|
||||
, logo_uri
|
||||
, client_uri
|
||||
, policy_uri
|
||||
, tos_uri
|
||||
, jwks_uri
|
||||
, jwks
|
||||
, id_token_signed_response_alg
|
||||
, userinfo_signed_response_alg
|
||||
, token_endpoint_auth_method
|
||||
, token_endpoint_auth_signing_alg
|
||||
, initiate_login_uri
|
||||
FROM oauth2_clients c
|
||||
|
||||
res.into_iter()
|
||||
.map(|r| {
|
||||
r.try_into()
|
||||
.map(|c: Client| (c.id, c))
|
||||
.map_err(DatabaseError::from)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
WHERE oauth2_client_id = $1
|
||||
"#,
|
||||
Uuid::from(id),
|
||||
)
|
||||
.traced()
|
||||
.fetch_one(&mut *self.conn)
|
||||
.await
|
||||
.to_option()?;
|
||||
|
||||
#[tracing::instrument(
|
||||
skip_all,
|
||||
fields(client.id = %id),
|
||||
err,
|
||||
)]
|
||||
pub async fn lookup_client(
|
||||
executor: impl PgExecutor<'_>,
|
||||
id: Ulid,
|
||||
) -> Result<Option<Client>, DatabaseError> {
|
||||
let res = sqlx::query_as!(
|
||||
OAuth2ClientLookup,
|
||||
r#"
|
||||
SELECT
|
||||
c.oauth2_client_id,
|
||||
c.encrypted_client_secret,
|
||||
ARRAY(
|
||||
SELECT redirect_uri
|
||||
FROM oauth2_client_redirect_uris r
|
||||
WHERE r.oauth2_client_id = c.oauth2_client_id
|
||||
) AS "redirect_uris!",
|
||||
c.grant_type_authorization_code,
|
||||
c.grant_type_refresh_token,
|
||||
c.client_name,
|
||||
c.logo_uri,
|
||||
c.client_uri,
|
||||
c.policy_uri,
|
||||
c.tos_uri,
|
||||
c.jwks_uri,
|
||||
c.jwks,
|
||||
c.id_token_signed_response_alg,
|
||||
c.userinfo_signed_response_alg,
|
||||
c.token_endpoint_auth_method,
|
||||
c.token_endpoint_auth_signing_alg,
|
||||
c.initiate_login_uri
|
||||
FROM oauth2_clients c
|
||||
let Some(res) = res else { return Ok(None) };
|
||||
|
||||
WHERE c.oauth2_client_id = $1
|
||||
"#,
|
||||
Uuid::from(id),
|
||||
)
|
||||
.fetch_one(executor)
|
||||
.await
|
||||
.to_option()?;
|
||||
Ok(Some(res.try_into()?))
|
||||
}
|
||||
|
||||
let Some(res) = res else { return Ok(None) };
|
||||
#[tracing::instrument(
|
||||
name = "db.oauth2_client.load_batch",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn load_batch(
|
||||
&mut self,
|
||||
ids: BTreeSet<Ulid>,
|
||||
) -> Result<BTreeMap<Ulid, Client>, Self::Error> {
|
||||
let ids: Vec<Uuid> = ids.into_iter().map(Uuid::from).collect();
|
||||
let res = sqlx::query_as!(
|
||||
OAuth2ClientLookup,
|
||||
r#"
|
||||
SELECT oauth2_client_id
|
||||
, encrypted_client_secret
|
||||
, ARRAY(
|
||||
SELECT redirect_uri
|
||||
FROM oauth2_client_redirect_uris r
|
||||
WHERE r.oauth2_client_id = c.oauth2_client_id
|
||||
) AS "redirect_uris!"
|
||||
, grant_type_authorization_code
|
||||
, grant_type_refresh_token
|
||||
, client_name
|
||||
, logo_uri
|
||||
, client_uri
|
||||
, policy_uri
|
||||
, tos_uri
|
||||
, jwks_uri
|
||||
, jwks
|
||||
, id_token_signed_response_alg
|
||||
, userinfo_signed_response_alg
|
||||
, token_endpoint_auth_method
|
||||
, token_endpoint_auth_signing_alg
|
||||
, initiate_login_uri
|
||||
FROM oauth2_clients c
|
||||
|
||||
Ok(Some(res.try_into()?))
|
||||
}
|
||||
WHERE oauth2_client_id = ANY($1::uuid[])
|
||||
"#,
|
||||
&ids,
|
||||
)
|
||||
.traced()
|
||||
.fetch_all(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
#[tracing::instrument(
|
||||
skip_all,
|
||||
fields(client.id = client_id),
|
||||
err,
|
||||
)]
|
||||
pub async fn lookup_client_by_client_id(
|
||||
executor: impl PgExecutor<'_>,
|
||||
client_id: &str,
|
||||
) -> Result<Option<Client>, DatabaseError> {
|
||||
let Ok(id) = client_id.parse() else { return Ok(None) };
|
||||
lookup_client(executor, id).await
|
||||
}
|
||||
res.into_iter()
|
||||
.map(|r| {
|
||||
r.try_into()
|
||||
.map(|c: Client| (c.id, c))
|
||||
.map_err(DatabaseError::from)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
skip_all,
|
||||
fields(client.id = %client_id, client.name = client_name),
|
||||
err,
|
||||
)]
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub async fn insert_client(
|
||||
conn: &mut PgConnection,
|
||||
mut rng: impl Rng + Send,
|
||||
clock: &Clock,
|
||||
client_id: Ulid,
|
||||
redirect_uris: &[Url],
|
||||
encrypted_client_secret: Option<&str>,
|
||||
grant_types: &[GrantType],
|
||||
_contacts: &[String],
|
||||
client_name: Option<&str>,
|
||||
logo_uri: Option<&Url>,
|
||||
client_uri: Option<&Url>,
|
||||
policy_uri: Option<&Url>,
|
||||
tos_uri: Option<&Url>,
|
||||
jwks_uri: Option<&Url>,
|
||||
jwks: Option<&PublicJsonWebKeySet>,
|
||||
id_token_signed_response_alg: Option<&JsonWebSignatureAlg>,
|
||||
userinfo_signed_response_alg: Option<&JsonWebSignatureAlg>,
|
||||
token_endpoint_auth_method: Option<&OAuthClientAuthenticationMethod>,
|
||||
token_endpoint_auth_signing_alg: Option<&JsonWebSignatureAlg>,
|
||||
initiate_login_uri: Option<&Url>,
|
||||
) -> Result<(), sqlx::Error> {
|
||||
let grant_type_authorization_code = grant_types.contains(&GrantType::AuthorizationCode);
|
||||
let grant_type_refresh_token = grant_types.contains(&GrantType::RefreshToken);
|
||||
let logo_uri = logo_uri.map(Url::as_str);
|
||||
let client_uri = client_uri.map(Url::as_str);
|
||||
let policy_uri = policy_uri.map(Url::as_str);
|
||||
let tos_uri = tos_uri.map(Url::as_str);
|
||||
let jwks = jwks.map(serde_json::to_value).transpose().unwrap(); // TODO
|
||||
let jwks_uri = jwks_uri.map(Url::as_str);
|
||||
let id_token_signed_response_alg = id_token_signed_response_alg.map(ToString::to_string);
|
||||
let userinfo_signed_response_alg = userinfo_signed_response_alg.map(ToString::to_string);
|
||||
let token_endpoint_auth_method = token_endpoint_auth_method.map(ToString::to_string);
|
||||
let token_endpoint_auth_signing_alg = token_endpoint_auth_signing_alg.map(ToString::to_string);
|
||||
let initiate_login_uri = initiate_login_uri.map(Url::as_str);
|
||||
#[tracing::instrument(
|
||||
name = "db.oauth2_client.add",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
client.id,
|
||||
client.name = client_name
|
||||
),
|
||||
err,
|
||||
)]
|
||||
#[allow(clippy::too_many_lines)]
|
||||
async fn add(
|
||||
&mut self,
|
||||
mut rng: &mut (dyn RngCore + Send),
|
||||
clock: &Clock,
|
||||
redirect_uris: Vec<Url>,
|
||||
encrypted_client_secret: Option<String>,
|
||||
grant_types: Vec<GrantType>,
|
||||
contacts: Vec<String>,
|
||||
client_name: Option<String>,
|
||||
logo_uri: Option<Url>,
|
||||
client_uri: Option<Url>,
|
||||
policy_uri: Option<Url>,
|
||||
tos_uri: Option<Url>,
|
||||
jwks_uri: Option<Url>,
|
||||
jwks: Option<PublicJsonWebKeySet>,
|
||||
id_token_signed_response_alg: Option<JsonWebSignatureAlg>,
|
||||
userinfo_signed_response_alg: Option<JsonWebSignatureAlg>,
|
||||
token_endpoint_auth_method: Option<OAuthClientAuthenticationMethod>,
|
||||
token_endpoint_auth_signing_alg: Option<JsonWebSignatureAlg>,
|
||||
initiate_login_uri: Option<Url>,
|
||||
) -> Result<Client, Self::Error> {
|
||||
let now = clock.now();
|
||||
let id = Ulid::from_datetime_with_source(now.into(), rng);
|
||||
tracing::Span::current().record("client.id", tracing::field::display(id));
|
||||
|
||||
sqlx::query!(
|
||||
r#"
|
||||
INSERT INTO oauth2_clients
|
||||
(oauth2_client_id,
|
||||
encrypted_client_secret,
|
||||
grant_type_authorization_code,
|
||||
grant_type_refresh_token,
|
||||
client_name,
|
||||
logo_uri,
|
||||
client_uri,
|
||||
policy_uri,
|
||||
tos_uri,
|
||||
jwks_uri,
|
||||
jwks,
|
||||
id_token_signed_response_alg,
|
||||
userinfo_signed_response_alg,
|
||||
token_endpoint_auth_method,
|
||||
token_endpoint_auth_signing_alg,
|
||||
initiate_login_uri)
|
||||
VALUES
|
||||
($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16)
|
||||
"#,
|
||||
Uuid::from(client_id),
|
||||
encrypted_client_secret,
|
||||
grant_type_authorization_code,
|
||||
grant_type_refresh_token,
|
||||
client_name,
|
||||
logo_uri,
|
||||
client_uri,
|
||||
policy_uri,
|
||||
tos_uri,
|
||||
jwks_uri,
|
||||
jwks,
|
||||
id_token_signed_response_alg,
|
||||
userinfo_signed_response_alg,
|
||||
token_endpoint_auth_method,
|
||||
token_endpoint_auth_signing_alg,
|
||||
initiate_login_uri,
|
||||
)
|
||||
.execute(&mut *conn)
|
||||
.await?;
|
||||
let jwks_json = jwks
|
||||
.as_ref()
|
||||
.map(serde_json::to_value)
|
||||
.transpose()
|
||||
.map_err(DatabaseError::to_invalid_operation)?;
|
||||
|
||||
let now = clock.now();
|
||||
let (ids, redirect_uris): (Vec<Uuid>, Vec<String>) = redirect_uris
|
||||
.iter()
|
||||
.map(|uri| {
|
||||
(
|
||||
Uuid::from(Ulid::from_datetime_with_source(now.into(), &mut rng)),
|
||||
uri.as_str().to_owned(),
|
||||
sqlx::query!(
|
||||
r#"
|
||||
INSERT INTO oauth2_clients
|
||||
( oauth2_client_id
|
||||
, encrypted_client_secret
|
||||
, grant_type_authorization_code
|
||||
, grant_type_refresh_token
|
||||
, client_name
|
||||
, logo_uri
|
||||
, client_uri
|
||||
, policy_uri
|
||||
, tos_uri
|
||||
, jwks_uri
|
||||
, jwks
|
||||
, id_token_signed_response_alg
|
||||
, userinfo_signed_response_alg
|
||||
, token_endpoint_auth_method
|
||||
, token_endpoint_auth_signing_alg
|
||||
, initiate_login_uri
|
||||
)
|
||||
VALUES
|
||||
($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16)
|
||||
"#,
|
||||
Uuid::from(id),
|
||||
encrypted_client_secret,
|
||||
grant_types.contains(&GrantType::AuthorizationCode),
|
||||
grant_types.contains(&GrantType::RefreshToken),
|
||||
client_name,
|
||||
logo_uri.as_ref().map(Url::as_str),
|
||||
client_uri.as_ref().map(Url::as_str),
|
||||
policy_uri.as_ref().map(Url::as_str),
|
||||
tos_uri.as_ref().map(Url::as_str),
|
||||
jwks_uri.as_ref().map(Url::as_str),
|
||||
jwks_json,
|
||||
id_token_signed_response_alg
|
||||
.as_ref()
|
||||
.map(ToString::to_string),
|
||||
userinfo_signed_response_alg
|
||||
.as_ref()
|
||||
.map(ToString::to_string),
|
||||
token_endpoint_auth_method.as_ref().map(ToString::to_string),
|
||||
token_endpoint_auth_signing_alg
|
||||
.as_ref()
|
||||
.map(ToString::to_string),
|
||||
initiate_login_uri.as_ref().map(Url::as_str),
|
||||
)
|
||||
.traced()
|
||||
.execute(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
{
|
||||
let span = info_span!(
|
||||
"db.oauth2_client.add.redirect_uris",
|
||||
db.statement = tracing::field::Empty,
|
||||
client.id = %id,
|
||||
);
|
||||
|
||||
let (uri_ids, redirect_uris): (Vec<Uuid>, Vec<String>) = redirect_uris
|
||||
.iter()
|
||||
.map(|uri| {
|
||||
(
|
||||
Uuid::from(Ulid::from_datetime_with_source(now.into(), &mut rng)),
|
||||
uri.as_str().to_owned(),
|
||||
)
|
||||
})
|
||||
.unzip();
|
||||
|
||||
sqlx::query!(
|
||||
r#"
|
||||
INSERT INTO oauth2_client_redirect_uris
|
||||
(oauth2_client_redirect_uri_id, oauth2_client_id, redirect_uri)
|
||||
SELECT id, $2, redirect_uri
|
||||
FROM UNNEST($1::uuid[], $3::text[]) r(id, redirect_uri)
|
||||
"#,
|
||||
&uri_ids,
|
||||
Uuid::from(id),
|
||||
&redirect_uris,
|
||||
)
|
||||
.record(&span)
|
||||
.execute(&mut *self.conn)
|
||||
.instrument(span)
|
||||
.await?;
|
||||
}
|
||||
|
||||
let jwks = match (jwks, jwks_uri) {
|
||||
(None, None) => None,
|
||||
(Some(jwks), None) => Some(JwksOrJwksUri::Jwks(jwks)),
|
||||
(None, Some(jwks_uri)) => Some(JwksOrJwksUri::JwksUri(jwks_uri)),
|
||||
_ => return Err(DatabaseError::invalid_operation()),
|
||||
};
|
||||
|
||||
Ok(Client {
|
||||
id,
|
||||
client_id: id.to_string(),
|
||||
encrypted_client_secret,
|
||||
redirect_uris,
|
||||
response_types: vec![
|
||||
OAuthAuthorizationEndpointResponseType::Code,
|
||||
OAuthAuthorizationEndpointResponseType::IdToken,
|
||||
OAuthAuthorizationEndpointResponseType::None,
|
||||
],
|
||||
grant_types,
|
||||
contacts,
|
||||
client_name,
|
||||
logo_uri,
|
||||
client_uri,
|
||||
policy_uri,
|
||||
tos_uri,
|
||||
jwks,
|
||||
id_token_signed_response_alg,
|
||||
userinfo_signed_response_alg,
|
||||
token_endpoint_auth_method,
|
||||
token_endpoint_auth_signing_alg,
|
||||
initiate_login_uri,
|
||||
})
|
||||
.unzip();
|
||||
}
|
||||
|
||||
sqlx::query!(
|
||||
r#"
|
||||
INSERT INTO oauth2_client_redirect_uris
|
||||
(oauth2_client_redirect_uri_id, oauth2_client_id, redirect_uri)
|
||||
SELECT id, $2, redirect_uri
|
||||
FROM UNNEST($1::uuid[], $3::text[]) r(id, redirect_uri)
|
||||
"#,
|
||||
&ids,
|
||||
Uuid::from(client_id),
|
||||
&redirect_uris,
|
||||
)
|
||||
.execute(&mut *conn)
|
||||
.await?;
|
||||
#[tracing::instrument(
|
||||
name = "db.oauth2_client.add_from_config",
|
||||
skip_all,
|
||||
fields(
|
||||
db.statement,
|
||||
client.id = %client_id,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn add_from_config(
|
||||
&mut self,
|
||||
mut rng: impl Rng + Send,
|
||||
clock: &Clock,
|
||||
client_id: Ulid,
|
||||
client_auth_method: OAuthClientAuthenticationMethod,
|
||||
encrypted_client_secret: Option<String>,
|
||||
jwks: Option<PublicJsonWebKeySet>,
|
||||
jwks_uri: Option<Url>,
|
||||
redirect_uris: Vec<Url>,
|
||||
) -> Result<Client, Self::Error> {
|
||||
let jwks_json = jwks
|
||||
.as_ref()
|
||||
.map(serde_json::to_value)
|
||||
.transpose()
|
||||
.map_err(DatabaseError::to_invalid_operation)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
let client_auth_method = client_auth_method.to_string();
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub async fn insert_client_from_config(
|
||||
conn: &mut PgConnection,
|
||||
mut rng: impl Rng + Send,
|
||||
clock: &Clock,
|
||||
client_id: Ulid,
|
||||
client_auth_method: OAuthClientAuthenticationMethod,
|
||||
encrypted_client_secret: Option<&str>,
|
||||
jwks: Option<&PublicJsonWebKeySet>,
|
||||
jwks_uri: Option<&Url>,
|
||||
redirect_uris: &[Url],
|
||||
) -> Result<(), DatabaseError> {
|
||||
let jwks = jwks
|
||||
.map(serde_json::to_value)
|
||||
.transpose()
|
||||
.map_err(DatabaseError::to_invalid_operation)?;
|
||||
|
||||
let jwks_uri = jwks_uri.map(Url::as_str);
|
||||
|
||||
let client_auth_method = client_auth_method.to_string();
|
||||
|
||||
sqlx::query!(
|
||||
r#"
|
||||
sqlx::query!(
|
||||
r#"
|
||||
INSERT INTO oauth2_clients
|
||||
( oauth2_client_id
|
||||
, encrypted_client_secret
|
||||
@ -500,41 +623,83 @@ pub async fn insert_client_from_config(
|
||||
, jwks = EXCLUDED.jwks
|
||||
, jwks_uri = EXCLUDED.jwks_uri
|
||||
"#,
|
||||
Uuid::from(client_id),
|
||||
encrypted_client_secret,
|
||||
true,
|
||||
true,
|
||||
client_auth_method,
|
||||
jwks,
|
||||
jwks_uri,
|
||||
)
|
||||
.execute(&mut *conn)
|
||||
.await?;
|
||||
Uuid::from(client_id),
|
||||
encrypted_client_secret,
|
||||
true,
|
||||
true,
|
||||
client_auth_method,
|
||||
jwks_json,
|
||||
jwks_uri.as_ref().map(Url::as_str),
|
||||
)
|
||||
.traced()
|
||||
.execute(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
let now = clock.now();
|
||||
let (ids, redirect_uris): (Vec<Uuid>, Vec<String>) = redirect_uris
|
||||
.iter()
|
||||
.map(|uri| {
|
||||
(
|
||||
Uuid::from(Ulid::from_datetime_with_source(now.into(), &mut rng)),
|
||||
uri.as_str().to_owned(),
|
||||
{
|
||||
let span = info_span!(
|
||||
"db.oauth2_client.add_from_config.redirect_uris",
|
||||
client.id = %client_id,
|
||||
db.statement = tracing::field::Empty,
|
||||
);
|
||||
|
||||
let now = clock.now();
|
||||
let (ids, redirect_uris): (Vec<Uuid>, Vec<String>) = redirect_uris
|
||||
.iter()
|
||||
.map(|uri| {
|
||||
(
|
||||
Uuid::from(Ulid::from_datetime_with_source(now.into(), &mut rng)),
|
||||
uri.as_str().to_owned(),
|
||||
)
|
||||
})
|
||||
.unzip();
|
||||
|
||||
sqlx::query!(
|
||||
r#"
|
||||
INSERT INTO oauth2_client_redirect_uris
|
||||
(oauth2_client_redirect_uri_id, oauth2_client_id, redirect_uri)
|
||||
SELECT id, $2, redirect_uri
|
||||
FROM UNNEST($1::uuid[], $3::text[]) r(id, redirect_uri)
|
||||
"#,
|
||||
&ids,
|
||||
Uuid::from(client_id),
|
||||
&redirect_uris,
|
||||
)
|
||||
.record(&span)
|
||||
.execute(&mut *self.conn)
|
||||
.instrument(span)
|
||||
.await?;
|
||||
}
|
||||
|
||||
let jwks = match (jwks, jwks_uri) {
|
||||
(None, None) => None,
|
||||
(Some(jwks), None) => Some(JwksOrJwksUri::Jwks(jwks)),
|
||||
(None, Some(jwks_uri)) => Some(JwksOrJwksUri::JwksUri(jwks_uri)),
|
||||
_ => return Err(DatabaseError::invalid_operation()),
|
||||
};
|
||||
|
||||
Ok(Client {
|
||||
id: client_id,
|
||||
client_id: client_id.to_string(),
|
||||
encrypted_client_secret,
|
||||
redirect_uris,
|
||||
response_types: vec![
|
||||
OAuthAuthorizationEndpointResponseType::Code,
|
||||
OAuthAuthorizationEndpointResponseType::IdToken,
|
||||
OAuthAuthorizationEndpointResponseType::None,
|
||||
],
|
||||
grant_types: Vec::new(),
|
||||
contacts: Vec::new(),
|
||||
client_name: None,
|
||||
logo_uri: None,
|
||||
client_uri: None,
|
||||
policy_uri: None,
|
||||
tos_uri: None,
|
||||
jwks,
|
||||
id_token_signed_response_alg: None,
|
||||
userinfo_signed_response_alg: None,
|
||||
token_endpoint_auth_method: None,
|
||||
token_endpoint_auth_signing_alg: None,
|
||||
initiate_login_uri: None,
|
||||
})
|
||||
.unzip();
|
||||
|
||||
sqlx::query!(
|
||||
r#"
|
||||
INSERT INTO oauth2_client_redirect_uris
|
||||
(oauth2_client_redirect_uri_id, oauth2_client_id, redirect_uri)
|
||||
SELECT id, $2, redirect_uri
|
||||
FROM UNNEST($1::uuid[], $3::text[]) r(id, redirect_uri)
|
||||
"#,
|
||||
&ids,
|
||||
Uuid::from(client_id),
|
||||
&redirect_uris,
|
||||
)
|
||||
.execute(&mut *conn)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
@ -20,7 +20,7 @@ use tracing::{info_span, Instrument};
|
||||
use ulid::Ulid;
|
||||
use uuid::Uuid;
|
||||
|
||||
use self::client::lookup_clients;
|
||||
use self::client::OAuth2ClientRepository;
|
||||
use crate::{
|
||||
pagination::{process_page, QueryBuilderExt},
|
||||
user::BrowserSessionRepository,
|
||||
@ -128,7 +128,7 @@ pub async fn get_paginated_user_oauth_sessions(
|
||||
let browser_session_ids: BTreeSet<Ulid> =
|
||||
page.iter().map(|i| Ulid::from(i.user_session_id)).collect();
|
||||
|
||||
let clients = lookup_clients(&mut *conn, client_ids).await?;
|
||||
let clients = conn.oauth2_client().load_batch(client_ids).await?;
|
||||
|
||||
// TODO: this can generate N queries instead of batching. This is less than
|
||||
// ideal
|
||||
|
@ -19,8 +19,8 @@ use sqlx::{PgConnection, PgExecutor};
|
||||
use ulid::Ulid;
|
||||
use uuid::Uuid;
|
||||
|
||||
use super::client::lookup_client;
|
||||
use crate::{Clock, DatabaseError, DatabaseInconsistencyError};
|
||||
use super::client::OAuth2ClientRepository;
|
||||
use crate::{Clock, DatabaseError, DatabaseInconsistencyError, Repository};
|
||||
|
||||
#[tracing::instrument(
|
||||
skip_all,
|
||||
@ -173,7 +173,9 @@ pub async fn lookup_active_refresh_token(
|
||||
};
|
||||
|
||||
let session_id = res.oauth2_session_id.into();
|
||||
let client = lookup_client(&mut *conn, res.oauth2_client_id.into())
|
||||
let client = conn
|
||||
.oauth2_client()
|
||||
.lookup(res.oauth2_client_id.into())
|
||||
.await?
|
||||
.ok_or_else(|| {
|
||||
DatabaseInconsistencyError::on("oauth2_sessions")
|
||||
|
@ -15,6 +15,7 @@
|
||||
use sqlx::{PgConnection, Postgres, Transaction};
|
||||
|
||||
use crate::{
|
||||
oauth2::client::PgOAuth2ClientRepository,
|
||||
upstream_oauth2::{
|
||||
PgUpstreamOAuthLinkRepository, PgUpstreamOAuthProviderRepository,
|
||||
PgUpstreamOAuthSessionRepository,
|
||||
@ -54,6 +55,10 @@ pub trait Repository {
|
||||
where
|
||||
Self: 'c;
|
||||
|
||||
type OAuth2ClientRepository<'c>
|
||||
where
|
||||
Self: 'c;
|
||||
|
||||
fn upstream_oauth_link(&mut self) -> Self::UpstreamOAuthLinkRepository<'_>;
|
||||
fn upstream_oauth_provider(&mut self) -> Self::UpstreamOAuthProviderRepository<'_>;
|
||||
fn upstream_oauth_session(&mut self) -> Self::UpstreamOAuthSessionRepository<'_>;
|
||||
@ -61,6 +66,7 @@ pub trait Repository {
|
||||
fn user_email(&mut self) -> Self::UserEmailRepository<'_>;
|
||||
fn user_password(&mut self) -> Self::UserPasswordRepository<'_>;
|
||||
fn browser_session(&mut self) -> Self::BrowserSessionRepository<'_>;
|
||||
fn oauth2_client(&mut self) -> Self::OAuth2ClientRepository<'_>;
|
||||
}
|
||||
|
||||
impl Repository for PgConnection {
|
||||
@ -71,6 +77,7 @@ impl Repository for PgConnection {
|
||||
type UserEmailRepository<'c> = PgUserEmailRepository<'c> where Self: 'c;
|
||||
type UserPasswordRepository<'c> = PgUserPasswordRepository<'c> where Self: 'c;
|
||||
type BrowserSessionRepository<'c> = PgBrowserSessionRepository<'c> where Self: 'c;
|
||||
type OAuth2ClientRepository<'c> = PgOAuth2ClientRepository<'c> where Self: 'c;
|
||||
|
||||
fn upstream_oauth_link(&mut self) -> Self::UpstreamOAuthLinkRepository<'_> {
|
||||
PgUpstreamOAuthLinkRepository::new(self)
|
||||
@ -99,6 +106,10 @@ impl Repository for PgConnection {
|
||||
fn browser_session(&mut self) -> Self::BrowserSessionRepository<'_> {
|
||||
PgBrowserSessionRepository::new(self)
|
||||
}
|
||||
|
||||
fn oauth2_client(&mut self) -> Self::OAuth2ClientRepository<'_> {
|
||||
PgOAuth2ClientRepository::new(self)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'t> Repository for Transaction<'t, Postgres> {
|
||||
@ -109,6 +120,7 @@ impl<'t> Repository for Transaction<'t, Postgres> {
|
||||
type UserEmailRepository<'c> = PgUserEmailRepository<'c> where Self: 'c;
|
||||
type UserPasswordRepository<'c> = PgUserPasswordRepository<'c> where Self: 'c;
|
||||
type BrowserSessionRepository<'c> = PgBrowserSessionRepository<'c> where Self: 'c;
|
||||
type OAuth2ClientRepository<'c> = PgOAuth2ClientRepository<'c> where Self: 'c;
|
||||
|
||||
fn upstream_oauth_link(&mut self) -> Self::UpstreamOAuthLinkRepository<'_> {
|
||||
PgUpstreamOAuthLinkRepository::new(self)
|
||||
@ -137,4 +149,8 @@ impl<'t> Repository for Transaction<'t, Postgres> {
|
||||
fn browser_session(&mut self) -> Self::BrowserSessionRepository<'_> {
|
||||
PgBrowserSessionRepository::new(self)
|
||||
}
|
||||
|
||||
fn oauth2_client(&mut self) -> Self::OAuth2ClientRepository<'_> {
|
||||
PgOAuth2ClientRepository::new(self)
|
||||
}
|
||||
}
|
||||
|
@ -12,9 +12,16 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
pub trait ExecuteExt<'q, DB> {
|
||||
use tracing::Span;
|
||||
|
||||
pub trait ExecuteExt<'q, DB>: Sized {
|
||||
/// Records the statement as `db.statement` in the current span
|
||||
fn traced(self) -> Self;
|
||||
fn traced(self) -> Self {
|
||||
self.record(&Span::current())
|
||||
}
|
||||
|
||||
/// Records the statement as `db.statement` in the given span
|
||||
fn record(self, span: &Span) -> Self;
|
||||
}
|
||||
|
||||
impl<'q, DB, T> ExecuteExt<'q, DB> for T
|
||||
@ -22,8 +29,8 @@ where
|
||||
T: sqlx::Execute<'q, DB>,
|
||||
DB: sqlx::Database,
|
||||
{
|
||||
fn traced(self) -> Self {
|
||||
tracing::Span::current().record("db.statement", self.sql());
|
||||
fn record(self, span: &Span) -> Self {
|
||||
span.record("db.statement", self.sql());
|
||||
self
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user