1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-08-07 17:03:01 +03:00

Upgrade axum to 0.5

This commit is contained in:
Quentin Gliech
2022-04-06 16:32:28 +02:00
parent 4e31fc6c84
commit 31bc8504c9
17 changed files with 65 additions and 79 deletions

13
Cargo.lock generated
View File

@@ -488,9 +488,9 @@ dependencies = [
[[package]] [[package]]
name = "axum" name = "axum"
version = "0.4.8" version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c9f346c92c1e9a71d14fe4aaf7c2a5d9932cc4e5e48d8fb6641524416eb79ddd" checksum = "47594e438a243791dba58124b6669561f5baa14cb12046641d8008bf035e5a25"
dependencies = [ dependencies = [
"async-trait", "async-trait",
"axum-core", "axum-core",
@@ -501,6 +501,7 @@ dependencies = [
"http", "http",
"http-body", "http-body",
"hyper", "hyper",
"itoa 1.0.1",
"matchit", "matchit",
"memchr", "memchr",
"mime", "mime",
@@ -519,9 +520,9 @@ dependencies = [
[[package]] [[package]]
name = "axum-core" name = "axum-core"
version = "0.1.2" version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6dbcda393bef9c87572779cb8ef916f12d77750b27535dd6819fa86591627a51" checksum = "9a671c9ae99531afdd5d3ee8340b8da547779430689947144c140fc74a740244"
dependencies = [ dependencies = [
"async-trait", "async-trait",
"bytes 1.1.0", "bytes 1.1.0",
@@ -2282,9 +2283,9 @@ checksum = "a3e378b66a060d48947b590737b30a1be76706c8dd7b8ba0f2fe3989c68a853f"
[[package]] [[package]]
name = "matchit" name = "matchit"
version = "0.4.6" version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9376a4f0340565ad675d11fc1419227faf5f60cd7ac9cb2e7185a471f30af833" checksum = "73cbba799671b762df5a175adf59ce145165747bb891505c43d09aefbbf38beb"
[[package]] [[package]]
name = "md-5" name = "md-5"

View File

@@ -7,7 +7,7 @@ license = "Apache-2.0"
[dependencies] [dependencies]
async-trait = "0.1.52" async-trait = "0.1.52"
axum = { version = "0.4.8", features = ["headers"] } axum = { version = "0.5.1", features = ["headers"] }
bincode = "1.3.3" bincode = "1.3.3"
chrono = "0.4.19" chrono = "0.4.19"
cookie = { version = "0.16.0", features = ["signed", "private", "percent-encode"] } cookie = { version = "0.16.0", features = ["signed", "private", "percent-encode"] }

View File

@@ -223,9 +223,7 @@ where
// If it's missing it is fine // If it's missing it is fine
TypedHeaderRejectionReason::Missing => None, TypedHeaderRejectionReason::Missing => None,
// If the header could not be parsed, return the error // If the header could not be parsed, return the error
TypedHeaderRejectionReason::Error(_) => { _ => return Err(ClientAuthorizationError::InvalidHeader),
return Err(ClientAuthorizationError::InvalidHeader)
}
}, },
}; };

View File

@@ -14,10 +14,13 @@
//! Private (encrypted) cookie jar, based on axum-extra's cookie jar //! Private (encrypted) cookie jar, based on axum-extra's cookie jar
use std::marker::PhantomData; use std::{convert::Infallible, marker::PhantomData};
use async_trait::async_trait; use async_trait::async_trait;
use axum::extract::{Extension, FromRequest, RequestParts}; use axum::{
extract::{Extension, FromRequest, RequestParts},
response::IntoResponseParts,
};
pub use cookie::Cookie; pub use cookie::Cookie;
use data_encoding::BASE64URL_NOPAD; use data_encoding::BASE64URL_NOPAD;
use headers::HeaderMap; use headers::HeaderMap;
@@ -68,12 +71,6 @@ impl<K> PrivateCookieJar<K> {
} }
} }
} }
pub fn headers(self) -> HeaderMap {
let mut headers = HeaderMap::new();
self.set_cookies(&mut headers);
headers
}
} }
#[async_trait] #[async_trait]
@@ -91,13 +88,8 @@ where
let mut jar = cookie::CookieJar::new(); let mut jar = cookie::CookieJar::new();
let mut private_jar = jar.private_mut(&key); let mut private_jar = jar.private_mut(&key);
// TODO: remove this when axum 0.5 gets released
// https://github.com/tokio-rs/axum/pull/698
let empty_headers = HeaderMap::new();
let cookies = req let cookies = req
.headers() .headers()
.unwrap_or(&empty_headers)
.get_all(COOKIE) .get_all(COOKIE)
.into_iter() .into_iter()
.filter_map(|value| value.to_str().ok()) .filter_map(|value| value.to_str().ok())
@@ -118,6 +110,17 @@ where
} }
} }
impl<K> IntoResponseParts for PrivateCookieJar<K> {
type Error = Infallible;
fn into_response_parts(
self,
mut res: axum::response::ResponseParts,
) -> Result<axum::response::ResponseParts, Self::Error> {
self.set_cookies(res.headers_mut());
Ok(res)
}
}
#[derive(Debug, Error)] #[derive(Debug, Error)]
#[error("could not decode cookie")] #[error("could not decode cookie")]
pub enum CookieDecodeError { pub enum CookieDecodeError {

View File

@@ -287,9 +287,7 @@ where
// If it's missing it is fine // If it's missing it is fine
TypedHeaderRejectionReason::Missing => None, TypedHeaderRejectionReason::Missing => None,
// If the header could not be parsed, return the error // If the header could not be parsed, return the error
TypedHeaderRejectionReason::Error(_) => { _ => return Err(UserAuthorizationError::InvalidHeader),
return Err(UserAuthorizationError::InvalidHeader)
}
}, },
}; };

