You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-07-03 07:22:32 +03:00
Run the registration policy on upstream OAuth registration
This commit is contained in:
@ -24,6 +24,7 @@ pub struct FancyError {
|
||||
}
|
||||
|
||||
impl FancyError {
|
||||
#[must_use]
|
||||
pub fn new(context: ErrorContext) -> Self {
|
||||
Self { context }
|
||||
}
|
||||
|
@ -21,10 +21,11 @@ use hyper::StatusCode;
|
||||
use mas_axum_utils::{
|
||||
cookies::CookieJar,
|
||||
csrf::{CsrfExt, ProtectedForm},
|
||||
SessionInfoExt,
|
||||
FancyError, SessionInfoExt,
|
||||
};
|
||||
use mas_data_model::{UpstreamOAuthProviderImportPreference, User};
|
||||
use mas_jose::jwt::Jwt;
|
||||
use mas_policy::Policy;
|
||||
use mas_storage::{
|
||||
job::{JobRepositoryExt, ProvisionUserJob},
|
||||
upstream_oauth2::{UpstreamOAuthLinkRepository, UpstreamOAuthSessionRepository},
|
||||
@ -32,7 +33,8 @@ use mas_storage::{
|
||||
BoxClock, BoxRepository, BoxRng, RepositoryAccess,
|
||||
};
|
||||
use mas_templates::{
|
||||
TemplateContext, Templates, UpstreamExistingLinkContext, UpstreamRegister, UpstreamSuggestLink,
|
||||
ErrorContext, TemplateContext, Templates, UpstreamExistingLinkContext, UpstreamRegister,
|
||||
UpstreamSuggestLink,
|
||||
};
|
||||
use serde::Deserialize;
|
||||
use thiserror::Error;
|
||||
@ -76,6 +78,11 @@ pub(crate) enum RouteError {
|
||||
#[error("Missing username")]
|
||||
MissingUsername,
|
||||
|
||||
#[error("Policy violation: {violations:?}")]
|
||||
PolicyViolation {
|
||||
violations: Vec<mas_policy::Violation>,
|
||||
},
|
||||
|
||||
#[error(transparent)]
|
||||
Internal(Box<dyn std::error::Error>),
|
||||
}
|
||||
@ -84,6 +91,7 @@ impl_from_error_for_route!(mas_templates::TemplateError);
|
||||
impl_from_error_for_route!(mas_axum_utils::csrf::CsrfError);
|
||||
impl_from_error_for_route!(super::cookie::UpstreamSessionNotFound);
|
||||
impl_from_error_for_route!(mas_storage::RepositoryError);
|
||||
impl_from_error_for_route!(mas_policy::EvaluationError);
|
||||
impl_from_error_for_route!(mas_jose::jwt::JwtDecodeError);
|
||||
|
||||
impl IntoResponse for RouteError {
|
||||
@ -91,6 +99,16 @@ impl IntoResponse for RouteError {
|
||||
sentry::capture_error(&self);
|
||||
match self {
|
||||
Self::LinkNotFound => (StatusCode::NOT_FOUND, "Link not found").into_response(),
|
||||
Self::PolicyViolation { violations } => {
|
||||
let details = violations.iter().map(|v| v.msg.clone()).collect::<Vec<_>>();
|
||||
let details = details.join("\n");
|
||||
let ctx = ErrorContext::new()
|
||||
.with_description(
|
||||
"Account registration denied because of policy violation".to_owned(),
|
||||
)
|
||||
.with_details(details);
|
||||
FancyError::new(ctx).into_response()
|
||||
}
|
||||
Self::Internal(e) => (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response(),
|
||||
e => (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response(),
|
||||
}
|
||||
@ -358,6 +376,7 @@ pub(crate) async fn post(
|
||||
mut repo: BoxRepository,
|
||||
cookie_jar: CookieJar,
|
||||
user_agent: Option<TypedHeader<headers::UserAgent>>,
|
||||
mut policy: Policy,
|
||||
Path(link_id): Path<Ulid>,
|
||||
Form(form): Form<ProtectedForm<FormData>>,
|
||||
) -> Result<impl IntoResponse, RouteError> {
|
||||
@ -478,6 +497,16 @@ pub(crate) async fn post(
|
||||
|
||||
let username = username.ok_or(RouteError::MissingUsername)?;
|
||||
|
||||
// Policy check
|
||||
let res = policy
|
||||
.evaluate_upstream_oauth_register(&username, email.as_deref())
|
||||
.await?;
|
||||
if !res.valid() {
|
||||
return Err(RouteError::PolicyViolation {
|
||||
violations: res.violations,
|
||||
});
|
||||
}
|
||||
|
||||
// Now we can create the user
|
||||
let user = repo.user().add(&mut rng, &clock, username).await?;
|
||||
|
||||
|
@ -107,12 +107,9 @@ pub(crate) async fn post(
|
||||
let user_email = if let Some(user_email) = existing_user_email {
|
||||
user_email
|
||||
} else {
|
||||
let user_email = repo
|
||||
.user_email()
|
||||
repo.user_email()
|
||||
.add(&mut rng, &clock, &session.user, form.email)
|
||||
.await?;
|
||||
|
||||
user_email
|
||||
.await?
|
||||
};
|
||||
|
||||
// If the email was not confirmed, send a confirmation email & redirect to the
|
||||
|
@ -251,6 +251,31 @@ impl Policy {
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "policy.evaluate.upstream_oauth_register",
|
||||
skip_all,
|
||||
fields(
|
||||
input.registration_method = "password",
|
||||
input.user.username = username,
|
||||
input.user.email = email,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
pub async fn evaluate_upstream_oauth_register(
|
||||
&mut self,
|
||||
username: &str,
|
||||
email: Option<&str>,
|
||||
) -> Result<EvaluationResult, EvaluationError> {
|
||||
let input = RegisterInput::UpstreamOAuth2 { username, email };
|
||||
|
||||
let [res]: [EvaluationResult; 1] = self
|
||||
.instance
|
||||
.evaluate(&mut self.store, &self.entrypoints.register, &input)
|
||||
.await?;
|
||||
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self))]
|
||||
pub async fn evaluate_client_registration(
|
||||
&mut self,
|
||||
|
@ -56,14 +56,23 @@ impl EvaluationResult {
|
||||
|
||||
/// Input for the user registration policy.
|
||||
#[derive(Serialize, Debug)]
|
||||
#[serde(tag = "registration_method", rename_all = "snake_case")]
|
||||
#[serde(tag = "registration_method")]
|
||||
#[cfg_attr(feature = "jsonschema", derive(schemars::JsonSchema))]
|
||||
pub enum RegisterInput<'a> {
|
||||
#[serde(rename = "password")]
|
||||
Password {
|
||||
username: &'a str,
|
||||
password: &'a str,
|
||||
email: &'a str,
|
||||
},
|
||||
|
||||
#[serde(rename = "upstream-oauth2")]
|
||||
UpstreamOAuth2 {
|
||||
username: &'a str,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
email: Option<&'a str>,
|
||||
},
|
||||
}
|
||||
|
||||
/// Input for the client registration policy.
|
||||
|
@ -46,5 +46,5 @@ coverage:
|
||||
|
||||
.PHONY: lint
|
||||
lint:
|
||||
$(OPA) fmt -d --fail .
|
||||
$(OPA) check --strict .
|
||||
$(OPA) fmt -d --fail ./*.rego util/*.rego
|
||||
$(OPA) check --strict --schema schema/ ./*.rego util/*.rego
|
||||
|
26
policies/email_test.rego
Normal file
26
policies/email_test.rego
Normal file
@ -0,0 +1,26 @@
|
||||
package email
|
||||
|
||||
test_allow_all_domains {
|
||||
allow with input.email as "hello@staging.element.io"
|
||||
}
|
||||
|
||||
test_allowed_domain {
|
||||
allow with input.email as "hello@staging.element.io"
|
||||
with data.allowed_domains as ["*.element.io"]
|
||||
}
|
||||
|
||||
test_not_allowed_domain {
|
||||
not allow with input.email as "hello@staging.element.io"
|
||||
with data.allowed_domains as ["example.com"]
|
||||
}
|
||||
|
||||
test_banned_domain {
|
||||
not allow with input.email as "hello@staging.element.io"
|
||||
with data.banned_domains as ["*.element.io"]
|
||||
}
|
||||
|
||||
test_banned_subdomain {
|
||||
not allow with input.email as "hello@staging.element.io"
|
||||
with data.allowed_domains as ["*.element.io"]
|
||||
with data.banned_domains as ["staging.element.io"]
|
||||
}
|
29
policies/password_test.rego
Normal file
29
policies/password_test.rego
Normal file
@ -0,0 +1,29 @@
|
||||
package password
|
||||
|
||||
test_password_require_number {
|
||||
allow with data.passwords.require_number as true
|
||||
|
||||
not allow with input.password as "hunter"
|
||||
with data.passwords.require_number as true
|
||||
}
|
||||
|
||||
test_password_require_lowercase {
|
||||
allow with data.passwords.require_lowercase as true
|
||||
|
||||
not allow with input.password as "HUNTER2"
|
||||
with data.passwords.require_lowercase as true
|
||||
}
|
||||
|
||||
test_password_require_uppercase {
|
||||
allow with data.passwords.require_uppercase as true
|
||||
|
||||
not allow with input.password as "hunter2"
|
||||
with data.passwords.require_uppercase as true
|
||||
}
|
||||
|
||||
test_password_min_length {
|
||||
allow with data.passwords.min_length as 6
|
||||
|
||||
not allow with input.password as "short"
|
||||
with data.passwords.min_length as 6
|
||||
}
|
@ -22,6 +22,18 @@ violation[{"field": "username", "msg": "username too long"}] {
|
||||
count(input.username) >= 15
|
||||
}
|
||||
|
||||
violation[{"field": "username", "msg": "username contains invalid characters"}] {
|
||||
not regex.match("^[a-z0-9.=_/-]+$", input.username)
|
||||
}
|
||||
|
||||
violation[{"msg": "unspecified registration method"}] {
|
||||
not input.registration_method
|
||||
}
|
||||
|
||||
violation[{"msg": "unknown registration method"}] {
|
||||
not input.registration_method in ["password", "upstream-oauth2"]
|
||||
}
|
||||
|
||||
violation[object.union({"field": "password"}, v)] {
|
||||
# Check if the registration method is password
|
||||
input.registration_method == "password"
|
||||
@ -30,9 +42,19 @@ violation[object.union({"field": "password"}, v)] {
|
||||
some v in password_policy.violation
|
||||
}
|
||||
|
||||
# Check that we supplied an email for password registration
|
||||
violation[{"field": "email", "msg": "email required for password-based registration"}] {
|
||||
input.registration_method == "password"
|
||||
|
||||
not input.email
|
||||
}
|
||||
|
||||
# Check if the email is valid using the email policy
|
||||
# and add the email field to the violation object
|
||||
violation[object.union({"field": "email"}, v)] {
|
||||
# Check if we have an email set in the input
|
||||
input.email
|
||||
|
||||
# Get the violation object from the email policy
|
||||
some v in email_policy.violation
|
||||
}
|
||||
|
@ -32,54 +32,58 @@ test_banned_subdomain {
|
||||
with data.banned_domains as ["staging.element.io"]
|
||||
}
|
||||
|
||||
test_email_required {
|
||||
not allow with input as {"username": "hello", "registration_method": "password"}
|
||||
}
|
||||
|
||||
test_no_email {
|
||||
allow with input as {"username": "hello", "registration_method": "upstream-oauth2"}
|
||||
}
|
||||
|
||||
test_short_username {
|
||||
not allow with input as {"username": "a", "email": "hello@element.io"}
|
||||
not allow with input as {"username": "a", "registration_method": "upstream-oauth2"}
|
||||
}
|
||||
|
||||
test_long_username {
|
||||
not allow with input as {"username": "aaaaaaaaaaaaaaaaaaaaaaaaaaaaa", "email": "hello@element.io"}
|
||||
not allow with input as {"username": "aaaaaaaaaaaaaaaaaaaaaaaaaaaaa", "registration_method": "upstream-oauth2"}
|
||||
}
|
||||
|
||||
test_invalid_username {
|
||||
not allow with input as {"username": "hello world", "registration_method": "upstream-oauth2"}
|
||||
}
|
||||
|
||||
test_password_require_number {
|
||||
allow with input as mock_registration
|
||||
with input.registration_method as "password"
|
||||
with data.passwords.require_number as true
|
||||
|
||||
not allow with input as mock_registration
|
||||
with input.registration_method as "password"
|
||||
with input.password as "hunter"
|
||||
with data.passwords.require_number as true
|
||||
}
|
||||
|
||||
test_password_require_lowercase {
|
||||
allow with input as mock_registration
|
||||
with input.registration_method as "password"
|
||||
with data.passwords.require_lowercase as true
|
||||
|
||||
not allow with input as mock_registration
|
||||
with input.registration_method as "password"
|
||||
with input.password as "HUNTER2"
|
||||
with data.passwords.require_lowercase as true
|
||||
}
|
||||
|
||||
test_password_require_uppercase {
|
||||
allow with input as mock_registration
|
||||
with input.registration_method as "password"
|
||||
with data.passwords.require_uppercase as true
|
||||
|
||||
not allow with input as mock_registration
|
||||
with input.registration_method as "password"
|
||||
with input.password as "hunter2"
|
||||
with data.passwords.require_uppercase as true
|
||||
}
|
||||
|
||||
test_password_min_length {
|
||||
allow with input as mock_registration
|
||||
with input.registration_method as "password"
|
||||
with data.passwords.min_length as 6
|
||||
|
||||
not allow with input as mock_registration
|
||||
with input.registration_method as "password"
|
||||
with input.password as "short"
|
||||
with data.passwords.min_length as 6
|
||||
}
|
||||
|
@ -28,6 +28,27 @@
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "object",
|
||||
"required": [
|
||||
"registration_method",
|
||||
"username"
|
||||
],
|
||||
"properties": {
|
||||
"email": {
|
||||
"type": "string"
|
||||
},
|
||||
"registration_method": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"upstream-oauth2"
|
||||
]
|
||||
},
|
||||
"username": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
@ -14,7 +14,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
#}
|
||||
|
||||
{% macro input(label, name, type="text", form_state=false, autocomplete=false, class="", inputmode="text", autocorrect=false, autocapitalize=false, disabled=false) %}
|
||||
{% macro input(label, name, type="text", form_state=false, autocomplete=false, class="", inputmode="text", autocorrect=false, autocapitalize=false, disabled=false, required=false) %}
|
||||
{% if not form_state %}
|
||||
{% set form_state = dict(errors=[], fields=dict()) %}
|
||||
{% endif %}
|
||||
@ -35,6 +35,7 @@ limitations under the License.
|
||||
class="z-0 px-3 py-2 bg-white dark:bg-black-900 rounded-lg {{ border_color }} border-2 focus:border-accent focus:ring-0 focus:outline-0"
|
||||
type="{{ type }}"
|
||||
inputmode="{{ inputmode }}"
|
||||
{% if required %} required {% endif %}
|
||||
{% if disabled %} disabled {% endif %}
|
||||
{% if autocomplete %} autocomplete="{{ autocomplete }}" {% endif %}
|
||||
{% if autocorrect %} autocorrect="{{ autocorrect }}" {% endif %}
|
||||
|
@ -33,7 +33,7 @@ limitations under the License.
|
||||
{% endif %}
|
||||
|
||||
<input type="hidden" name="csrf" value="{{ csrf_token }}" />
|
||||
{{ field::input(label="Email", name="email", type="email", form_state=form, autocomplete="email") }}
|
||||
{{ field::input(label="Email", name="email", type="email", form_state=form, autocomplete="email", required=true) }}
|
||||
{{ button::button(text="Next") }}
|
||||
</section>
|
||||
{% endblock content %}
|
||||
|
@ -17,23 +17,25 @@ limitations under the License.
|
||||
{% extends "base.html" %}
|
||||
|
||||
{% block content %}
|
||||
<section class="hero is-danger">
|
||||
<div class="hero-body">
|
||||
<div class="container">
|
||||
{% if code %}
|
||||
<p class="title">
|
||||
{{ code }}
|
||||
</p>
|
||||
{% endif %}
|
||||
{% if description %}
|
||||
<p class="subtitle">
|
||||
{{ description }}
|
||||
</p>
|
||||
{% endif %}
|
||||
{% if details %}
|
||||
<pre><code>{{ details }}</code></pre>
|
||||
{% endif %}
|
||||
</div>
|
||||
</div>
|
||||
</section>
|
||||
<section class="flex-1 flex items-center justify-center">
|
||||
<div class="w-64 flex flex-col gap-2">
|
||||
<h1 class="text-xl font-semibold">Unexpected error</h1>
|
||||
{% if code %}
|
||||
<p class="font-semibold font-mono">
|
||||
{{ code }}
|
||||
</p>
|
||||
{% endif %}
|
||||
{% if description %}
|
||||
<p>
|
||||
{{ description }}
|
||||
</p>
|
||||
{% endif %}
|
||||
{% if details %}
|
||||
<hr />
|
||||
<code>
|
||||
<pre class="font-mono whitespace-pre-wrap break-all">{{ details }}</pre>
|
||||
</code>
|
||||
{% endif %}
|
||||
</div>
|
||||
</section>
|
||||
{% endblock %}
|
||||
|
Reference in New Issue
Block a user