You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-07-31 09:24:31 +03:00
Add test helpers for handlers and use them
Also expands the test coverage of the revoke handler.
This commit is contained in:
@ -31,24 +31,18 @@ pub async fn get(State(pool): State<PgPool>) -> Result<impl IntoResponse, FancyE
|
|||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use hyper::{Body, Request, StatusCode};
|
use hyper::{Request, StatusCode};
|
||||||
use tower::ServiceExt;
|
|
||||||
|
|
||||||
use super::*;
|
use super::*;
|
||||||
|
use crate::test_utils::{RequestBuilderExt, ResponseExt, TestState};
|
||||||
|
|
||||||
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
|
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
|
||||||
async fn test_get_health(pool: PgPool) -> Result<(), anyhow::Error> {
|
async fn test_get_health(pool: PgPool) {
|
||||||
let state = crate::test_state(pool).await?;
|
let state = TestState::from_pool(pool).await.unwrap();
|
||||||
let app = crate::healthcheck_router().with_state(state);
|
let request = Request::get("/health").empty();
|
||||||
|
|
||||||
let request = Request::builder().uri("/health").body(Body::empty())?;
|
let response = state.request(request).await;
|
||||||
|
response.assert_status(StatusCode::OK);
|
||||||
let response = app.oneshot(request).await?;
|
assert_eq!(response.body(), "ok");
|
||||||
|
|
||||||
assert_eq!(response.status(), StatusCode::OK);
|
|
||||||
let body = hyper::body::to_bytes(response.into_body()).await?;
|
|
||||||
assert_eq!(body, "ok");
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -59,6 +59,9 @@ pub mod passwords;
|
|||||||
mod upstream_oauth2;
|
mod upstream_oauth2;
|
||||||
mod views;
|
mod views;
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod test_utils;
|
||||||
|
|
||||||
/// Implement `From<E>` for `RouteError`, for "internal server error" kind of
|
/// Implement `From<E>` for `RouteError`, for "internal server error" kind of
|
||||||
/// errors.
|
/// errors.
|
||||||
#[macro_export]
|
#[macro_export]
|
||||||
@ -363,68 +366,3 @@ where
|
|||||||
},
|
},
|
||||||
))
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
|
||||||
async fn test_state(pool: sqlx::PgPool) -> Result<AppState, anyhow::Error> {
|
|
||||||
use mas_email::MailTransport;
|
|
||||||
use mas_keystore::{JsonWebKey, JsonWebKeySet, PrivateKey};
|
|
||||||
|
|
||||||
use crate::passwords::Hasher;
|
|
||||||
|
|
||||||
let workspace_root = camino::Utf8Path::new(env!("CARGO_MANIFEST_DIR"))
|
|
||||||
.join("..")
|
|
||||||
.join("..");
|
|
||||||
|
|
||||||
let url_builder = UrlBuilder::new("https://example.com/".parse()?);
|
|
||||||
|
|
||||||
let templates = Templates::load(workspace_root.join("templates"), url_builder.clone()).await?;
|
|
||||||
|
|
||||||
// TODO: add more test keys to the store
|
|
||||||
let rsa =
|
|
||||||
PrivateKey::load_pem(include_str!("../../keystore/tests/keys/rsa.pkcs1.pem")).unwrap();
|
|
||||||
let rsa = JsonWebKey::new(rsa).with_kid("test-rsa");
|
|
||||||
|
|
||||||
let jwks = JsonWebKeySet::new(vec![rsa]);
|
|
||||||
let key_store = Keystore::new(jwks);
|
|
||||||
|
|
||||||
let encrypter = Encrypter::new(&[0x42; 32]);
|
|
||||||
|
|
||||||
let password_manager = PasswordManager::new([(1, Hasher::argon2id(None))])?;
|
|
||||||
|
|
||||||
let transport = MailTransport::blackhole();
|
|
||||||
let mailbox: lettre::message::Mailbox = "server@example.com".parse()?;
|
|
||||||
let mailer = Mailer::new(templates.clone(), transport, mailbox.clone(), mailbox);
|
|
||||||
|
|
||||||
let homeserver = MatrixHomeserver::new("example.com".to_owned());
|
|
||||||
|
|
||||||
let file = tokio::fs::File::open(workspace_root.join("policies").join("policy.wasm")).await?;
|
|
||||||
|
|
||||||
let policy_factory = PolicyFactory::load(
|
|
||||||
file,
|
|
||||||
serde_json::json!({}),
|
|
||||||
"register/violation".to_owned(),
|
|
||||||
"client_registration/violation".to_owned(),
|
|
||||||
"authorization_grant/violation".to_owned(),
|
|
||||||
)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
let policy_factory = Arc::new(policy_factory);
|
|
||||||
|
|
||||||
let graphql_schema = graphql_schema();
|
|
||||||
|
|
||||||
let http_client_factory = HttpClientFactory::new(10);
|
|
||||||
|
|
||||||
Ok(AppState {
|
|
||||||
pool,
|
|
||||||
templates,
|
|
||||||
key_store,
|
|
||||||
encrypter,
|
|
||||||
url_builder,
|
|
||||||
mailer,
|
|
||||||
homeserver,
|
|
||||||
policy_factory,
|
|
||||||
graphql_schema,
|
|
||||||
http_client_factory,
|
|
||||||
password_manager,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
@ -201,24 +201,19 @@ pub(crate) async fn post(
|
|||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use hyper::{
|
use hyper::Request;
|
||||||
header::{AUTHORIZATION, CONTENT_TYPE},
|
|
||||||
Request,
|
|
||||||
};
|
|
||||||
use mas_data_model::AuthorizationCode;
|
use mas_data_model::AuthorizationCode;
|
||||||
use mas_router::SimpleRoute;
|
use mas_router::SimpleRoute;
|
||||||
use mas_storage::{RepositoryAccess, RepositoryTransaction, SystemClock};
|
use mas_storage::RepositoryAccess;
|
||||||
use mas_storage_pg::PgRepository;
|
|
||||||
use oauth2_types::{
|
use oauth2_types::{
|
||||||
registration::ClientRegistrationResponse,
|
registration::ClientRegistrationResponse,
|
||||||
requests::{AccessTokenResponse, ResponseMode},
|
requests::{AccessTokenResponse, ResponseMode},
|
||||||
scope::{Scope, OPENID},
|
scope::{Scope, OPENID},
|
||||||
};
|
};
|
||||||
use rand::SeedableRng;
|
|
||||||
use sqlx::PgPool;
|
use sqlx::PgPool;
|
||||||
use tower::{Service, ServiceExt};
|
|
||||||
|
|
||||||
use super::*;
|
use super::*;
|
||||||
|
use crate::test_utils::{RequestBuilderExt, ResponseExt, TestState};
|
||||||
|
|
||||||
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
|
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
|
||||||
async fn test_revoke_access_token(pool: PgPool) {
|
async fn test_revoke_access_token(pool: PgPool) {
|
||||||
@ -227,50 +222,40 @@ mod tests {
|
|||||||
.with_test_writer()
|
.with_test_writer()
|
||||||
.init();
|
.init();
|
||||||
|
|
||||||
let clock = SystemClock::default();
|
let state = TestState::from_pool(pool).await.unwrap();
|
||||||
let mut rng = rand_chacha::ChaChaRng::seed_from_u64(42);
|
|
||||||
|
|
||||||
let state = crate::test_state(pool.clone()).await.unwrap();
|
let request =
|
||||||
let mut app = crate::api_router().with_state(state);
|
Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({
|
||||||
|
|
||||||
let request = Request::post(mas_router::OAuth2RegistrationEndpoint::PATH)
|
|
||||||
.header(CONTENT_TYPE, "application/json")
|
|
||||||
.body(
|
|
||||||
serde_json::json!({
|
|
||||||
"client_uri": "https://example.com/",
|
"client_uri": "https://example.com/",
|
||||||
"redirect_uris": ["https://example.com/callback"],
|
"redirect_uris": ["https://example.com/callback"],
|
||||||
"contacts": ["contact@example.com"],
|
"contacts": ["contact@example.com"],
|
||||||
"token_endpoint_auth_method": "client_secret_post",
|
"token_endpoint_auth_method": "client_secret_post",
|
||||||
"response_types": ["code"],
|
"response_types": ["code"],
|
||||||
"grant_types": ["authorization_code"],
|
"grant_types": ["authorization_code"],
|
||||||
})
|
}));
|
||||||
.to_string(),
|
|
||||||
)
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
let response = app.ready().await.unwrap().call(request).await.unwrap();
|
let response = state.request(request).await;
|
||||||
assert_eq!(response.status(), StatusCode::CREATED);
|
response.assert_status(StatusCode::CREATED);
|
||||||
|
|
||||||
let body = hyper::body::to_bytes(response.into_body()).await.unwrap();
|
|
||||||
let client_registration: ClientRegistrationResponse =
|
let client_registration: ClientRegistrationResponse =
|
||||||
serde_json::from_slice(&body).unwrap();
|
serde_json::from_str(response.body()).unwrap();
|
||||||
|
|
||||||
let client_id = client_registration.client_id;
|
let client_id = client_registration.client_id;
|
||||||
let client_secret = client_registration.client_secret.unwrap();
|
let client_secret = client_registration.client_secret.unwrap();
|
||||||
|
|
||||||
// Let's provision a user and create a session for them. This part is hard to
|
// Let's provision a user and create a session for them. This part is hard to
|
||||||
// test with just HTTP requests, so we'll use the repository directly.
|
// test with just HTTP requests, so we'll use the repository directly.
|
||||||
let mut repo = PgRepository::from_pool(&pool).await.unwrap();
|
let mut repo = state.repository().await.unwrap();
|
||||||
|
|
||||||
let user = repo
|
let user = repo
|
||||||
.user()
|
.user()
|
||||||
.add(&mut rng, &clock, "alice".to_owned())
|
.add(&mut state.rng(), &state.clock, "alice".to_owned())
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let browser_session = repo
|
let browser_session = repo
|
||||||
.browser_session()
|
.browser_session()
|
||||||
.add(&mut rng, &clock, &user)
|
.add(&mut state.rng(), &state.clock, &user)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
@ -286,8 +271,8 @@ mod tests {
|
|||||||
let grant = repo
|
let grant = repo
|
||||||
.oauth2_authorization_grant()
|
.oauth2_authorization_grant()
|
||||||
.add(
|
.add(
|
||||||
&mut rng,
|
&mut state.rng(),
|
||||||
&clock,
|
&state.clock,
|
||||||
&client,
|
&client,
|
||||||
"https://example.com/redirect".parse().unwrap(),
|
"https://example.com/redirect".parse().unwrap(),
|
||||||
Scope::from_iter([OPENID]),
|
Scope::from_iter([OPENID]),
|
||||||
@ -299,7 +284,7 @@ mod tests {
|
|||||||
Some("nonce".to_owned()),
|
Some("nonce".to_owned()),
|
||||||
None,
|
None,
|
||||||
ResponseMode::Query,
|
ResponseMode::Query,
|
||||||
true,
|
false,
|
||||||
false,
|
false,
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
@ -307,69 +292,169 @@ mod tests {
|
|||||||
|
|
||||||
let session = repo
|
let session = repo
|
||||||
.oauth2_session()
|
.oauth2_session()
|
||||||
.create_from_grant(&mut rng, &clock, &grant, &browser_session)
|
.create_from_grant(&mut state.rng(), &state.clock, &grant, &browser_session)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let grant = repo
|
let grant = repo
|
||||||
.oauth2_authorization_grant()
|
.oauth2_authorization_grant()
|
||||||
.fulfill(&clock, &session, grant)
|
.fulfill(&state.clock, &session, grant)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
Box::new(repo).save().await.unwrap();
|
repo.save().await.unwrap();
|
||||||
|
|
||||||
// Now call the token endpoint to get an access token.
|
// Now call the token endpoint to get an access token.
|
||||||
let request = Request::post(mas_router::OAuth2TokenEndpoint::PATH)
|
let request =
|
||||||
.header(CONTENT_TYPE, "application/x-www-form-urlencoded")
|
Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
|
||||||
.body(
|
"grant_type": "authorization_code",
|
||||||
format!(
|
"code": grant.code.unwrap().code,
|
||||||
"grant_type=authorization_code&code={code}&redirect_uri={redirect_uri}&client_id={client_id}&client_secret={client_secret}",
|
"redirect_uri": grant.redirect_uri,
|
||||||
code = grant.code.unwrap().code,
|
"client_id": client_id,
|
||||||
redirect_uri = grant.redirect_uri,
|
"client_secret": client_secret,
|
||||||
),
|
}));
|
||||||
)
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
let response = app.ready().await.unwrap().call(request).await.unwrap();
|
let response = state.request(request).await;
|
||||||
let status = response.status();
|
response.assert_status(StatusCode::OK);
|
||||||
assert_eq!(status, StatusCode::OK);
|
|
||||||
|
|
||||||
let body = hyper::body::to_bytes(response.into_body()).await.unwrap();
|
let token: AccessTokenResponse = serde_json::from_str(response.body()).unwrap();
|
||||||
let token: AccessTokenResponse = serde_json::from_slice(&body).unwrap();
|
|
||||||
|
|
||||||
// Let's call the userinfo endpoint to make sure we can access it.
|
// Check that the token is valid
|
||||||
let request = Request::get(mas_router::OidcUserinfo::PATH)
|
assert!(state.is_access_token_valid(&token.access_token).await);
|
||||||
.header(AUTHORIZATION, format!("Bearer {}", token.access_token))
|
|
||||||
.body(String::new())
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
let response = app.ready().await.unwrap().call(request).await.unwrap();
|
|
||||||
let status = response.status();
|
|
||||||
assert_eq!(status, StatusCode::OK);
|
|
||||||
|
|
||||||
// Now let's revoke the access token.
|
// Now let's revoke the access token.
|
||||||
let request = Request::post(mas_router::OAuth2Revocation::PATH)
|
let request = Request::post(mas_router::OAuth2Revocation::PATH).form(serde_json::json!({
|
||||||
.header(CONTENT_TYPE, "application/x-www-form-urlencoded")
|
"token": token.access_token,
|
||||||
.body(format!(
|
"token_type_hint": "access_token",
|
||||||
"token={token}&token_type_hint=access_token&client_id={client_id}&client_secret={client_secret}",
|
"client_id": client_id,
|
||||||
token = token.access_token
|
"client_secret": client_secret,
|
||||||
))
|
}));
|
||||||
|
|
||||||
|
let response = state.request(request).await;
|
||||||
|
response.assert_status(StatusCode::OK);
|
||||||
|
|
||||||
|
// Check that the token is no longer valid
|
||||||
|
assert!(!state.is_access_token_valid(&token.access_token).await);
|
||||||
|
|
||||||
|
// Try using the refresh token to get a new access token, it should fail.
|
||||||
|
let request =
|
||||||
|
Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
|
||||||
|
"grant_type": "refresh_token",
|
||||||
|
"refresh_token": token.refresh_token,
|
||||||
|
"client_id": client_id,
|
||||||
|
"client_secret": client_secret,
|
||||||
|
}));
|
||||||
|
|
||||||
|
let response = state.request(request).await;
|
||||||
|
response.assert_status(StatusCode::BAD_REQUEST);
|
||||||
|
|
||||||
|
// Now try with a new grant, and by revoking the refresh token instead
|
||||||
|
let mut repo = state.repository().await.unwrap();
|
||||||
|
let grant = repo
|
||||||
|
.oauth2_authorization_grant()
|
||||||
|
.add(
|
||||||
|
&mut state.rng(),
|
||||||
|
&state.clock,
|
||||||
|
&client,
|
||||||
|
"https://example.com/redirect".parse().unwrap(),
|
||||||
|
Scope::from_iter([OPENID]),
|
||||||
|
Some(AuthorizationCode {
|
||||||
|
code: "anotherverysecretcode".to_owned(),
|
||||||
|
pkce: None,
|
||||||
|
}),
|
||||||
|
Some("state".to_owned()),
|
||||||
|
Some("nonce".to_owned()),
|
||||||
|
None,
|
||||||
|
ResponseMode::Query,
|
||||||
|
false,
|
||||||
|
false,
|
||||||
|
)
|
||||||
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let response = app.ready().await.unwrap().call(request).await.unwrap();
|
let session = repo
|
||||||
let status = response.status();
|
.oauth2_session()
|
||||||
assert_eq!(status, StatusCode::OK);
|
.create_from_grant(&mut state.rng(), &state.clock, &grant, &browser_session)
|
||||||
|
.await
|
||||||
// Call the userinfo endpoint again to make sure we can't access it anymore.
|
|
||||||
let request = Request::get(mas_router::OidcUserinfo::PATH)
|
|
||||||
.header(AUTHORIZATION, format!("Bearer {}", token.access_token))
|
|
||||||
.body(String::new())
|
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let response = app.ready().await.unwrap().call(request).await.unwrap();
|
let grant = repo
|
||||||
let status = response.status();
|
.oauth2_authorization_grant()
|
||||||
assert_eq!(status, StatusCode::UNAUTHORIZED);
|
.fulfill(&state.clock, &session, grant)
|
||||||
// TODO: test refreshing the access token, test refresh token revocation
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
repo.save().await.unwrap();
|
||||||
|
|
||||||
|
// Now call the token endpoint to get an access token.
|
||||||
|
let request =
|
||||||
|
Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
|
||||||
|
"grant_type": "authorization_code",
|
||||||
|
"code": grant.code.unwrap().code,
|
||||||
|
"redirect_uri": grant.redirect_uri,
|
||||||
|
"client_id": client_id,
|
||||||
|
"client_secret": client_secret,
|
||||||
|
}));
|
||||||
|
|
||||||
|
let response = state.request(request).await;
|
||||||
|
response.assert_status(StatusCode::OK);
|
||||||
|
|
||||||
|
let token: AccessTokenResponse = serde_json::from_str(response.body()).unwrap();
|
||||||
|
|
||||||
|
// Use the refresh token to get a new access token.
|
||||||
|
let request =
|
||||||
|
Request::post(mas_router::OAuth2TokenEndpoint::PATH).form(serde_json::json!({
|
||||||
|
"grant_type": "refresh_token",
|
||||||
|
"refresh_token": token.refresh_token,
|
||||||
|
"client_id": client_id,
|
||||||
|
"client_secret": client_secret,
|
||||||
|
}));
|
||||||
|
|
||||||
|
let response = state.request(request).await;
|
||||||
|
response.assert_status(StatusCode::OK);
|
||||||
|
|
||||||
|
let old_token = token;
|
||||||
|
let token: AccessTokenResponse = serde_json::from_str(response.body()).unwrap();
|
||||||
|
assert!(state.is_access_token_valid(&token.access_token).await);
|
||||||
|
assert!(!state.is_access_token_valid(&old_token.access_token).await);
|
||||||
|
|
||||||
|
// Revoking the old access token shouldn't do anything.
|
||||||
|
let request = Request::post(mas_router::OAuth2Revocation::PATH).form(serde_json::json!({
|
||||||
|
"token": old_token.access_token,
|
||||||
|
"token_type_hint": "access_token",
|
||||||
|
"client_id": client_id,
|
||||||
|
"client_secret": client_secret,
|
||||||
|
}));
|
||||||
|
|
||||||
|
let response = state.request(request).await;
|
||||||
|
response.assert_status(StatusCode::OK);
|
||||||
|
|
||||||
|
assert!(state.is_access_token_valid(&token.access_token).await);
|
||||||
|
|
||||||
|
// Revoking the old refresh token shouldn't do anything.
|
||||||
|
let request = Request::post(mas_router::OAuth2Revocation::PATH).form(serde_json::json!({
|
||||||
|
"token": old_token.refresh_token,
|
||||||
|
"token_type_hint": "refresh_token",
|
||||||
|
"client_id": client_id,
|
||||||
|
"client_secret": client_secret,
|
||||||
|
}));
|
||||||
|
|
||||||
|
let response = state.request(request).await;
|
||||||
|
response.assert_status(StatusCode::OK);
|
||||||
|
|
||||||
|
assert!(state.is_access_token_valid(&token.access_token).await);
|
||||||
|
|
||||||
|
// Revoking the new refresh token should invalidate the session
|
||||||
|
let request = Request::post(mas_router::OAuth2Revocation::PATH).form(serde_json::json!({
|
||||||
|
"token": token.refresh_token,
|
||||||
|
"token_type_hint": "refresh_token",
|
||||||
|
"client_id": client_id,
|
||||||
|
"client_secret": client_secret,
|
||||||
|
}));
|
||||||
|
|
||||||
|
let response = state.request(request).await;
|
||||||
|
response.assert_status(StatusCode::OK);
|
||||||
|
|
||||||
|
assert!(!state.is_access_token_valid(&token.access_token).await);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
376
crates/handlers/src/test_utils.rs
Normal file
376
crates/handlers/src/test_utils.rs
Normal file
@ -0,0 +1,376 @@
|
|||||||
|
// Copyright 2023 The Matrix.org Foundation C.I.C.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
use std::{convert::Infallible, sync::Arc};
|
||||||
|
|
||||||
|
use axum::{
|
||||||
|
async_trait,
|
||||||
|
body::HttpBody,
|
||||||
|
extract::{FromRef, FromRequestParts},
|
||||||
|
};
|
||||||
|
use headers::{Authorization, ContentType, HeaderMapExt};
|
||||||
|
use hyper::{Request, Response, StatusCode};
|
||||||
|
use mas_axum_utils::http_client_factory::HttpClientFactory;
|
||||||
|
use mas_email::{MailTransport, Mailer};
|
||||||
|
use mas_keystore::{Encrypter, JsonWebKey, JsonWebKeySet, Keystore, PrivateKey};
|
||||||
|
use mas_policy::PolicyFactory;
|
||||||
|
use mas_router::{SimpleRoute, UrlBuilder};
|
||||||
|
use mas_storage::{clock::MockClock, BoxClock, BoxRepository, BoxRng, Repository};
|
||||||
|
use mas_storage_pg::PgRepository;
|
||||||
|
use mas_templates::Templates;
|
||||||
|
use rand::SeedableRng;
|
||||||
|
use rand_chacha::ChaChaRng;
|
||||||
|
use serde::Serialize;
|
||||||
|
use sqlx::PgPool;
|
||||||
|
use tokio::sync::Mutex;
|
||||||
|
use tower::{Service, ServiceExt};
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
app_state::RepositoryError,
|
||||||
|
graphql_schema,
|
||||||
|
passwords::{Hasher, PasswordManager},
|
||||||
|
MatrixHomeserver,
|
||||||
|
};
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub(crate) struct TestState {
|
||||||
|
pub pool: PgPool,
|
||||||
|
pub templates: Templates,
|
||||||
|
pub key_store: Keystore,
|
||||||
|
pub encrypter: Encrypter,
|
||||||
|
pub url_builder: UrlBuilder,
|
||||||
|
pub mailer: Mailer,
|
||||||
|
pub homeserver: MatrixHomeserver,
|
||||||
|
pub policy_factory: Arc<PolicyFactory>,
|
||||||
|
pub graphql_schema: mas_graphql::Schema,
|
||||||
|
pub http_client_factory: HttpClientFactory,
|
||||||
|
pub password_manager: PasswordManager,
|
||||||
|
pub clock: Arc<MockClock>,
|
||||||
|
pub rng: Arc<Mutex<ChaChaRng>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TestState {
|
||||||
|
/// Create a new test state from the given database pool
|
||||||
|
pub async fn from_pool(pool: PgPool) -> Result<Self, anyhow::Error> {
|
||||||
|
let workspace_root = camino::Utf8Path::new(env!("CARGO_MANIFEST_DIR"))
|
||||||
|
.join("..")
|
||||||
|
.join("..");
|
||||||
|
|
||||||
|
let url_builder = UrlBuilder::new("https://example.com/".parse()?);
|
||||||
|
|
||||||
|
let templates =
|
||||||
|
Templates::load(workspace_root.join("templates"), url_builder.clone()).await?;
|
||||||
|
|
||||||
|
// TODO: add more test keys to the store
|
||||||
|
let rsa =
|
||||||
|
PrivateKey::load_pem(include_str!("../../keystore/tests/keys/rsa.pkcs1.pem")).unwrap();
|
||||||
|
let rsa = JsonWebKey::new(rsa).with_kid("test-rsa");
|
||||||
|
|
||||||
|
let jwks = JsonWebKeySet::new(vec![rsa]);
|
||||||
|
let key_store = Keystore::new(jwks);
|
||||||
|
|
||||||
|
let encrypter = Encrypter::new(&[0x42; 32]);
|
||||||
|
|
||||||
|
let password_manager = PasswordManager::new([(1, Hasher::argon2id(None))])?;
|
||||||
|
|
||||||
|
let transport = MailTransport::blackhole();
|
||||||
|
let mailbox: lettre::message::Mailbox = "server@example.com".parse()?;
|
||||||
|
let mailer = Mailer::new(templates.clone(), transport, mailbox.clone(), mailbox);
|
||||||
|
|
||||||
|
let homeserver = MatrixHomeserver::new("example.com".to_owned());
|
||||||
|
|
||||||
|
let file =
|
||||||
|
tokio::fs::File::open(workspace_root.join("policies").join("policy.wasm")).await?;
|
||||||
|
|
||||||
|
let policy_factory = PolicyFactory::load(
|
||||||
|
file,
|
||||||
|
serde_json::json!({}),
|
||||||
|
"register/violation".to_owned(),
|
||||||
|
"client_registration/violation".to_owned(),
|
||||||
|
"authorization_grant/violation".to_owned(),
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let policy_factory = Arc::new(policy_factory);
|
||||||
|
|
||||||
|
let graphql_schema = graphql_schema();
|
||||||
|
|
||||||
|
let http_client_factory = HttpClientFactory::new(10);
|
||||||
|
|
||||||
|
let clock = Arc::new(MockClock::default());
|
||||||
|
let rng = Arc::new(Mutex::new(ChaChaRng::seed_from_u64(42)));
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
pool,
|
||||||
|
templates,
|
||||||
|
key_store,
|
||||||
|
encrypter,
|
||||||
|
url_builder,
|
||||||
|
mailer,
|
||||||
|
homeserver,
|
||||||
|
policy_factory,
|
||||||
|
graphql_schema,
|
||||||
|
http_client_factory,
|
||||||
|
password_manager,
|
||||||
|
clock,
|
||||||
|
rng,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn request<B>(&self, request: Request<B>) -> Response<String>
|
||||||
|
where
|
||||||
|
B: HttpBody + Send + 'static,
|
||||||
|
B::Error: std::error::Error + Send + Sync,
|
||||||
|
B::Data: Send,
|
||||||
|
{
|
||||||
|
let app = crate::healthcheck_router()
|
||||||
|
.merge(crate::discovery_router())
|
||||||
|
.merge(crate::api_router())
|
||||||
|
.merge(crate::compat_router())
|
||||||
|
.merge(crate::human_router(self.templates.clone()))
|
||||||
|
.with_state(self.clone());
|
||||||
|
|
||||||
|
// Both unwrap are on Infallible, so this is safe
|
||||||
|
let response = app
|
||||||
|
.ready_oneshot()
|
||||||
|
.await
|
||||||
|
.unwrap()
|
||||||
|
.call(request)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let (parts, body) = response.into_parts();
|
||||||
|
|
||||||
|
// This could actually fail, but do we really care about that?
|
||||||
|
let body = hyper::body::to_bytes(body)
|
||||||
|
.await
|
||||||
|
.expect("Failed to read response body");
|
||||||
|
let body = std::str::from_utf8(&body)
|
||||||
|
.expect("Response body is not valid UTF-8")
|
||||||
|
.to_owned();
|
||||||
|
|
||||||
|
Response::from_parts(parts, body)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn repository(&self) -> Result<BoxRepository, RepositoryError> {
|
||||||
|
let repo = PgRepository::from_pool(&self.pool).await?;
|
||||||
|
Ok(repo
|
||||||
|
.map_err(mas_storage::RepositoryError::from_error)
|
||||||
|
.boxed())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns a new random number generator.
|
||||||
|
///
|
||||||
|
/// # Panics
|
||||||
|
///
|
||||||
|
/// Panics if the RNG is already locked.
|
||||||
|
pub fn rng(&self) -> ChaChaRng {
|
||||||
|
let mut parent_rng = self.rng.try_lock().expect("Failed to lock RNG");
|
||||||
|
ChaChaRng::from_rng(&mut *parent_rng).unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Do a call to the userinfo endpoint to check if the given token is valid.
|
||||||
|
/// Returns true if the token is valid.
|
||||||
|
///
|
||||||
|
/// # Panics
|
||||||
|
///
|
||||||
|
/// Panics if the response status code is not 200 or 401.
|
||||||
|
pub async fn is_access_token_valid(&self, token: &str) -> bool {
|
||||||
|
let request = Request::get(mas_router::OidcUserinfo::PATH)
|
||||||
|
.bearer(token)
|
||||||
|
.empty();
|
||||||
|
|
||||||
|
let response = self.request(request).await;
|
||||||
|
|
||||||
|
match response.status() {
|
||||||
|
StatusCode::OK => true,
|
||||||
|
StatusCode::UNAUTHORIZED => false,
|
||||||
|
_ => panic!("Unexpected status code: {}", response.status()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl FromRef<TestState> for PgPool {
|
||||||
|
fn from_ref(input: &TestState) -> Self {
|
||||||
|
input.pool.clone()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl FromRef<TestState> for mas_graphql::Schema {
|
||||||
|
fn from_ref(input: &TestState) -> Self {
|
||||||
|
input.graphql_schema.clone()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl FromRef<TestState> for Templates {
|
||||||
|
fn from_ref(input: &TestState) -> Self {
|
||||||
|
input.templates.clone()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl FromRef<TestState> for Keystore {
|
||||||
|
fn from_ref(input: &TestState) -> Self {
|
||||||
|
input.key_store.clone()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl FromRef<TestState> for Encrypter {
|
||||||
|
fn from_ref(input: &TestState) -> Self {
|
||||||
|
input.encrypter.clone()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl FromRef<TestState> for UrlBuilder {
|
||||||
|
fn from_ref(input: &TestState) -> Self {
|
||||||
|
input.url_builder.clone()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl FromRef<TestState> for Mailer {
|
||||||
|
fn from_ref(input: &TestState) -> Self {
|
||||||
|
input.mailer.clone()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl FromRef<TestState> for MatrixHomeserver {
|
||||||
|
fn from_ref(input: &TestState) -> Self {
|
||||||
|
input.homeserver.clone()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl FromRef<TestState> for Arc<PolicyFactory> {
|
||||||
|
fn from_ref(input: &TestState) -> Self {
|
||||||
|
input.policy_factory.clone()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl FromRef<TestState> for HttpClientFactory {
|
||||||
|
fn from_ref(input: &TestState) -> Self {
|
||||||
|
input.http_client_factory.clone()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl FromRef<TestState> for PasswordManager {
|
||||||
|
fn from_ref(input: &TestState) -> Self {
|
||||||
|
input.password_manager.clone()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl FromRequestParts<TestState> for BoxClock {
|
||||||
|
type Rejection = Infallible;
|
||||||
|
|
||||||
|
async fn from_request_parts(
|
||||||
|
_parts: &mut axum::http::request::Parts,
|
||||||
|
state: &TestState,
|
||||||
|
) -> Result<Self, Self::Rejection> {
|
||||||
|
Ok(Box::new(state.clock.clone()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl FromRequestParts<TestState> for BoxRng {
|
||||||
|
type Rejection = Infallible;
|
||||||
|
|
||||||
|
async fn from_request_parts(
|
||||||
|
_parts: &mut axum::http::request::Parts,
|
||||||
|
state: &TestState,
|
||||||
|
) -> Result<Self, Self::Rejection> {
|
||||||
|
let mut parent_rng = state.rng.lock().await;
|
||||||
|
let rng = ChaChaRng::from_rng(&mut *parent_rng).expect("Failed to seed RNG");
|
||||||
|
Ok(Box::new(rng))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl FromRequestParts<TestState> for BoxRepository {
|
||||||
|
type Rejection = RepositoryError;
|
||||||
|
|
||||||
|
async fn from_request_parts(
|
||||||
|
_parts: &mut axum::http::request::Parts,
|
||||||
|
state: &TestState,
|
||||||
|
) -> Result<Self, Self::Rejection> {
|
||||||
|
let repo = PgRepository::from_pool(&state.pool).await?;
|
||||||
|
Ok(repo
|
||||||
|
.map_err(mas_storage::RepositoryError::from_error)
|
||||||
|
.boxed())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) trait RequestBuilderExt {
|
||||||
|
/// Builds the request with the given JSON value as body.
|
||||||
|
fn json<T: Serialize>(self, body: T) -> hyper::Request<String>;
|
||||||
|
|
||||||
|
/// Builds the request with the given form value as body.
|
||||||
|
fn form<T: Serialize>(self, body: T) -> hyper::Request<String>;
|
||||||
|
|
||||||
|
/// Sets the request Authorization header to the given bearer token.
|
||||||
|
fn bearer(self, token: &str) -> Self;
|
||||||
|
|
||||||
|
/// Builds the request with an empty body.
|
||||||
|
fn empty(self) -> hyper::Request<String>;
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RequestBuilderExt for hyper::http::request::Builder {
|
||||||
|
fn json<T: Serialize>(mut self, body: T) -> hyper::Request<String> {
|
||||||
|
self.headers_mut()
|
||||||
|
.unwrap()
|
||||||
|
.typed_insert(ContentType::json());
|
||||||
|
|
||||||
|
self.body(serde_json::to_string(&body).unwrap()).unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn form<T: Serialize>(mut self, body: T) -> hyper::Request<String> {
|
||||||
|
self.headers_mut()
|
||||||
|
.unwrap()
|
||||||
|
.typed_insert(ContentType::form_url_encoded());
|
||||||
|
|
||||||
|
self.body(serde_urlencoded::to_string(&body).unwrap())
|
||||||
|
.unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn bearer(mut self, token: &str) -> Self {
|
||||||
|
self.headers_mut()
|
||||||
|
.unwrap()
|
||||||
|
.typed_insert(Authorization::bearer(token).unwrap());
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
fn empty(self) -> hyper::Request<String> {
|
||||||
|
self.body(String::new()).unwrap()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) trait ResponseExt {
|
||||||
|
/// Asserts that the response has the given status code.
|
||||||
|
///
|
||||||
|
/// # Panics
|
||||||
|
///
|
||||||
|
/// Panics if the response has a different status code.
|
||||||
|
fn assert_status(&self, status: StatusCode);
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ResponseExt for Response<String> {
|
||||||
|
#[track_caller]
|
||||||
|
fn assert_status(&self, status: StatusCode) {
|
||||||
|
assert_eq!(
|
||||||
|
self.status(),
|
||||||
|
status,
|
||||||
|
"HTTP status code mismatch: got {}, expected {}. Body: {}",
|
||||||
|
self.status(),
|
||||||
|
status,
|
||||||
|
self.body()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
@ -18,7 +18,7 @@
|
|||||||
//! [`SystemClock`] which uses the system time, and a [`MockClock`], which can
|
//! [`SystemClock`] which uses the system time, and a [`MockClock`], which can
|
||||||
//! be used and freely manipulated in tests.
|
//! be used and freely manipulated in tests.
|
||||||
|
|
||||||
use std::sync::atomic::AtomicI64;
|
use std::sync::{atomic::AtomicI64, Arc};
|
||||||
|
|
||||||
use chrono::{DateTime, TimeZone, Utc};
|
use chrono::{DateTime, TimeZone, Utc};
|
||||||
|
|
||||||
@ -28,6 +28,12 @@ pub trait Clock: Sync {
|
|||||||
fn now(&self) -> DateTime<Utc>;
|
fn now(&self) -> DateTime<Utc>;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<C: Clock + Send + ?Sized> Clock for Arc<C> {
|
||||||
|
fn now(&self) -> DateTime<Utc> {
|
||||||
|
(**self).now()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl<C: Clock + ?Sized> Clock for Box<C> {
|
impl<C: Clock + ?Sized> Clock for Box<C> {
|
||||||
fn now(&self) -> DateTime<Utc> {
|
fn now(&self) -> DateTime<Utc> {
|
||||||
(**self).now()
|
(**self).now()
|
||||||
|
Reference in New Issue
Block a user