View File

@@ -22,7 +22,7 @@ anyhow = "1.0.56"
# Web server # Web server
hyper = { version = "0.14.17", features = ["full"] } hyper = { version = "0.14.17", features = ["full"] }
tower = "0.4.12" tower = "0.4.12"
axum = "0.4.8" axum = "0.5.1"
axum-macros = "0.2.0" axum-macros = "0.2.0"
# Emails # Emails

View File

@@ -149,12 +149,8 @@ where
.context("could not serialize redirect URI query params")?; .context("could not serialize redirect URI query params")?;
redirect_uri.set_query(Some(&new_qs)); redirect_uri.set_query(Some(&new_qs));
let redirect_uri = redirect_uri
.as_str()
.parse()
.context("could not convert redirect URI")?;
Ok(Redirect::to(redirect_uri).into_response()) Ok(Redirect::to(redirect_uri.as_str()).into_response())
} }
ResponseMode::Fragment => { ResponseMode::Fragment => {
let existing: Option<HashMap<&str, &str>> = redirect_uri let existing: Option<HashMap<&str, &str>> = redirect_uri
@@ -173,12 +169,8 @@ where
.context("could not serialize redirect URI fragment params")?; .context("could not serialize redirect URI fragment params")?;
redirect_uri.set_fragment(Some(&new_qs)); redirect_uri.set_fragment(Some(&new_qs));
let redirect_uri = redirect_uri
.as_str()
.parse()
.context("could not convert redirect URI")?;
Ok(Redirect::to(redirect_uri).into_response()) Ok(Redirect::to(redirect_uri.as_str()).into_response())
} }
ResponseMode::FormPost => { ResponseMode::FormPost => {
let merged = ParamsWithState { state, params }; let merged = ParamsWithState { state, params };
@@ -389,7 +381,7 @@ pub(crate) async fn get(
let next: ReauthRequest = next.into(); let next: ReauthRequest = next.into();
let next = next.build_uri()?; let next = next.build_uri()?;
Ok(Redirect::to(next).into_response()) Ok(Redirect::to(&next.to_string()).into_response())
} }
(Some(user_session), _) => { (Some(user_session), _) => {
// Other cases where we already have a session // Other cases where we already have a session
@@ -403,7 +395,7 @@ pub(crate) async fn get(
let next: RegisterRequest = next.into(); let next: RegisterRequest = next.into();
let next = next.build_uri()?; let next = next.build_uri()?;
Ok(Redirect::to(next).into_response()) Ok(Redirect::to(&next.to_string()).into_response())
} }
(None, _) => { (None, _) => {
// Other cases where we don't have a session, ask for a login // Other cases where we don't have a session, ask for a login
@@ -413,7 +405,7 @@ pub(crate) async fn get(
let next: LoginRequest = next.into(); let next: LoginRequest = next.into();
let next = next.build_uri()?; let next = next.build_uri()?;
Ok(Redirect::to(next).into_response()) Ok(Redirect::to(&next.to_string()).into_response())
} }
} }
}) })
@@ -424,7 +416,7 @@ pub(crate) async fn get(
Err(_e) => StatusCode::INTERNAL_SERVER_ERROR.into_response(), Err(_e) => StatusCode::INTERNAL_SERVER_ERROR.into_response(),
}; };
Ok((cookie_jar.headers(), response).into_response()) Ok((cookie_jar, response).into_response())
} }
#[derive(Serialize, Deserialize, Clone)] #[derive(Serialize, Deserialize, Clone)]
@@ -486,7 +478,7 @@ pub(crate) async fn step_get(
let next: PostAuthAction = next.into(); let next: PostAuthAction = next.into();
let login: LoginRequest = next.into(); let login: LoginRequest = next.into();
let login = login.build_uri()?; let login = login.build_uri()?;
return Ok((cookie_jar.headers(), Redirect::to(login)).into_response()); return Ok((cookie_jar, Redirect::to(&login.to_string())).into_response());
}; };
step(next, session, txn, &templates).await step(next, session, txn, &templates).await
@@ -565,7 +557,7 @@ async fn step(
let next: ReauthRequest = next.into(); let next: ReauthRequest = next.into();
let next = next.build_uri()?; let next = next.build_uri()?;
Redirect::to(next).into_response() Redirect::to(&next.to_string()).into_response()
} }
}; };

