You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-07-29 22:01:14 +03:00
Upgrade axum to 0.6.0-rc.1
This commit is contained in:
16
Cargo.lock
generated
16
Cargo.lock
generated
@ -516,9 +516,9 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "axum"
|
name = "axum"
|
||||||
version = "0.5.15"
|
version = "0.6.0-rc.1"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "9de18bc5f2e9df8f52da03856bf40e29b747de5a84e43aefff90e3dc4a21529b"
|
checksum = "d49958d54e0bab71947eb00a33175eb9164ccc0ea4c262d2139c5f8899a3616e"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"async-trait",
|
"async-trait",
|
||||||
"axum-core",
|
"axum-core",
|
||||||
@ -548,9 +548,9 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "axum-core"
|
name = "axum-core"
|
||||||
version = "0.2.7"
|
version = "0.3.0-rc.1"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "e4f44a0e6200e9d11a1cdc989e4b358f6e3d354fbf48478f345a17f4e43f8635"
|
checksum = "5e52ebadfce2f1e7fec9b2dd920952477ffeac9f07a5c492c0cbba2bb22cd294"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"async-trait",
|
"async-trait",
|
||||||
"bytes 1.2.1",
|
"bytes 1.2.1",
|
||||||
@ -562,9 +562,9 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "axum-extra"
|
name = "axum-extra"
|
||||||
version = "0.3.7"
|
version = "0.4.0-rc.1"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "69034b3b0fd97923eee2ce8a47540edb21e07f48f87f67d44bb4271cec622bdb"
|
checksum = "090ae29ae83a40882fb99bce421d8dce4a819325a013bb301dad4cc4b74ab40c"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"axum",
|
"axum",
|
||||||
"bytes 1.2.1",
|
"bytes 1.2.1",
|
||||||
@ -2730,9 +2730,9 @@ checksum = "a3e378b66a060d48947b590737b30a1be76706c8dd7b8ba0f2fe3989c68a853f"
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "matchit"
|
name = "matchit"
|
||||||
version = "0.5.0"
|
version = "0.6.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "73cbba799671b762df5a175adf59ce145165747bb891505c43d09aefbbf38beb"
|
checksum = "3dfc802da7b1cf80aefffa0c7b2f77247c8b32206cc83c270b61264f5b360a80"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "md-5"
|
name = "md-5"
|
||||||
|
@ -7,8 +7,8 @@ license = "Apache-2.0"
|
|||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
async-trait = "0.1.57"
|
async-trait = "0.1.57"
|
||||||
axum = { version = "0.5.15", features = ["headers"] }
|
axum = { version = "0.6.0-rc.1", features = ["headers"] }
|
||||||
axum-extra = { version = "0.3.7", features = ["cookie-private"] }
|
axum-extra = { version = "0.4.0-rc.1", features = ["cookie-private"] }
|
||||||
bincode = "1.3.3"
|
bincode = "1.3.3"
|
||||||
chrono = "0.4.22"
|
chrono = "0.4.22"
|
||||||
data-encoding = "2.3.2"
|
data-encoding = "2.3.2"
|
||||||
|
@ -19,13 +19,13 @@ use axum::{
|
|||||||
body::HttpBody,
|
body::HttpBody,
|
||||||
extract::{
|
extract::{
|
||||||
rejection::{FailedToDeserializeQueryString, FormRejection, TypedHeaderRejectionReason},
|
rejection::{FailedToDeserializeQueryString, FormRejection, TypedHeaderRejectionReason},
|
||||||
Form, FromRequest, RequestParts, TypedHeader,
|
Form, FromRequest, FromRequestParts, TypedHeader,
|
||||||
},
|
},
|
||||||
response::IntoResponse,
|
response::IntoResponse,
|
||||||
BoxError,
|
BoxError,
|
||||||
};
|
};
|
||||||
use headers::{authorization::Basic, Authorization};
|
use headers::{authorization::Basic, Authorization};
|
||||||
use http::StatusCode;
|
use http::{Request, StatusCode};
|
||||||
use mas_data_model::{Client, JwksOrJwksUri, StorageBackend};
|
use mas_data_model::{Client, JwksOrJwksUri, StorageBackend};
|
||||||
use mas_http::HttpServiceExt;
|
use mas_http::HttpServiceExt;
|
||||||
use mas_iana::oauth::OAuthClientAuthenticationMethod;
|
use mas_iana::oauth::OAuthClientAuthenticationMethod;
|
||||||
@ -234,18 +234,23 @@ impl IntoResponse for ClientAuthorizationError {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl<B, F> FromRequest<B> for ClientAuthorization<F>
|
impl<S, B, F> FromRequest<S, B> for ClientAuthorization<F>
|
||||||
where
|
where
|
||||||
B: Send + HttpBody,
|
|
||||||
B::Data: Send,
|
|
||||||
B::Error: std::error::Error + Send + Sync + 'static,
|
|
||||||
F: DeserializeOwned,
|
F: DeserializeOwned,
|
||||||
|
B: HttpBody + Send + 'static,
|
||||||
|
B::Data: Send,
|
||||||
|
B::Error: Into<BoxError>,
|
||||||
|
S: Send + Sync,
|
||||||
{
|
{
|
||||||
type Rejection = ClientAuthorizationError;
|
type Rejection = ClientAuthorizationError;
|
||||||
|
|
||||||
#[allow(clippy::too_many_lines)]
|
#[allow(clippy::too_many_lines)]
|
||||||
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
|
async fn from_request(req: Request<B>, state: &S) -> Result<Self, Self::Rejection> {
|
||||||
let header = TypedHeader::<Authorization<Basic>>::from_request(req).await;
|
// Split the request into parts so we can extract some headers
|
||||||
|
let (mut parts, body) = req.into_parts();
|
||||||
|
|
||||||
|
let header =
|
||||||
|
TypedHeader::<Authorization<Basic>>::from_request_parts(&mut parts, state).await;
|
||||||
|
|
||||||
// Take the Authorization header
|
// Take the Authorization header
|
||||||
let credentials_from_header = match header {
|
let credentials_from_header = match header {
|
||||||
@ -258,6 +263,9 @@ where
|
|||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Reconstruct the request from the parts
|
||||||
|
let req = Request::from_parts(parts, body);
|
||||||
|
|
||||||
// Take the form value
|
// Take the form value
|
||||||
let (
|
let (
|
||||||
client_id_from_form,
|
client_id_from_form,
|
||||||
@ -265,7 +273,7 @@ where
|
|||||||
client_assertion_type,
|
client_assertion_type,
|
||||||
client_assertion,
|
client_assertion,
|
||||||
form,
|
form,
|
||||||
) = match Form::<AuthorizedForm<F>>::from_request(req).await {
|
) = match Form::<AuthorizedForm<F>>::from_request(req, state).await {
|
||||||
Ok(Form(form)) => (
|
Ok(Form(form)) => (
|
||||||
form.client_id,
|
form.client_id,
|
||||||
form.client_secret,
|
form.client_secret,
|
||||||
@ -385,19 +393,17 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn none_test() {
|
async fn none_test() {
|
||||||
let mut req = RequestParts::new(
|
let req = Request::builder()
|
||||||
Request::builder()
|
.method(Method::POST)
|
||||||
.method(Method::POST)
|
.header(
|
||||||
.header(
|
http::header::CONTENT_TYPE,
|
||||||
http::header::CONTENT_TYPE,
|
mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
|
||||||
mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
|
)
|
||||||
)
|
.body(Full::<Bytes>::new("client_id=client-id&foo=bar".into()))
|
||||||
.body(Full::<Bytes>::new("client_id=client-id&foo=bar".into()))
|
.unwrap();
|
||||||
.unwrap(),
|
|
||||||
);
|
|
||||||
|
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
ClientAuthorization::<serde_json::Value>::from_request(&mut req)
|
ClientAuthorization::<serde_json::Value>::from_request(req, &())
|
||||||
.await
|
.await
|
||||||
.unwrap(),
|
.unwrap(),
|
||||||
ClientAuthorization {
|
ClientAuthorization {
|
||||||
@ -411,23 +417,21 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn client_secret_basic_test() {
|
async fn client_secret_basic_test() {
|
||||||
let mut req = RequestParts::new(
|
let req = Request::builder()
|
||||||
Request::builder()
|
.method(Method::POST)
|
||||||
.method(Method::POST)
|
.header(
|
||||||
.header(
|
http::header::CONTENT_TYPE,
|
||||||
http::header::CONTENT_TYPE,
|
mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
|
||||||
mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
|
)
|
||||||
)
|
.header(
|
||||||
.header(
|
http::header::AUTHORIZATION,
|
||||||
http::header::AUTHORIZATION,
|
"Basic Y2xpZW50LWlkOmNsaWVudC1zZWNyZXQ=",
|
||||||
"Basic Y2xpZW50LWlkOmNsaWVudC1zZWNyZXQ=",
|
)
|
||||||
)
|
.body(Full::<Bytes>::new("foo=bar".into()))
|
||||||
.body(Full::<Bytes>::new("foo=bar".into()))
|
.unwrap();
|
||||||
.unwrap(),
|
|
||||||
);
|
|
||||||
|
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
ClientAuthorization::<serde_json::Value>::from_request(&mut req)
|
ClientAuthorization::<serde_json::Value>::from_request(req, &())
|
||||||
.await
|
.await
|
||||||
.unwrap(),
|
.unwrap(),
|
||||||
ClientAuthorization {
|
ClientAuthorization {
|
||||||
@ -440,23 +444,21 @@ mod tests {
|
|||||||
);
|
);
|
||||||
|
|
||||||
// client_id in both header and body
|
// client_id in both header and body
|
||||||
let mut req = RequestParts::new(
|
let req = Request::builder()
|
||||||
Request::builder()
|
.method(Method::POST)
|
||||||
.method(Method::POST)
|
.header(
|
||||||
.header(
|
http::header::CONTENT_TYPE,
|
||||||
http::header::CONTENT_TYPE,
|
mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
|
||||||
mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
|
)
|
||||||
)
|
.header(
|
||||||
.header(
|
http::header::AUTHORIZATION,
|
||||||
http::header::AUTHORIZATION,
|
"Basic Y2xpZW50LWlkOmNsaWVudC1zZWNyZXQ=",
|
||||||
"Basic Y2xpZW50LWlkOmNsaWVudC1zZWNyZXQ=",
|
)
|
||||||
)
|
.body(Full::<Bytes>::new("client_id=client-id&foo=bar".into()))
|
||||||
.body(Full::<Bytes>::new("client_id=client-id&foo=bar".into()))
|
.unwrap();
|
||||||
.unwrap(),
|
|
||||||
);
|
|
||||||
|
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
ClientAuthorization::<serde_json::Value>::from_request(&mut req)
|
ClientAuthorization::<serde_json::Value>::from_request(req, &())
|
||||||
.await
|
.await
|
||||||
.unwrap(),
|
.unwrap(),
|
||||||
ClientAuthorization {
|
ClientAuthorization {
|
||||||
@ -469,62 +471,56 @@ mod tests {
|
|||||||
);
|
);
|
||||||
|
|
||||||
// client_id in both header and body mismatch
|
// client_id in both header and body mismatch
|
||||||
let mut req = RequestParts::new(
|
let req = Request::builder()
|
||||||
Request::builder()
|
.method(Method::POST)
|
||||||
.method(Method::POST)
|
.header(
|
||||||
.header(
|
http::header::CONTENT_TYPE,
|
||||||
http::header::CONTENT_TYPE,
|
mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
|
||||||
mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
|
)
|
||||||
)
|
.header(
|
||||||
.header(
|
http::header::AUTHORIZATION,
|
||||||
http::header::AUTHORIZATION,
|
"Basic Y2xpZW50LWlkOmNsaWVudC1zZWNyZXQ=",
|
||||||
"Basic Y2xpZW50LWlkOmNsaWVudC1zZWNyZXQ=",
|
)
|
||||||
)
|
.body(Full::<Bytes>::new("client_id=mismatch-id&foo=bar".into()))
|
||||||
.body(Full::<Bytes>::new("client_id=mismatch-id&foo=bar".into()))
|
.unwrap();
|
||||||
.unwrap(),
|
|
||||||
);
|
|
||||||
|
|
||||||
assert!(matches!(
|
assert!(matches!(
|
||||||
ClientAuthorization::<serde_json::Value>::from_request(&mut req).await,
|
ClientAuthorization::<serde_json::Value>::from_request(req, &()).await,
|
||||||
Err(ClientAuthorizationError::ClientIdMismatch { .. }),
|
Err(ClientAuthorizationError::ClientIdMismatch { .. }),
|
||||||
));
|
));
|
||||||
|
|
||||||
// Invalid header
|
// Invalid header
|
||||||
let mut req = RequestParts::new(
|
let req = Request::builder()
|
||||||
Request::builder()
|
.method(Method::POST)
|
||||||
.method(Method::POST)
|
.header(
|
||||||
.header(
|
http::header::CONTENT_TYPE,
|
||||||
http::header::CONTENT_TYPE,
|
mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
|
||||||
mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
|
)
|
||||||
)
|
.header(http::header::AUTHORIZATION, "Basic invalid")
|
||||||
.header(http::header::AUTHORIZATION, "Basic invalid")
|
.body(Full::<Bytes>::new("foo=bar".into()))
|
||||||
.body(Full::<Bytes>::new("foo=bar".into()))
|
.unwrap();
|
||||||
.unwrap(),
|
|
||||||
);
|
|
||||||
|
|
||||||
assert!(matches!(
|
assert!(matches!(
|
||||||
ClientAuthorization::<serde_json::Value>::from_request(&mut req).await,
|
ClientAuthorization::<serde_json::Value>::from_request(req, &()).await,
|
||||||
Err(ClientAuthorizationError::InvalidHeader),
|
Err(ClientAuthorizationError::InvalidHeader),
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn client_secret_post_test() {
|
async fn client_secret_post_test() {
|
||||||
let mut req = RequestParts::new(
|
let req = Request::builder()
|
||||||
Request::builder()
|
.method(Method::POST)
|
||||||
.method(Method::POST)
|
.header(
|
||||||
.header(
|
http::header::CONTENT_TYPE,
|
||||||
http::header::CONTENT_TYPE,
|
mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
|
||||||
mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
|
)
|
||||||
)
|
.body(Full::<Bytes>::new(
|
||||||
.body(Full::<Bytes>::new(
|
"client_id=client-id&client_secret=client-secret&foo=bar".into(),
|
||||||
"client_id=client-id&client_secret=client-secret&foo=bar".into(),
|
))
|
||||||
))
|
.unwrap();
|
||||||
.unwrap(),
|
|
||||||
);
|
|
||||||
|
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
ClientAuthorization::<serde_json::Value>::from_request(&mut req)
|
ClientAuthorization::<serde_json::Value>::from_request(req, &())
|
||||||
.await
|
.await
|
||||||
.unwrap(),
|
.unwrap(),
|
||||||
ClientAuthorization {
|
ClientAuthorization {
|
||||||
@ -546,18 +542,16 @@ mod tests {
|
|||||||
JWT_BEARER_CLIENT_ASSERTION, jwt,
|
JWT_BEARER_CLIENT_ASSERTION, jwt,
|
||||||
));
|
));
|
||||||
|
|
||||||
let mut req = RequestParts::new(
|
let req = Request::builder()
|
||||||
Request::builder()
|
.method(Method::POST)
|
||||||
.method(Method::POST)
|
.header(
|
||||||
.header(
|
http::header::CONTENT_TYPE,
|
||||||
http::header::CONTENT_TYPE,
|
mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
|
||||||
mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
|
)
|
||||||
)
|
.body(Full::new(body))
|
||||||
.body(Full::new(body))
|
.unwrap();
|
||||||
.unwrap(),
|
|
||||||
);
|
|
||||||
|
|
||||||
let authz = ClientAuthorization::<serde_json::Value>::from_request(&mut req)
|
let authz = ClientAuthorization::<serde_json::Value>::from_request(req, &())
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
assert_eq!(authz.form, Some(serde_json::json!({"foo": "bar"})));
|
assert_eq!(authz.form, Some(serde_json::json!({"foo": "bar"})));
|
||||||
|
@ -19,12 +19,13 @@ use axum::{
|
|||||||
body::HttpBody,
|
body::HttpBody,
|
||||||
extract::{
|
extract::{
|
||||||
rejection::{FailedToDeserializeQueryString, FormRejection, TypedHeaderRejectionReason},
|
rejection::{FailedToDeserializeQueryString, FormRejection, TypedHeaderRejectionReason},
|
||||||
Form, FromRequest, TypedHeader,
|
Form, FromRequest, FromRequestParts, TypedHeader,
|
||||||
},
|
},
|
||||||
response::{IntoResponse, Response},
|
response::{IntoResponse, Response},
|
||||||
|
BoxError,
|
||||||
};
|
};
|
||||||
use headers::{authorization::Bearer, Authorization, Header, HeaderMapExt, HeaderName};
|
use headers::{authorization::Bearer, Authorization, Header, HeaderMapExt, HeaderName};
|
||||||
use http::{header::WWW_AUTHENTICATE, HeaderMap, HeaderValue, StatusCode};
|
use http::{header::WWW_AUTHENTICATE, HeaderMap, HeaderValue, Request, StatusCode};
|
||||||
use mas_data_model::Session;
|
use mas_data_model::Session;
|
||||||
use mas_storage::{
|
use mas_storage::{
|
||||||
oauth2::access_token::{lookup_active_access_token, AccessTokenLookupError},
|
oauth2::access_token::{lookup_active_access_token, AccessTokenLookupError},
|
||||||
@ -275,19 +276,20 @@ impl IntoResponse for AuthorizationVerificationError {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl<B, F> FromRequest<B> for UserAuthorization<F>
|
impl<S, B, F> FromRequest<S, B> for UserAuthorization<F>
|
||||||
where
|
where
|
||||||
B: Send + HttpBody,
|
|
||||||
B::Data: Send,
|
|
||||||
B::Error: Error + Send + Sync + 'static,
|
|
||||||
F: DeserializeOwned,
|
F: DeserializeOwned,
|
||||||
|
B: HttpBody + Send + 'static,
|
||||||
|
B::Data: Send,
|
||||||
|
B::Error: Into<BoxError>,
|
||||||
|
S: Send + Sync,
|
||||||
{
|
{
|
||||||
type Rejection = UserAuthorizationError;
|
type Rejection = UserAuthorizationError;
|
||||||
|
|
||||||
async fn from_request(
|
async fn from_request(req: Request<B>, state: &S) -> Result<Self, Self::Rejection> {
|
||||||
req: &mut axum::extract::RequestParts<B>,
|
let (mut parts, body) = req.into_parts();
|
||||||
) -> Result<Self, Self::Rejection> {
|
let header =
|
||||||
let header = TypedHeader::<Authorization<Bearer>>::from_request(req).await;
|
TypedHeader::<Authorization<Bearer>>::from_request_parts(&mut parts, state).await;
|
||||||
|
|
||||||
// Take the Authorization header
|
// Take the Authorization header
|
||||||
let token_from_header = match header {
|
let token_from_header = match header {
|
||||||
@ -300,18 +302,21 @@ where
|
|||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let req = Request::from_parts(parts, body);
|
||||||
|
|
||||||
// Take the form value
|
// Take the form value
|
||||||
let (token_from_form, form) = match Form::<AuthorizedForm<F>>::from_request(req).await {
|
let (token_from_form, form) =
|
||||||
Ok(Form(form)) => (form.access_token, Some(form.inner)),
|
match Form::<AuthorizedForm<F>>::from_request(req, state).await {
|
||||||
// If it is not a form, continue
|
Ok(Form(form)) => (form.access_token, Some(form.inner)),
|
||||||
Err(FormRejection::InvalidFormContentType(_err)) => (None, None),
|
// If it is not a form, continue
|
||||||
// If the form could not be read, return a Bad Request error
|
Err(FormRejection::InvalidFormContentType(_err)) => (None, None),
|
||||||
Err(FormRejection::FailedToDeserializeQueryString(err)) => {
|
// If the form could not be read, return a Bad Request error
|
||||||
return Err(UserAuthorizationError::BadForm(err))
|
Err(FormRejection::FailedToDeserializeQueryString(err)) => {
|
||||||
}
|
return Err(UserAuthorizationError::BadForm(err))
|
||||||
// Other errors (body read twice, byte stream broke) return an internal error
|
}
|
||||||
Err(e) => return Err(UserAuthorizationError::InternalError(Box::new(e))),
|
// Other errors (body read twice, byte stream broke) return an internal error
|
||||||
};
|
Err(e) => return Err(UserAuthorizationError::InternalError(Box::new(e))),
|
||||||
|
};
|
||||||
|
|
||||||
let access_token = match (token_from_header, token_from_form) {
|
let access_token = match (token_from_header, token_from_form) {
|
||||||
// Ensure the token should not be in both the form and the access token
|
// Ensure the token should not be in both the form and the access token
|
||||||
|
@ -24,7 +24,7 @@ use futures::stream::{StreamExt, TryStreamExt};
|
|||||||
use hyper::Server;
|
use hyper::Server;
|
||||||
use mas_config::RootConfig;
|
use mas_config::RootConfig;
|
||||||
use mas_email::Mailer;
|
use mas_email::Mailer;
|
||||||
use mas_handlers::MatrixHomeserver;
|
use mas_handlers::{AppState, MatrixHomeserver};
|
||||||
use mas_http::ServerLayer;
|
use mas_http::ServerLayer;
|
||||||
use mas_policy::PolicyFactory;
|
use mas_policy::PolicyFactory;
|
||||||
use mas_router::UrlBuilder;
|
use mas_router::UrlBuilder;
|
||||||
@ -174,8 +174,6 @@ impl Options {
|
|||||||
.key_store()
|
.key_store()
|
||||||
.await
|
.await
|
||||||
.context("could not import keys from config")?;
|
.context("could not import keys from config")?;
|
||||||
// Wrap the key store in an Arc
|
|
||||||
let key_store = Arc::new(key_store);
|
|
||||||
|
|
||||||
let encrypter = config.secrets.encrypter();
|
let encrypter = config.secrets.encrypter();
|
||||||
|
|
||||||
@ -236,18 +234,20 @@ impl Options {
|
|||||||
.context("could not watch for templates changes")?;
|
.context("could not watch for templates changes")?;
|
||||||
}
|
}
|
||||||
|
|
||||||
let router = mas_handlers::router(
|
let state = AppState {
|
||||||
&pool,
|
pool,
|
||||||
&templates,
|
templates,
|
||||||
&key_store,
|
key_store,
|
||||||
&encrypter,
|
encrypter,
|
||||||
&mailer,
|
url_builder,
|
||||||
&url_builder,
|
mailer,
|
||||||
&homeserver,
|
homeserver,
|
||||||
&policy_factory,
|
policy_factory,
|
||||||
)
|
};
|
||||||
.fallback(static_files)
|
|
||||||
.layer(ServerLayer::default());
|
let router = mas_handlers::router(state)
|
||||||
|
.fallback_service(static_files)
|
||||||
|
.layer(ServerLayer::default());
|
||||||
|
|
||||||
info!("Listening on http://{}", listener.local_addr().unwrap());
|
info!("Listening on http://{}", listener.local_addr().unwrap());
|
||||||
|
|
||||||
|
@ -20,9 +20,9 @@ anyhow = "1.0.64"
|
|||||||
hyper = { version = "0.14.20", features = ["full"] }
|
hyper = { version = "0.14.20", features = ["full"] }
|
||||||
tower = "0.4.13"
|
tower = "0.4.13"
|
||||||
tower-http = { version = "0.3.4", features = ["cors"] }
|
tower-http = { version = "0.3.4", features = ["cors"] }
|
||||||
axum = "0.5.15"
|
axum = "0.6.0-rc.1"
|
||||||
axum-macros = "0.2.3"
|
axum-macros = "0.2.3"
|
||||||
axum-extra = { version = "0.3.7", features = ["cookie-private"] }
|
axum-extra = { version = "0.4.0-rc.1", features = ["cookie-private"] }
|
||||||
|
|
||||||
# Emails
|
# Emails
|
||||||
lettre = { version = "0.10.1", default-features = false, features = ["builder"] }
|
lettre = { version = "0.10.1", default-features = false, features = ["builder"] }
|
||||||
|
85
crates/handlers/src/app_state.rs
Normal file
85
crates/handlers/src/app_state.rs
Normal file
@ -0,0 +1,85 @@
|
|||||||
|
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use axum::extract::FromRef;
|
||||||
|
use mas_email::Mailer;
|
||||||
|
use mas_keystore::{Encrypter, Keystore};
|
||||||
|
use mas_policy::PolicyFactory;
|
||||||
|
use mas_router::UrlBuilder;
|
||||||
|
use mas_templates::Templates;
|
||||||
|
use sqlx::PgPool;
|
||||||
|
|
||||||
|
use crate::MatrixHomeserver;
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct AppState {
|
||||||
|
pub pool: PgPool,
|
||||||
|
pub templates: Templates,
|
||||||
|
pub key_store: Keystore,
|
||||||
|
pub encrypter: Encrypter,
|
||||||
|
pub url_builder: UrlBuilder,
|
||||||
|
pub mailer: Mailer,
|
||||||
|
pub homeserver: MatrixHomeserver,
|
||||||
|
pub policy_factory: Arc<PolicyFactory>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl FromRef<AppState> for PgPool {
|
||||||
|
fn from_ref(input: &AppState) -> Self {
|
||||||
|
input.pool.clone()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl FromRef<AppState> for Templates {
|
||||||
|
fn from_ref(input: &AppState) -> Self {
|
||||||
|
input.templates.clone()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl FromRef<AppState> for Keystore {
|
||||||
|
fn from_ref(input: &AppState) -> Self {
|
||||||
|
input.key_store.clone()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl FromRef<AppState> for Encrypter {
|
||||||
|
fn from_ref(input: &AppState) -> Self {
|
||||||
|
input.encrypter.clone()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl FromRef<AppState> for UrlBuilder {
|
||||||
|
fn from_ref(input: &AppState) -> Self {
|
||||||
|
input.url_builder.clone()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl FromRef<AppState> for Mailer {
|
||||||
|
fn from_ref(input: &AppState) -> Self {
|
||||||
|
input.mailer.clone()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl FromRef<AppState> for MatrixHomeserver {
|
||||||
|
fn from_ref(input: &AppState) -> Self {
|
||||||
|
input.homeserver.clone()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl FromRef<AppState> for Arc<PolicyFactory> {
|
||||||
|
fn from_ref(input: &AppState) -> Self {
|
||||||
|
input.policy_factory.clone()
|
||||||
|
}
|
||||||
|
}
|
@ -12,7 +12,7 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
use axum::{response::IntoResponse, Extension, Json};
|
use axum::{extract::State, response::IntoResponse, Json};
|
||||||
use chrono::{Duration, Utc};
|
use chrono::{Duration, Utc};
|
||||||
use hyper::StatusCode;
|
use hyper::StatusCode;
|
||||||
use mas_data_model::{CompatSession, CompatSsoLoginState, Device, TokenType};
|
use mas_data_model::{CompatSession, CompatSsoLoginState, Device, TokenType};
|
||||||
@ -197,8 +197,8 @@ impl IntoResponse for RouteError {
|
|||||||
|
|
||||||
#[tracing::instrument(skip_all, err)]
|
#[tracing::instrument(skip_all, err)]
|
||||||
pub(crate) async fn post(
|
pub(crate) async fn post(
|
||||||
Extension(pool): Extension<PgPool>,
|
State(pool): State<PgPool>,
|
||||||
Extension(homeserver): Extension<MatrixHomeserver>,
|
State(homeserver): State<MatrixHomeserver>,
|
||||||
Json(input): Json<RequestBody>,
|
Json(input): Json<RequestBody>,
|
||||||
) -> Result<impl IntoResponse, RouteError> {
|
) -> Result<impl IntoResponse, RouteError> {
|
||||||
let mut txn = pool.begin().await?;
|
let mut txn = pool.begin().await?;
|
||||||
|
@ -16,9 +16,8 @@
|
|||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
|
||||||
use axum::{
|
use axum::{
|
||||||
extract::{Form, Path, Query},
|
extract::{Form, Path, Query, State},
|
||||||
response::{Html, IntoResponse, Redirect, Response},
|
response::{Html, IntoResponse, Redirect, Response},
|
||||||
Extension,
|
|
||||||
};
|
};
|
||||||
use axum_extra::extract::PrivateCookieJar;
|
use axum_extra::extract::PrivateCookieJar;
|
||||||
use chrono::{Duration, Utc};
|
use chrono::{Duration, Utc};
|
||||||
@ -50,8 +49,8 @@ pub struct Params {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub async fn get(
|
pub async fn get(
|
||||||
Extension(pool): Extension<PgPool>,
|
State(pool): State<PgPool>,
|
||||||
Extension(templates): Extension<Templates>,
|
State(templates): State<Templates>,
|
||||||
cookie_jar: PrivateCookieJar<Encrypter>,
|
cookie_jar: PrivateCookieJar<Encrypter>,
|
||||||
Path(id): Path<i64>,
|
Path(id): Path<i64>,
|
||||||
Query(params): Query<Params>,
|
Query(params): Query<Params>,
|
||||||
@ -114,12 +113,12 @@ pub async fn get(
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub async fn post(
|
pub async fn post(
|
||||||
Extension(pool): Extension<PgPool>,
|
State(pool): State<PgPool>,
|
||||||
Extension(templates): Extension<Templates>,
|
State(templates): State<Templates>,
|
||||||
cookie_jar: PrivateCookieJar<Encrypter>,
|
cookie_jar: PrivateCookieJar<Encrypter>,
|
||||||
Path(id): Path<i64>,
|
Path(id): Path<i64>,
|
||||||
Form(form): Form<ProtectedForm<()>>,
|
|
||||||
Query(params): Query<Params>,
|
Query(params): Query<Params>,
|
||||||
|
Form(form): Form<ProtectedForm<()>>,
|
||||||
) -> Result<Response, FancyError> {
|
) -> Result<Response, FancyError> {
|
||||||
let mut txn = pool.begin().await?;
|
let mut txn = pool.begin().await?;
|
||||||
|
|
||||||
|
@ -13,7 +13,10 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
use axum::{extract::Query, response::IntoResponse, Extension};
|
use axum::{
|
||||||
|
extract::{Query, State},
|
||||||
|
response::IntoResponse,
|
||||||
|
};
|
||||||
use hyper::StatusCode;
|
use hyper::StatusCode;
|
||||||
use mas_router::{CompatLoginSsoAction, CompatLoginSsoComplete, UrlBuilder};
|
use mas_router::{CompatLoginSsoAction, CompatLoginSsoComplete, UrlBuilder};
|
||||||
use mas_storage::compat::insert_compat_sso_login;
|
use mas_storage::compat::insert_compat_sso_login;
|
||||||
@ -63,8 +66,8 @@ impl IntoResponse for RouteError {
|
|||||||
|
|
||||||
#[tracing::instrument(skip(pool, url_builder), err)]
|
#[tracing::instrument(skip(pool, url_builder), err)]
|
||||||
pub async fn get(
|
pub async fn get(
|
||||||
Extension(pool): Extension<PgPool>,
|
State(pool): State<PgPool>,
|
||||||
Extension(url_builder): Extension<UrlBuilder>,
|
State(url_builder): State<UrlBuilder>,
|
||||||
Query(params): Query<Params>,
|
Query(params): Query<Params>,
|
||||||
) -> Result<impl IntoResponse, RouteError> {
|
) -> Result<impl IntoResponse, RouteError> {
|
||||||
// Check the redirectUrl parameter
|
// Check the redirectUrl parameter
|
||||||
|
@ -12,7 +12,7 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
use axum::{response::IntoResponse, Extension, Json, TypedHeader};
|
use axum::{extract::State, response::IntoResponse, Json, TypedHeader};
|
||||||
use headers::{authorization::Bearer, Authorization};
|
use headers::{authorization::Bearer, Authorization};
|
||||||
use hyper::StatusCode;
|
use hyper::StatusCode;
|
||||||
use mas_data_model::{TokenFormatError, TokenType};
|
use mas_data_model::{TokenFormatError, TokenType};
|
||||||
@ -64,7 +64,7 @@ impl From<TokenFormatError> for RouteError {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) async fn post(
|
pub(crate) async fn post(
|
||||||
Extension(pool): Extension<PgPool>,
|
State(pool): State<PgPool>,
|
||||||
maybe_authorization: Option<TypedHeader<Authorization<Bearer>>>,
|
maybe_authorization: Option<TypedHeader<Authorization<Bearer>>>,
|
||||||
) -> Result<impl IntoResponse, RouteError> {
|
) -> Result<impl IntoResponse, RouteError> {
|
||||||
let mut conn = pool.acquire().await?;
|
let mut conn = pool.acquire().await?;
|
||||||
|
@ -12,7 +12,7 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
use axum::{response::IntoResponse, Extension, Json};
|
use axum::{extract::State, response::IntoResponse, Json};
|
||||||
use chrono::Duration;
|
use chrono::Duration;
|
||||||
use hyper::StatusCode;
|
use hyper::StatusCode;
|
||||||
use mas_data_model::{TokenFormatError, TokenType};
|
use mas_data_model::{TokenFormatError, TokenType};
|
||||||
@ -96,7 +96,7 @@ pub struct ResponseBody {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) async fn post(
|
pub(crate) async fn post(
|
||||||
Extension(pool): Extension<PgPool>,
|
State(pool): State<PgPool>,
|
||||||
Json(input): Json<RequestBody>,
|
Json(input): Json<RequestBody>,
|
||||||
) -> Result<impl IntoResponse, RouteError> {
|
) -> Result<impl IntoResponse, RouteError> {
|
||||||
let mut txn = pool.begin().await?;
|
let mut txn = pool.begin().await?;
|
||||||
|
@ -12,12 +12,12 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
use axum::{extract::Extension, response::IntoResponse};
|
use axum::{extract::State, response::IntoResponse};
|
||||||
use mas_axum_utils::FancyError;
|
use mas_axum_utils::FancyError;
|
||||||
use sqlx::PgPool;
|
use sqlx::PgPool;
|
||||||
use tracing::{info_span, Instrument};
|
use tracing::{info_span, Instrument};
|
||||||
|
|
||||||
pub async fn get(Extension(pool): Extension<PgPool>) -> Result<impl IntoResponse, FancyError> {
|
pub async fn get(State(pool): State<PgPool>) -> Result<impl IntoResponse, FancyError> {
|
||||||
let mut conn = pool.acquire().await?;
|
let mut conn = pool.acquire().await?;
|
||||||
|
|
||||||
sqlx::query("SELECT $1")
|
sqlx::query("SELECT $1")
|
||||||
@ -38,7 +38,9 @@ mod tests {
|
|||||||
|
|
||||||
#[sqlx::test(migrator = "mas_storage::MIGRATOR")]
|
#[sqlx::test(migrator = "mas_storage::MIGRATOR")]
|
||||||
async fn test_get_health(pool: PgPool) -> Result<(), anyhow::Error> {
|
async fn test_get_health(pool: PgPool) -> Result<(), anyhow::Error> {
|
||||||
let app = crate::test_router(&pool).await?;
|
let state = crate::test_state(pool).await?;
|
||||||
|
let app = crate::api_router(state);
|
||||||
|
|
||||||
let request = Request::builder().uri("/health").body(Body::empty())?;
|
let request = Request::builder().uri("/health").body(Body::empty())?;
|
||||||
|
|
||||||
let response = app.oneshot(request).await?;
|
let response = app.oneshot(request).await?;
|
||||||
|
@ -23,7 +23,7 @@ use std::{convert::Infallible, sync::Arc, time::Duration};
|
|||||||
|
|
||||||
use axum::{
|
use axum::{
|
||||||
body::HttpBody,
|
body::HttpBody,
|
||||||
extract::Extension,
|
extract::FromRef,
|
||||||
response::{Html, IntoResponse},
|
response::{Html, IntoResponse},
|
||||||
routing::{get, on, post, MethodFilter},
|
routing::{get, on, post, MethodFilter},
|
||||||
Router,
|
Router,
|
||||||
@ -37,9 +37,10 @@ use mas_policy::PolicyFactory;
|
|||||||
use mas_router::{Route, UrlBuilder};
|
use mas_router::{Route, UrlBuilder};
|
||||||
use mas_templates::{ErrorContext, Templates};
|
use mas_templates::{ErrorContext, Templates};
|
||||||
use sqlx::PgPool;
|
use sqlx::PgPool;
|
||||||
use tower::util::ThenLayer;
|
use tower::util::AndThenLayer;
|
||||||
use tower_http::cors::{Any, CorsLayer};
|
use tower_http::cors::{Any, CorsLayer};
|
||||||
|
|
||||||
|
mod app_state;
|
||||||
mod compat;
|
mod compat;
|
||||||
mod health;
|
mod health;
|
||||||
mod oauth2;
|
mod oauth2;
|
||||||
@ -47,30 +48,24 @@ mod views;
|
|||||||
|
|
||||||
pub use compat::MatrixHomeserver;
|
pub use compat::MatrixHomeserver;
|
||||||
|
|
||||||
|
pub use self::app_state::AppState;
|
||||||
|
|
||||||
#[must_use]
|
#[must_use]
|
||||||
#[allow(
|
#[allow(clippy::trait_duplication_in_bounds)]
|
||||||
clippy::too_many_lines,
|
pub fn api_router<S, B>(state: Arc<S>) -> Router<S, B>
|
||||||
clippy::missing_panics_doc,
|
|
||||||
clippy::too_many_arguments,
|
|
||||||
clippy::trait_duplication_in_bounds
|
|
||||||
)]
|
|
||||||
pub fn router<B>(
|
|
||||||
pool: &PgPool,
|
|
||||||
templates: &Templates,
|
|
||||||
key_store: &Keystore,
|
|
||||||
encrypter: &Encrypter,
|
|
||||||
mailer: &Mailer,
|
|
||||||
url_builder: &UrlBuilder,
|
|
||||||
homeserver: &MatrixHomeserver,
|
|
||||||
policy_factory: &Arc<PolicyFactory>,
|
|
||||||
) -> Router<B>
|
|
||||||
where
|
where
|
||||||
B: HttpBody + Send + 'static,
|
B: HttpBody + Send + 'static,
|
||||||
<B as HttpBody>::Data: Send,
|
<B as HttpBody>::Data: Send,
|
||||||
<B as HttpBody>::Error: std::error::Error + Send + Sync,
|
<B as HttpBody>::Error: std::error::Error + Send + Sync,
|
||||||
|
S: Send + Sync + 'static,
|
||||||
|
Keystore: FromRef<S>,
|
||||||
|
UrlBuilder: FromRef<S>,
|
||||||
|
Arc<PolicyFactory>: FromRef<S>,
|
||||||
|
PgPool: FromRef<S>,
|
||||||
|
Encrypter: FromRef<S>,
|
||||||
{
|
{
|
||||||
// All those routes are API-like, with a common CORS layer
|
// All those routes are API-like, with a common CORS layer
|
||||||
let api_router = Router::new()
|
Router::with_state_arc(state)
|
||||||
.route(
|
.route(
|
||||||
mas_router::ChangePasswordDiscovery::route(),
|
mas_router::ChangePasswordDiscovery::route(),
|
||||||
get(|| async { mas_router::AccountPassword.go() }),
|
get(|| async { mas_router::AccountPassword.go() }),
|
||||||
@ -118,9 +113,21 @@ where
|
|||||||
CONTENT_TYPE,
|
CONTENT_TYPE,
|
||||||
])
|
])
|
||||||
.max_age(Duration::from_secs(60 * 60)),
|
.max_age(Duration::from_secs(60 * 60)),
|
||||||
);
|
)
|
||||||
|
}
|
||||||
let compat_router = Router::new()
|
#[must_use]
|
||||||
|
#[allow(clippy::trait_duplication_in_bounds)]
|
||||||
|
pub fn compat_router<S, B>(state: Arc<S>) -> Router<S, B>
|
||||||
|
where
|
||||||
|
B: HttpBody + Send + 'static,
|
||||||
|
<B as HttpBody>::Data: Send,
|
||||||
|
<B as HttpBody>::Error: std::error::Error + Send + Sync,
|
||||||
|
S: Send + Sync + 'static,
|
||||||
|
UrlBuilder: FromRef<S>,
|
||||||
|
PgPool: FromRef<S>,
|
||||||
|
MatrixHomeserver: FromRef<S>,
|
||||||
|
{
|
||||||
|
Router::with_state_arc(state)
|
||||||
.route(
|
.route(
|
||||||
mas_router::CompatLogin::route(),
|
mas_router::CompatLogin::route(),
|
||||||
get(self::compat::login::get).post(self::compat::login::post),
|
get(self::compat::login::get).post(self::compat::login::post),
|
||||||
@ -146,106 +153,131 @@ where
|
|||||||
HeaderName::from_static("x-requested-with"),
|
HeaderName::from_static("x-requested-with"),
|
||||||
])
|
])
|
||||||
.max_age(Duration::from_secs(60 * 60)),
|
.max_age(Duration::from_secs(60 * 60)),
|
||||||
);
|
)
|
||||||
|
}
|
||||||
|
|
||||||
let human_router = {
|
#[must_use]
|
||||||
let templates = templates.clone();
|
#[allow(clippy::trait_duplication_in_bounds)]
|
||||||
Router::new()
|
pub fn human_router<S, B>(state: Arc<S>) -> Router<S, B>
|
||||||
.route(mas_router::Index::route(), get(self::views::index::get))
|
where
|
||||||
.route(mas_router::Healthcheck::route(), get(self::health::get))
|
B: HttpBody + Send + 'static,
|
||||||
.route(
|
<B as HttpBody>::Data: Send,
|
||||||
mas_router::Login::route(),
|
<B as HttpBody>::Error: std::error::Error + Send + Sync,
|
||||||
get(self::views::login::get).post(self::views::login::post),
|
S: Send + Sync + 'static,
|
||||||
)
|
UrlBuilder: FromRef<S>,
|
||||||
.route(mas_router::Logout::route(), post(self::views::logout::post))
|
Arc<PolicyFactory>: FromRef<S>,
|
||||||
.route(
|
PgPool: FromRef<S>,
|
||||||
mas_router::Reauth::route(),
|
Encrypter: FromRef<S>,
|
||||||
get(self::views::reauth::get).post(self::views::reauth::post),
|
Templates: FromRef<S>,
|
||||||
)
|
Mailer: FromRef<S>,
|
||||||
.route(
|
{
|
||||||
mas_router::Register::route(),
|
let templates = Templates::from_ref(&state);
|
||||||
get(self::views::register::get).post(self::views::register::post),
|
Router::with_state_arc(state)
|
||||||
)
|
.route(mas_router::Index::route(), get(self::views::index::get))
|
||||||
.route(mas_router::Account::route(), get(self::views::account::get))
|
.route(mas_router::Healthcheck::route(), get(self::health::get))
|
||||||
.route(
|
.route(
|
||||||
mas_router::AccountPassword::route(),
|
mas_router::Login::route(),
|
||||||
get(self::views::account::password::get).post(self::views::account::password::post),
|
get(self::views::login::get).post(self::views::login::post),
|
||||||
)
|
)
|
||||||
.route(
|
.route(mas_router::Logout::route(), post(self::views::logout::post))
|
||||||
mas_router::AccountEmails::route(),
|
.route(
|
||||||
get(self::views::account::emails::get).post(self::views::account::emails::post),
|
mas_router::Reauth::route(),
|
||||||
)
|
get(self::views::reauth::get).post(self::views::reauth::post),
|
||||||
.route(
|
)
|
||||||
mas_router::AccountVerifyEmail::route(),
|
.route(
|
||||||
get(self::views::account::emails::verify::get)
|
mas_router::Register::route(),
|
||||||
.post(self::views::account::emails::verify::post),
|
get(self::views::register::get).post(self::views::register::post),
|
||||||
)
|
)
|
||||||
.route(
|
.route(mas_router::Account::route(), get(self::views::account::get))
|
||||||
mas_router::AccountAddEmail::route(),
|
.route(
|
||||||
get(self::views::account::emails::add::get)
|
mas_router::AccountPassword::route(),
|
||||||
.post(self::views::account::emails::add::post),
|
get(self::views::account::password::get).post(self::views::account::password::post),
|
||||||
)
|
)
|
||||||
.route(
|
.route(
|
||||||
mas_router::OAuth2AuthorizationEndpoint::route(),
|
mas_router::AccountEmails::route(),
|
||||||
get(self::oauth2::authorization::get),
|
get(self::views::account::emails::get).post(self::views::account::emails::post),
|
||||||
)
|
)
|
||||||
.route(
|
.route(
|
||||||
mas_router::ContinueAuthorizationGrant::route(),
|
mas_router::AccountVerifyEmail::route(),
|
||||||
get(self::oauth2::authorization::complete::get),
|
get(self::views::account::emails::verify::get)
|
||||||
)
|
.post(self::views::account::emails::verify::post),
|
||||||
.route(
|
)
|
||||||
mas_router::Consent::route(),
|
.route(
|
||||||
get(self::oauth2::consent::get).post(self::oauth2::consent::post),
|
mas_router::AccountAddEmail::route(),
|
||||||
)
|
get(self::views::account::emails::add::get)
|
||||||
.route(
|
.post(self::views::account::emails::add::post),
|
||||||
mas_router::CompatLoginSsoRedirect::route(),
|
)
|
||||||
get(self::compat::login_sso_redirect::get),
|
.route(
|
||||||
)
|
mas_router::OAuth2AuthorizationEndpoint::route(),
|
||||||
.route(
|
get(self::oauth2::authorization::get),
|
||||||
mas_router::CompatLoginSsoRedirectIdp::route(),
|
)
|
||||||
get(self::compat::login_sso_redirect::get),
|
.route(
|
||||||
)
|
mas_router::ContinueAuthorizationGrant::route(),
|
||||||
.route(
|
get(self::oauth2::authorization::complete::get),
|
||||||
mas_router::CompatLoginSsoComplete::route(),
|
)
|
||||||
get(self::compat::login_sso_complete::get)
|
.route(
|
||||||
.post(self::compat::login_sso_complete::post),
|
mas_router::Consent::route(),
|
||||||
)
|
get(self::oauth2::consent::get).post(self::oauth2::consent::post),
|
||||||
.layer(ThenLayer::new(
|
)
|
||||||
move |result: Result<axum::response::Response, Infallible>| async move {
|
.route(
|
||||||
let response = result.unwrap();
|
mas_router::CompatLoginSsoRedirect::route(),
|
||||||
|
get(self::compat::login_sso_redirect::get),
|
||||||
if response.status().is_server_error() {
|
)
|
||||||
// Error responses should have an ErrorContext attached to them
|
.route(
|
||||||
let ext = response.extensions().get::<ErrorContext>();
|
mas_router::CompatLoginSsoRedirectIdp::route(),
|
||||||
if let Some(ctx) = ext {
|
get(self::compat::login_sso_redirect::get),
|
||||||
if let Ok(res) = templates.render_error(ctx).await {
|
)
|
||||||
let (mut parts, _original_body) = response.into_parts();
|
.route(
|
||||||
parts.headers.remove(CONTENT_TYPE);
|
mas_router::CompatLoginSsoComplete::route(),
|
||||||
return Ok((parts, Html(res)).into_response());
|
get(self::compat::login_sso_complete::get).post(self::compat::login_sso_complete::post),
|
||||||
}
|
)
|
||||||
|
.layer(AndThenLayer::new(
|
||||||
|
move |response: axum::response::Response| async move {
|
||||||
|
if response.status().is_server_error() {
|
||||||
|
// Error responses should have an ErrorContext attached to them
|
||||||
|
let ext = response.extensions().get::<ErrorContext>();
|
||||||
|
if let Some(ctx) = ext {
|
||||||
|
if let Ok(res) = templates.render_error(ctx).await {
|
||||||
|
let (mut parts, _original_body) = response.into_parts();
|
||||||
|
parts.headers.remove(CONTENT_TYPE);
|
||||||
|
return Ok((parts, Html(res)).into_response());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
Ok(response)
|
Ok::<_, Infallible>(response)
|
||||||
},
|
},
|
||||||
))
|
))
|
||||||
};
|
}
|
||||||
|
|
||||||
human_router
|
#[must_use]
|
||||||
.merge(api_router)
|
#[allow(clippy::trait_duplication_in_bounds)]
|
||||||
.merge(compat_router)
|
pub fn router<S, B>(state: S) -> Router<S, B>
|
||||||
.layer(Extension(pool.clone()))
|
where
|
||||||
.layer(Extension(templates.clone()))
|
B: HttpBody + Send + 'static,
|
||||||
.layer(Extension(key_store.clone()))
|
<B as HttpBody>::Data: Send,
|
||||||
.layer(Extension(encrypter.clone()))
|
<B as HttpBody>::Error: std::error::Error + Send + Sync,
|
||||||
.layer(Extension(url_builder.clone()))
|
S: Send + Sync + 'static,
|
||||||
.layer(Extension(mailer.clone()))
|
Keystore: FromRef<S>,
|
||||||
.layer(Extension(homeserver.clone()))
|
UrlBuilder: FromRef<S>,
|
||||||
.layer(Extension(policy_factory.clone()))
|
Arc<PolicyFactory>: FromRef<S>,
|
||||||
|
PgPool: FromRef<S>,
|
||||||
|
Encrypter: FromRef<S>,
|
||||||
|
Templates: FromRef<S>,
|
||||||
|
Mailer: FromRef<S>,
|
||||||
|
MatrixHomeserver: FromRef<S>,
|
||||||
|
{
|
||||||
|
let state = Arc::new(state);
|
||||||
|
|
||||||
|
let api_router = api_router(state.clone());
|
||||||
|
let compat_router = compat_router(state.clone());
|
||||||
|
let human_router = human_router(state);
|
||||||
|
|
||||||
|
human_router.merge(api_router).merge(compat_router)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
async fn test_router(pool: &PgPool) -> Result<Router, anyhow::Error> {
|
async fn test_state(pool: PgPool) -> Result<Arc<AppState>, anyhow::Error> {
|
||||||
use mas_email::MailTransport;
|
use mas_email::MailTransport;
|
||||||
|
|
||||||
let templates = Templates::load(None, true).await?;
|
let templates = Templates::load(None, true).await?;
|
||||||
@ -265,14 +297,14 @@ async fn test_router(pool: &PgPool) -> Result<Router, anyhow::Error> {
|
|||||||
let policy_factory = PolicyFactory::load_default(serde_json::json!({})).await?;
|
let policy_factory = PolicyFactory::load_default(serde_json::json!({})).await?;
|
||||||
let policy_factory = Arc::new(policy_factory);
|
let policy_factory = Arc::new(policy_factory);
|
||||||
|
|
||||||
Ok(router(
|
Ok(Arc::new(AppState {
|
||||||
pool,
|
pool,
|
||||||
&templates,
|
templates,
|
||||||
&key_store,
|
key_store,
|
||||||
&encrypter,
|
encrypter,
|
||||||
&mailer,
|
url_builder,
|
||||||
&url_builder,
|
mailer,
|
||||||
&homeserver,
|
homeserver,
|
||||||
&policy_factory,
|
policy_factory,
|
||||||
))
|
}))
|
||||||
}
|
}
|
||||||
|
@ -16,9 +16,8 @@ use std::sync::Arc;
|
|||||||
|
|
||||||
use anyhow::anyhow;
|
use anyhow::anyhow;
|
||||||
use axum::{
|
use axum::{
|
||||||
extract::Path,
|
extract::{Path, State},
|
||||||
response::{IntoResponse, Response},
|
response::{IntoResponse, Response},
|
||||||
Extension,
|
|
||||||
};
|
};
|
||||||
use axum_extra::extract::PrivateCookieJar;
|
use axum_extra::extract::PrivateCookieJar;
|
||||||
use hyper::StatusCode;
|
use hyper::StatusCode;
|
||||||
@ -104,9 +103,9 @@ impl From<CallbackDestinationError> for RouteError {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) async fn get(
|
pub(crate) async fn get(
|
||||||
Extension(policy_factory): Extension<Arc<PolicyFactory>>,
|
State(policy_factory): State<Arc<PolicyFactory>>,
|
||||||
Extension(templates): Extension<Templates>,
|
State(templates): State<Templates>,
|
||||||
Extension(pool): Extension<PgPool>,
|
State(pool): State<PgPool>,
|
||||||
cookie_jar: PrivateCookieJar<Encrypter>,
|
cookie_jar: PrivateCookieJar<Encrypter>,
|
||||||
Path(grant_id): Path<i64>,
|
Path(grant_id): Path<i64>,
|
||||||
) -> Result<Response, RouteError> {
|
) -> Result<Response, RouteError> {
|
||||||
|
@ -16,7 +16,7 @@ use std::sync::Arc;
|
|||||||
|
|
||||||
use anyhow::{anyhow, Context};
|
use anyhow::{anyhow, Context};
|
||||||
use axum::{
|
use axum::{
|
||||||
extract::{Extension, Form},
|
extract::{Form, State},
|
||||||
response::{IntoResponse, Response},
|
response::{IntoResponse, Response},
|
||||||
};
|
};
|
||||||
use axum_extra::extract::PrivateCookieJar;
|
use axum_extra::extract::PrivateCookieJar;
|
||||||
@ -156,9 +156,9 @@ fn resolve_response_mode(
|
|||||||
|
|
||||||
#[allow(clippy::too_many_lines)]
|
#[allow(clippy::too_many_lines)]
|
||||||
pub(crate) async fn get(
|
pub(crate) async fn get(
|
||||||
Extension(policy_factory): Extension<Arc<PolicyFactory>>,
|
State(policy_factory): State<Arc<PolicyFactory>>,
|
||||||
Extension(templates): Extension<Templates>,
|
State(templates): State<Templates>,
|
||||||
Extension(pool): Extension<PgPool>,
|
State(pool): State<PgPool>,
|
||||||
cookie_jar: PrivateCookieJar<Encrypter>,
|
cookie_jar: PrivateCookieJar<Encrypter>,
|
||||||
Form(params): Form<Params>,
|
Form(params): Form<Params>,
|
||||||
) -> Result<Response, RouteError> {
|
) -> Result<Response, RouteError> {
|
||||||
|
@ -16,7 +16,7 @@ use std::sync::Arc;
|
|||||||
|
|
||||||
use anyhow::Context;
|
use anyhow::Context;
|
||||||
use axum::{
|
use axum::{
|
||||||
extract::{Extension, Form, Path},
|
extract::{Form, Path, State},
|
||||||
response::{Html, IntoResponse, Response},
|
response::{Html, IntoResponse, Response},
|
||||||
};
|
};
|
||||||
use axum_extra::extract::PrivateCookieJar;
|
use axum_extra::extract::PrivateCookieJar;
|
||||||
@ -50,9 +50,9 @@ impl IntoResponse for RouteError {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) async fn get(
|
pub(crate) async fn get(
|
||||||
Extension(policy_factory): Extension<Arc<PolicyFactory>>,
|
State(policy_factory): State<Arc<PolicyFactory>>,
|
||||||
Extension(templates): Extension<Templates>,
|
State(templates): State<Templates>,
|
||||||
Extension(pool): Extension<PgPool>,
|
State(pool): State<PgPool>,
|
||||||
cookie_jar: PrivateCookieJar<Encrypter>,
|
cookie_jar: PrivateCookieJar<Encrypter>,
|
||||||
Path(grant_id): Path<i64>,
|
Path(grant_id): Path<i64>,
|
||||||
) -> Result<Response, RouteError> {
|
) -> Result<Response, RouteError> {
|
||||||
@ -112,8 +112,8 @@ pub(crate) async fn get(
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) async fn post(
|
pub(crate) async fn post(
|
||||||
Extension(policy_factory): Extension<Arc<PolicyFactory>>,
|
State(policy_factory): State<Arc<PolicyFactory>>,
|
||||||
Extension(pool): Extension<PgPool>,
|
State(pool): State<PgPool>,
|
||||||
cookie_jar: PrivateCookieJar<Encrypter>,
|
cookie_jar: PrivateCookieJar<Encrypter>,
|
||||||
Path(grant_id): Path<i64>,
|
Path(grant_id): Path<i64>,
|
||||||
Form(form): Form<ProtectedForm<()>>,
|
Form(form): Form<ProtectedForm<()>>,
|
||||||
|
@ -12,7 +12,7 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
use axum::{extract::Extension, response::IntoResponse, Json};
|
use axum::{extract::State, response::IntoResponse, Json};
|
||||||
use mas_iana::{
|
use mas_iana::{
|
||||||
jose::JsonWebSignatureAlg,
|
jose::JsonWebSignatureAlg,
|
||||||
oauth::{
|
oauth::{
|
||||||
@ -30,8 +30,8 @@ use oauth2_types::{
|
|||||||
|
|
||||||
#[allow(clippy::too_many_lines)]
|
#[allow(clippy::too_many_lines)]
|
||||||
pub(crate) async fn get(
|
pub(crate) async fn get(
|
||||||
Extension(key_store): Extension<Keystore>,
|
State(key_store): State<Keystore>,
|
||||||
Extension(url_builder): Extension<UrlBuilder>,
|
State(url_builder): State<UrlBuilder>,
|
||||||
) -> impl IntoResponse {
|
) -> impl IntoResponse {
|
||||||
// This is how clients can authenticate
|
// This is how clients can authenticate
|
||||||
let client_auth_methods_supported = Some(vec![
|
let client_auth_methods_supported = Some(vec![
|
||||||
|
@ -12,7 +12,7 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
use axum::{extract::Extension, response::IntoResponse, Json};
|
use axum::{extract::State, response::IntoResponse, Json};
|
||||||
use hyper::StatusCode;
|
use hyper::StatusCode;
|
||||||
use mas_axum_utils::client_authorization::{ClientAuthorization, CredentialsVerificationError};
|
use mas_axum_utils::client_authorization::{ClientAuthorization, CredentialsVerificationError};
|
||||||
use mas_data_model::{TokenFormatError, TokenType};
|
use mas_data_model::{TokenFormatError, TokenType};
|
||||||
@ -154,8 +154,8 @@ const INACTIVE: IntrospectionResponse = IntrospectionResponse {
|
|||||||
|
|
||||||
#[tracing::instrument(skip_all, err)]
|
#[tracing::instrument(skip_all, err)]
|
||||||
pub(crate) async fn post(
|
pub(crate) async fn post(
|
||||||
Extension(pool): Extension<PgPool>,
|
State(pool): State<PgPool>,
|
||||||
Extension(encrypter): Extension<Encrypter>,
|
State(encrypter): State<Encrypter>,
|
||||||
client_authorization: ClientAuthorization<IntrospectionRequest>,
|
client_authorization: ClientAuthorization<IntrospectionRequest>,
|
||||||
) -> Result<impl IntoResponse, RouteError> {
|
) -> Result<impl IntoResponse, RouteError> {
|
||||||
let mut conn = pool.acquire().await?;
|
let mut conn = pool.acquire().await?;
|
||||||
|
@ -12,10 +12,10 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
use axum::{extract::Extension, response::IntoResponse, Json};
|
use axum::{extract::State, response::IntoResponse, Json};
|
||||||
use mas_keystore::Keystore;
|
use mas_keystore::Keystore;
|
||||||
|
|
||||||
pub(crate) async fn get(Extension(key_store): Extension<Keystore>) -> impl IntoResponse {
|
pub(crate) async fn get(State(key_store): State<Keystore>) -> impl IntoResponse {
|
||||||
let jwks = key_store.public_jwks();
|
let jwks = key_store.public_jwks();
|
||||||
Json(jwks)
|
Json(jwks)
|
||||||
}
|
}
|
||||||
|
@ -14,7 +14,7 @@
|
|||||||
|
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
use axum::{response::IntoResponse, Extension, Json};
|
use axum::{extract::State, response::IntoResponse, Json};
|
||||||
use hyper::StatusCode;
|
use hyper::StatusCode;
|
||||||
use mas_policy::{PolicyFactory, Violation};
|
use mas_policy::{PolicyFactory, Violation};
|
||||||
use mas_storage::oauth2::client::insert_client;
|
use mas_storage::oauth2::client::insert_client;
|
||||||
@ -105,8 +105,8 @@ impl IntoResponse for RouteError {
|
|||||||
|
|
||||||
#[tracing::instrument(skip_all, err)]
|
#[tracing::instrument(skip_all, err)]
|
||||||
pub(crate) async fn post(
|
pub(crate) async fn post(
|
||||||
Extension(pool): Extension<PgPool>,
|
State(pool): State<PgPool>,
|
||||||
Extension(policy_factory): Extension<Arc<PolicyFactory>>,
|
State(policy_factory): State<Arc<PolicyFactory>>,
|
||||||
Json(body): Json<ClientMetadata>,
|
Json(body): Json<ClientMetadata>,
|
||||||
) -> Result<impl IntoResponse, RouteError> {
|
) -> Result<impl IntoResponse, RouteError> {
|
||||||
info!(?body, "Client registration");
|
info!(?body, "Client registration");
|
||||||
|
@ -15,7 +15,7 @@
|
|||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
|
||||||
use anyhow::Context;
|
use anyhow::Context;
|
||||||
use axum::{extract::Extension, response::IntoResponse, Json};
|
use axum::{extract::State, response::IntoResponse, Json};
|
||||||
use chrono::{DateTime, Duration, Utc};
|
use chrono::{DateTime, Duration, Utc};
|
||||||
use data_encoding::BASE64URL_NOPAD;
|
use data_encoding::BASE64URL_NOPAD;
|
||||||
use headers::{CacheControl, HeaderMap, HeaderMapExt, Pragma};
|
use headers::{CacheControl, HeaderMap, HeaderMapExt, Pragma};
|
||||||
@ -188,11 +188,11 @@ impl From<JwtSignatureError> for RouteError {
|
|||||||
|
|
||||||
#[tracing::instrument(skip_all, err)]
|
#[tracing::instrument(skip_all, err)]
|
||||||
pub(crate) async fn post(
|
pub(crate) async fn post(
|
||||||
|
State(key_store): State<Keystore>,
|
||||||
|
State(url_builder): State<UrlBuilder>,
|
||||||
|
State(pool): State<PgPool>,
|
||||||
|
State(encrypter): State<Encrypter>,
|
||||||
client_authorization: ClientAuthorization<AccessTokenRequest>,
|
client_authorization: ClientAuthorization<AccessTokenRequest>,
|
||||||
Extension(key_store): Extension<Keystore>,
|
|
||||||
Extension(url_builder): Extension<UrlBuilder>,
|
|
||||||
Extension(pool): Extension<PgPool>,
|
|
||||||
Extension(encrypter): Extension<Encrypter>,
|
|
||||||
) -> Result<impl IntoResponse, RouteError> {
|
) -> Result<impl IntoResponse, RouteError> {
|
||||||
let mut txn = pool.begin().await?;
|
let mut txn = pool.begin().await?;
|
||||||
|
|
||||||
|
@ -14,7 +14,7 @@
|
|||||||
|
|
||||||
use anyhow::Context;
|
use anyhow::Context;
|
||||||
use axum::{
|
use axum::{
|
||||||
extract::Extension,
|
extract::State,
|
||||||
response::{IntoResponse, Response},
|
response::{IntoResponse, Response},
|
||||||
Json,
|
Json,
|
||||||
};
|
};
|
||||||
@ -48,9 +48,9 @@ struct SignedUserInfo {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub async fn get(
|
pub async fn get(
|
||||||
Extension(url_builder): Extension<UrlBuilder>,
|
State(url_builder): State<UrlBuilder>,
|
||||||
Extension(pool): Extension<PgPool>,
|
State(pool): State<PgPool>,
|
||||||
Extension(key_store): Extension<Keystore>,
|
State(key_store): State<Keystore>,
|
||||||
user_authorization: UserAuthorization,
|
user_authorization: UserAuthorization,
|
||||||
) -> Result<Response, FancyError> {
|
) -> Result<Response, FancyError> {
|
||||||
// TODO: error handling
|
// TODO: error handling
|
||||||
|
@ -12,7 +12,11 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
use axum::{extract::Query, response::IntoResponse, Extension, Json, TypedHeader};
|
use axum::{
|
||||||
|
extract::{Query, State},
|
||||||
|
response::IntoResponse,
|
||||||
|
Json, TypedHeader,
|
||||||
|
};
|
||||||
use headers::ContentType;
|
use headers::ContentType;
|
||||||
use mas_router::UrlBuilder;
|
use mas_router::UrlBuilder;
|
||||||
use oauth2_types::webfinger::WebFingerResponse;
|
use oauth2_types::webfinger::WebFingerResponse;
|
||||||
@ -33,7 +37,7 @@ fn jrd() -> mime::Mime {
|
|||||||
|
|
||||||
pub(crate) async fn get(
|
pub(crate) async fn get(
|
||||||
Query(params): Query<Params>,
|
Query(params): Query<Params>,
|
||||||
Extension(url_builder): Extension<UrlBuilder>,
|
State(url_builder): State<UrlBuilder>,
|
||||||
) -> impl IntoResponse {
|
) -> impl IntoResponse {
|
||||||
// TODO: should we validate the subject?
|
// TODO: should we validate the subject?
|
||||||
let subject = params.resource;
|
let subject = params.resource;
|
||||||
|
@ -13,7 +13,7 @@
|
|||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
use axum::{
|
use axum::{
|
||||||
extract::{Extension, Form, Query},
|
extract::{Form, Query, State},
|
||||||
response::{Html, IntoResponse, Response},
|
response::{Html, IntoResponse, Response},
|
||||||
};
|
};
|
||||||
use axum_extra::extract::PrivateCookieJar;
|
use axum_extra::extract::PrivateCookieJar;
|
||||||
@ -38,8 +38,8 @@ pub struct EmailForm {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) async fn get(
|
pub(crate) async fn get(
|
||||||
Extension(templates): Extension<Templates>,
|
State(templates): State<Templates>,
|
||||||
Extension(pool): Extension<PgPool>,
|
State(pool): State<PgPool>,
|
||||||
cookie_jar: PrivateCookieJar<Encrypter>,
|
cookie_jar: PrivateCookieJar<Encrypter>,
|
||||||
) -> Result<Response, FancyError> {
|
) -> Result<Response, FancyError> {
|
||||||
let mut conn = pool.begin().await?;
|
let mut conn = pool.begin().await?;
|
||||||
@ -66,8 +66,8 @@ pub(crate) async fn get(
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) async fn post(
|
pub(crate) async fn post(
|
||||||
Extension(pool): Extension<PgPool>,
|
State(pool): State<PgPool>,
|
||||||
Extension(mailer): Extension<Mailer>,
|
State(mailer): State<Mailer>,
|
||||||
cookie_jar: PrivateCookieJar<Encrypter>,
|
cookie_jar: PrivateCookieJar<Encrypter>,
|
||||||
Query(query): Query<OptionalPostAuthAction>,
|
Query(query): Query<OptionalPostAuthAction>,
|
||||||
Form(form): Form<ProtectedForm<EmailForm>>,
|
Form(form): Form<ProtectedForm<EmailForm>>,
|
||||||
|
@ -13,7 +13,7 @@
|
|||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
use axum::{
|
use axum::{
|
||||||
extract::{Extension, Form},
|
extract::{Form, State},
|
||||||
response::{Html, IntoResponse, Response},
|
response::{Html, IntoResponse, Response},
|
||||||
};
|
};
|
||||||
use axum_extra::extract::PrivateCookieJar;
|
use axum_extra::extract::PrivateCookieJar;
|
||||||
@ -52,8 +52,8 @@ pub enum ManagementForm {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) async fn get(
|
pub(crate) async fn get(
|
||||||
Extension(templates): Extension<Templates>,
|
State(templates): State<Templates>,
|
||||||
Extension(pool): Extension<PgPool>,
|
State(pool): State<PgPool>,
|
||||||
cookie_jar: PrivateCookieJar<Encrypter>,
|
cookie_jar: PrivateCookieJar<Encrypter>,
|
||||||
) -> Result<Response, FancyError> {
|
) -> Result<Response, FancyError> {
|
||||||
let mut conn = pool.acquire().await?;
|
let mut conn = pool.acquire().await?;
|
||||||
@ -118,9 +118,9 @@ async fn start_email_verification(
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) async fn post(
|
pub(crate) async fn post(
|
||||||
Extension(templates): Extension<Templates>,
|
State(templates): State<Templates>,
|
||||||
Extension(pool): Extension<PgPool>,
|
State(pool): State<PgPool>,
|
||||||
Extension(mailer): Extension<Mailer>,
|
State(mailer): State<Mailer>,
|
||||||
cookie_jar: PrivateCookieJar<Encrypter>,
|
cookie_jar: PrivateCookieJar<Encrypter>,
|
||||||
Form(form): Form<ProtectedForm<ManagementForm>>,
|
Form(form): Form<ProtectedForm<ManagementForm>>,
|
||||||
) -> Result<Response, FancyError> {
|
) -> Result<Response, FancyError> {
|
||||||
|
@ -13,7 +13,7 @@
|
|||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
use axum::{
|
use axum::{
|
||||||
extract::{Extension, Form, Path, Query},
|
extract::{Form, Path, Query, State},
|
||||||
response::{Html, IntoResponse, Response},
|
response::{Html, IntoResponse, Response},
|
||||||
};
|
};
|
||||||
use axum_extra::extract::PrivateCookieJar;
|
use axum_extra::extract::PrivateCookieJar;
|
||||||
@ -40,8 +40,8 @@ pub struct CodeForm {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) async fn get(
|
pub(crate) async fn get(
|
||||||
Extension(templates): Extension<Templates>,
|
State(templates): State<Templates>,
|
||||||
Extension(pool): Extension<PgPool>,
|
State(pool): State<PgPool>,
|
||||||
Query(query): Query<OptionalPostAuthAction>,
|
Query(query): Query<OptionalPostAuthAction>,
|
||||||
Path(id): Path<i64>,
|
Path(id): Path<i64>,
|
||||||
cookie_jar: PrivateCookieJar<Encrypter>,
|
cookie_jar: PrivateCookieJar<Encrypter>,
|
||||||
@ -78,7 +78,7 @@ pub(crate) async fn get(
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) async fn post(
|
pub(crate) async fn post(
|
||||||
Extension(pool): Extension<PgPool>,
|
State(pool): State<PgPool>,
|
||||||
cookie_jar: PrivateCookieJar<Encrypter>,
|
cookie_jar: PrivateCookieJar<Encrypter>,
|
||||||
Query(query): Query<OptionalPostAuthAction>,
|
Query(query): Query<OptionalPostAuthAction>,
|
||||||
Path(id): Path<i64>,
|
Path(id): Path<i64>,
|
||||||
|
@ -16,7 +16,7 @@ pub mod emails;
|
|||||||
pub mod password;
|
pub mod password;
|
||||||
|
|
||||||
use axum::{
|
use axum::{
|
||||||
extract::Extension,
|
extract::State,
|
||||||
response::{Html, IntoResponse, Response},
|
response::{Html, IntoResponse, Response},
|
||||||
};
|
};
|
||||||
use axum_extra::extract::PrivateCookieJar;
|
use axum_extra::extract::PrivateCookieJar;
|
||||||
@ -28,8 +28,8 @@ use mas_templates::{AccountContext, TemplateContext, Templates};
|
|||||||
use sqlx::PgPool;
|
use sqlx::PgPool;
|
||||||
|
|
||||||
pub(crate) async fn get(
|
pub(crate) async fn get(
|
||||||
Extension(templates): Extension<Templates>,
|
State(templates): State<Templates>,
|
||||||
Extension(pool): Extension<PgPool>,
|
State(pool): State<PgPool>,
|
||||||
cookie_jar: PrivateCookieJar<Encrypter>,
|
cookie_jar: PrivateCookieJar<Encrypter>,
|
||||||
) -> Result<Response, FancyError> {
|
) -> Result<Response, FancyError> {
|
||||||
let mut conn = pool.acquire().await?;
|
let mut conn = pool.acquire().await?;
|
||||||
|
@ -14,7 +14,7 @@
|
|||||||
|
|
||||||
use argon2::Argon2;
|
use argon2::Argon2;
|
||||||
use axum::{
|
use axum::{
|
||||||
extract::{Extension, Form},
|
extract::{Form, State},
|
||||||
response::{Html, IntoResponse, Response},
|
response::{Html, IntoResponse, Response},
|
||||||
};
|
};
|
||||||
use axum_extra::extract::PrivateCookieJar;
|
use axum_extra::extract::PrivateCookieJar;
|
||||||
@ -41,8 +41,8 @@ pub struct ChangeForm {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) async fn get(
|
pub(crate) async fn get(
|
||||||
Extension(templates): Extension<Templates>,
|
State(templates): State<Templates>,
|
||||||
Extension(pool): Extension<PgPool>,
|
State(pool): State<PgPool>,
|
||||||
cookie_jar: PrivateCookieJar<Encrypter>,
|
cookie_jar: PrivateCookieJar<Encrypter>,
|
||||||
) -> Result<Response, FancyError> {
|
) -> Result<Response, FancyError> {
|
||||||
let mut conn = pool.acquire().await?;
|
let mut conn = pool.acquire().await?;
|
||||||
@ -76,8 +76,8 @@ async fn render(
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) async fn post(
|
pub(crate) async fn post(
|
||||||
Extension(templates): Extension<Templates>,
|
State(templates): State<Templates>,
|
||||||
Extension(pool): Extension<PgPool>,
|
State(pool): State<PgPool>,
|
||||||
cookie_jar: PrivateCookieJar<Encrypter>,
|
cookie_jar: PrivateCookieJar<Encrypter>,
|
||||||
Form(form): Form<ProtectedForm<ChangeForm>>,
|
Form(form): Form<ProtectedForm<ChangeForm>>,
|
||||||
) -> Result<Response, FancyError> {
|
) -> Result<Response, FancyError> {
|
||||||
|
@ -13,7 +13,7 @@
|
|||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
use axum::{
|
use axum::{
|
||||||
extract::Extension,
|
extract::State,
|
||||||
response::{Html, IntoResponse},
|
response::{Html, IntoResponse},
|
||||||
};
|
};
|
||||||
use axum_extra::extract::PrivateCookieJar;
|
use axum_extra::extract::PrivateCookieJar;
|
||||||
@ -24,9 +24,9 @@ use mas_templates::{IndexContext, TemplateContext, Templates};
|
|||||||
use sqlx::PgPool;
|
use sqlx::PgPool;
|
||||||
|
|
||||||
pub async fn get(
|
pub async fn get(
|
||||||
Extension(templates): Extension<Templates>,
|
State(templates): State<Templates>,
|
||||||
Extension(url_builder): Extension<UrlBuilder>,
|
State(url_builder): State<UrlBuilder>,
|
||||||
Extension(pool): Extension<PgPool>,
|
State(pool): State<PgPool>,
|
||||||
cookie_jar: PrivateCookieJar<Encrypter>,
|
cookie_jar: PrivateCookieJar<Encrypter>,
|
||||||
) -> Result<impl IntoResponse, FancyError> {
|
) -> Result<impl IntoResponse, FancyError> {
|
||||||
let mut conn = pool.acquire().await?;
|
let mut conn = pool.acquire().await?;
|
||||||
|
@ -13,7 +13,7 @@
|
|||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
use axum::{
|
use axum::{
|
||||||
extract::{Extension, Form, Query},
|
extract::{Form, Query, State},
|
||||||
response::{Html, IntoResponse, Response},
|
response::{Html, IntoResponse, Response},
|
||||||
};
|
};
|
||||||
use axum_extra::extract::PrivateCookieJar;
|
use axum_extra::extract::PrivateCookieJar;
|
||||||
@ -44,8 +44,8 @@ impl ToFormState for LoginForm {
|
|||||||
|
|
||||||
#[tracing::instrument(skip(templates, pool, cookie_jar))]
|
#[tracing::instrument(skip(templates, pool, cookie_jar))]
|
||||||
pub(crate) async fn get(
|
pub(crate) async fn get(
|
||||||
Extension(templates): Extension<Templates>,
|
State(templates): State<Templates>,
|
||||||
Extension(pool): Extension<PgPool>,
|
State(pool): State<PgPool>,
|
||||||
Query(query): Query<OptionalPostAuthAction>,
|
Query(query): Query<OptionalPostAuthAction>,
|
||||||
cookie_jar: PrivateCookieJar<Encrypter>,
|
cookie_jar: PrivateCookieJar<Encrypter>,
|
||||||
) -> Result<Response, FancyError> {
|
) -> Result<Response, FancyError> {
|
||||||
@ -74,8 +74,8 @@ pub(crate) async fn get(
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) async fn post(
|
pub(crate) async fn post(
|
||||||
Extension(templates): Extension<Templates>,
|
State(templates): State<Templates>,
|
||||||
Extension(pool): Extension<PgPool>,
|
State(pool): State<PgPool>,
|
||||||
Query(query): Query<OptionalPostAuthAction>,
|
Query(query): Query<OptionalPostAuthAction>,
|
||||||
cookie_jar: PrivateCookieJar<Encrypter>,
|
cookie_jar: PrivateCookieJar<Encrypter>,
|
||||||
Form(form): Form<ProtectedForm<LoginForm>>,
|
Form(form): Form<ProtectedForm<LoginForm>>,
|
||||||
|
@ -13,7 +13,7 @@
|
|||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
use axum::{
|
use axum::{
|
||||||
extract::{Extension, Form},
|
extract::{Form, State},
|
||||||
response::IntoResponse,
|
response::IntoResponse,
|
||||||
};
|
};
|
||||||
use axum_extra::extract::PrivateCookieJar;
|
use axum_extra::extract::PrivateCookieJar;
|
||||||
@ -27,7 +27,7 @@ use mas_storage::user::end_session;
|
|||||||
use sqlx::PgPool;
|
use sqlx::PgPool;
|
||||||
|
|
||||||
pub(crate) async fn post(
|
pub(crate) async fn post(
|
||||||
Extension(pool): Extension<PgPool>,
|
State(pool): State<PgPool>,
|
||||||
cookie_jar: PrivateCookieJar<Encrypter>,
|
cookie_jar: PrivateCookieJar<Encrypter>,
|
||||||
Form(form): Form<ProtectedForm<Option<PostAuthAction>>>,
|
Form(form): Form<ProtectedForm<Option<PostAuthAction>>>,
|
||||||
) -> Result<impl IntoResponse, FancyError> {
|
) -> Result<impl IntoResponse, FancyError> {
|
||||||
|
@ -13,7 +13,7 @@
|
|||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
use axum::{
|
use axum::{
|
||||||
extract::{Extension, Form, Query},
|
extract::{Form, Query, State},
|
||||||
response::{Html, IntoResponse, Response},
|
response::{Html, IntoResponse, Response},
|
||||||
};
|
};
|
||||||
use axum_extra::extract::PrivateCookieJar;
|
use axum_extra::extract::PrivateCookieJar;
|
||||||
@ -36,8 +36,8 @@ pub(crate) struct ReauthForm {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) async fn get(
|
pub(crate) async fn get(
|
||||||
Extension(templates): Extension<Templates>,
|
State(templates): State<Templates>,
|
||||||
Extension(pool): Extension<PgPool>,
|
State(pool): State<PgPool>,
|
||||||
Query(query): Query<OptionalPostAuthAction>,
|
Query(query): Query<OptionalPostAuthAction>,
|
||||||
cookie_jar: PrivateCookieJar<Encrypter>,
|
cookie_jar: PrivateCookieJar<Encrypter>,
|
||||||
) -> Result<Response, FancyError> {
|
) -> Result<Response, FancyError> {
|
||||||
@ -75,7 +75,7 @@ pub(crate) async fn get(
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) async fn post(
|
pub(crate) async fn post(
|
||||||
Extension(pool): Extension<PgPool>,
|
State(pool): State<PgPool>,
|
||||||
Query(query): Query<OptionalPostAuthAction>,
|
Query(query): Query<OptionalPostAuthAction>,
|
||||||
cookie_jar: PrivateCookieJar<Encrypter>,
|
cookie_jar: PrivateCookieJar<Encrypter>,
|
||||||
Form(form): Form<ProtectedForm<ReauthForm>>,
|
Form(form): Form<ProtectedForm<ReauthForm>>,
|
||||||
|
@ -18,7 +18,7 @@ use std::{str::FromStr, sync::Arc};
|
|||||||
|
|
||||||
use argon2::Argon2;
|
use argon2::Argon2;
|
||||||
use axum::{
|
use axum::{
|
||||||
extract::{Extension, Form, Query},
|
extract::{Form, Query, State},
|
||||||
response::{Html, IntoResponse, Response},
|
response::{Html, IntoResponse, Response},
|
||||||
};
|
};
|
||||||
use axum_extra::extract::PrivateCookieJar;
|
use axum_extra::extract::PrivateCookieJar;
|
||||||
@ -57,8 +57,8 @@ impl ToFormState for RegisterForm {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) async fn get(
|
pub(crate) async fn get(
|
||||||
Extension(templates): Extension<Templates>,
|
State(templates): State<Templates>,
|
||||||
Extension(pool): Extension<PgPool>,
|
State(pool): State<PgPool>,
|
||||||
Query(query): Query<OptionalPostAuthAction>,
|
Query(query): Query<OptionalPostAuthAction>,
|
||||||
cookie_jar: PrivateCookieJar<Encrypter>,
|
cookie_jar: PrivateCookieJar<Encrypter>,
|
||||||
) -> Result<Response, FancyError> {
|
) -> Result<Response, FancyError> {
|
||||||
@ -87,10 +87,10 @@ pub(crate) async fn get(
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) async fn post(
|
pub(crate) async fn post(
|
||||||
Extension(mailer): Extension<Mailer>,
|
State(mailer): State<Mailer>,
|
||||||
Extension(policy_factory): Extension<Arc<PolicyFactory>>,
|
State(policy_factory): State<Arc<PolicyFactory>>,
|
||||||
Extension(templates): Extension<Templates>,
|
State(templates): State<Templates>,
|
||||||
Extension(pool): Extension<PgPool>,
|
State(pool): State<PgPool>,
|
||||||
Query(query): Query<OptionalPostAuthAction>,
|
Query(query): Query<OptionalPostAuthAction>,
|
||||||
cookie_jar: PrivateCookieJar<Encrypter>,
|
cookie_jar: PrivateCookieJar<Encrypter>,
|
||||||
Form(form): Form<ProtectedForm<RegisterForm>>,
|
Form(form): Form<ProtectedForm<RegisterForm>>,
|
||||||
|
@ -6,7 +6,7 @@ edition = "2021"
|
|||||||
license = "Apache-2.0"
|
license = "Apache-2.0"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
axum = { version = "0.5.15", optional = true }
|
axum = { version = "0.6.0-rc.1", optional = true }
|
||||||
bytes = "1.2.1"
|
bytes = "1.2.1"
|
||||||
futures-util = "0.3.24"
|
futures-util = "0.3.24"
|
||||||
headers = "0.3.8"
|
headers = "0.3.8"
|
||||||
|
@ -6,7 +6,7 @@ edition = "2021"
|
|||||||
license = "Apache-2.0"
|
license = "Apache-2.0"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
axum = { version = "0.5.15", default-features = false }
|
axum = { version = "0.6.0-rc.1", default-features = false }
|
||||||
serde = { version = "1.0.144", features = ["derive"] }
|
serde = { version = "1.0.144", features = ["derive"] }
|
||||||
serde_urlencoded = "0.7.1"
|
serde_urlencoded = "0.7.1"
|
||||||
serde_with = "2.0.0"
|
serde_with = "2.0.0"
|
||||||
|
@ -9,8 +9,8 @@ license = "Apache-2.0"
|
|||||||
dev = []
|
dev = []
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
axum = "0.5.15"
|
axum = { version = "0.6.0-rc.1", features = ["headers"] }
|
||||||
headers = "0.3.8"
|
headers = "0.3.7"
|
||||||
http = "0.2.8"
|
http = "0.2.8"
|
||||||
http-body = "0.4.5"
|
http-body = "0.4.5"
|
||||||
mime_guess = "2.0.4"
|
mime_guess = "2.0.4"
|
||||||
|
Reference in New Issue
Block a user