You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-08-09 04:22:45 +03:00
Upgrade axum to 0.6.0-rc.1
This commit is contained in:
@@ -7,8 +7,8 @@ license = "Apache-2.0"
|
||||
|
||||
[dependencies]
|
||||
async-trait = "0.1.57"
|
||||
axum = { version = "0.5.15", features = ["headers"] }
|
||||
axum-extra = { version = "0.3.7", features = ["cookie-private"] }
|
||||
axum = { version = "0.6.0-rc.1", features = ["headers"] }
|
||||
axum-extra = { version = "0.4.0-rc.1", features = ["cookie-private"] }
|
||||
bincode = "1.3.3"
|
||||
chrono = "0.4.22"
|
||||
data-encoding = "2.3.2"
|
||||
|
@@ -19,13 +19,13 @@ use axum::{
|
||||
body::HttpBody,
|
||||
extract::{
|
||||
rejection::{FailedToDeserializeQueryString, FormRejection, TypedHeaderRejectionReason},
|
||||
Form, FromRequest, RequestParts, TypedHeader,
|
||||
Form, FromRequest, FromRequestParts, TypedHeader,
|
||||
},
|
||||
response::IntoResponse,
|
||||
BoxError,
|
||||
};
|
||||
use headers::{authorization::Basic, Authorization};
|
||||
use http::StatusCode;
|
||||
use http::{Request, StatusCode};
|
||||
use mas_data_model::{Client, JwksOrJwksUri, StorageBackend};
|
||||
use mas_http::HttpServiceExt;
|
||||
use mas_iana::oauth::OAuthClientAuthenticationMethod;
|
||||
@@ -234,18 +234,23 @@ impl IntoResponse for ClientAuthorizationError {
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<B, F> FromRequest<B> for ClientAuthorization<F>
|
||||
impl<S, B, F> FromRequest<S, B> for ClientAuthorization<F>
|
||||
where
|
||||
B: Send + HttpBody,
|
||||
B::Data: Send,
|
||||
B::Error: std::error::Error + Send + Sync + 'static,
|
||||
F: DeserializeOwned,
|
||||
B: HttpBody + Send + 'static,
|
||||
B::Data: Send,
|
||||
B::Error: Into<BoxError>,
|
||||
S: Send + Sync,
|
||||
{
|
||||
type Rejection = ClientAuthorizationError;
|
||||
|
||||
#[allow(clippy::too_many_lines)]
|
||||
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
|
||||
let header = TypedHeader::<Authorization<Basic>>::from_request(req).await;
|
||||
async fn from_request(req: Request<B>, state: &S) -> Result<Self, Self::Rejection> {
|
||||
// Split the request into parts so we can extract some headers
|
||||
let (mut parts, body) = req.into_parts();
|
||||
|
||||
let header =
|
||||
TypedHeader::<Authorization<Basic>>::from_request_parts(&mut parts, state).await;
|
||||
|
||||
// Take the Authorization header
|
||||
let credentials_from_header = match header {
|
||||
@@ -258,6 +263,9 @@ where
|
||||
},
|
||||
};
|
||||
|
||||
// Reconstruct the request from the parts
|
||||
let req = Request::from_parts(parts, body);
|
||||
|
||||
// Take the form value
|
||||
let (
|
||||
client_id_from_form,
|
||||
@@ -265,7 +273,7 @@ where
|
||||
client_assertion_type,
|
||||
client_assertion,
|
||||
form,
|
||||
) = match Form::<AuthorizedForm<F>>::from_request(req).await {
|
||||
) = match Form::<AuthorizedForm<F>>::from_request(req, state).await {
|
||||
Ok(Form(form)) => (
|
||||
form.client_id,
|
||||
form.client_secret,
|
||||
@@ -385,19 +393,17 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn none_test() {
|
||||
let mut req = RequestParts::new(
|
||||
Request::builder()
|
||||
.method(Method::POST)
|
||||
.header(
|
||||
http::header::CONTENT_TYPE,
|
||||
mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
|
||||
)
|
||||
.body(Full::<Bytes>::new("client_id=client-id&foo=bar".into()))
|
||||
.unwrap(),
|
||||
);
|
||||
let req = Request::builder()
|
||||
.method(Method::POST)
|
||||
.header(
|
||||
http::header::CONTENT_TYPE,
|
||||
mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
|
||||
)
|
||||
.body(Full::<Bytes>::new("client_id=client-id&foo=bar".into()))
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(
|
||||
ClientAuthorization::<serde_json::Value>::from_request(&mut req)
|
||||
ClientAuthorization::<serde_json::Value>::from_request(req, &())
|
||||
.await
|
||||
.unwrap(),
|
||||
ClientAuthorization {
|
||||
@@ -411,23 +417,21 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn client_secret_basic_test() {
|
||||
let mut req = RequestParts::new(
|
||||
Request::builder()
|
||||
.method(Method::POST)
|
||||
.header(
|
||||
http::header::CONTENT_TYPE,
|
||||
mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
|
||||
)
|
||||
.header(
|
||||
http::header::AUTHORIZATION,
|
||||
"Basic Y2xpZW50LWlkOmNsaWVudC1zZWNyZXQ=",
|
||||
)
|
||||
.body(Full::<Bytes>::new("foo=bar".into()))
|
||||
.unwrap(),
|
||||
);
|
||||
let req = Request::builder()
|
||||
.method(Method::POST)
|
||||
.header(
|
||||
http::header::CONTENT_TYPE,
|
||||
mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
|
||||
)
|
||||
.header(
|
||||
http::header::AUTHORIZATION,
|
||||
"Basic Y2xpZW50LWlkOmNsaWVudC1zZWNyZXQ=",
|
||||
)
|
||||
.body(Full::<Bytes>::new("foo=bar".into()))
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(
|
||||
ClientAuthorization::<serde_json::Value>::from_request(&mut req)
|
||||
ClientAuthorization::<serde_json::Value>::from_request(req, &())
|
||||
.await
|
||||
.unwrap(),
|
||||
ClientAuthorization {
|
||||
@@ -440,23 +444,21 @@ mod tests {
|
||||
);
|
||||
|
||||
// client_id in both header and body
|
||||
let mut req = RequestParts::new(
|
||||
Request::builder()
|
||||
.method(Method::POST)
|
||||
.header(
|
||||
http::header::CONTENT_TYPE,
|
||||
mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
|
||||
)
|
||||
.header(
|
||||
http::header::AUTHORIZATION,
|
||||
"Basic Y2xpZW50LWlkOmNsaWVudC1zZWNyZXQ=",
|
||||
)
|
||||
.body(Full::<Bytes>::new("client_id=client-id&foo=bar".into()))
|
||||
.unwrap(),
|
||||
);
|
||||
let req = Request::builder()
|
||||
.method(Method::POST)
|
||||
.header(
|
||||
http::header::CONTENT_TYPE,
|
||||
mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
|
||||
)
|
||||
.header(
|
||||
http::header::AUTHORIZATION,
|
||||
"Basic Y2xpZW50LWlkOmNsaWVudC1zZWNyZXQ=",
|
||||
)
|
||||
.body(Full::<Bytes>::new("client_id=client-id&foo=bar".into()))
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(
|
||||
ClientAuthorization::<serde_json::Value>::from_request(&mut req)
|
||||
ClientAuthorization::<serde_json::Value>::from_request(req, &())
|
||||
.await
|
||||
.unwrap(),
|
||||
ClientAuthorization {
|
||||
@@ -469,62 +471,56 @@ mod tests {
|
||||
);
|
||||
|
||||
// client_id in both header and body mismatch
|
||||
let mut req = RequestParts::new(
|
||||
Request::builder()
|
||||
.method(Method::POST)
|
||||
.header(
|
||||
http::header::CONTENT_TYPE,
|
||||
mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
|
||||
)
|
||||
.header(
|
||||
http::header::AUTHORIZATION,
|
||||
"Basic Y2xpZW50LWlkOmNsaWVudC1zZWNyZXQ=",
|
||||
)
|
||||
.body(Full::<Bytes>::new("client_id=mismatch-id&foo=bar".into()))
|
||||
.unwrap(),
|
||||
);
|
||||
let req = Request::builder()
|
||||
.method(Method::POST)
|
||||
.header(
|
||||
http::header::CONTENT_TYPE,
|
||||
mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
|
||||
)
|
||||
.header(
|
||||
http::header::AUTHORIZATION,
|
||||
"Basic Y2xpZW50LWlkOmNsaWVudC1zZWNyZXQ=",
|
||||
)
|
||||
.body(Full::<Bytes>::new("client_id=mismatch-id&foo=bar".into()))
|
||||
.unwrap();
|
||||
|
||||
assert!(matches!(
|
||||
ClientAuthorization::<serde_json::Value>::from_request(&mut req).await,
|
||||
ClientAuthorization::<serde_json::Value>::from_request(req, &()).await,
|
||||
Err(ClientAuthorizationError::ClientIdMismatch { .. }),
|
||||
));
|
||||
|
||||
// Invalid header
|
||||
let mut req = RequestParts::new(
|
||||
Request::builder()
|
||||
.method(Method::POST)
|
||||
.header(
|
||||
http::header::CONTENT_TYPE,
|
||||
mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
|
||||
)
|
||||
.header(http::header::AUTHORIZATION, "Basic invalid")
|
||||
.body(Full::<Bytes>::new("foo=bar".into()))
|
||||
.unwrap(),
|
||||
);
|
||||
let req = Request::builder()
|
||||
.method(Method::POST)
|
||||
.header(
|
||||
http::header::CONTENT_TYPE,
|
||||
mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
|
||||
)
|
||||
.header(http::header::AUTHORIZATION, "Basic invalid")
|
||||
.body(Full::<Bytes>::new("foo=bar".into()))
|
||||
.unwrap();
|
||||
|
||||
assert!(matches!(
|
||||
ClientAuthorization::<serde_json::Value>::from_request(&mut req).await,
|
||||
ClientAuthorization::<serde_json::Value>::from_request(req, &()).await,
|
||||
Err(ClientAuthorizationError::InvalidHeader),
|
||||
));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn client_secret_post_test() {
|
||||
let mut req = RequestParts::new(
|
||||
Request::builder()
|
||||
.method(Method::POST)
|
||||
.header(
|
||||
http::header::CONTENT_TYPE,
|
||||
mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
|
||||
)
|
||||
.body(Full::<Bytes>::new(
|
||||
"client_id=client-id&client_secret=client-secret&foo=bar".into(),
|
||||
))
|
||||
.unwrap(),
|
||||
);
|
||||
let req = Request::builder()
|
||||
.method(Method::POST)
|
||||
.header(
|
||||
http::header::CONTENT_TYPE,
|
||||
mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
|
||||
)
|
||||
.body(Full::<Bytes>::new(
|
||||
"client_id=client-id&client_secret=client-secret&foo=bar".into(),
|
||||
))
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(
|
||||
ClientAuthorization::<serde_json::Value>::from_request(&mut req)
|
||||
ClientAuthorization::<serde_json::Value>::from_request(req, &())
|
||||
.await
|
||||
.unwrap(),
|
||||
ClientAuthorization {
|
||||
@@ -546,18 +542,16 @@ mod tests {
|
||||
JWT_BEARER_CLIENT_ASSERTION, jwt,
|
||||
));
|
||||
|
||||
let mut req = RequestParts::new(
|
||||
Request::builder()
|
||||
.method(Method::POST)
|
||||
.header(
|
||||
http::header::CONTENT_TYPE,
|
||||
mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
|
||||
)
|
||||
.body(Full::new(body))
|
||||
.unwrap(),
|
||||
);
|
||||
let req = Request::builder()
|
||||
.method(Method::POST)
|
||||
.header(
|
||||
http::header::CONTENT_TYPE,
|
||||
mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
|
||||
)
|
||||
.body(Full::new(body))
|
||||
.unwrap();
|
||||
|
||||
let authz = ClientAuthorization::<serde_json::Value>::from_request(&mut req)
|
||||
let authz = ClientAuthorization::<serde_json::Value>::from_request(req, &())
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(authz.form, Some(serde_json::json!({"foo": "bar"})));
|
||||
|
@@ -19,12 +19,13 @@ use axum::{
|
||||
body::HttpBody,
|
||||
extract::{
|
||||
rejection::{FailedToDeserializeQueryString, FormRejection, TypedHeaderRejectionReason},
|
||||
Form, FromRequest, TypedHeader,
|
||||
Form, FromRequest, FromRequestParts, TypedHeader,
|
||||
},
|
||||
response::{IntoResponse, Response},
|
||||
BoxError,
|
||||
};
|
||||
use headers::{authorization::Bearer, Authorization, Header, HeaderMapExt, HeaderName};
|
||||
use http::{header::WWW_AUTHENTICATE, HeaderMap, HeaderValue, StatusCode};
|
||||
use http::{header::WWW_AUTHENTICATE, HeaderMap, HeaderValue, Request, StatusCode};
|
||||
use mas_data_model::Session;
|
||||
use mas_storage::{
|
||||
oauth2::access_token::{lookup_active_access_token, AccessTokenLookupError},
|
||||
@@ -275,19 +276,20 @@ impl IntoResponse for AuthorizationVerificationError {
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<B, F> FromRequest<B> for UserAuthorization<F>
|
||||
impl<S, B, F> FromRequest<S, B> for UserAuthorization<F>
|
||||
where
|
||||
B: Send + HttpBody,
|
||||
B::Data: Send,
|
||||
B::Error: Error + Send + Sync + 'static,
|
||||
F: DeserializeOwned,
|
||||
B: HttpBody + Send + 'static,
|
||||
B::Data: Send,
|
||||
B::Error: Into<BoxError>,
|
||||
S: Send + Sync,
|
||||
{
|
||||
type Rejection = UserAuthorizationError;
|
||||
|
||||
async fn from_request(
|
||||
req: &mut axum::extract::RequestParts<B>,
|
||||
) -> Result<Self, Self::Rejection> {
|
||||
let header = TypedHeader::<Authorization<Bearer>>::from_request(req).await;
|
||||
async fn from_request(req: Request<B>, state: &S) -> Result<Self, Self::Rejection> {
|
||||
let (mut parts, body) = req.into_parts();
|
||||
let header =
|
||||
TypedHeader::<Authorization<Bearer>>::from_request_parts(&mut parts, state).await;
|
||||
|
||||
// Take the Authorization header
|
||||
let token_from_header = match header {
|
||||
@@ -300,18 +302,21 @@ where
|
||||
},
|
||||
};
|
||||
|
||||
let req = Request::from_parts(parts, body);
|
||||
|
||||
// Take the form value
|
||||
let (token_from_form, form) = match Form::<AuthorizedForm<F>>::from_request(req).await {
|
||||
Ok(Form(form)) => (form.access_token, Some(form.inner)),
|
||||
// If it is not a form, continue
|
||||
Err(FormRejection::InvalidFormContentType(_err)) => (None, None),
|
||||
// If the form could not be read, return a Bad Request error
|
||||
Err(FormRejection::FailedToDeserializeQueryString(err)) => {
|
||||
return Err(UserAuthorizationError::BadForm(err))
|
||||
}
|
||||
// Other errors (body read twice, byte stream broke) return an internal error
|
||||
Err(e) => return Err(UserAuthorizationError::InternalError(Box::new(e))),
|
||||
};
|
||||
let (token_from_form, form) =
|
||||
match Form::<AuthorizedForm<F>>::from_request(req, state).await {
|
||||
Ok(Form(form)) => (form.access_token, Some(form.inner)),
|
||||
// If it is not a form, continue
|
||||
Err(FormRejection::InvalidFormContentType(_err)) => (None, None),
|
||||
// If the form could not be read, return a Bad Request error
|
||||
Err(FormRejection::FailedToDeserializeQueryString(err)) => {
|
||||
return Err(UserAuthorizationError::BadForm(err))
|
||||
}
|
||||
// Other errors (body read twice, byte stream broke) return an internal error
|
||||
Err(e) => return Err(UserAuthorizationError::InternalError(Box::new(e))),
|
||||
};
|
||||
|
||||
let access_token = match (token_from_header, token_from_form) {
|
||||
// Ensure the token should not be in both the form and the access token
|
||||
|
Reference in New Issue
Block a user