View File

@@ -70,7 +70,7 @@ pub(crate) async fn get(
} else { } else {
let login = LoginRequest::default(); let login = LoginRequest::default();
let login = login.build_uri().map_err(fancy_error(templates.clone()))?; let login = login.build_uri().map_err(fancy_error(templates.clone()))?;
Ok((cookie_jar.headers(), Redirect::to(login)).into_response()) Ok((cookie_jar, Redirect::to(&login.to_string())).into_response())
} }
} }
@@ -95,7 +95,7 @@ async fn render(
.await .await
.map_err(fancy_error(templates))?; .map_err(fancy_error(templates))?;
Ok((cookie_jar.headers(), Html(content)).into_response()) Ok((cookie_jar, Html(content)).into_response())
} }
async fn start_email_verification( async fn start_email_verification(
@@ -151,7 +151,7 @@ pub(crate) async fn post(
} else { } else {
let login = LoginRequest::default(); let login = LoginRequest::default();
let login = login.build_uri().map_err(fancy_error(templates.clone()))?; let login = login.build_uri().map_err(fancy_error(templates.clone()))?;
return Ok((cookie_jar.headers(), Redirect::to(login)).into_response()); return Ok((cookie_jar, Redirect::to(&login.to_string())).into_response());
}; };
let form = cookie_jar let form = cookie_jar

View File

@@ -50,7 +50,7 @@ pub(crate) async fn get(
} else { } else {
let login = LoginRequest::default(); let login = LoginRequest::default();
let login = login.build_uri().map_err(fancy_error(templates.clone()))?; let login = login.build_uri().map_err(fancy_error(templates.clone()))?;
return Ok((cookie_jar.headers(), Redirect::to(login)).into_response()); return Ok((cookie_jar, Redirect::to(&login.to_string())).into_response());
}; };
let active_sessions = count_active_sessions(&mut conn, &session.user) let active_sessions = count_active_sessions(&mut conn, &session.user)
@@ -70,5 +70,5 @@ pub(crate) async fn get(
.await .await
.map_err(fancy_error(templates.clone()))?; .map_err(fancy_error(templates.clone()))?;
Ok((cookie_jar.headers(), Html(content)).into_response()) Ok((cookie_jar, Html(content)).into_response())
} }

View File

