1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-03 07:22:32 +03:00

Run the registration policy on upstream OAuth registration

This commit is contained in:
Quentin Gliech
2023-08-30 18:36:53 +02:00
parent 7fcd022eea
commit 23571e87ea
14 changed files with 207 additions and 41 deletions

View File

@ -24,6 +24,7 @@ pub struct FancyError {
} }
impl FancyError { impl FancyError {
#[must_use]
pub fn new(context: ErrorContext) -> Self { pub fn new(context: ErrorContext) -> Self {
Self { context } Self { context }
} }

View File

@ -21,10 +21,11 @@ use hyper::StatusCode;
use mas_axum_utils::{ use mas_axum_utils::{
cookies::CookieJar, cookies::CookieJar,
csrf::{CsrfExt, ProtectedForm}, csrf::{CsrfExt, ProtectedForm},
SessionInfoExt, FancyError, SessionInfoExt,
}; };
use mas_data_model::{UpstreamOAuthProviderImportPreference, User}; use mas_data_model::{UpstreamOAuthProviderImportPreference, User};
use mas_jose::jwt::Jwt; use mas_jose::jwt::Jwt;
use mas_policy::Policy;
use mas_storage::{ use mas_storage::{
job::{JobRepositoryExt, ProvisionUserJob}, job::{JobRepositoryExt, ProvisionUserJob},
upstream_oauth2::{UpstreamOAuthLinkRepository, UpstreamOAuthSessionRepository}, upstream_oauth2::{UpstreamOAuthLinkRepository, UpstreamOAuthSessionRepository},
@ -32,7 +33,8 @@ use mas_storage::{
BoxClock, BoxRepository, BoxRng, RepositoryAccess, BoxClock, BoxRepository, BoxRng, RepositoryAccess,
}; };
use mas_templates::{ use mas_templates::{
TemplateContext, Templates, UpstreamExistingLinkContext, UpstreamRegister, UpstreamSuggestLink, ErrorContext, TemplateContext, Templates, UpstreamExistingLinkContext, UpstreamRegister,
UpstreamSuggestLink,
}; };
use serde::Deserialize; use serde::Deserialize;
use thiserror::Error; use thiserror::Error;
@ -76,6 +78,11 @@ pub(crate) enum RouteError {
#[error("Missing username")] #[error("Missing username")]
MissingUsername, MissingUsername,
#[error("Policy violation: {violations:?}")]
PolicyViolation {
violations: Vec<mas_policy::Violation>,
},
#[error(transparent)] #[error(transparent)]
Internal(Box<dyn std::error::Error>), Internal(Box<dyn std::error::Error>),
} }
@ -84,6 +91,7 @@ impl_from_error_for_route!(mas_templates::TemplateError);
impl_from_error_for_route!(mas_axum_utils::csrf::CsrfError); impl_from_error_for_route!(mas_axum_utils::csrf::CsrfError);
impl_from_error_for_route!(super::cookie::UpstreamSessionNotFound); impl_from_error_for_route!(super::cookie::UpstreamSessionNotFound);
impl_from_error_for_route!(mas_storage::RepositoryError); impl_from_error_for_route!(mas_storage::RepositoryError);
impl_from_error_for_route!(mas_policy::EvaluationError);
impl_from_error_for_route!(mas_jose::jwt::JwtDecodeError); impl_from_error_for_route!(mas_jose::jwt::JwtDecodeError);
impl IntoResponse for RouteError { impl IntoResponse for RouteError {
@ -91,6 +99,16 @@ impl IntoResponse for RouteError {
sentry::capture_error(&self); sentry::capture_error(&self);
match self { match self {
Self::LinkNotFound => (StatusCode::NOT_FOUND, "Link not found").into_response(), Self::LinkNotFound => (StatusCode::NOT_FOUND, "Link not found").into_response(),
Self::PolicyViolation { violations } => {
let details = violations.iter().map(|v| v.msg.clone()).collect::<Vec<_>>();
let details = details.join("\n");
let ctx = ErrorContext::new()
.with_description(
"Account registration denied because of policy violation".to_owned(),
)
.with_details(details);
FancyError::new(ctx).into_response()
}
Self::Internal(e) => (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response(), Self::Internal(e) => (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response(),
e => (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response(), e => (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response(),
} }
@ -358,6 +376,7 @@ pub(crate) async fn post(
mut repo: BoxRepository, mut repo: BoxRepository,
cookie_jar: CookieJar, cookie_jar: CookieJar,
user_agent: Option<TypedHeader<headers::UserAgent>>, user_agent: Option<TypedHeader<headers::UserAgent>>,
mut policy: Policy,
Path(link_id): Path<Ulid>, Path(link_id): Path<Ulid>,
Form(form): Form<ProtectedForm<FormData>>, Form(form): Form<ProtectedForm<FormData>>,
) -> Result<impl IntoResponse, RouteError> { ) -> Result<impl IntoResponse, RouteError> {
@ -478,6 +497,16 @@ pub(crate) async fn post(
let username = username.ok_or(RouteError::MissingUsername)?; let username = username.ok_or(RouteError::MissingUsername)?;
// Policy check
let res = policy
.evaluate_upstream_oauth_register(&username, email.as_deref())
.await?;
if !res.valid() {
return Err(RouteError::PolicyViolation {
violations: res.violations,
});
}
// Now we can create the user // Now we can create the user
let user = repo.user().add(&mut rng, &clock, username).await?; let user = repo.user().add(&mut rng, &clock, username).await?;

View File

@ -107,12 +107,9 @@ pub(crate) async fn post(
let user_email = if let Some(user_email) = existing_user_email { let user_email = if let Some(user_email) = existing_user_email {
user_email user_email
} else { } else {
let user_email = repo repo.user_email()
.user_email()
.add(&mut rng, &clock, &session.user, form.email) .add(&mut rng, &clock, &session.user, form.email)
.await?; .await?
user_email
}; };
// If the email was not confirmed, send a confirmation email & redirect to the // If the email was not confirmed, send a confirmation email & redirect to the

View File

@ -251,6 +251,31 @@ impl Policy {
Ok(res) Ok(res)
} }
#[tracing::instrument(
name = "policy.evaluate.upstream_oauth_register",
skip_all,
fields(
input.registration_method = "password",
input.user.username = username,
input.user.email = email,
),
err,
)]
pub async fn evaluate_upstream_oauth_register(
&mut self,
username: &str,
email: Option<&str>,
) -> Result<EvaluationResult, EvaluationError> {
let input = RegisterInput::UpstreamOAuth2 { username, email };
let [res]: [EvaluationResult; 1] = self
.instance
.evaluate(&mut self.store, &self.entrypoints.register, &input)
.await?;
Ok(res)
}
#[tracing::instrument(skip(self))] #[tracing::instrument(skip(self))]
pub async fn evaluate_client_registration( pub async fn evaluate_client_registration(
&mut self, &mut self,

View File

@ -56,14 +56,23 @@ impl EvaluationResult {
/// Input for the user registration policy. /// Input for the user registration policy.
#[derive(Serialize, Debug)] #[derive(Serialize, Debug)]
#[serde(tag = "registration_method", rename_all = "snake_case")] #[serde(tag = "registration_method")]
#[cfg_attr(feature = "jsonschema", derive(schemars::JsonSchema))] #[cfg_attr(feature = "jsonschema", derive(schemars::JsonSchema))]
pub enum RegisterInput<'a> { pub enum RegisterInput<'a> {
#[serde(rename = "password")]
Password { Password {
username: &'a str, username: &'a str,
password: &'a str, password: &'a str,
email: &'a str, email: &'a str,
}, },
#[serde(rename = "upstream-oauth2")]
UpstreamOAuth2 {
username: &'a str,
#[serde(skip_serializing_if = "Option::is_none")]
email: Option<&'a str>,
},
} }
/// Input for the client registration policy. /// Input for the client registration policy.

View File

@ -46,5 +46,5 @@ coverage:
.PHONY: lint .PHONY: lint
lint: lint:
$(OPA) fmt -d --fail . $(OPA) fmt -d --fail ./*.rego util/*.rego
$(OPA) check --strict . $(OPA) check --strict --schema schema/ ./*.rego util/*.rego

26
policies/email_test.rego Normal file
View File

@ -0,0 +1,26 @@
package email
test_allow_all_domains {
allow with input.email as "hello@staging.element.io"
}
test_allowed_domain {
allow with input.email as "hello@staging.element.io"
with data.allowed_domains as ["*.element.io"]
}
test_not_allowed_domain {
not allow with input.email as "hello@staging.element.io"
with data.allowed_domains as ["example.com"]
}
test_banned_domain {
not allow with input.email as "hello@staging.element.io"
with data.banned_domains as ["*.element.io"]
}
test_banned_subdomain {
not allow with input.email as "hello@staging.element.io"
with data.allowed_domains as ["*.element.io"]
with data.banned_domains as ["staging.element.io"]
}

View File

@ -0,0 +1,29 @@
package password
test_password_require_number {
allow with data.passwords.require_number as true
not allow with input.password as "hunter"
with data.passwords.require_number as true
}
test_password_require_lowercase {
allow with data.passwords.require_lowercase as true
not allow with input.password as "HUNTER2"
with data.passwords.require_lowercase as true
}
test_password_require_uppercase {
allow with data.passwords.require_uppercase as true
not allow with input.password as "hunter2"
with data.passwords.require_uppercase as true
}
test_password_min_length {
allow with data.passwords.min_length as 6
not allow with input.password as "short"
with data.passwords.min_length as 6
}

View File

@ -22,6 +22,18 @@ violation[{"field": "username", "msg": "username too long"}] {
count(input.username) >= 15 count(input.username) >= 15
} }
violation[{"field": "username", "msg": "username contains invalid characters"}] {
not regex.match("^[a-z0-9.=_/-]+$", input.username)
}
violation[{"msg": "unspecified registration method"}] {
not input.registration_method
}
violation[{"msg": "unknown registration method"}] {
not input.registration_method in ["password", "upstream-oauth2"]
}
violation[object.union({"field": "password"}, v)] { violation[object.union({"field": "password"}, v)] {
# Check if the registration method is password # Check if the registration method is password
input.registration_method == "password" input.registration_method == "password"
@ -30,9 +42,19 @@ violation[object.union({"field": "password"}, v)] {
some v in password_policy.violation some v in password_policy.violation
} }
# Check that we supplied an email for password registration
violation[{"field": "email", "msg": "email required for password-based registration"}] {
input.registration_method == "password"
not input.email
}
# Check if the email is valid using the email policy # Check if the email is valid using the email policy
# and add the email field to the violation object # and add the email field to the violation object
violation[object.union({"field": "email"}, v)] { violation[object.union({"field": "email"}, v)] {
# Check if we have an email set in the input
input.email
# Get the violation object from the email policy # Get the violation object from the email policy
some v in email_policy.violation some v in email_policy.violation
} }

View File

@ -32,54 +32,58 @@ test_banned_subdomain {
with data.banned_domains as ["staging.element.io"] with data.banned_domains as ["staging.element.io"]
} }
test_email_required {
not allow with input as {"username": "hello", "registration_method": "password"}
}
test_no_email {
allow with input as {"username": "hello", "registration_method": "upstream-oauth2"}
}
test_short_username { test_short_username {
not allow with input as {"username": "a", "email": "hello@element.io"} not allow with input as {"username": "a", "registration_method": "upstream-oauth2"}
} }
test_long_username { test_long_username {
not allow with input as {"username": "aaaaaaaaaaaaaaaaaaaaaaaaaaaaa", "email": "hello@element.io"} not allow with input as {"username": "aaaaaaaaaaaaaaaaaaaaaaaaaaaaa", "registration_method": "upstream-oauth2"}
}
test_invalid_username {
not allow with input as {"username": "hello world", "registration_method": "upstream-oauth2"}
} }
test_password_require_number { test_password_require_number {
allow with input as mock_registration allow with input as mock_registration
with input.registration_method as "password"
with data.passwords.require_number as true with data.passwords.require_number as true
not allow with input as mock_registration not allow with input as mock_registration
with input.registration_method as "password"
with input.password as "hunter" with input.password as "hunter"
with data.passwords.require_number as true with data.passwords.require_number as true
} }
test_password_require_lowercase { test_password_require_lowercase {
allow with input as mock_registration allow with input as mock_registration
with input.registration_method as "password"
with data.passwords.require_lowercase as true with data.passwords.require_lowercase as true
not allow with input as mock_registration not allow with input as mock_registration
with input.registration_method as "password"
with input.password as "HUNTER2" with input.password as "HUNTER2"
with data.passwords.require_lowercase as true with data.passwords.require_lowercase as true
} }
test_password_require_uppercase { test_password_require_uppercase {
allow with input as mock_registration allow with input as mock_registration
with input.registration_method as "password"
with data.passwords.require_uppercase as true with data.passwords.require_uppercase as true
not allow with input as mock_registration not allow with input as mock_registration
with input.registration_method as "password"
with input.password as "hunter2" with input.password as "hunter2"
with data.passwords.require_uppercase as true with data.passwords.require_uppercase as true
} }
test_password_min_length { test_password_min_length {
allow with input as mock_registration allow with input as mock_registration
with input.registration_method as "password"
with data.passwords.min_length as 6 with data.passwords.min_length as 6
not allow with input as mock_registration not allow with input as mock_registration
with input.registration_method as "password"
with input.password as "short" with input.password as "short"
with data.passwords.min_length as 6 with data.passwords.min_length as 6
} }

View File

@ -28,6 +28,27 @@
"type": "string" "type": "string"
} }
} }
},
{
"type": "object",
"required": [
"registration_method",
"username"
],
"properties": {
"email": {
"type": "string"
},
"registration_method": {
"type": "string",
"enum": [
"upstream-oauth2"
]
},
"username": {
"type": "string"
}
}
} }
] ]
} }

View File

@ -14,7 +14,7 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
#} #}
{% macro input(label, name, type="text", form_state=false, autocomplete=false, class="", inputmode="text", autocorrect=false, autocapitalize=false, disabled=false) %} {% macro input(label, name, type="text", form_state=false, autocomplete=false, class="", inputmode="text", autocorrect=false, autocapitalize=false, disabled=false, required=false) %}
{% if not form_state %} {% if not form_state %}
{% set form_state = dict(errors=[], fields=dict()) %} {% set form_state = dict(errors=[], fields=dict()) %}
{% endif %} {% endif %}
@ -35,6 +35,7 @@ limitations under the License.
class="z-0 px-3 py-2 bg-white dark:bg-black-900 rounded-lg {{ border_color }} border-2 focus:border-accent focus:ring-0 focus:outline-0" class="z-0 px-3 py-2 bg-white dark:bg-black-900 rounded-lg {{ border_color }} border-2 focus:border-accent focus:ring-0 focus:outline-0"
type="{{ type }}" type="{{ type }}"
inputmode="{{ inputmode }}" inputmode="{{ inputmode }}"
{% if required %} required {% endif %}
{% if disabled %} disabled {% endif %} {% if disabled %} disabled {% endif %}
{% if autocomplete %} autocomplete="{{ autocomplete }}" {% endif %} {% if autocomplete %} autocomplete="{{ autocomplete }}" {% endif %}
{% if autocorrect %} autocorrect="{{ autocorrect }}" {% endif %} {% if autocorrect %} autocorrect="{{ autocorrect }}" {% endif %}

View File

@ -33,7 +33,7 @@ limitations under the License.
{% endif %} {% endif %}
<input type="hidden" name="csrf" value="{{ csrf_token }}" /> <input type="hidden" name="csrf" value="{{ csrf_token }}" />
{{ field::input(label="Email", name="email", type="email", form_state=form, autocomplete="email") }} {{ field::input(label="Email", name="email", type="email", form_state=form, autocomplete="email", required=true) }}
{{ button::button(text="Next") }} {{ button::button(text="Next") }}
</section> </section>
{% endblock content %} {% endblock content %}

View File

@ -17,23 +17,25 @@ limitations under the License.
{% extends "base.html" %} {% extends "base.html" %}
{% block content %} {% block content %}
<section class="hero is-danger"> <section class="flex-1 flex items-center justify-center">
<div class="hero-body"> <div class="w-64 flex flex-col gap-2">
<div class="container"> <h1 class="text-xl font-semibold">Unexpected error</h1>
{% if code %} {% if code %}
<p class="title"> <p class="font-semibold font-mono">
{{ code }} {{ code }}
</p> </p>
{% endif %} {% endif %}
{% if description %} {% if description %}
<p class="subtitle"> <p>
{{ description }} {{ description }}
</p> </p>
{% endif %} {% endif %}
{% if details %} {% if details %}
<pre><code>{{ details }}</code></pre> <hr />
{% endif %} <code>
</div> <pre class="font-mono whitespace-pre-wrap break-all">{{ details }}</pre>
</div> </code>
</section> {% endif %}
</div>
</section>
{% endblock %} {% endblock %}