From 50654d2e40d68ea3dd02f9eb46d1b2a714d77a54 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Thu, 7 Dec 2023 17:59:35 +0100 Subject: [PATCH] Implement the device code authorisation request --- crates/handlers/src/lib.rs | 4 + .../handlers/src/oauth2/device/authorize.rs | 208 ++++++++++++++++++ crates/handlers/src/oauth2/device/mod.rs | 1 + crates/handlers/src/oauth2/discovery.rs | 2 + crates/oauth2-types/src/requests.rs | 12 +- crates/router/src/endpoints.rs | 15 ++ crates/router/src/url_builder.rs | 18 ++ .../src/oauth2/device_code_grant.rs | 4 +- 8 files changed, 256 insertions(+), 8 deletions(-) create mode 100644 crates/handlers/src/oauth2/device/authorize.rs diff --git a/crates/handlers/src/lib.rs b/crates/handlers/src/lib.rs index 4a62a187..86d18749 100644 --- a/crates/handlers/src/lib.rs +++ b/crates/handlers/src/lib.rs @@ -225,6 +225,10 @@ where mas_router::OAuth2RegistrationEndpoint::route(), post(self::oauth2::registration::post), ) + .route( + mas_router::OAuth2DeviceAuthorizationEndpoint::route(), + post(self::oauth2::device::authorize::post), + ) .layer( CorsLayer::new() .allow_origin(Any) diff --git a/crates/handlers/src/oauth2/device/authorize.rs b/crates/handlers/src/oauth2/device/authorize.rs new file mode 100644 index 00000000..7564bb98 --- /dev/null +++ b/crates/handlers/src/oauth2/device/authorize.rs @@ -0,0 +1,208 @@ +// Copyright 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use axum::{extract::State, response::IntoResponse, Json, TypedHeader}; +use chrono::Duration; +use headers::{CacheControl, Pragma}; +use hyper::StatusCode; +use mas_axum_utils::{ + client_authorization::{ClientAuthorization, CredentialsVerificationError}, + http_client_factory::HttpClientFactory, + sentry::SentryEventID, +}; +use mas_keystore::Encrypter; +use mas_router::UrlBuilder; +use mas_storage::{oauth2::OAuth2DeviceCodeGrantParams, BoxClock, BoxRepository, BoxRng}; +use oauth2_types::{ + errors::{ClientError, ClientErrorCode}, + requests::{DeviceAuthorizationRequest, DeviceAuthorizationResponse}, + scope::ScopeToken, +}; +use rand::distributions::{Alphanumeric, DistString}; +use thiserror::Error; + +use crate::impl_from_error_for_route; + +#[derive(Debug, Error)] +pub(crate) enum RouteError { + #[error(transparent)] + Internal(Box), + + #[error("client not found")] + ClientNotFound, + + #[error("client not allowed")] + ClientNotAllowed, + + #[error("could not verify client credentials")] + ClientCredentialsVerification(#[from] CredentialsVerificationError), +} + +impl_from_error_for_route!(mas_storage::RepositoryError); + +impl IntoResponse for RouteError { + fn into_response(self) -> axum::response::Response { + let event_id = sentry::capture_error(&self); + + let response = match self { + Self::Internal(_) => ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ClientError::from(ClientErrorCode::ServerError)), + ), + Self::ClientNotFound | Self::ClientCredentialsVerification(_) => ( + StatusCode::UNAUTHORIZED, + Json(ClientError::from(ClientErrorCode::InvalidClient)), + ), + Self::ClientNotAllowed => ( + StatusCode::UNAUTHORIZED, + Json(ClientError::from(ClientErrorCode::UnauthorizedClient)), + ), + }; + + (SentryEventID::from(event_id), response).into_response() + } +} + +#[tracing::instrument( + name = "handlers.oauth2.device.request.post", + fields(client.id = client_authorization.client_id()), + skip_all, + err, +)] +pub(crate) async fn post( + mut rng: BoxRng, + clock: BoxClock, + mut repo: BoxRepository, + State(url_builder): State, + State(http_client_factory): State, + State(encrypter): State, + client_authorization: ClientAuthorization, +) -> Result { + let client = client_authorization + .credentials + .fetch(&mut repo) + .await? + .ok_or(RouteError::ClientNotFound)?; + + // Reuse the token endpoint auth method to verify the client + let method = client + .token_endpoint_auth_method + .as_ref() + .ok_or(RouteError::ClientNotAllowed)?; + + client_authorization + .credentials + .verify(&http_client_factory, &encrypter, method, &client) + .await?; + + client_authorization + .credentials + .verify(&http_client_factory, &encrypter, method, &client) + .await?; + + // TODO: check if the client can use the device code grant type + + let scope = client_authorization + .form + .and_then(|f| f.scope) + // XXX: Is this really how we do empty scopes? + .unwrap_or(std::iter::empty::().collect()); + + let expires_in = Duration::minutes(20); + + let device_code = Alphanumeric.sample_string(&mut rng, 32); + let user_code = Alphanumeric.sample_string(&mut rng, 6).to_uppercase(); + + let device_code = repo + .oauth2_device_code_grant() + .add( + &mut rng, + &clock, + OAuth2DeviceCodeGrantParams { + client: &client, + scope, + device_code, + user_code, + expires_in, + }, + ) + .await?; + + repo.save().await?; + + let response = DeviceAuthorizationResponse { + device_code: device_code.device_code, + user_code: device_code.user_code.clone(), + verification_uri: url_builder.device_code_link(), + verification_uri_complete: Some(url_builder.device_code_link_full(device_code.user_code)), + expires_in, + interval: Some(Duration::seconds(5)), + }; + + Ok(( + StatusCode::OK, + TypedHeader(CacheControl::new().with_no_store()), + TypedHeader(Pragma::no_cache()), + Json(response), + ) + .into_response()) +} + +#[cfg(test)] +mod tests { + use hyper::{Request, StatusCode}; + use mas_router::SimpleRoute; + use oauth2_types::{ + registration::ClientRegistrationResponse, requests::DeviceAuthorizationResponse, + }; + use sqlx::PgPool; + + use crate::test_utils::{init_tracing, RequestBuilderExt, ResponseExt, TestState}; + + #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")] + async fn test_device_code_request(pool: PgPool) { + init_tracing(); + let state = TestState::from_pool(pool).await.unwrap(); + + // Provision a client + let request = + Request::post(mas_router::OAuth2RegistrationEndpoint::PATH).json(serde_json::json!({ + "client_uri": "https://example.com/", + "contacts": ["contact@example.com"], + "token_endpoint_auth_method": "none", + "grant_types": ["urn:ietf:params:oauth:grant-type:device_code"], + "response_types": [], + })); + + let response = state.request(request).await; + response.assert_status(StatusCode::CREATED); + + let response: ClientRegistrationResponse = response.json(); + let client_id = response.client_id; + + // Test the happy path: the client is allowed to use the device code grant type + let request = Request::post(mas_router::OAuth2DeviceAuthorizationEndpoint::PATH).form( + serde_json::json!({ + "client_id": client_id, + "scope": "openid", + }), + ); + let response = state.request(request).await; + response.assert_status(StatusCode::OK); + + let response: DeviceAuthorizationResponse = response.json(); + assert_eq!(response.device_code.len(), 32); + assert_eq!(response.user_code.len(), 6); + } +} diff --git a/crates/handlers/src/oauth2/device/mod.rs b/crates/handlers/src/oauth2/device/mod.rs index b822a3e1..7236a11c 100644 --- a/crates/handlers/src/oauth2/device/mod.rs +++ b/crates/handlers/src/oauth2/device/mod.rs @@ -12,4 +12,5 @@ // See the License for the specific language governing permissions and // limitations under the License. +pub mod authorize; pub mod link; diff --git a/crates/handlers/src/oauth2/discovery.rs b/crates/handlers/src/oauth2/discovery.rs index 19f3ef3a..374817bd 100644 --- a/crates/handlers/src/oauth2/discovery.rs +++ b/crates/handlers/src/oauth2/discovery.rs @@ -65,6 +65,7 @@ pub(crate) async fn get( let issuer = Some(url_builder.oidc_issuer().into()); let authorization_endpoint = Some(url_builder.oauth_authorization_endpoint()); let token_endpoint = Some(url_builder.oauth_token_endpoint()); + let device_authorization_endpoint = Some(url_builder.oauth_device_authorization_endpoint()); let jwks_uri = Some(url_builder.jwks_uri()); let introspection_endpoint = Some(url_builder.oauth_introspection_endpoint()); let revocation_endpoint = Some(url_builder.oauth_revocation_endpoint()); @@ -166,6 +167,7 @@ pub(crate) async fn get( request_parameter_supported, request_uri_parameter_supported, prompt_values_supported, + device_authorization_endpoint, ..ProviderMetadata::default() }; diff --git a/crates/oauth2-types/src/requests.rs b/crates/oauth2-types/src/requests.rs index aee3525a..d28c0202 100644 --- a/crates/oauth2-types/src/requests.rs +++ b/crates/oauth2-types/src/requests.rs @@ -376,32 +376,32 @@ pub const DEFAULT_DEVICE_AUTHORIZATION_INTERVAL_SECONDS: i64 = 5; #[derive(Serialize, Deserialize, Clone, PartialEq, Eq)] pub struct DeviceAuthorizationResponse { /// The device verification code. - device_code: String, + pub device_code: String, /// The end-user verification code. - user_code: String, + pub user_code: String, /// The end-user verification URI on the authorization server. /// /// The URI should be short and easy to remember as end users will be asked /// to manually type it into their user agent. - verification_uri: Url, + pub verification_uri: Url, /// A verification URI that includes the `user_code` (or other information /// with the same function as the `user_code`), which is designed for /// non-textual transmission. - verification_uri_complete: Option, + pub verification_uri_complete: Option, /// The lifetime of the `device_code` and `user_code`. #[serde_as(as = "DurationSeconds")] - expires_in: Duration, + pub expires_in: Duration, /// The minimum amount of time in seconds that the client should wait /// between polling requests to the token endpoint. /// /// Defaults to [`DEFAULT_DEVICE_AUTHORIZATION_INTERVAL_SECONDS`]. #[serde_as(as = "Option>")] - interval: Option, + pub interval: Option, } impl DeviceAuthorizationResponse { diff --git a/crates/router/src/endpoints.rs b/crates/router/src/endpoints.rs index aea28e58..dd8c06c0 100644 --- a/crates/router/src/endpoints.rs +++ b/crates/router/src/endpoints.rs @@ -695,6 +695,13 @@ pub struct DeviceCodeLink { code: Option, } +impl DeviceCodeLink { + #[must_use] + pub fn with_code(code: String) -> Self { + Self { code: Some(code) } + } +} + impl Route for DeviceCodeLink { type Query = DeviceCodeLink; fn route() -> &'static str { @@ -706,6 +713,14 @@ impl Route for DeviceCodeLink { } } +/// `POST /oauth2/device` +#[derive(Default, Serialize, Deserialize, Debug, Clone)] +pub struct OAuth2DeviceAuthorizationEndpoint; + +impl SimpleRoute for OAuth2DeviceAuthorizationEndpoint { + const PATH: &'static str = "/oauth2/device"; +} + /// `GET /assets` pub struct StaticAsset { path: String, diff --git a/crates/router/src/url_builder.rs b/crates/router/src/url_builder.rs index 6d72a84a..b86505c5 100644 --- a/crates/router/src/url_builder.rs +++ b/crates/router/src/url_builder.rs @@ -154,6 +154,24 @@ impl UrlBuilder { self.absolute_url_for(&crate::endpoints::OAuth2RegistrationEndpoint) } + /// OAuth 2.0 device authorization endpoint + #[must_use] + pub fn oauth_device_authorization_endpoint(&self) -> Url { + self.absolute_url_for(&crate::endpoints::OAuth2DeviceAuthorizationEndpoint) + } + + /// OAuth 2.0 device code link + #[must_use] + pub fn device_code_link(&self) -> Url { + self.absolute_url_for(&crate::endpoints::DeviceCodeLink::default()) + } + + /// OAuth 2.0 device code link full URL + #[must_use] + pub fn device_code_link_full(&self, code: String) -> Url { + self.absolute_url_for(&crate::endpoints::DeviceCodeLink::with_code(code)) + } + // OIDC userinfo endpoint #[must_use] pub fn oidc_userinfo_endpoint(&self) -> Url { diff --git a/crates/storage-pg/src/oauth2/device_code_grant.rs b/crates/storage-pg/src/oauth2/device_code_grant.rs index 0b5094b1..519c1dd2 100644 --- a/crates/storage-pg/src/oauth2/device_code_grant.rs +++ b/crates/storage-pg/src/oauth2/device_code_grant.rs @@ -353,7 +353,7 @@ impl<'c> OAuth2DeviceCodeGrantRepository for PgOAuth2DeviceCodeGrantRepository<' ) -> Result { let fulfilled_at = clock.now(); let device_code_grant = device_code_grant - .fulfill(&browser_session, fulfilled_at) + .fulfill(browser_session, fulfilled_at) .map_err(DatabaseError::to_invalid_operation)?; let res = sqlx::query!( @@ -396,7 +396,7 @@ impl<'c> OAuth2DeviceCodeGrantRepository for PgOAuth2DeviceCodeGrantRepository<' ) -> Result { let fulfilled_at = clock.now(); let device_code_grant = device_code_grant - .reject(&browser_session, fulfilled_at) + .reject(browser_session, fulfilled_at) .map_err(DatabaseError::to_invalid_operation)?; let res = sqlx::query!(