@@ -62,7 +62,7 @@ pub(crate) async fn get(
} else { } else {
let login = LoginRequest::default(); let login = LoginRequest::default();
let login = login.build_uri().map_err(fancy_error(templates.clone()))?; let login = login.build_uri().map_err(fancy_error(templates.clone()))?;
Ok((cookie_jar.headers(), Redirect::to(login)).into_response()) Ok((cookie_jar, Redirect::to(&login.to_string())).into_response())
} }
} }
@@ -82,7 +82,7 @@ async fn render(
.await .await
.map_err(fancy_error(templates))?; .map_err(fancy_error(templates))?;
Ok((cookie_jar.headers(), Html(content)).into_response()) Ok((cookie_jar, Html(content)).into_response())
} }
pub(crate) async fn post( pub(crate) async fn post(
@@ -109,7 +109,7 @@ pub(crate) async fn post(
} else { } else {
let login = LoginRequest::default(); let login = LoginRequest::default();
let login = login.build_uri().map_err(fancy_error(templates.clone()))?; let login = login.build_uri().map_err(fancy_error(templates.clone()))?;
return Ok((cookie_jar.headers(), Redirect::to(login)).into_response()); return Ok((cookie_jar, Redirect::to(&login.to_string())).into_response());
}; };
authenticate_session(&mut txn, &mut session, form.current_password) authenticate_session(&mut txn, &mut session, form.current_password)

View File

@@ -50,5 +50,5 @@ pub async fn get(
.await .await
.map_err(fancy_error(templates))?; .map_err(fancy_error(templates))?;
Ok((cookie_jar.headers(), Html(content))) Ok((cookie_jar, Html(content)))
} }

View File

@@ -71,7 +71,7 @@ impl LoginRequest {
Uri::from_static("/") Uri::from_static("/")
}; };
Ok(Redirect::to(uri)) Ok(Redirect::to(&uri.to_string()))
} }
} }
@@ -129,7 +129,7 @@ pub(crate) async fn get(
.await .await
.map_err(fancy_error(templates.clone()))?; .map_err(fancy_error(templates.clone()))?;
Ok((cookie_jar.headers(), Html(content)).into_response()) Ok((cookie_jar, Html(content)).into_response())
} }
} }
@@ -157,7 +157,7 @@ pub(crate) async fn post(
Ok(session_info) => { Ok(session_info) => {
let cookie_jar = cookie_jar.set_session(&session_info); let cookie_jar = cookie_jar.set_session(&session_info);
let reply = query.redirect().map_err(fancy_error(templates.clone()))?; let reply = query.redirect().map_err(fancy_error(templates.clone()))?;
Ok((cookie_jar.headers(), reply).into_response()) Ok((cookie_jar, reply).into_response())
} }
Err(e) => { Err(e) => {
let errored_form = match e { let errored_form = match e {
@@ -174,7 +174,7 @@ pub(crate) async fn post(
.await .await
.map_err(fancy_error(templates.clone()))?; .map_err(fancy_error(templates.clone()))?;
Ok((cookie_jar.headers(), Html(content)).into_response()) Ok((cookie_jar, Html(content)).into_response())
} }
} }
} }

View File

