1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-29 22:01:14 +03:00

PKCE support

This commit is contained in:
Quentin Gliech
2021-10-05 14:08:21 +02:00
parent af71adbe7a
commit 8ecdf7c6c8
12 changed files with 168 additions and 57 deletions

2
Cargo.lock generated
View File

@ -1623,6 +1623,7 @@ name = "oauth2-types"
version = "0.1.0" version = "0.1.0"
dependencies = [ dependencies = [
"chrono", "chrono",
"data-encoding",
"http", "http",
"indoc", "indoc",
"language-tags", "language-tags",
@ -1630,6 +1631,7 @@ dependencies = [
"serde", "serde",
"serde_json", "serde_json",
"serde_with", "serde_with",
"sha2",
"sqlx", "sqlx",
"url", "url",
] ]

View File

@ -1,3 +1,17 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
fn main() { fn main() {
// trigger recompilation when a new migration is added // trigger recompilation when a new migration is added
println!("cargo:rerun-if-changed=migrations"); println!("cargo:rerun-if-changed=migrations");

View File

@ -381,56 +381,6 @@
] ]
} }
}, },
"886dee6a6f1f426f0e891790bbeffbc222fd75d8da0a107e7de673f1cc445f30": {
"query": "\n SELECT\n oc.id,\n os.id AS \"oauth2_session_id!\",\n os.client_id AS \"client_id!\",\n os.redirect_uri,\n os.scope AS \"scope!\",\n os.nonce\n FROM oauth2_codes oc\n INNER JOIN oauth2_sessions os\n ON os.id = oc.oauth2_session_id\n WHERE oc.code = $1\n ",
"describe": {
"columns": [
{
"ordinal": 0,
"name": "id",
"type_info": "Int8"
},
{
"ordinal": 1,
"name": "oauth2_session_id!",
"type_info": "Int8"
},
{
"ordinal": 2,
"name": "client_id!",
"type_info": "Text"
},
{
"ordinal": 3,
"name": "redirect_uri",
"type_info": "Text"
},
{
"ordinal": 4,
"name": "scope!",
"type_info": "Text"
},
{
"ordinal": 5,
"name": "nonce",
"type_info": "Text"
}
],
"parameters": {
"Left": [
"Text"
]
},
"nullable": [
false,
false,
false,
false,
false,
true
]
}
},
"88ac8783bd5881c42eafd9cf87a16fe6031f3153fd6a8618e689694584aeb2de": { "88ac8783bd5881c42eafd9cf87a16fe6031f3153fd6a8618e689694584aeb2de": {
"query": "\n DELETE FROM oauth2_access_tokens\n WHERE id = $1\n ", "query": "\n DELETE FROM oauth2_access_tokens\n WHERE id = $1\n ",
"describe": { "describe": {
@ -673,6 +623,68 @@
"nullable": [] "nullable": []
} }
}, },
"eb5f772a7387de0dc2f9f660f470476c075da097134a8ded226eb630545c16eb": {
"query": "\n SELECT\n oc.id,\n oc.code_challenge,\n oc.code_challenge_method,\n os.id AS \"oauth2_session_id!\",\n os.client_id AS \"client_id!\",\n os.redirect_uri,\n os.scope AS \"scope!\",\n os.nonce\n FROM oauth2_codes oc\n INNER JOIN oauth2_sessions os\n ON os.id = oc.oauth2_session_id\n WHERE oc.code = $1\n ",
"describe": {
"columns": [
{
"ordinal": 0,
"name": "id",
"type_info": "Int8"
},
{
"ordinal": 1,
"name": "code_challenge",
"type_info": "Text"
},
{
"ordinal": 2,
"name": "code_challenge_method",
"type_info": "Int2"
},
{
"ordinal": 3,
"name": "oauth2_session_id!",
"type_info": "Int8"
},
{
"ordinal": 4,
"name": "client_id!",
"type_info": "Text"
},
{
"ordinal": 5,
"name": "redirect_uri",
"type_info": "Text"
},
{
"ordinal": 6,
"name": "scope!",
"type_info": "Text"
},
{
"ordinal": 7,
"name": "nonce",
"type_info": "Text"
}
],
"parameters": {
"Left": [
"Text"
]
},
"nullable": [
false,
true,
true,
false,
false,
false,
false,
true
]
}
},
"f9a09ff53b6f221649f4f050e3d5ade114f852ddf50a78610a6c0ef0689af681": { "f9a09ff53b6f221649f4f050e3d5ade114f852ddf50a78610a6c0ef0689af681": {
"query": "\n INSERT INTO users (username, hashed_password)\n VALUES ($1, $2)\n RETURNING id\n ", "query": "\n INSERT INTO users (username, hashed_password)\n VALUES ($1, $2)\n RETURNING id\n ",
"describe": { "describe": {

View File

@ -168,7 +168,7 @@ struct Params {
auth: AuthorizationRequest, auth: AuthorizationRequest,
#[serde(flatten)] #[serde(flatten)]
pkce: Option<pkce::Request>, pkce: Option<pkce::AuthorizationRequest>,
} }
/// Given a list of response types and an optional user-defined response mode, /// Given a list of response types and an optional user-defined response mode,
@ -349,7 +349,7 @@ async fn get(
.add_code(&mut txn, &code, &params.pkce) .add_code(&mut txn, &code, &params.pkce)
.await .await
.wrap_error()?; .wrap_error()?;
}; }
// Do we already have a user session for this oauth2 session? // Do we already have a user session for this oauth2 session?
let user_session = oauth2_session.fetch_session(&mut txn).await.wrap_error()?; let user_session = oauth2_session.fetch_session(&mut txn).await.wrap_error()?;

View File

@ -16,6 +16,7 @@ use std::collections::HashSet;
use oauth2_types::{ use oauth2_types::{
oidc::Metadata, oidc::Metadata,
pkce::CodeChallengeMethod,
requests::{ClientAuthenticationMethod, GrantType, ResponseMode}, requests::{ClientAuthenticationMethod, GrantType, ResponseMode},
}; };
use warp::{Filter, Rejection, Reply}; use warp::{Filter, Rejection, Reply};
@ -62,6 +63,13 @@ pub(super) fn filter(
s s
}); });
let code_challenge_methods_supported = Some({
let mut s = HashSet::new();
s.insert(CodeChallengeMethod::Plain);
s.insert(CodeChallengeMethod::S256);
s
});
let metadata = Metadata { let metadata = Metadata {
authorization_endpoint: base.join("oauth2/authorize").ok(), authorization_endpoint: base.join("oauth2/authorize").ok(),
token_endpoint: base.join("oauth2/token").ok(), token_endpoint: base.join("oauth2/token").ok(),
@ -75,7 +83,7 @@ pub(super) fn filter(
response_modes_supported, response_modes_supported,
grant_types_supported, grant_types_supported,
token_endpoint_auth_methods_supported, token_endpoint_auth_methods_supported,
code_challenge_methods_supported: None, code_challenge_methods_supported,
}; };
let cors = warp::cors().allow_any_origin(); let cors = warp::cors().allow_any_origin();

View File

@ -19,7 +19,10 @@ use headers::{CacheControl, Pragma};
use hyper::StatusCode; use hyper::StatusCode;
use jwt_compact::{Claims, Header, TimeOptions}; use jwt_compact::{Claims, Header, TimeOptions};
use oauth2_types::{ use oauth2_types::{
errors::{InvalidGrant, OAuth2Error, OAuth2ErrorCode, UnauthorizedClient}, errors::{
InvalidGrant, InvalidRequest, OAuth2Error, OAuth2ErrorCode, ServerError, UnauthorizedClient,
},
pkce::CodeChallengeMethod,
requests::{ requests::{
AccessTokenRequest, AccessTokenResponse, AuthorizationCodeGrant, RefreshTokenGrant, AccessTokenRequest, AccessTokenResponse, AuthorizationCodeGrant, RefreshTokenGrant,
}, },
@ -166,6 +169,34 @@ async fn authorization_code_grant(
return error(UnauthorizedClient); return error(UnauthorizedClient);
} }
match (
code.code_challenge_method.as_ref(),
code.code_challenge.as_ref(),
grant.code_verifier.as_ref(),
) {
(None, None, None) => {}
// We have a challenge but no verifier (or vice-versa)? Bad request.
(Some(_), Some(_), None) | (None, None, Some(_)) => return error(InvalidRequest),
(Some(0 /* Plain */), Some(code_challenge), Some(code_verifier)) => {
if !CodeChallengeMethod::Plain.verify(code_challenge, code_verifier) {
return error(InvalidRequest);
}
}
(Some(1 /* S256 */), Some(code_challenge), Some(code_verifier)) => {
if !CodeChallengeMethod::S256.verify(code_challenge, code_verifier) {
return error(InvalidRequest);
}
}
// We have something else?
// That's a DB inconcistancy, we should bail out
_ => {
// TODO: are we sure we want to handle errors like that?
tracing::error!("Invalid state from the database");
return error(ServerError); // Somthing bad happened in the database
}
};
// TODO: verify PKCE // TODO: verify PKCE
let ttl = Duration::minutes(5); let ttl = Duration::minutes(5);
let (access_token, refresh_token) = { let (access_token, refresh_token) = {

View File

@ -32,7 +32,7 @@ pub async fn add_code(
executor: impl Executor<'_, Database = Postgres>, executor: impl Executor<'_, Database = Postgres>,
oauth2_session_id: i64, oauth2_session_id: i64,
code: &str, code: &str,
code_challenge: &Option<pkce::Request>, code_challenge: &Option<pkce::AuthorizationRequest>,
) -> anyhow::Result<OAuth2Code> { ) -> anyhow::Result<OAuth2Code> {
let code_challenge_method = code_challenge let code_challenge_method = code_challenge
.as_ref() .as_ref()
@ -65,6 +65,8 @@ pub struct OAuth2CodeLookup {
pub redirect_uri: String, pub redirect_uri: String,
pub scope: String, pub scope: String,
pub nonce: Option<String>, pub nonce: Option<String>,
pub code_challenge: Option<String>,
pub code_challenge_method: Option<i16>,
} }
#[derive(Debug, Error)] #[derive(Debug, Error)]
@ -84,11 +86,14 @@ pub async fn lookup_code(
executor: impl Executor<'_, Database = Postgres>, executor: impl Executor<'_, Database = Postgres>,
code: &str, code: &str,
) -> Result<OAuth2CodeLookup, CodeLookupError> { ) -> Result<OAuth2CodeLookup, CodeLookupError> {
// TODO: this should return a better type
let res = sqlx::query_as!( let res = sqlx::query_as!(
OAuth2CodeLookup, OAuth2CodeLookup,
r#" r#"
SELECT SELECT
oc.id, oc.id,
oc.code_challenge,
oc.code_challenge_method,
os.id AS "oauth2_session_id!", os.id AS "oauth2_session_id!",
os.client_id AS "client_id!", os.client_id AS "client_id!",
os.redirect_uri, os.redirect_uri,

View File

@ -52,7 +52,7 @@ impl OAuth2Session {
&self, &self,
executor: impl Executor<'e, Database = Postgres>, executor: impl Executor<'e, Database = Postgres>,
code: &str, code: &str,
code_challenge: &Option<pkce::Request>, code_challenge: &Option<pkce::AuthorizationRequest>,
) -> anyhow::Result<OAuth2Code> { ) -> anyhow::Result<OAuth2Code> {
add_code(executor, self.id, code, code_challenge).await add_code(executor, self.id, code, code_challenge).await
} }

View File

@ -16,6 +16,8 @@ indoc = "1.0.3"
serde_with = { version = "1.10.0", features = ["chrono"] } serde_with = { version = "1.10.0", features = ["chrono"] }
sqlx = { version = "0.5.9", default-features = false, optional = true } sqlx = { version = "0.5.9", default-features = false, optional = true }
chrono = "0.4.19" chrono = "0.4.19"
sha2 = "0.9.8"
data-encoding = "2.3.2"
[features] [features]
sqlx_type = ["sqlx"] sqlx_type = ["sqlx"]

View File

@ -237,6 +237,7 @@ pub mod rfc6749 {
oauth2_error! { oauth2_error! {
ServerError, ServerError,
code: INTERNAL_SERVER_ERROR,
"server_error" => "server_error" =>
"The authorization server encountered an unexpected \ "The authorization server encountered an unexpected \
condition that prevented it from fulfilling the request." condition that prevented it from fulfilling the request."

View File

@ -12,8 +12,12 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use std::borrow::Cow;
use data_encoding::BASE64URL_NOPAD;
use parse_display::{Display, FromStr}; use parse_display::{Display, FromStr};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
#[derive( #[derive(
Debug, Debug,
@ -41,8 +45,34 @@ pub enum CodeChallengeMethod {
S256 = 1, S256 = 1,
} }
impl CodeChallengeMethod {
#[must_use]
pub fn compute_challenge(self, verifier: &str) -> Cow<'_, str> {
match self {
CodeChallengeMethod::Plain => verifier.into(),
CodeChallengeMethod::S256 => {
let mut hasher = Sha256::new();
hasher.update(verifier.as_bytes());
let hash = hasher.finalize();
let verifier = BASE64URL_NOPAD.encode(&hash);
verifier.into()
}
}
}
#[must_use]
pub fn verify(self, challenge: &str, verifier: &str) -> bool {
self.compute_challenge(verifier) == challenge
}
}
#[derive(Serialize, Deserialize)] #[derive(Serialize, Deserialize)]
pub struct Request { pub struct AuthorizationRequest {
pub code_challenge_method: CodeChallengeMethod, pub code_challenge_method: CodeChallengeMethod,
pub code_challenge: String, pub code_challenge: String,
} }
#[derive(Serialize, Deserialize)]
pub struct TokenRequest {
pub code_challenge_verifier: String,
}

View File

@ -200,11 +200,16 @@ pub enum TokenType {
Bearer, Bearer,
} }
#[skip_serializing_none]
#[derive(Serialize, Deserialize, Debug, PartialEq)] #[derive(Serialize, Deserialize, Debug, PartialEq)]
pub struct AuthorizationCodeGrant { pub struct AuthorizationCodeGrant {
pub code: String, pub code: String,
#[serde(default)] #[serde(default)]
pub redirect_uri: Option<Url>, pub redirect_uri: Option<Url>,
// TODO: move this somehow in the pkce module
#[serde(default)]
pub code_verifier: Option<String>,
} }
#[serde_as] #[serde_as]
@ -406,6 +411,7 @@ mod tests {
let req = AccessTokenRequest::AuthorizationCode(AuthorizationCodeGrant { let req = AccessTokenRequest::AuthorizationCode(AuthorizationCodeGrant {
code: "abcd".into(), code: "abcd".into(),
redirect_uri: Some("https://example.com/redirect".parse().unwrap()), redirect_uri: Some("https://example.com/redirect".parse().unwrap()),
code_verifier: None,
}); });
assert_serde_json(&req, expected); assert_serde_json(&req, expected);