1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-31 09:24:31 +03:00

Bump the latest axum rc

This commit is contained in:
Quentin Gliech
2022-11-18 14:41:07 +01:00
parent 0ecda1e468
commit c76a1dd2e7
22 changed files with 225 additions and 189 deletions

View File

@ -7,8 +7,8 @@ license = "Apache-2.0"
[dependencies]
async-trait = "0.1.58"
axum = { version = "0.6.0-rc.2", features = ["headers"] }
axum-extra = { version = "0.4.0-rc.1", features = ["cookie-private"] }
axum = { version = "0.6.0-rc.4", features = ["headers"] }
axum-extra = { version = "0.4.0-rc.2", features = ["cookie-private"] }
bincode = "1.3.3"
chrono = "0.4.23"
data-encoding = "2.3.2"
@ -21,7 +21,7 @@ rand = "0.8.5"
serde = "1.0.147"
serde_with = "2.1.0"
serde_urlencoded = "0.7.1"
serde_json = "1.0.87"
serde_json = "1.0.88"
sqlx = "0.6.2"
thiserror = "1.0.37"
tokio = "1.21.2"

View File

@ -18,7 +18,7 @@ use async_trait::async_trait;
use axum::{
body::HttpBody,
extract::{
rejection::{FailedToDeserializeQueryString, FormRejection, TypedHeaderRejectionReason},
rejection::{FailedToDeserializeForm, FormRejection, TypedHeaderRejectionReason},
Form, FromRequest, FromRequestParts, TypedHeader,
},
response::IntoResponse,
@ -217,7 +217,7 @@ pub struct ClientAuthorization<F = ()> {
#[derive(Debug)]
pub enum ClientAuthorizationError {
InvalidHeader,
BadForm(FailedToDeserializeQueryString),
BadForm(FailedToDeserializeForm),
ClientIdMismatch { credential: String, form: String },
UnsupportedClientAssertion { client_assertion_type: String },
MissingCredentials,
@ -284,7 +284,7 @@ where
// If it is not a form, continue
Err(FormRejection::InvalidFormContentType(_err)) => (None, None, None, None, None),
// If the form could not be read, return a Bad Request error
Err(FormRejection::FailedToDeserializeQueryString(err)) => {
Err(FormRejection::FailedToDeserializeForm(err)) => {
return Err(ClientAuthorizationError::BadForm(err))
}
// Other errors (body read twice, byte stream broke) return an internal error

View File

@ -18,7 +18,7 @@ use async_trait::async_trait;
use axum::{
body::HttpBody,
extract::{
rejection::{FailedToDeserializeQueryString, FormRejection, TypedHeaderRejectionReason},
rejection::{FailedToDeserializeForm, FormRejection, TypedHeaderRejectionReason},
Form, FromRequest, FromRequestParts, TypedHeader,
},
response::{IntoResponse, Response},
@ -109,7 +109,7 @@ impl<F: Send> UserAuthorization<F> {
pub enum UserAuthorizationError {
InvalidHeader,
TokenInFormAndHeader,
BadForm(FailedToDeserializeQueryString),
BadForm(FailedToDeserializeForm),
InternalError(Box<dyn Error>),
}
@ -311,7 +311,7 @@ where
// If it is not a form, continue
Err(FormRejection::InvalidFormContentType(_err)) => (None, None),
// If the form could not be read, return a Bad Request error
Err(FormRejection::FailedToDeserializeQueryString(err)) => {
Err(FormRejection::FailedToDeserializeForm(err)) => {
return Err(UserAuthorizationError::BadForm(err))
}
// Other errors (body read twice, byte stream broke) return an internal error

View File

@ -9,7 +9,7 @@ license = "Apache-2.0"
anyhow = "1.0.66"
argon2 = { version = "0.4.1", features = ["password-hash"] }
atty = "0.2.14"
axum = "0.6.0-rc.2"
axum = "0.6.0-rc.4"
clap = { version = "4.0.26", features = ["derive"] }
dotenv = "0.15.0"
futures-util = "0.3.25"
@ -19,7 +19,7 @@ listenfd = "1.0.0"
rand = "0.8.5"
rand_chacha = "0.3.1"
rustls = "0.20.7"
serde_json = "1.0.87"
serde_json = "1.0.88"
serde_yaml = "0.9.14"
tokio = { version = "1.21.2", features = ["full"] }
tower = { version = "0.4.13", features = ["full"] }

View File

@ -205,7 +205,7 @@ impl Options {
let graphql_schema = mas_handlers::graphql_schema(&pool);
let state = Arc::new(AppState {
let state = AppState {
pool,
templates,
key_store,
@ -215,7 +215,7 @@ impl Options {
homeserver,
policy_factory,
graphql_schema,
});
};
let mut fd_manager = listenfd::ListenFd::from_env();
@ -234,8 +234,9 @@ impl Options {
};
// and build the router
let router = crate::server::build_router(&state, &config.resources)
.layer(ServerLayer::new(config.name.clone()));
let router = crate::server::build_router(state.clone(), &config.resources)
.layer(ServerLayer::new(config.name.clone()))
.into_service();
// Display some informations about where we'll be serving connections
let is_tls = config.tls.is_some();

View File

@ -16,11 +16,10 @@ use std::{
future::ready,
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, TcpListener, ToSocketAddrs},
os::unix::net::UnixListener,
sync::Arc,
};
use anyhow::Context;
use axum::{body::HttpBody, error_handling::HandleErrorLayer, Extension, Router};
use axum::{body::HttpBody, error_handling::HandleErrorLayer, extract::FromRef, Extension, Router};
use hyper::StatusCode;
use listenfd::ListenFd;
use mas_config::{HttpBindConfig, HttpResource, HttpTlsConfig, UnixOrTcp};
@ -28,45 +27,47 @@ use mas_handlers::AppState;
use mas_listener::{unix_or_tcp::UnixOrTcpListener, ConnectionInfo};
use mas_router::Route;
use mas_spa::ViteManifestService;
use mas_templates::Templates;
use rustls::ServerConfig;
use tower::Layer;
use tower_http::services::ServeDir;
#[allow(clippy::trait_duplication_in_bounds)]
pub fn build_router<B>(state: &Arc<AppState>, resources: &[HttpResource]) -> Router<AppState, B>
pub fn build_router<B>(state: AppState, resources: &[HttpResource]) -> Router<AppState, B>
where
B: HttpBody + Send + 'static,
<B as HttpBody>::Data: Into<axum::body::Bytes> + Send,
<B as HttpBody>::Error: std::error::Error + Send + Sync,
{
let mut router = Router::with_state_arc(state.clone());
let templates = Templates::from_ref(&state);
let mut router = Router::with_state(state);
for resource in resources {
router = match resource {
mas_config::HttpResource::Health => {
router.merge(mas_handlers::healthcheck_router(state.clone()))
router.merge(mas_handlers::healthcheck_router::<AppState, B>())
}
mas_config::HttpResource::Prometheus => {
router.route_service("/metrics", crate::telemetry::prometheus_service())
}
mas_config::HttpResource::Discovery => {
router.merge(mas_handlers::discovery_router(state.clone()))
router.merge(mas_handlers::discovery_router::<AppState, B>())
}
mas_config::HttpResource::Human => {
router.merge(mas_handlers::human_router(state.clone()))
router.merge(mas_handlers::human_router::<AppState, B>(templates.clone()))
}
mas_config::HttpResource::GraphQL { playground } => {
router.merge(mas_handlers::graphql_router(state.clone(), *playground))
router.merge(mas_handlers::graphql_router::<AppState, B>(*playground))
}
mas_config::HttpResource::Static { web_root } => {
let handler = mas_static_files::service(web_root);
router.nest(mas_router::StaticAsset::route(), handler)
router.nest_service(mas_router::StaticAsset::route(), handler)
}
mas_config::HttpResource::OAuth => {
router.merge(mas_handlers::api_router(state.clone()))
router.merge(mas_handlers::api_router::<AppState, B>())
}
mas_config::HttpResource::Compat => {
router.merge(mas_handlers::compat_router(state.clone()))
router.merge(mas_handlers::compat_router::<AppState, B>())
}
// TODO: do a better handler here
mas_config::HttpResource::ConnectionInfo => router.route(
@ -99,8 +100,8 @@ where
let static_service = ServeDir::new(assets).append_index_html_on_directories(false);
router
.nest(app_base, error_layer.layer(index_service))
.nest(assets_base, error_layer.layer(static_service))
.nest_service(app_base, error_layer.layer(index_service))
.nest_service(assets_base, error_layer.layer(static_service))
}
}
}

View File

@ -21,7 +21,7 @@ ulid = { version = "1.0.0", features = ["serde"] }
serde = { version = "1.0.147", features = ["derive"] }
serde_with = { version = "2.1.0", features = ["hex", "chrono"] }
serde_json = "1.0.87"
serde_json = "1.0.88"
sqlx = { version = "0.6.2", features = ["runtime-tokio-rustls", "postgres"] }
lettre = { version = "0.10.1", default-features = false, features = ["serde", "builder"] }

View File

@ -21,9 +21,9 @@ anyhow = "1.0.66"
hyper = { version = "0.14.23", features = ["full"] }
tower = "0.4.13"
tower-http = { version = "0.3.4", features = ["cors"] }
axum = { version = "0.6.0-rc.2", features = ["ws"] }
axum-macros = "0.3.0-rc.1"
axum-extra = { version = "0.4.0-rc.1", features = ["cookie-private"] }
axum = { version = "0.6.0-rc.4", features = ["ws"] }
axum-macros = "0.3.0-rc.2"
axum-extra = { version = "0.4.0-rc.2", features = ["cookie-private"] }
async-graphql = { version = "4.0.16", features = ["tracing", "apollo_tracing"] }
@ -36,7 +36,7 @@ sqlx = { version = "0.6.2", features = ["runtime-tokio-rustls", "postgres"] }
# Various structure (de)serialization
serde = { version = "1.0.147", features = ["derive"] }
serde_with = { version = "2.1.0", features = ["hex", "chrono"] }
serde_json = "1.0.87"
serde_json = "1.0.88"
serde_urlencoded = "0.7.1"
# Password hashing

View File

@ -39,7 +39,7 @@ mod tests {
#[sqlx::test(migrator = "mas_storage::MIGRATOR")]
async fn test_get_health(pool: PgPool) -> Result<(), anyhow::Error> {
let state = crate::test_state(pool).await?;
let app = crate::router(state);
let app = crate::router(state).into_service();
let request = Request::builder().uri("/health").body(Body::empty())?;

View File

@ -59,35 +59,35 @@ pub use compat::MatrixHomeserver;
pub use self::{app_state::AppState, graphql::schema as graphql_schema};
#[must_use]
pub fn empty_router<S, B>(state: Arc<S>) -> Router<S, B>
pub fn empty_router<S, B>(state: S) -> Router<S, B>
where
B: HttpBody + Send + 'static,
S: Send + Sync + 'static,
S: Clone + Send + Sync + 'static,
{
Router::with_state_arc(state)
Router::with_state(state)
}
#[must_use]
pub fn healthcheck_router<S, B>(state: Arc<S>) -> Router<S, B>
pub fn healthcheck_router<S, B>() -> Router<S, B>
where
B: HttpBody + Send + 'static,
S: Send + Sync + 'static,
S: Clone + Send + Sync + 'static,
PgPool: FromRef<S>,
{
Router::with_state_arc(state).route(mas_router::Healthcheck::route(), get(self::health::get))
Router::inherit_state().route(mas_router::Healthcheck::route(), get(self::health::get))
}
#[must_use]
pub fn graphql_router<S, B>(state: Arc<S>, playground: bool) -> Router<S, B>
pub fn graphql_router<S, B>(playground: bool) -> Router<S, B>
where
B: HttpBody + Send + 'static,
<B as HttpBody>::Data: Into<Bytes>,
<B as HttpBody>::Error: std::error::Error + Send + Sync,
S: Send + Sync + 'static,
S: Clone + Send + Sync + 'static,
mas_graphql::Schema: FromRef<S>,
Encrypter: FromRef<S>,
{
let mut router = Router::with_state_arc(state)
let mut router = Router::inherit_state()
.route(
"/graphql",
get(self::graphql::get).post(self::graphql::post),
@ -102,14 +102,14 @@ where
}
#[must_use]
pub fn discovery_router<S, B>(state: Arc<S>) -> Router<S, B>
pub fn discovery_router<S, B>() -> Router<S, B>
where
B: HttpBody + Send + 'static,
S: Send + Sync + 'static,
S: Clone + Send + Sync + 'static,
Keystore: FromRef<S>,
UrlBuilder: FromRef<S>,
{
Router::with_state_arc(state)
Router::inherit_state()
.route(
mas_router::OidcConfiguration::route(),
get(self::oauth2::discovery::get),
@ -135,12 +135,12 @@ where
#[must_use]
#[allow(clippy::trait_duplication_in_bounds)]
pub fn api_router<S, B>(state: Arc<S>) -> Router<S, B>
pub fn api_router<S, B>() -> Router<S, B>
where
B: HttpBody + Send + 'static,
<B as HttpBody>::Data: Send,
<B as HttpBody>::Error: std::error::Error + Send + Sync,
S: Send + Sync + 'static,
S: Clone + Send + Sync + 'static,
Keystore: FromRef<S>,
UrlBuilder: FromRef<S>,
Arc<PolicyFactory>: FromRef<S>,
@ -148,7 +148,7 @@ where
Encrypter: FromRef<S>,
{
// All those routes are API-like, with a common CORS layer
Router::with_state_arc(state)
Router::inherit_state()
.route(
mas_router::OAuth2Keys::route(),
get(self::oauth2::keys::get),
@ -189,17 +189,17 @@ where
#[must_use]
#[allow(clippy::trait_duplication_in_bounds)]
pub fn compat_router<S, B>(state: Arc<S>) -> Router<S, B>
pub fn compat_router<S, B>() -> Router<S, B>
where
B: HttpBody + Send + 'static,
<B as HttpBody>::Data: Send,
<B as HttpBody>::Error: std::error::Error + Send + Sync,
S: Send + Sync + 'static,
S: Clone + Send + Sync + 'static,
UrlBuilder: FromRef<S>,
PgPool: FromRef<S>,
MatrixHomeserver: FromRef<S>,
{
Router::with_state_arc(state)
Router::inherit_state()
.route(
mas_router::CompatLogin::route(),
get(self::compat::login::get).post(self::compat::login::post),
@ -230,12 +230,12 @@ where
#[must_use]
#[allow(clippy::trait_duplication_in_bounds)]
pub fn human_router<S, B>(state: Arc<S>) -> Router<S, B>
pub fn human_router<S, B>(templates: Templates) -> Router<S, B>
where
B: HttpBody + Send + 'static,
<B as HttpBody>::Data: Send,
<B as HttpBody>::Error: std::error::Error + Send + Sync,
S: Send + Sync + 'static,
S: Clone + Send + Sync + 'static,
UrlBuilder: FromRef<S>,
Arc<PolicyFactory>: FromRef<S>,
PgPool: FromRef<S>,
@ -243,8 +243,7 @@ where
Templates: FromRef<S>,
Mailer: FromRef<S>,
{
let templates = Templates::from_ref(&state);
Router::with_state_arc(state)
Router::inherit_state()
.route(
mas_router::ChangePasswordDiscovery::route(),
get(|| async { mas_router::AccountPassword.go() }),
@ -327,12 +326,12 @@ where
#[must_use]
#[allow(clippy::trait_duplication_in_bounds)]
pub fn router<S, B>(state: Arc<S>) -> Router<S, B>
pub fn router<S, B>(state: S) -> Router<S, B>
where
B: HttpBody + Send + 'static,
<B as HttpBody>::Data: Into<Bytes> + Send,
<B as HttpBody>::Error: std::error::Error + Send + Sync,
S: Send + Sync + 'static,
S: Clone + Send + Sync + 'static,
Keystore: FromRef<S>,
UrlBuilder: FromRef<S>,
Arc<PolicyFactory>: FromRef<S>,
@ -343,14 +342,14 @@ where
MatrixHomeserver: FromRef<S>,
mas_graphql::Schema: FromRef<S>,
{
let healthcheck_router = healthcheck_router(state.clone());
let discovery_router = discovery_router(state.clone());
let api_router = api_router(state.clone());
let graphql_router = graphql_router(state.clone(), true);
let compat_router = compat_router(state.clone());
let human_router = human_router(state.clone());
let healthcheck_router = healthcheck_router();
let discovery_router = discovery_router();
let api_router = api_router();
let graphql_router = graphql_router(true);
let compat_router = compat_router();
let human_router = human_router(Templates::from_ref(&state));
Router::with_state_arc(state)
Router::with_state(state)
.merge(healthcheck_router)
.merge(discovery_router)
.merge(human_router)
@ -360,7 +359,7 @@ where
}
#[cfg(test)]
async fn test_state(pool: PgPool) -> Result<Arc<AppState>, anyhow::Error> {
async fn test_state(pool: PgPool) -> Result<AppState, anyhow::Error> {
use mas_email::MailTransport;
let url_builder = UrlBuilder::new("https://example.com/".parse()?);
@ -382,7 +381,7 @@ async fn test_state(pool: PgPool) -> Result<Arc<AppState>, anyhow::Error> {
let graphql_schema = graphql_schema(&pool);
Ok(Arc::new(AppState {
Ok(AppState {
pool,
templates,
key_store,
@ -392,7 +391,7 @@ async fn test_state(pool: PgPool) -> Result<Arc<AppState>, anyhow::Error> {
homeserver,
policy_factory,
graphql_schema,
}))
})
}
// XXX: that should be moved somewhere else

View File

@ -8,7 +8,7 @@ license = "Apache-2.0"
[dependencies]
aws-smithy-http = { version = "0.51.0", optional = true }
aws-types = { version = "0.51.0", optional = true }
axum = { version = "0.6.0-rc.2", optional = true }
axum = { version = "0.6.0-rc.4", optional = true }
bytes = "1.2.1"
futures-util = "0.3.25"
headers = "0.3.8"
@ -23,7 +23,7 @@ opentelemetry-semantic-conventions = "0.10.0"
rustls = { version = "0.20.7", optional = true }
rustls-native-certs = { version = "0.6.2", optional = true }
serde = "1.0.147"
serde_json = "1.0.87"
serde_json = "1.0.88"
serde_urlencoded = "0.7.1"
thiserror = "1.0.37"
tokio = { version = "1.21.2", features = ["sync", "parking_lot"], optional = true }

View File

@ -27,6 +27,15 @@ pub struct ServerLayer<ReqBody> {
_t: PhantomData<ReqBody>,
}
impl<B> Clone for ServerLayer<B> {
fn clone(&self) -> Self {
Self {
listener_name: self.listener_name.clone(),
_t: PhantomData,
}
}
}
impl<B> ServerLayer<B> {
#[must_use]
pub fn new(listener_name: Option<String>) -> Self {

View File

@ -11,7 +11,7 @@ async-trait = "0.1.58"
convert_case = "0.6.0"
csv = "1.1.6"
futures-util = "0.3.25"
reqwest = { version = "0.11.12", features = ["blocking", "rustls-tls"], default-features = false }
reqwest = { version = "0.11.13", features = ["blocking", "rustls-tls"], default-features = false }
serde = { version = "1.0.147", features = ["derive"] }
tokio = { version = "1.21.2", features = ["full"] }
tracing = "0.1.37"

View File

@ -22,7 +22,7 @@ rsa = "0.7.2"
schemars = "0.8.11"
sec1 = "0.3.0"
serde = { version = "1.0.147", features = ["derive"] }
serde_json = "1.0.87"
serde_json = "1.0.88"
serde_with = { version = "2.1.0", features = ["base64"] }
sha2 = { version = "0.10.6", features = ["oid"] }
signature = "1.6.4"

View File

@ -8,7 +8,7 @@ license = "Apache-2.0"
[dependencies]
http = "0.2.8"
serde = "1.0.147"
serde_json = "1.0.87"
serde_json = "1.0.88"
language-tags = { version = "0.3.2", features = ["serde"] }
url = { version = "2.3.1", features = ["serde"] }
parse-display = "0.6.0"

View File

@ -9,7 +9,7 @@ license = "Apache-2.0"
anyhow = "1.0.66"
opa-wasm = { git = "https://github.com/matrix-org/rust-opa-wasm.git" }
serde = { version = "1.0.147", features = ["derive"] }
serde_json = "1.0.87"
serde_json = "1.0.88"
thiserror = "1.0.37"
tokio = { version = "1.21.2", features = ["io-util", "rt"] }
tracing = "0.1.37"

View File

@ -6,7 +6,7 @@ edition = "2021"
license = "Apache-2.0"
[dependencies]
axum = { version = "0.6.0-rc.2", default-features = false }
axum = { version = "0.6.0-rc.4", default-features = false }
serde = { version = "1.0.147", features = ["derive"] }
serde_urlencoded = "0.7.1"
serde_with = "2.1.0"

View File

@ -7,10 +7,10 @@ license = "Apache-2.0"
[dependencies]
serde = { version = "1.0.147", features = ["derive"] }
serde_json = "1.0.87"
serde_json = "1.0.88"
thiserror = "1.0.37"
camino = { version = "1.1.1", features = ["serde1"] }
headers = "0.3.2"
headers = "0.3.8"
http = "0.2.8"
tower-service = "0.3.2"
tower-http = { version = "0.3.4", features = ["fs"] }

View File

@ -9,7 +9,7 @@ license = "Apache-2.0"
dev = []
[dependencies]
axum = { version = "0.6.0-rc.2", features = ["headers"] }
axum = { version = "0.6.0-rc.4", features = ["headers"] }
headers = "0.3.8"
http = "0.2.8"
http-body = "0.4.5"

View File

@ -10,7 +10,7 @@ tokio = "1.21.2"
sqlx = { version = "0.6.2", features = ["runtime-tokio-rustls", "postgres", "migrate", "chrono", "offline", "json", "uuid"] }
chrono = { version = "0.4.23", features = ["serde"] }
serde = { version = "1.0.147", features = ["derive"] }
serde_json = "1.0.87"
serde_json = "1.0.88"
thiserror = "1.0.37"
anyhow = "1.0.66"
tracing = "0.1.37"
@ -21,7 +21,7 @@ password-hash = { version = "0.4.2", features = ["std"] }
rand = "0.8.5"
rand_chacha = "0.3.1"
url = { version = "2.3.1", features = ["serde"] }
uuid = "1.2.1"
uuid = "1.2.2"
ulid = { version = "1.0.0", features = ["uuid", "serde"] }
oauth2-types = { path = "../oauth2-types" }

View File

@ -17,7 +17,7 @@ thiserror = "1.0.37"
tera = "1.17.1"
serde = { version = "1.0.147", features = ["derive"] }
serde_json = "1.0.87"
serde_json = "1.0.88"
serde_urlencoded = "0.7.1"
chrono = "0.4.23"