@@ -16,7 +16,6 @@ use axum::{
extract::{Extension, Form}, extract::{Extension, Form},
response::{IntoResponse, Redirect}, response::{IntoResponse, Redirect},
}; };
use hyper::Uri;
use mas_axum_utils::{ use mas_axum_utils::{
csrf::{CsrfExt, ProtectedForm}, csrf::{CsrfExt, ProtectedForm},
fancy_error, FancyError, PrivateCookieJar, SessionInfoExt, fancy_error, FancyError, PrivateCookieJar, SessionInfoExt,
@@ -54,6 +53,5 @@ pub(crate) async fn post(
txn.commit().await.map_err(fancy_error(templates))?; txn.commit().await.map_err(fancy_error(templates))?;
let to = Uri::from_static("/login"); Ok((cookie_jar, Redirect::to("/login")))
Ok((cookie_jar.headers(), Redirect::to(to)))
} }

View File

@@ -63,13 +63,11 @@ impl ReauthRequest {
} }
fn redirect(self) -> Result<impl IntoResponse, anyhow::Error> { fn redirect(self) -> Result<impl IntoResponse, anyhow::Error> {
let uri = if let Some(action) = self.post_auth_action { if let Some(action) = self.post_auth_action {
action.build_uri()? Ok(Redirect::to(&action.build_uri()?.to_string()))
} else { } else {
Uri::from_static("/") Ok(Redirect::to("/"))
}; }
Ok(Redirect::to(uri))
} }
} }
@@ -104,7 +102,7 @@ pub(crate) async fn get(
// PostAuthAction // PostAuthAction
let login: LoginRequest = query.post_auth_action.into(); let login: LoginRequest = query.post_auth_action.into();
let login = login.build_uri().map_err(fancy_error(templates.clone()))?; let login = login.build_uri().map_err(fancy_error(templates.clone()))?;
return Ok((cookie_jar.headers(), Redirect::to(login)).into_response()); return Ok((cookie_jar, Redirect::to(&login.to_string())).into_response());
}; };
let ctx = ReauthContext::default(); let ctx = ReauthContext::default();
@@ -125,7 +123,7 @@ pub(crate) async fn get(
.await .await
.map_err(fancy_error(templates.clone()))?; .map_err(fancy_error(templates.clone()))?;
Ok((cookie_jar.headers(), Html(content)).into_response()) Ok((cookie_jar, Html(content)).into_response())
} }
pub(crate) async fn post( pub(crate) async fn post(
@@ -155,7 +153,7 @@ pub(crate) async fn post(
// PostAuthAction // PostAuthAction
let login: LoginRequest = query.post_auth_action.into(); let login: LoginRequest = query.post_auth_action.into();
let login = login.build_uri().map_err(fancy_error(templates.clone()))?; let login = login.build_uri().map_err(fancy_error(templates.clone()))?;
return Ok((cookie_jar.headers(), Redirect::to(login)).into_response()); return Ok((cookie_jar, Redirect::to(&login.to_string())).into_response());
}; };
// TODO: recover from errors here // TODO: recover from errors here
@@ -166,5 +164,5 @@ pub(crate) async fn post(
txn.commit().await.map_err(fancy_error(templates.clone()))?; txn.commit().await.map_err(fancy_error(templates.clone()))?;
let redirection = query.redirect().map_err(fancy_error(templates.clone()))?; let redirection = query.redirect().map_err(fancy_error(templates.clone()))?;
Ok((cookie_jar.headers(), redirection).into_response()) Ok((cookie_jar, redirection).into_response())
} }

View File

@@ -64,13 +64,11 @@ impl RegisterRequest {
} }
fn redirect(self) -> Result<impl IntoResponse, anyhow::Error> { fn redirect(self) -> Result<impl IntoResponse, anyhow::Error> {
let uri = if let Some(action) = self.post_auth_action { if let Some(action) = self.post_auth_action {
action.build_uri()? Ok(Redirect::to(&action.build_uri()?.to_string()))
} else { } else {
Uri::from_static("/") Ok(Redirect::to("/"))
}; }
Ok(Redirect::to(uri))
} }
} }
@@ -129,7 +127,7 @@ pub(crate) async fn get(
.await .await
.map_err(fancy_error(templates.clone()))?; .map_err(fancy_error(templates.clone()))?;
Ok((cookie_jar.headers(), Html(content)).into_response()) Ok((cookie_jar, Html(content)).into_response())
} }
} }
@@ -164,5 +162,5 @@ pub(crate) async fn post(
let cookie_jar = cookie_jar.set_session(&session); let cookie_jar = cookie_jar.set_session(&session);
let reply = query.redirect().map_err(fancy_error(templates.clone()))?; let reply = query.redirect().map_err(fancy_error(templates.clone()))?;
Ok((cookie_jar.headers(), reply).into_response()) Ok((cookie_jar, reply).into_response())
} }

View File

@@ -66,5 +66,5 @@ pub(crate) async fn get(
txn.commit().await.map_err(fancy_error(templates.clone()))?; txn.commit().await.map_err(fancy_error(templates.clone()))?;
Ok((cookie_jar.headers(), Html(content))) Ok((cookie_jar, Html(content)))
} }

View File

@@ -9,7 +9,7 @@ license = "Apache-2.0"
dev = [] dev = []
[dependencies] [dependencies]
axum = "0.4.8" axum = "0.5.1"
headers = "0.3.7" headers = "0.3.7"
http = "0.2.6" http = "0.2.6"
mime_guess = "2.0.4" mime_guess = "2.0.4"