You've already forked authentication-service
mirror of
https://github.com/matrix-org/matrix-authentication-service.git
synced 2025-07-29 22:01:14 +03:00
Split the service in multiple crates
This commit is contained in:
27
crates/cli/Cargo.toml
Normal file
27
crates/cli/Cargo.toml
Normal file
@ -0,0 +1,27 @@
|
||||
[package]
|
||||
name = "mas-cli"
|
||||
version = "0.1.0"
|
||||
authors = ["Quentin Gliech <quenting@element.io>"]
|
||||
edition = "2018"
|
||||
license = "Apache-2.0"
|
||||
|
||||
[dependencies]
|
||||
tokio = { version = "1.11.0", features = ["full"] }
|
||||
anyhow = "1.0.44"
|
||||
clap = "3.0.0-beta.4"
|
||||
tracing = "0.1.27"
|
||||
tracing-subscriber = "0.2.22"
|
||||
dotenv = "0.15.0"
|
||||
schemars = { version = "0.8.3", features = ["url", "chrono"] }
|
||||
tower = { version = "0.4.8", features = ["full"] }
|
||||
tower-http = { version = "0.1.1", features = ["full"] }
|
||||
hyper = { version = "0.14.12", features = ["full"] }
|
||||
serde_yaml = "0.8.21"
|
||||
warp = "0.3.1"
|
||||
argon2 = { version = "0.3.1", features = ["password-hash"] }
|
||||
|
||||
mas-config = { path = "../config" }
|
||||
mas-core = { path = "../core" }
|
||||
|
||||
[dev-dependencies]
|
||||
indoc = "1.0.3"
|
75
crates/cli/src/config.rs
Normal file
75
crates/cli/src/config.rs
Normal file
@ -0,0 +1,75 @@
|
||||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use clap::Clap;
|
||||
use mas_config::{ConfigurationSection, RootConfig};
|
||||
use schemars::schema_for;
|
||||
use tracing::info;
|
||||
|
||||
use super::RootCommand;
|
||||
|
||||
#[derive(Clap, Debug)]
|
||||
pub(super) struct ConfigCommand {
|
||||
#[clap(subcommand)]
|
||||
subcommand: ConfigSubcommand,
|
||||
}
|
||||
|
||||
#[derive(Clap, Debug)]
|
||||
enum ConfigSubcommand {
|
||||
/// Dump the current config as YAML
|
||||
Dump,
|
||||
|
||||
/// Print the JSON Schema that validates configuration files
|
||||
Schema,
|
||||
|
||||
/// Check a config file
|
||||
Check,
|
||||
|
||||
/// Generate a new config file
|
||||
Generate,
|
||||
}
|
||||
|
||||
impl ConfigCommand {
|
||||
pub async fn run(&self, root: &RootCommand) -> anyhow::Result<()> {
|
||||
use ConfigSubcommand as SC;
|
||||
match &self.subcommand {
|
||||
SC::Dump => {
|
||||
let config: RootConfig = root.load_config()?;
|
||||
|
||||
serde_yaml::to_writer(std::io::stdout(), &config)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
SC::Schema => {
|
||||
let schema = schema_for!(RootConfig);
|
||||
|
||||
serde_yaml::to_writer(std::io::stdout(), &schema)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
SC::Check => {
|
||||
let _config: RootConfig = root.load_config()?;
|
||||
info!(path = ?root.config, "Configuration file looks good");
|
||||
Ok(())
|
||||
}
|
||||
SC::Generate => {
|
||||
let config = RootConfig::load_and_generate().await?;
|
||||
|
||||
serde_yaml::to_writer(std::io::stdout(), &config)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
47
crates/cli/src/database.rs
Normal file
47
crates/cli/src/database.rs
Normal file
@ -0,0 +1,47 @@
|
||||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use anyhow::Context;
|
||||
use clap::Clap;
|
||||
use mas_config::DatabaseConfig;
|
||||
use mas_core::storage::MIGRATOR;
|
||||
|
||||
use super::RootCommand;
|
||||
|
||||
#[derive(Clap, Debug)]
|
||||
pub(super) struct DatabaseCommand {
|
||||
#[clap(subcommand)]
|
||||
subcommand: DatabaseSubcommand,
|
||||
}
|
||||
|
||||
#[derive(Clap, Debug)]
|
||||
enum DatabaseSubcommand {
|
||||
/// Run database migrations
|
||||
Migrate,
|
||||
}
|
||||
|
||||
impl DatabaseCommand {
|
||||
pub async fn run(&self, root: &RootCommand) -> anyhow::Result<()> {
|
||||
let config: DatabaseConfig = root.load_config()?;
|
||||
let pool = config.connect().await?;
|
||||
|
||||
// Run pending migrations
|
||||
MIGRATOR
|
||||
.run(&pool)
|
||||
.await
|
||||
.context("could not run migrations")?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
103
crates/cli/src/main.rs
Normal file
103
crates/cli/src/main.rs
Normal file
@ -0,0 +1,103 @@
|
||||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#![forbid(unsafe_code)]
|
||||
#![deny(clippy::all)]
|
||||
#![warn(clippy::pedantic)]
|
||||
#![allow(clippy::module_name_repetitions)]
|
||||
#![allow(clippy::suspicious_else_formatting)]
|
||||
|
||||
use std::path::PathBuf;
|
||||
|
||||
use anyhow::Context;
|
||||
use clap::Clap;
|
||||
use mas_config::ConfigurationSection;
|
||||
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilter, Registry};
|
||||
|
||||
use self::{
|
||||
config::ConfigCommand, database::DatabaseCommand, manage::ManageCommand, server::ServerCommand,
|
||||
};
|
||||
|
||||
mod config;
|
||||
mod database;
|
||||
mod manage;
|
||||
mod server;
|
||||
|
||||
#[derive(Clap, Debug)]
|
||||
enum Subcommand {
|
||||
/// Configuration-related commands
|
||||
Config(ConfigCommand),
|
||||
|
||||
/// Manage the database
|
||||
Database(DatabaseCommand),
|
||||
|
||||
/// Runs the web server
|
||||
Server(ServerCommand),
|
||||
|
||||
/// Manage the instance
|
||||
Manage(ManageCommand),
|
||||
}
|
||||
|
||||
#[derive(Clap, Debug)]
|
||||
struct RootCommand {
|
||||
/// Path to the configuration file
|
||||
#[clap(short, long, global = true, default_value = "config.yaml")]
|
||||
config: PathBuf,
|
||||
|
||||
#[clap(subcommand)]
|
||||
subcommand: Option<Subcommand>,
|
||||
}
|
||||
|
||||
impl RootCommand {
|
||||
async fn run(&self) -> anyhow::Result<()> {
|
||||
use Subcommand as S;
|
||||
match &self.subcommand {
|
||||
Some(S::Config(c)) => c.run(self).await,
|
||||
Some(S::Database(c)) => c.run(self).await,
|
||||
Some(S::Server(c)) => c.run(self).await,
|
||||
Some(S::Manage(c)) => c.run(self).await,
|
||||
None => ServerCommand::default().run(self).await,
|
||||
}
|
||||
}
|
||||
|
||||
fn load_config<'de, T: ConfigurationSection<'de>>(&self) -> anyhow::Result<T> {
|
||||
T::load_from_file(&self.config).context("could not load configuration")
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
// Load environment variables from .env files
|
||||
if let Err(e) = dotenv::dotenv() {
|
||||
// Display the error if it is something other than the .env file not existing
|
||||
if !e.not_found() {
|
||||
return Err(e).context("could not load .env file");
|
||||
}
|
||||
}
|
||||
|
||||
// Setup logging & tracing
|
||||
let fmt_layer = tracing_subscriber::fmt::layer().with_writer(std::io::stderr);
|
||||
let filter_layer = EnvFilter::try_from_default_env().or_else(|_| EnvFilter::try_new("info"))?;
|
||||
|
||||
let subscriber = Registry::default().with(filter_layer).with(fmt_layer);
|
||||
subscriber
|
||||
.try_init()
|
||||
.context("could not initialize logging")?;
|
||||
|
||||
// Parse the CLI arguments
|
||||
let opts = RootCommand::parse();
|
||||
|
||||
// And run the command
|
||||
opts.run().await
|
||||
}
|
59
crates/cli/src/manage.rs
Normal file
59
crates/cli/src/manage.rs
Normal file
@ -0,0 +1,59 @@
|
||||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use argon2::Argon2;
|
||||
use clap::Clap;
|
||||
use mas_config::DatabaseConfig;
|
||||
use mas_core::storage::register_user;
|
||||
use tracing::{info, warn};
|
||||
|
||||
use super::RootCommand;
|
||||
|
||||
#[derive(Clap, Debug)]
|
||||
pub(super) struct ManageCommand {
|
||||
#[clap(subcommand)]
|
||||
subcommand: ManageSubcommand,
|
||||
}
|
||||
|
||||
#[derive(Clap, Debug)]
|
||||
enum ManageSubcommand {
|
||||
/// Register a new user
|
||||
Register { username: String, password: String },
|
||||
|
||||
/// List active users
|
||||
Users,
|
||||
}
|
||||
|
||||
impl ManageCommand {
|
||||
pub async fn run(&self, root: &RootCommand) -> anyhow::Result<()> {
|
||||
use ManageSubcommand as SC;
|
||||
match &self.subcommand {
|
||||
SC::Register { username, password } => {
|
||||
let config: DatabaseConfig = root.load_config()?;
|
||||
let pool = config.connect().await?;
|
||||
let hasher = Argon2::default();
|
||||
|
||||
let user = register_user(&pool, hasher, username, password).await?;
|
||||
info!(?user, "User registered");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
SC::Users => {
|
||||
warn!("Not implemented yet");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
93
crates/cli/src/server.rs
Normal file
93
crates/cli/src/server.rs
Normal file
@ -0,0 +1,93 @@
|
||||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::{
|
||||
net::{SocketAddr, TcpListener},
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
use anyhow::Context;
|
||||
use clap::Clap;
|
||||
use hyper::{header, Server};
|
||||
use mas_config::RootConfig;
|
||||
use mas_core::{
|
||||
tasks::{self, TaskQueue},
|
||||
templates::Templates,
|
||||
};
|
||||
use tower::{make::Shared, ServiceBuilder};
|
||||
use tower_http::{
|
||||
compression::CompressionLayer,
|
||||
sensitive_headers::SetSensitiveHeadersLayer,
|
||||
trace::{DefaultMakeSpan, DefaultOnResponse, TraceLayer},
|
||||
LatencyUnit,
|
||||
};
|
||||
|
||||
use super::RootCommand;
|
||||
|
||||
#[derive(Clap, Debug, Default)]
|
||||
pub(super) struct ServerCommand;
|
||||
|
||||
impl ServerCommand {
|
||||
pub async fn run(&self, root: &RootCommand) -> anyhow::Result<()> {
|
||||
let config: RootConfig = root.load_config()?;
|
||||
|
||||
let addr: SocketAddr = config.http.address.parse()?;
|
||||
let listener = TcpListener::bind(addr)?;
|
||||
|
||||
// Connect to the database
|
||||
let pool = config.database.connect().await?;
|
||||
|
||||
// Load and compile the templates
|
||||
let templates = Templates::load().context("could not load templates")?;
|
||||
|
||||
// Start the server
|
||||
let root = mas_core::handlers::root(&pool, &templates, &config);
|
||||
|
||||
let queue = TaskQueue::default();
|
||||
queue.recuring(Duration::from_secs(15), tasks::cleanup_expired(&pool));
|
||||
queue.start();
|
||||
|
||||
let warp_service = warp::service(root);
|
||||
|
||||
let service = ServiceBuilder::new()
|
||||
// Add high level tracing/logging to all requests
|
||||
.layer(
|
||||
TraceLayer::new_for_http()
|
||||
.make_span_with(DefaultMakeSpan::new().include_headers(true))
|
||||
.on_response(
|
||||
DefaultOnResponse::new()
|
||||
.include_headers(true)
|
||||
.latency_unit(LatencyUnit::Micros),
|
||||
),
|
||||
)
|
||||
// Set a timeout
|
||||
.timeout(Duration::from_secs(10))
|
||||
// Compress responses
|
||||
.layer(CompressionLayer::new())
|
||||
// Mark the `Authorization` and `Cookie` headers as sensitive so it doesn't show in logs
|
||||
.layer(SetSensitiveHeadersLayer::new(vec![
|
||||
header::AUTHORIZATION,
|
||||
header::COOKIE,
|
||||
]))
|
||||
.service(warp_service);
|
||||
|
||||
tracing::info!("Listening on http://{}", listener.local_addr().unwrap());
|
||||
|
||||
Server::from_tcp(listener)?
|
||||
.serve(Shared::new(service))
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
38
crates/config/Cargo.toml
Normal file
38
crates/config/Cargo.toml
Normal file
@ -0,0 +1,38 @@
|
||||
[package]
|
||||
name = "mas-config"
|
||||
version = "0.1.0"
|
||||
authors = ["Quentin Gliech <quenting@element.io>"]
|
||||
edition = "2018"
|
||||
license = "Apache-2.0"
|
||||
|
||||
[dependencies]
|
||||
tokio = { version = "1.11.0", features = [] }
|
||||
tracing = "0.1.27"
|
||||
async-trait = "0.1.51"
|
||||
|
||||
thiserror = "1.0.29"
|
||||
anyhow = "1.0.44"
|
||||
|
||||
schemars = { version = "0.8.3", features = ["url", "chrono"] }
|
||||
figment = { version = "0.10.6", features = ["env", "yaml", "test"] }
|
||||
chrono = { version = "0.4.19", features = ["serde"] }
|
||||
url = { version = "2.2.2", features = ["serde"] }
|
||||
|
||||
serde = { version = "1.0.130", features = ["derive"] }
|
||||
serde_with = { version = "1.10.0", features = ["hex", "chrono"] }
|
||||
serde_json = "1.0.68"
|
||||
sqlx = { version = "0.5.7", features = ["runtime-tokio-rustls", "postgres"] }
|
||||
|
||||
rand = "0.8.4"
|
||||
rsa = "0.5.0"
|
||||
k256 = "0.9.6"
|
||||
pkcs8 = { version = "0.7.6", features = ["pem"] }
|
||||
elliptic-curve = { version = "0.10.6", features = ["pem"] }
|
||||
|
||||
indoc = "1.0.3"
|
||||
|
||||
[dependencies.jwt-compact]
|
||||
# Waiting on the next release because of the bump of the `rsa` dependency
|
||||
git = "https://github.com/slowli/jwt-compact.git"
|
||||
rev = "7a6dee6824c1d4e7c7f81019c9a968e5c9e44923"
|
||||
features = ["rsa", "k256"]
|
49
crates/config/src/cookies.rs
Normal file
49
crates/config/src/cookies.rs
Normal file
@ -0,0 +1,49 @@
|
||||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use async_trait::async_trait;
|
||||
use schemars::{gen::SchemaGenerator, schema::Schema, JsonSchema};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_with::serde_as;
|
||||
|
||||
use super::ConfigurationSection;
|
||||
|
||||
fn secret_schema(gen: &mut SchemaGenerator) -> Schema {
|
||||
String::json_schema(gen)
|
||||
}
|
||||
|
||||
#[serde_as]
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct CookiesConfig {
|
||||
#[schemars(schema_with = "secret_schema")]
|
||||
#[serde_as(as = "serde_with::hex::Hex")]
|
||||
pub secret: [u8; 32],
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ConfigurationSection<'_> for CookiesConfig {
|
||||
fn path() -> &'static str {
|
||||
"cookies"
|
||||
}
|
||||
|
||||
async fn generate() -> anyhow::Result<Self> {
|
||||
Ok(Self {
|
||||
secret: rand::random(),
|
||||
})
|
||||
}
|
||||
|
||||
fn test() -> Self {
|
||||
Self { secret: [0xEA; 32] }
|
||||
}
|
||||
}
|
85
crates/config/src/csrf.rs
Normal file
85
crates/config/src/csrf.rs
Normal file
@ -0,0 +1,85 @@
|
||||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use async_trait::async_trait;
|
||||
use chrono::Duration;
|
||||
use schemars::{gen::SchemaGenerator, schema::Schema, JsonSchema};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_with::serde_as;
|
||||
|
||||
use super::ConfigurationSection;
|
||||
|
||||
fn default_ttl() -> Duration {
|
||||
Duration::hours(1)
|
||||
}
|
||||
|
||||
fn ttl_schema(gen: &mut SchemaGenerator) -> Schema {
|
||||
u64::json_schema(gen)
|
||||
}
|
||||
|
||||
#[serde_as]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct CsrfConfig {
|
||||
#[schemars(schema_with = "ttl_schema")]
|
||||
#[serde(default = "default_ttl")]
|
||||
#[serde_as(as = "serde_with::DurationSeconds<i64>")]
|
||||
pub ttl: Duration,
|
||||
}
|
||||
|
||||
impl Default for CsrfConfig {
|
||||
fn default() -> Self {
|
||||
Self { ttl: default_ttl() }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ConfigurationSection<'_> for CsrfConfig {
|
||||
fn path() -> &'static str {
|
||||
"csrf"
|
||||
}
|
||||
|
||||
async fn generate() -> anyhow::Result<Self> {
|
||||
Ok(Self::default())
|
||||
}
|
||||
|
||||
fn test() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use figment::Jail;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn load_config() {
|
||||
Jail::expect_with(|jail| {
|
||||
jail.create_file(
|
||||
"config.yaml",
|
||||
r#"
|
||||
csrf:
|
||||
ttl: 1800
|
||||
"#,
|
||||
)?;
|
||||
|
||||
let config = CsrfConfig::load_from_file("config.yaml")?;
|
||||
|
||||
assert_eq!(config.ttl, Duration::minutes(30));
|
||||
|
||||
Ok(())
|
||||
});
|
||||
}
|
||||
}
|
166
crates/config/src/database.rs
Normal file
166
crates/config/src/database.rs
Normal file
@ -0,0 +1,166 @@
|
||||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::time::Duration;
|
||||
|
||||
use anyhow::Context;
|
||||
use async_trait::async_trait;
|
||||
use schemars::{gen::SchemaGenerator, schema::Schema, JsonSchema};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_with::{serde_as, skip_serializing_none};
|
||||
use sqlx::postgres::{PgConnectOptions, PgPool, PgPoolOptions};
|
||||
|
||||
// FIXME
|
||||
// use sqlx::ConnectOptions
|
||||
// use tracing::log::LevelFilter;
|
||||
use super::ConfigurationSection;
|
||||
|
||||
fn default_uri() -> String {
|
||||
"postgresql://".to_string()
|
||||
}
|
||||
|
||||
fn default_max_connections() -> u32 {
|
||||
10
|
||||
}
|
||||
|
||||
fn default_connect_timeout() -> Duration {
|
||||
Duration::from_secs(30)
|
||||
}
|
||||
|
||||
#[allow(clippy::unnecessary_wraps)]
|
||||
fn default_idle_timeout() -> Option<Duration> {
|
||||
Some(Duration::from_secs(10 * 60))
|
||||
}
|
||||
|
||||
#[allow(clippy::unnecessary_wraps)]
|
||||
fn default_max_lifetime() -> Option<Duration> {
|
||||
Some(Duration::from_secs(30 * 60))
|
||||
}
|
||||
|
||||
impl Default for DatabaseConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
uri: default_uri(),
|
||||
max_connections: default_max_connections(),
|
||||
min_connections: Default::default(),
|
||||
connect_timeout: default_connect_timeout(),
|
||||
idle_timeout: default_idle_timeout(),
|
||||
max_lifetime: default_max_lifetime(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn duration_schema(gen: &mut SchemaGenerator) -> Schema {
|
||||
Option::<u64>::json_schema(gen)
|
||||
}
|
||||
|
||||
fn optional_duration_schema(gen: &mut SchemaGenerator) -> Schema {
|
||||
u64::json_schema(gen)
|
||||
}
|
||||
|
||||
#[serde_as]
|
||||
#[skip_serializing_none]
|
||||
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct DatabaseConfig {
|
||||
#[serde(default = "default_uri")]
|
||||
uri: String,
|
||||
|
||||
#[serde(default = "default_max_connections")]
|
||||
max_connections: u32,
|
||||
|
||||
#[serde(default)]
|
||||
min_connections: u32,
|
||||
|
||||
#[schemars(schema_with = "duration_schema")]
|
||||
#[serde(default = "default_connect_timeout")]
|
||||
#[serde_as(as = "serde_with::DurationSeconds<u64>")]
|
||||
connect_timeout: Duration,
|
||||
|
||||
#[schemars(schema_with = "optional_duration_schema")]
|
||||
#[serde(default = "default_idle_timeout")]
|
||||
#[serde_as(as = "Option<serde_with::DurationSeconds<u64>>")]
|
||||
idle_timeout: Option<Duration>,
|
||||
|
||||
#[schemars(schema_with = "optional_duration_schema")]
|
||||
#[serde(default = "default_max_lifetime")]
|
||||
#[serde_as(as = "Option<serde_with::DurationSeconds<u64>>")]
|
||||
max_lifetime: Option<Duration>,
|
||||
}
|
||||
|
||||
impl DatabaseConfig {
|
||||
#[tracing::instrument(err)]
|
||||
pub async fn connect(&self) -> anyhow::Result<PgPool> {
|
||||
let options = self
|
||||
.uri
|
||||
.parse::<PgConnectOptions>()
|
||||
.context("invalid database URL")?
|
||||
.application_name("matrix-authentication-service");
|
||||
|
||||
// FIXME
|
||||
// options
|
||||
// .log_statements(LevelFilter::Debug)
|
||||
// .log_slow_statements(LevelFilter::Warn, Duration::from_millis(100));
|
||||
|
||||
PgPoolOptions::new()
|
||||
.max_connections(self.max_connections)
|
||||
.min_connections(self.min_connections)
|
||||
.connect_timeout(self.connect_timeout)
|
||||
.idle_timeout(self.idle_timeout)
|
||||
.max_lifetime(self.max_lifetime)
|
||||
.connect_with(options)
|
||||
.await
|
||||
.context("could not connect to the database")
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ConfigurationSection<'_> for DatabaseConfig {
|
||||
fn path() -> &'static str {
|
||||
"database"
|
||||
}
|
||||
|
||||
async fn generate() -> anyhow::Result<Self> {
|
||||
Ok(Self::default())
|
||||
}
|
||||
|
||||
fn test() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use figment::Jail;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn load_config() {
|
||||
Jail::expect_with(|jail| {
|
||||
jail.create_file(
|
||||
"config.yaml",
|
||||
r#"
|
||||
database:
|
||||
uri: postgresql://user:password@host/database
|
||||
"#,
|
||||
)?;
|
||||
|
||||
let config = DatabaseConfig::load_from_file("config.yaml")?;
|
||||
|
||||
assert_eq!(config.uri, "postgresql://user:password@host/database");
|
||||
|
||||
Ok(())
|
||||
});
|
||||
}
|
||||
}
|
52
crates/config/src/http.rs
Normal file
52
crates/config/src/http.rs
Normal file
@ -0,0 +1,52 @@
|
||||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use async_trait::async_trait;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use super::ConfigurationSection;
|
||||
|
||||
fn default_http_address() -> String {
|
||||
"[::]:8080".into()
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct HttpConfig {
|
||||
#[serde(default = "default_http_address")]
|
||||
pub address: String,
|
||||
}
|
||||
|
||||
impl Default for HttpConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
address: default_http_address(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ConfigurationSection<'_> for HttpConfig {
|
||||
fn path() -> &'static str {
|
||||
"http"
|
||||
}
|
||||
|
||||
async fn generate() -> anyhow::Result<Self> {
|
||||
Ok(Self::default())
|
||||
}
|
||||
|
||||
fn test() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
}
|
76
crates/config/src/lib.rs
Normal file
76
crates/config/src/lib.rs
Normal file
@ -0,0 +1,76 @@
|
||||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use async_trait::async_trait;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
mod cookies;
|
||||
mod csrf;
|
||||
mod database;
|
||||
mod http;
|
||||
mod oauth2;
|
||||
mod util;
|
||||
|
||||
pub use self::{
|
||||
cookies::CookiesConfig,
|
||||
csrf::CsrfConfig,
|
||||
database::DatabaseConfig,
|
||||
http::HttpConfig,
|
||||
oauth2::{Algorithm, KeySet, OAuth2ClientConfig, OAuth2Config},
|
||||
util::ConfigurationSection,
|
||||
};
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct RootConfig {
|
||||
pub oauth2: OAuth2Config,
|
||||
|
||||
#[serde(default)]
|
||||
pub http: HttpConfig,
|
||||
|
||||
#[serde(default)]
|
||||
pub database: DatabaseConfig,
|
||||
|
||||
pub cookies: CookiesConfig,
|
||||
|
||||
#[serde(default)]
|
||||
pub csrf: CsrfConfig,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ConfigurationSection<'_> for RootConfig {
|
||||
fn path() -> &'static str {
|
||||
""
|
||||
}
|
||||
|
||||
async fn generate() -> anyhow::Result<Self> {
|
||||
Ok(Self {
|
||||
oauth2: OAuth2Config::generate().await?,
|
||||
http: HttpConfig::generate().await?,
|
||||
database: DatabaseConfig::generate().await?,
|
||||
cookies: CookiesConfig::generate().await?,
|
||||
csrf: CsrfConfig::generate().await?,
|
||||
})
|
||||
}
|
||||
|
||||
fn test() -> Self {
|
||||
Self {
|
||||
oauth2: OAuth2Config::test(),
|
||||
http: HttpConfig::test(),
|
||||
database: DatabaseConfig::test(),
|
||||
cookies: CookiesConfig::test(),
|
||||
csrf: CsrfConfig::test(),
|
||||
}
|
||||
}
|
||||
}
|
476
crates/config/src/oauth2.rs
Normal file
476
crates/config/src/oauth2.rs
Normal file
@ -0,0 +1,476 @@
|
||||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::convert::TryFrom;
|
||||
|
||||
use anyhow::Context;
|
||||
use async_trait::async_trait;
|
||||
use jwt_compact::{
|
||||
alg::{self, StrongAlg, StrongKey},
|
||||
jwk::JsonWebKey,
|
||||
AlgorithmExt, Claims, Header,
|
||||
};
|
||||
use pkcs8::{FromPrivateKey, ToPrivateKey};
|
||||
use rsa::RsaPrivateKey;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{
|
||||
de::{MapAccess, Visitor},
|
||||
ser::SerializeStruct,
|
||||
Deserialize, Serialize,
|
||||
};
|
||||
use serde_with::skip_serializing_none;
|
||||
use thiserror::Error;
|
||||
use tokio::task;
|
||||
use tracing::info;
|
||||
use url::Url;
|
||||
|
||||
use super::ConfigurationSection;
|
||||
|
||||
// TODO: a lot of the signing logic should go out somewhere else
|
||||
|
||||
const RS256: StrongAlg<alg::Rsa> = StrongAlg(alg::Rsa::rs256());
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone, Copy, Debug)]
|
||||
#[serde(rename_all = "UPPERCASE")]
|
||||
pub enum Algorithm {
|
||||
Rs256,
|
||||
Es256k,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Clone)]
|
||||
pub struct Jwk {
|
||||
kid: String,
|
||||
alg: Algorithm,
|
||||
|
||||
#[serde(flatten)]
|
||||
inner: serde_json::Value,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Clone)]
|
||||
pub struct Jwks {
|
||||
keys: Vec<Jwk>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
#[serde(transparent)]
|
||||
pub struct KeySet(Vec<Key>);
|
||||
|
||||
impl KeySet {
|
||||
pub fn to_public_jwks(&self) -> Jwks {
|
||||
let keys = self.0.iter().map(Key::to_public_jwk).collect();
|
||||
Jwks { keys }
|
||||
}
|
||||
|
||||
#[tracing::instrument(err)]
|
||||
pub async fn token<T>(
|
||||
&self,
|
||||
alg: Algorithm,
|
||||
header: Header,
|
||||
claims: Claims<T>,
|
||||
) -> anyhow::Result<String>
|
||||
where
|
||||
T: std::fmt::Debug + Serialize + Send + Sync + 'static,
|
||||
{
|
||||
match alg {
|
||||
Algorithm::Rs256 => {
|
||||
let (kid, key) = self
|
||||
.0
|
||||
.iter()
|
||||
.find_map(Key::rsa)
|
||||
.context("could not find RSA key")?;
|
||||
let header = header.with_key_id(kid);
|
||||
|
||||
// TODO: store them as strong keys
|
||||
let key = StrongKey::try_from(key.clone())?;
|
||||
task::spawn_blocking(move || {
|
||||
RS256
|
||||
.token(header, &claims, &key)
|
||||
.context("failed to sign token")
|
||||
})
|
||||
.await?
|
||||
}
|
||||
Algorithm::Es256k => {
|
||||
// TODO: make this const with lazy_static?
|
||||
let es256k: alg::Es256k = alg::Es256k::default();
|
||||
let (kid, key) = self
|
||||
.0
|
||||
.iter()
|
||||
.find_map(Key::ecdsa)
|
||||
.context("could not find ECDSA key")?;
|
||||
let key = k256::ecdsa::SigningKey::from(key);
|
||||
let header = header.with_key_id(kid);
|
||||
// TODO: use StrongAlg
|
||||
|
||||
task::spawn_blocking(move || {
|
||||
es256k
|
||||
.token(header, &claims, &key)
|
||||
.context("failed to sign token")
|
||||
})
|
||||
.await?
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
#[non_exhaustive]
|
||||
pub enum Key {
|
||||
Rsa { key: RsaPrivateKey, kid: String },
|
||||
Ecdsa { key: k256::SecretKey, kid: String },
|
||||
}
|
||||
|
||||
impl Key {
|
||||
fn from_ecdsa(key: k256::SecretKey) -> Self {
|
||||
// TODO: hash the key and use as KID
|
||||
let kid = String::from("ecdsa-kid");
|
||||
Self::Ecdsa { kid, key }
|
||||
}
|
||||
|
||||
fn from_ecdsa_pem(key: &str) -> anyhow::Result<Self> {
|
||||
let key = k256::SecretKey::from_pkcs8_pem(key)?;
|
||||
Ok(Self::from_ecdsa(key))
|
||||
}
|
||||
|
||||
fn from_rsa(key: RsaPrivateKey) -> Self {
|
||||
// TODO: hash the key and use as KID
|
||||
let kid = String::from("rsa-kid");
|
||||
Self::Rsa { kid, key }
|
||||
}
|
||||
|
||||
fn from_rsa_pem(key: &str) -> anyhow::Result<Self> {
|
||||
let key = RsaPrivateKey::from_pkcs8_pem(key)?;
|
||||
Ok(Self::from_rsa(key))
|
||||
}
|
||||
|
||||
fn to_public_jwk(&self) -> Jwk {
|
||||
match self {
|
||||
Key::Rsa { key, kid } => {
|
||||
let pubkey = key.to_public_key();
|
||||
let inner = JsonWebKey::from(&pubkey);
|
||||
let inner = serde_json::to_value(&inner).unwrap();
|
||||
let kid = kid.to_string();
|
||||
let alg = Algorithm::Rs256;
|
||||
Jwk { kid, alg, inner }
|
||||
}
|
||||
Key::Ecdsa { key, kid } => {
|
||||
let pubkey = k256::ecdsa::VerifyingKey::from(key.public_key());
|
||||
let inner = JsonWebKey::from(&pubkey);
|
||||
let inner = serde_json::to_value(&inner).unwrap();
|
||||
let kid = kid.to_string();
|
||||
let alg = Algorithm::Es256k;
|
||||
Jwk { kid, alg, inner }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn rsa(&self) -> Option<(&str, &RsaPrivateKey)> {
|
||||
match self {
|
||||
Key::Rsa { key, kid } => Some((kid, key)),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
fn ecdsa(&self) -> Option<(&str, &k256::SecretKey)> {
|
||||
match self {
|
||||
Key::Ecdsa { key, kid } => Some((kid, key)),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Serialize for Key {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: serde::Serializer,
|
||||
{
|
||||
let mut map = serializer.serialize_struct("Key", 2)?;
|
||||
match self {
|
||||
Key::Rsa { key, kid: _ } => {
|
||||
map.serialize_field("type", "rsa")?;
|
||||
let pem = key.to_pkcs8_pem().map_err(serde::ser::Error::custom)?;
|
||||
map.serialize_field("key", pem.as_str())?;
|
||||
}
|
||||
Key::Ecdsa { key, kid: _ } => {
|
||||
map.serialize_field("type", "ecdsa")?;
|
||||
let pem = key.to_pkcs8_pem().map_err(serde::ser::Error::custom)?;
|
||||
map.serialize_field("key", pem.as_str())?;
|
||||
}
|
||||
}
|
||||
|
||||
map.end()
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> Deserialize<'de> for Key {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: serde::Deserializer<'de>,
|
||||
{
|
||||
#[derive(Deserialize, Debug)]
|
||||
#[serde(field_identifier, rename_all = "lowercase")]
|
||||
enum Field {
|
||||
Type,
|
||||
Key,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
enum KeyType {
|
||||
Rsa,
|
||||
Ecdsa,
|
||||
}
|
||||
|
||||
struct KeyVisitor;
|
||||
|
||||
impl<'de> Visitor<'de> for KeyVisitor {
|
||||
type Value = Key;
|
||||
|
||||
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
formatter.write_str("struct Key")
|
||||
}
|
||||
|
||||
fn visit_map<V>(self, mut map: V) -> Result<Key, V::Error>
|
||||
where
|
||||
V: MapAccess<'de>,
|
||||
{
|
||||
let mut key_type = None;
|
||||
let mut key_key = None;
|
||||
while let Some(key) = map.next_key()? {
|
||||
match key {
|
||||
Field::Type => {
|
||||
if key_type.is_some() {
|
||||
return Err(serde::de::Error::duplicate_field("type"));
|
||||
}
|
||||
key_type = Some(map.next_value()?);
|
||||
}
|
||||
Field::Key => {
|
||||
if key_key.is_some() {
|
||||
return Err(serde::de::Error::duplicate_field("key"));
|
||||
}
|
||||
key_key = Some(map.next_value()?);
|
||||
}
|
||||
}
|
||||
}
|
||||
let key_type: KeyType =
|
||||
key_type.ok_or_else(|| serde::de::Error::missing_field("type"))?;
|
||||
let key_key: String =
|
||||
key_key.ok_or_else(|| serde::de::Error::missing_field("key"))?;
|
||||
|
||||
match key_type {
|
||||
KeyType::Rsa => Key::from_rsa_pem(&key_key).map_err(serde::de::Error::custom),
|
||||
KeyType::Ecdsa => {
|
||||
Key::from_ecdsa_pem(&key_key).map_err(serde::de::Error::custom)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
deserializer.deserialize_struct("Key", &["type", "key"], KeyVisitor)
|
||||
}
|
||||
}
|
||||
|
||||
#[skip_serializing_none]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct OAuth2ClientConfig {
|
||||
pub client_id: String,
|
||||
|
||||
#[serde(default)]
|
||||
pub client_secret: Option<String>,
|
||||
|
||||
#[serde(default)]
|
||||
pub redirect_uris: Vec<Url>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
#[error("Invalid redirect URI")]
|
||||
pub struct InvalidRedirectUriError;
|
||||
|
||||
impl OAuth2ClientConfig {
|
||||
pub fn resolve_redirect_uri<'a>(
|
||||
&'a self,
|
||||
suggested_uri: &'a Option<Url>,
|
||||
) -> Result<&'a Url, InvalidRedirectUriError> {
|
||||
suggested_uri.as_ref().map_or_else(
|
||||
|| self.redirect_uris.get(0).ok_or(InvalidRedirectUriError),
|
||||
|suggested_uri| self.check_redirect_uri(suggested_uri),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn check_redirect_uri<'a>(
|
||||
&self,
|
||||
redirect_uri: &'a Url,
|
||||
) -> Result<&'a Url, InvalidRedirectUriError> {
|
||||
if self.redirect_uris.contains(redirect_uri) {
|
||||
Ok(redirect_uri)
|
||||
} else {
|
||||
Err(InvalidRedirectUriError)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn default_oauth2_issuer() -> Url {
|
||||
"http://[::]:8080".parse().unwrap()
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct OAuth2Config {
|
||||
#[serde(default = "default_oauth2_issuer")]
|
||||
pub issuer: Url,
|
||||
|
||||
#[serde(default)]
|
||||
pub clients: Vec<OAuth2ClientConfig>,
|
||||
|
||||
#[schemars(with = "Vec<String>")] // TODO: this is a lie
|
||||
pub keys: KeySet,
|
||||
}
|
||||
|
||||
impl OAuth2Config {
|
||||
pub fn discovery_url(&self) -> Url {
|
||||
self.issuer
|
||||
.join(".well-known/openid-configuration")
|
||||
.expect("could not build discovery url")
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ConfigurationSection<'_> for OAuth2Config {
|
||||
fn path() -> &'static str {
|
||||
"oauth2"
|
||||
}
|
||||
|
||||
#[tracing::instrument]
|
||||
async fn generate() -> anyhow::Result<Self> {
|
||||
info!("Generating keys...");
|
||||
|
||||
let span = tracing::info_span!("rsa");
|
||||
let rsa_key = task::spawn_blocking(move || {
|
||||
let _entered = span.enter();
|
||||
let mut rng = rand::thread_rng();
|
||||
let ret =
|
||||
RsaPrivateKey::new(&mut rng, 2048).context("could not generate RSA private key");
|
||||
info!("Done generating RSA key");
|
||||
ret
|
||||
})
|
||||
.await
|
||||
.context("could not join blocking task")??;
|
||||
|
||||
let span = tracing::info_span!("ecdsa");
|
||||
let ecdsa_key = task::spawn_blocking(move || {
|
||||
let _entered = span.enter();
|
||||
let rng = rand::thread_rng();
|
||||
let ret = k256::SecretKey::random(rng);
|
||||
info!("Done generating ECDSA key");
|
||||
ret
|
||||
})
|
||||
.await
|
||||
.context("could not join blocking task")?;
|
||||
|
||||
Ok(Self {
|
||||
issuer: default_oauth2_issuer(),
|
||||
clients: Vec::new(),
|
||||
keys: KeySet(vec![Key::from_rsa(rsa_key), Key::from_ecdsa(ecdsa_key)]),
|
||||
})
|
||||
}
|
||||
|
||||
fn test() -> Self {
|
||||
let rsa_key = Key::from_rsa_pem(indoc::indoc! {r#"
|
||||
-----BEGIN PRIVATE KEY-----
|
||||
MIIBVQIBADANBgkqhkiG9w0BAQEFAASCAT8wggE7AgEAAkEAymS2RkeIZo7pUeEN
|
||||
QUGCG4GLJru5jzxomO9jiNr5D/oRcerhpQVc9aCpBfAAg4l4a1SmYdBzWqX0X5pU
|
||||
scgTtQIDAQABAkEArNIMlrxUK4bSklkCcXtXdtdKE9vuWfGyOw0GyAB69fkEUBxh
|
||||
3j65u+u3ZmW+bpMWHgp1FtdobE9nGwb2VBTWAQIhAOyU1jiUEkrwKK004+6b5QRE
|
||||
vC9UI2vDWy5vioMNx5Y1AiEA2wGAJ6ETF8FF2Vd+kZlkKK7J0em9cl0gbJDsWIEw
|
||||
N4ECIEyWYkMurD1WQdTQqnk0Po+DMOihdFYOiBYgRdbnPxWBAiEAmtd0xJAd7622
|
||||
tPQniMnrBtiN2NxqFXHCev/8Gpc8gAECIBcaPcF59qVeRmYrfqzKBxFm7LmTwlAl
|
||||
Gh7BNzCeN+D6
|
||||
-----END PRIVATE KEY-----
|
||||
"#})
|
||||
.unwrap();
|
||||
let ecdsa_key = Key::from_ecdsa_pem(indoc::indoc! {r#"
|
||||
-----BEGIN PRIVATE KEY-----
|
||||
MIGEAgEAMBAGByqGSM49AgEGBSuBBAAKBG0wawIBAQQgqfn5mYO/5Qq/wOOiWgHA
|
||||
NaiDiepgUJ2GI5eq2V8D8nahRANCAARMK9aKUd/H28qaU+0qvS6bSJItzAge1VHn
|
||||
OhBAAUVci1RpmUA+KdCL5sw9nadAEiONeiGr+28RYHZmlB9qXnjC
|
||||
-----END PRIVATE KEY-----
|
||||
"#})
|
||||
.unwrap();
|
||||
|
||||
Self {
|
||||
issuer: default_oauth2_issuer(),
|
||||
clients: Vec::new(),
|
||||
keys: KeySet(vec![rsa_key, ecdsa_key]),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use figment::Jail;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn load_config() {
|
||||
Jail::expect_with(|jail| {
|
||||
jail.create_file(
|
||||
"config.yaml",
|
||||
r#"
|
||||
oauth2:
|
||||
keys:
|
||||
- type: rsa
|
||||
key: |
|
||||
-----BEGIN PRIVATE KEY-----
|
||||
MIIBVQIBADANBgkqhkiG9w0BAQEFAASCAT8wggE7AgEAAkEAymS2RkeIZo7pUeEN
|
||||
QUGCG4GLJru5jzxomO9jiNr5D/oRcerhpQVc9aCpBfAAg4l4a1SmYdBzWqX0X5pU
|
||||
scgTtQIDAQABAkEArNIMlrxUK4bSklkCcXtXdtdKE9vuWfGyOw0GyAB69fkEUBxh
|
||||
3j65u+u3ZmW+bpMWHgp1FtdobE9nGwb2VBTWAQIhAOyU1jiUEkrwKK004+6b5QRE
|
||||
vC9UI2vDWy5vioMNx5Y1AiEA2wGAJ6ETF8FF2Vd+kZlkKK7J0em9cl0gbJDsWIEw
|
||||
N4ECIEyWYkMurD1WQdTQqnk0Po+DMOihdFYOiBYgRdbnPxWBAiEAmtd0xJAd7622
|
||||
tPQniMnrBtiN2NxqFXHCev/8Gpc8gAECIBcaPcF59qVeRmYrfqzKBxFm7LmTwlAl
|
||||
Gh7BNzCeN+D6
|
||||
-----END PRIVATE KEY-----
|
||||
- type: ecdsa
|
||||
key: |
|
||||
-----BEGIN PRIVATE KEY-----
|
||||
MIGEAgEAMBAGByqGSM49AgEGBSuBBAAKBG0wawIBAQQgqfn5mYO/5Qq/wOOiWgHA
|
||||
NaiDiepgUJ2GI5eq2V8D8nahRANCAARMK9aKUd/H28qaU+0qvS6bSJItzAge1VHn
|
||||
OhBAAUVci1RpmUA+KdCL5sw9nadAEiONeiGr+28RYHZmlB9qXnjC
|
||||
-----END PRIVATE KEY-----
|
||||
issuer: https://example.com
|
||||
clients:
|
||||
- client_id: hello
|
||||
redirect_uris:
|
||||
- https://exemple.fr/callback
|
||||
- client_id: world
|
||||
"#,
|
||||
)?;
|
||||
|
||||
let config = OAuth2Config::load_from_file("config.yaml")?;
|
||||
|
||||
assert_eq!(config.issuer, "https://example.com".parse().unwrap());
|
||||
assert_eq!(config.clients.len(), 2);
|
||||
|
||||
assert_eq!(config.clients[0].client_id, "hello");
|
||||
assert_eq!(
|
||||
config.clients[0].redirect_uris,
|
||||
vec!["https://exemple.fr/callback".parse().unwrap()]
|
||||
);
|
||||
|
||||
assert_eq!(config.clients[1].client_id, "world");
|
||||
assert_eq!(config.clients[1].redirect_uris, Vec::new());
|
||||
|
||||
Ok(())
|
||||
});
|
||||
}
|
||||
}
|
72
crates/config/src/util.rs
Normal file
72
crates/config/src/util.rs
Normal file
@ -0,0 +1,72 @@
|
||||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::path::Path;
|
||||
|
||||
use anyhow::Context;
|
||||
use async_trait::async_trait;
|
||||
use figment::{
|
||||
error::Error as FigmentError,
|
||||
providers::{Env, Format, Serialized, Yaml},
|
||||
Figment, Profile,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[async_trait]
|
||||
/// Trait implemented by all configuration section to help loading specific part
|
||||
/// of the config and generate the sample config.
|
||||
pub trait ConfigurationSection<'a>: Sized + Deserialize<'a> + Serialize {
|
||||
/// Specify where this section should live relative to the root.
|
||||
fn path() -> &'static str;
|
||||
|
||||
/// Generate a sample configuration for this section.
|
||||
async fn generate() -> anyhow::Result<Self>;
|
||||
|
||||
/// Generate a sample configuration and override it with environment
|
||||
/// variables.
|
||||
///
|
||||
/// This is what backs the `config generate` subcommand, allowing to
|
||||
/// programatically generate a configuration file, e.g.
|
||||
///
|
||||
/// ```sh
|
||||
/// export MAS_OAUTH2_ISSUER=https://example.com/
|
||||
/// export MAS_HTTP_ADDRESS=127.0.0.1:1234
|
||||
/// matrix-authentication-service config generate
|
||||
/// ```
|
||||
async fn load_and_generate() -> anyhow::Result<Self> {
|
||||
let base = Self::generate()
|
||||
.await
|
||||
.context("could not generate configuration")?;
|
||||
|
||||
Figment::new()
|
||||
.merge(Serialized::from(&base, Profile::Default))
|
||||
.merge(Env::prefixed("MAS_").split("_"))
|
||||
.extract_inner(Self::path())
|
||||
.context("could not load configuration")
|
||||
}
|
||||
|
||||
/// Load configuration from a file and environment variables.
|
||||
fn load_from_file<P>(path: P) -> Result<Self, FigmentError>
|
||||
where
|
||||
P: AsRef<Path>,
|
||||
{
|
||||
Figment::new()
|
||||
.merge(Env::prefixed("MAS_").split("_"))
|
||||
.merge(Yaml::file(path))
|
||||
.extract_inner(Self::path())
|
||||
}
|
||||
|
||||
/// Generate config used in unit tests
|
||||
fn test() -> Self;
|
||||
}
|
77
crates/core/Cargo.toml
Normal file
77
crates/core/Cargo.toml
Normal file
@ -0,0 +1,77 @@
|
||||
[package]
|
||||
name = "mas-core"
|
||||
version = "0.1.0"
|
||||
authors = ["Quentin Gliech <quenting@element.io>"]
|
||||
edition = "2018"
|
||||
license = "Apache-2.0"
|
||||
|
||||
[dependencies]
|
||||
# Async runtime
|
||||
tokio = { version = "1.11.0", features = ["full"] }
|
||||
async-trait = "0.1.51"
|
||||
tokio-stream = "0.1.7"
|
||||
futures-util = "0.3.17"
|
||||
|
||||
# Logging and tracing
|
||||
tracing = "0.1.27"
|
||||
|
||||
# Error management
|
||||
thiserror = "1.0.29"
|
||||
anyhow = "1.0.44"
|
||||
|
||||
# Web server
|
||||
warp = "0.3.1"
|
||||
hyper = { version = "0.14.12", features = ["full"] }
|
||||
|
||||
# Template engine
|
||||
tera = "1.12.1"
|
||||
|
||||
# Database access
|
||||
sqlx = { version = "0.5.7", features = ["runtime-tokio-rustls", "postgres", "migrate", "chrono", "offline"] }
|
||||
|
||||
# Various structure (de)serialization
|
||||
serde = { version = "1.0.130", features = ["derive"] }
|
||||
serde_yaml = "0.8.21"
|
||||
serde_with = { version = "1.10.0", features = ["hex", "chrono"] }
|
||||
serde_json = "1.0.68"
|
||||
serde_urlencoded = "0.7.0"
|
||||
|
||||
# Argument & config parsing
|
||||
figment = { version = "0.10.6", features = ["env", "yaml", "test"] }
|
||||
schemars = { version = "0.8.3", features = ["url", "chrono"] }
|
||||
|
||||
# Password hashing
|
||||
argon2 = { version = "0.3.1", features = ["password-hash"] }
|
||||
password-hash = { version = "0.3.2", features = ["std"] }
|
||||
|
||||
# Crypto, hashing and signing stuff
|
||||
rsa = "0.5.0"
|
||||
k256 = "0.9.6"
|
||||
pkcs8 = { version = "0.7.6", features = ["pem"] }
|
||||
elliptic-curve = { version = "0.10.6", features = ["pem"] }
|
||||
chacha20poly1305 = { version = "0.9.0", features = ["std"] }
|
||||
sha2 = "0.9.8"
|
||||
crc = "2.0.0"
|
||||
|
||||
# Various data types and utilities
|
||||
data-encoding = "2.3.2"
|
||||
chrono = { version = "0.4.19", features = ["serde"] }
|
||||
url = { version = "2.2.2", features = ["serde"] }
|
||||
itertools = "0.10.1"
|
||||
mime = "0.3.16"
|
||||
rand = "0.8.4"
|
||||
bincode = "1.3.3"
|
||||
headers = "0.3.4"
|
||||
cookie = "0.15.1"
|
||||
|
||||
oauth2-types = { path = "../oauth2-types", features = ["sqlx_type"] }
|
||||
mas-config = { path = "../config" }
|
||||
|
||||
[dependencies.jwt-compact]
|
||||
# Waiting on the next release because of the bump of the `rsa` dependency
|
||||
git = "https://github.com/slowli/jwt-compact.git"
|
||||
rev = "7a6dee6824c1d4e7c7f81019c9a968e5c9e44923"
|
||||
features = ["rsa", "k256"]
|
||||
|
||||
[dev-dependencies]
|
||||
indoc = "1.0.3"
|
@ -0,0 +1,15 @@
|
||||
-- Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
--
|
||||
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||
-- you may not use this file except in compliance with the License.
|
||||
-- You may obtain a copy of the License at
|
||||
--
|
||||
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||
--
|
||||
-- Unless required by applicable law or agreed to in writing, software
|
||||
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
-- See the License for the specific language governing permissions and
|
||||
-- limitations under the License.
|
||||
|
||||
DROP FUNCTION IF EXISTS trigger_set_timestamp();
|
@ -0,0 +1,21 @@
|
||||
-- Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
--
|
||||
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||
-- you may not use this file except in compliance with the License.
|
||||
-- You may obtain a copy of the License at
|
||||
--
|
||||
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||
--
|
||||
-- Unless required by applicable law or agreed to in writing, software
|
||||
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
-- See the License for the specific language governing permissions and
|
||||
-- limitations under the License.
|
||||
|
||||
CREATE OR REPLACE FUNCTION trigger_set_timestamp()
|
||||
RETURNS TRIGGER AS $$
|
||||
BEGIN
|
||||
NEW.updated_at = NOW();
|
||||
RETURN NEW;
|
||||
END;
|
||||
$$ LANGUAGE plpgsql;
|
16
crates/core/migrations/20210716213724_users.down.sql
Normal file
16
crates/core/migrations/20210716213724_users.down.sql
Normal file
@ -0,0 +1,16 @@
|
||||
-- Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
--
|
||||
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||
-- you may not use this file except in compliance with the License.
|
||||
-- You may obtain a copy of the License at
|
||||
--
|
||||
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||
--
|
||||
-- Unless required by applicable law or agreed to in writing, software
|
||||
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
-- See the License for the specific language governing permissions and
|
||||
-- limitations under the License.
|
||||
|
||||
DROP TRIGGER set_timestamp ON users;
|
||||
DROP TABLE users;
|
26
crates/core/migrations/20210716213724_users.up.sql
Normal file
26
crates/core/migrations/20210716213724_users.up.sql
Normal file
@ -0,0 +1,26 @@
|
||||
-- Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
--
|
||||
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||
-- you may not use this file except in compliance with the License.
|
||||
-- You may obtain a copy of the License at
|
||||
--
|
||||
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||
--
|
||||
-- Unless required by applicable law or agreed to in writing, software
|
||||
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
-- See the License for the specific language governing permissions and
|
||||
-- limitations under the License.
|
||||
|
||||
CREATE TABLE users (
|
||||
"id" BIGSERIAL PRIMARY KEY,
|
||||
"username" TEXT NOT NULL UNIQUE,
|
||||
"hashed_password" TEXT NOT NULL,
|
||||
"created_at" TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT now(),
|
||||
"updated_at" TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT now()
|
||||
);
|
||||
|
||||
CREATE TRIGGER set_timestamp
|
||||
BEFORE UPDATE ON users
|
||||
FOR EACH ROW
|
||||
EXECUTE PROCEDURE trigger_set_timestamp();
|
17
crates/core/migrations/20210722072901_user_sessions.down.sql
Normal file
17
crates/core/migrations/20210722072901_user_sessions.down.sql
Normal file
@ -0,0 +1,17 @@
|
||||
-- Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
--
|
||||
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||
-- you may not use this file except in compliance with the License.
|
||||
-- You may obtain a copy of the License at
|
||||
--
|
||||
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||
--
|
||||
-- Unless required by applicable law or agreed to in writing, software
|
||||
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
-- See the License for the specific language governing permissions and
|
||||
-- limitations under the License.
|
||||
|
||||
DROP TRIGGER set_timestamp ON user_sessions;
|
||||
DROP TABLE user_session_authentications;
|
||||
DROP TABLE user_sessions;
|
35
crates/core/migrations/20210722072901_user_sessions.up.sql
Normal file
35
crates/core/migrations/20210722072901_user_sessions.up.sql
Normal file
@ -0,0 +1,35 @@
|
||||
-- Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
--
|
||||
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||
-- you may not use this file except in compliance with the License.
|
||||
-- You may obtain a copy of the License at
|
||||
--
|
||||
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||
--
|
||||
-- Unless required by applicable law or agreed to in writing, software
|
||||
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
-- See the License for the specific language governing permissions and
|
||||
-- limitations under the License.
|
||||
|
||||
-- A logged in session
|
||||
CREATE TABLE user_sessions (
|
||||
"id" BIGSERIAL PRIMARY KEY,
|
||||
"user_id" BIGINT NOT NULL REFERENCES users (id) ON DELETE CASCADE,
|
||||
"active" BOOLEAN NOT NULL DEFAULT TRUE,
|
||||
|
||||
"created_at" TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT now(),
|
||||
"updated_at" TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT now()
|
||||
);
|
||||
|
||||
CREATE TRIGGER set_timestamp
|
||||
BEFORE UPDATE ON user_sessions
|
||||
FOR EACH ROW
|
||||
EXECUTE PROCEDURE trigger_set_timestamp();
|
||||
|
||||
-- An authentication within a session
|
||||
CREATE TABLE user_session_authentications (
|
||||
"id" BIGSERIAL PRIMARY KEY,
|
||||
"session_id" BIGINT NOT NULL REFERENCES user_sessions (id) ON DELETE CASCADE,
|
||||
"created_at" TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT now()
|
||||
);
|
@ -0,0 +1,17 @@
|
||||
-- Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
--
|
||||
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||
-- you may not use this file except in compliance with the License.
|
||||
-- You may obtain a copy of the License at
|
||||
--
|
||||
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||
--
|
||||
-- Unless required by applicable law or agreed to in writing, software
|
||||
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
-- See the License for the specific language governing permissions and
|
||||
-- limitations under the License.
|
||||
|
||||
DROP TRIGGER set_timestamp ON oauth2_sessions;
|
||||
DROP TABLE oauth2_codes;
|
||||
DROP TABLE oauth2_sessions;
|
45
crates/core/migrations/20210731130515_oauth2_sessions.up.sql
Normal file
45
crates/core/migrations/20210731130515_oauth2_sessions.up.sql
Normal file
@ -0,0 +1,45 @@
|
||||
-- Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
--
|
||||
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||
-- you may not use this file except in compliance with the License.
|
||||
-- You may obtain a copy of the License at
|
||||
--
|
||||
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||
--
|
||||
-- Unless required by applicable law or agreed to in writing, software
|
||||
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
-- See the License for the specific language governing permissions and
|
||||
-- limitations under the License.
|
||||
|
||||
CREATE TABLE oauth2_sessions (
|
||||
"id" BIGSERIAL PRIMARY KEY,
|
||||
"user_session_id" BIGINT REFERENCES user_sessions (id) ON DELETE CASCADE,
|
||||
"client_id" TEXT NOT NULL,
|
||||
"redirect_uri" TEXT NOT NULL,
|
||||
"scope" TEXT NOT NULL,
|
||||
"state" TEXT,
|
||||
"nonce" TEXT,
|
||||
"max_age" INT,
|
||||
"response_type" TEXT NOT NULL,
|
||||
"response_mode" TEXT NOT NULL,
|
||||
|
||||
"created_at" TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT now(),
|
||||
"updated_at" TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT now()
|
||||
);
|
||||
|
||||
CREATE TRIGGER set_timestamp
|
||||
BEFORE UPDATE ON oauth2_sessions
|
||||
FOR EACH ROW
|
||||
EXECUTE PROCEDURE trigger_set_timestamp();
|
||||
|
||||
CREATE TABLE oauth2_codes (
|
||||
"id" BIGSERIAL PRIMARY KEY,
|
||||
"oauth2_session_id" BIGINT NOT NULL REFERENCES oauth2_sessions (id) ON DELETE CASCADE,
|
||||
"code" TEXT UNIQUE NOT NULL,
|
||||
"code_challenge_method" SMALLINT,
|
||||
"code_challenge" TEXT,
|
||||
|
||||
CHECK (("code_challenge" IS NULL AND "code_challenge_method" IS NULL)
|
||||
OR ("code_challenge" IS NOT NULL AND "code_challenge_method" IS NOT NULL))
|
||||
);
|
@ -0,0 +1,15 @@
|
||||
-- Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
--
|
||||
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||
-- you may not use this file except in compliance with the License.
|
||||
-- You may obtain a copy of the License at
|
||||
--
|
||||
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||
--
|
||||
-- Unless required by applicable law or agreed to in writing, software
|
||||
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
-- See the License for the specific language governing permissions and
|
||||
-- limitations under the License.
|
||||
|
||||
DROP TABLE oauth2_access_tokens;
|
@ -0,0 +1,23 @@
|
||||
-- Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
--
|
||||
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||
-- you may not use this file except in compliance with the License.
|
||||
-- You may obtain a copy of the License at
|
||||
--
|
||||
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||
--
|
||||
-- Unless required by applicable law or agreed to in writing, software
|
||||
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
-- See the License for the specific language governing permissions and
|
||||
-- limitations under the License.
|
||||
|
||||
CREATE TABLE oauth2_access_tokens (
|
||||
"id" BIGSERIAL PRIMARY KEY,
|
||||
"oauth2_session_id" BIGINT NOT NULL REFERENCES oauth2_sessions (id) ON DELETE CASCADE,
|
||||
|
||||
"token" TEXT UNIQUE NOT NULL,
|
||||
|
||||
"expires_after" INT NOT NULL,
|
||||
"created_at" TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT now()
|
||||
);
|
@ -0,0 +1,16 @@
|
||||
-- Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
--
|
||||
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||
-- you may not use this file except in compliance with the License.
|
||||
-- You may obtain a copy of the License at
|
||||
--
|
||||
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||
--
|
||||
-- Unless required by applicable law or agreed to in writing, software
|
||||
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
-- See the License for the specific language governing permissions and
|
||||
-- limitations under the License.
|
||||
|
||||
DROP TRIGGER set_timestamp ON oauth2_refresh_tokens;
|
||||
DROP TABLE oauth2_refresh_tokens;
|
@ -0,0 +1,30 @@
|
||||
-- Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
--
|
||||
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||
-- you may not use this file except in compliance with the License.
|
||||
-- You may obtain a copy of the License at
|
||||
--
|
||||
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||
--
|
||||
-- Unless required by applicable law or agreed to in writing, software
|
||||
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
-- See the License for the specific language governing permissions and
|
||||
-- limitations under the License.
|
||||
|
||||
CREATE TABLE oauth2_refresh_tokens (
|
||||
"id" BIGSERIAL PRIMARY KEY,
|
||||
"oauth2_session_id" BIGINT NOT NULL REFERENCES oauth2_sessions (id) ON DELETE CASCADE,
|
||||
"oauth2_access_token_id" BIGINT REFERENCES oauth2_access_tokens (id) ON DELETE SET NULL,
|
||||
|
||||
"token" TEXT UNIQUE NOT NULL,
|
||||
"next_token_id" BIGINT REFERENCES oauth2_refresh_tokens (id),
|
||||
|
||||
"created_at" TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT now(),
|
||||
"updated_at" TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT now()
|
||||
);
|
||||
|
||||
CREATE TRIGGER set_timestamp
|
||||
BEFORE UPDATE ON oauth2_refresh_tokens
|
||||
FOR EACH ROW
|
||||
EXECUTE PROCEDURE trigger_set_timestamp();
|
771
crates/core/sqlx-data.json
Normal file
771
crates/core/sqlx-data.json
Normal file
@ -0,0 +1,771 @@
|
||||
{
|
||||
"db": "PostgreSQL",
|
||||
"037ba804eabd0b4290d87d1de37054f358eb11397d3a8e4b69a81cdce0a178e0": {
|
||||
"query": "\n SELECT id, username\n FROM users\n WHERE username = $1\n ",
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
"ordinal": 0,
|
||||
"name": "id",
|
||||
"type_info": "Int8"
|
||||
},
|
||||
{
|
||||
"ordinal": 1,
|
||||
"name": "username",
|
||||
"type_info": "Text"
|
||||
}
|
||||
],
|
||||
"parameters": {
|
||||
"Left": [
|
||||
"Text"
|
||||
]
|
||||
},
|
||||
"nullable": [
|
||||
false,
|
||||
false
|
||||
]
|
||||
}
|
||||
},
|
||||
"138c3297a66107d8428ca10d04f9a4dd75faf9c1d3f84bcedd3b09f55dd84206": {
|
||||
"query": "\n INSERT INTO oauth2_codes\n (oauth2_session_id, code, code_challenge_method, code_challenge)\n VALUES\n ($1, $2, $3, $4)\n RETURNING\n id, oauth2_session_id, code, code_challenge_method, code_challenge\n ",
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
"ordinal": 0,
|
||||
"name": "id",
|
||||
"type_info": "Int8"
|
||||
},
|
||||
{
|
||||
"ordinal": 1,
|
||||
"name": "oauth2_session_id",
|
||||
"type_info": "Int8"
|
||||
},
|
||||
{
|
||||
"ordinal": 2,
|
||||
"name": "code",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 3,
|
||||
"name": "code_challenge_method",
|
||||
"type_info": "Int2"
|
||||
},
|
||||
{
|
||||
"ordinal": 4,
|
||||
"name": "code_challenge",
|
||||
"type_info": "Text"
|
||||
}
|
||||
],
|
||||
"parameters": {
|
||||
"Left": [
|
||||
"Int8",
|
||||
"Text",
|
||||
"Int2",
|
||||
"Text"
|
||||
]
|
||||
},
|
||||
"nullable": [
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
true
|
||||
]
|
||||
}
|
||||
},
|
||||
"17729fd0354a84e04bfcd525db6575ed2ba75dd730bea3f2be964f4b347dd484": {
|
||||
"query": "\n SELECT code\n FROM oauth2_codes\n WHERE oauth2_session_id = $1\n ",
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
"ordinal": 0,
|
||||
"name": "code",
|
||||
"type_info": "Text"
|
||||
}
|
||||
],
|
||||
"parameters": {
|
||||
"Left": [
|
||||
"Int8"
|
||||
]
|
||||
},
|
||||
"nullable": [
|
||||
false
|
||||
]
|
||||
}
|
||||
},
|
||||
"35bedaa6fdf7ac91d54b458b4637f2182c2f82be3e2f80cd2db934ee279a7f2a": {
|
||||
"query": "\n SELECT id, username\n FROM users\n WHERE id = $1\n ",
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
"ordinal": 0,
|
||||
"name": "id",
|
||||
"type_info": "Int8"
|
||||
},
|
||||
{
|
||||
"ordinal": 1,
|
||||
"name": "username",
|
||||
"type_info": "Text"
|
||||
}
|
||||
],
|
||||
"parameters": {
|
||||
"Left": [
|
||||
"Int8"
|
||||
]
|
||||
},
|
||||
"nullable": [
|
||||
false,
|
||||
false
|
||||
]
|
||||
}
|
||||
},
|
||||
"49888f812910633b87ce65c277f8969377fe264be154d8aa6b33d861d26d2b3b": {
|
||||
"query": "\n SELECT\n u.username AS \"username!\",\n us.active AS \"active!\",\n os.client_id AS \"client_id!\",\n os.scope AS \"scope!\",\n at.created_at AS \"created_at!\",\n at.expires_after AS \"expires_after!\"\n FROM oauth2_access_tokens at\n INNER JOIN oauth2_sessions os\n ON os.id = at.oauth2_session_id\n INNER JOIN user_sessions us\n ON us.id = os.user_session_id\n INNER JOIN users u\n ON u.id = us.user_id\n WHERE at.token = $1\n ",
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
"ordinal": 0,
|
||||
"name": "username!",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 1,
|
||||
"name": "active!",
|
||||
"type_info": "Bool"
|
||||
},
|
||||
{
|
||||
"ordinal": 2,
|
||||
"name": "client_id!",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 3,
|
||||
"name": "scope!",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 4,
|
||||
"name": "created_at!",
|
||||
"type_info": "Timestamptz"
|
||||
},
|
||||
{
|
||||
"ordinal": 5,
|
||||
"name": "expires_after!",
|
||||
"type_info": "Int4"
|
||||
}
|
||||
],
|
||||
"parameters": {
|
||||
"Left": [
|
||||
"Text"
|
||||
]
|
||||
},
|
||||
"nullable": [
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false
|
||||
]
|
||||
}
|
||||
},
|
||||
"4f925a277d73df779360f81e0cf5d7983b50ebe744f461559dd561b7e36c20d4": {
|
||||
"query": "\n SELECT\n s.id,\n u.id as user_id,\n u.username,\n s.active,\n s.created_at,\n a.created_at as \"last_authd_at?\"\n FROM user_sessions s\n INNER JOIN users u \n ON s.user_id = u.id\n LEFT JOIN user_session_authentications a\n ON a.session_id = s.id\n WHERE s.id = $1 AND s.active\n ORDER BY a.created_at DESC\n LIMIT 1\n ",
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
"ordinal": 0,
|
||||
"name": "id",
|
||||
"type_info": "Int8"
|
||||
},
|
||||
{
|
||||
"ordinal": 1,
|
||||
"name": "user_id",
|
||||
"type_info": "Int8"
|
||||
},
|
||||
{
|
||||
"ordinal": 2,
|
||||
"name": "username",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 3,
|
||||
"name": "active",
|
||||
"type_info": "Bool"
|
||||
},
|
||||
{
|
||||
"ordinal": 4,
|
||||
"name": "created_at",
|
||||
"type_info": "Timestamptz"
|
||||
},
|
||||
{
|
||||
"ordinal": 5,
|
||||
"name": "last_authd_at?",
|
||||
"type_info": "Timestamptz"
|
||||
}
|
||||
],
|
||||
"parameters": {
|
||||
"Left": [
|
||||
"Int8"
|
||||
]
|
||||
},
|
||||
"nullable": [
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false
|
||||
]
|
||||
}
|
||||
},
|
||||
"562b0d4dcf857e99c20e9288e9c8bd46232290715c0d2459b0398a1c746cf65d": {
|
||||
"query": "\n SELECT\n rt.id,\n rt.oauth2_session_id,\n rt.oauth2_access_token_id,\n os.client_id AS \"client_id!\",\n os.scope AS \"scope!\"\n FROM oauth2_refresh_tokens rt\n INNER JOIN oauth2_sessions os\n ON os.id = rt.oauth2_session_id\n WHERE rt.token = $1 AND rt.next_token_id IS NULL\n ",
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
"ordinal": 0,
|
||||
"name": "id",
|
||||
"type_info": "Int8"
|
||||
},
|
||||
{
|
||||
"ordinal": 1,
|
||||
"name": "oauth2_session_id",
|
||||
"type_info": "Int8"
|
||||
},
|
||||
{
|
||||
"ordinal": 2,
|
||||
"name": "oauth2_access_token_id",
|
||||
"type_info": "Int8"
|
||||
},
|
||||
{
|
||||
"ordinal": 3,
|
||||
"name": "client_id!",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 4,
|
||||
"name": "scope!",
|
||||
"type_info": "Text"
|
||||
}
|
||||
],
|
||||
"parameters": {
|
||||
"Left": [
|
||||
"Text"
|
||||
]
|
||||
},
|
||||
"nullable": [
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
false
|
||||
]
|
||||
}
|
||||
},
|
||||
"5d1a17b2ad6153217551ae31549ad9d62cc39d2f9a4e62a7ccb60fd91e0ac685": {
|
||||
"query": "\n DELETE FROM oauth2_access_tokens\n WHERE created_at + (expires_after * INTERVAL '1 second') + INTERVAL '15 minutes' < now()\n ",
|
||||
"describe": {
|
||||
"columns": [],
|
||||
"parameters": {
|
||||
"Left": []
|
||||
},
|
||||
"nullable": []
|
||||
}
|
||||
},
|
||||
"62986972431bfc4649e3d8c8c7648f9049c4197773e53496422ad8b8aa15b459": {
|
||||
"query": "\n SELECT\n s.id,\n u.id as user_id,\n u.username,\n s.active,\n s.created_at,\n a.created_at as \"last_authd_at?\"\n FROM user_sessions s\n INNER JOIN users u \n ON s.user_id = u.id\n LEFT JOIN user_session_authentications a\n ON a.session_id = s.id\n WHERE s.id = $1\n ORDER BY a.created_at DESC\n LIMIT 1\n ",
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
"ordinal": 0,
|
||||
"name": "id",
|
||||
"type_info": "Int8"
|
||||
},
|
||||
{
|
||||
"ordinal": 1,
|
||||
"name": "user_id",
|
||||
"type_info": "Int8"
|
||||
},
|
||||
{
|
||||
"ordinal": 2,
|
||||
"name": "username",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 3,
|
||||
"name": "active",
|
||||
"type_info": "Bool"
|
||||
},
|
||||
{
|
||||
"ordinal": 4,
|
||||
"name": "created_at",
|
||||
"type_info": "Timestamptz"
|
||||
},
|
||||
{
|
||||
"ordinal": 5,
|
||||
"name": "last_authd_at?",
|
||||
"type_info": "Timestamptz"
|
||||
}
|
||||
],
|
||||
"parameters": {
|
||||
"Left": [
|
||||
"Int8"
|
||||
]
|
||||
},
|
||||
"nullable": [
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false
|
||||
]
|
||||
}
|
||||
},
|
||||
"73f2d928f7bf88af79a3685bd6346652b4e4454b0ce75e38343840c9765e3f27": {
|
||||
"query": "\n INSERT INTO oauth2_refresh_tokens\n (oauth2_session_id, oauth2_access_token_id, token)\n VALUES\n ($1, $2, $3)\n RETURNING\n id, oauth2_session_id, oauth2_access_token_id, token, next_token_id, \n created_at, updated_at\n ",
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
"ordinal": 0,
|
||||
"name": "id",
|
||||
"type_info": "Int8"
|
||||
},
|
||||
{
|
||||
"ordinal": 1,
|
||||
"name": "oauth2_session_id",
|
||||
"type_info": "Int8"
|
||||
},
|
||||
{
|
||||
"ordinal": 2,
|
||||
"name": "oauth2_access_token_id",
|
||||
"type_info": "Int8"
|
||||
},
|
||||
{
|
||||
"ordinal": 3,
|
||||
"name": "token",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 4,
|
||||
"name": "next_token_id",
|
||||
"type_info": "Int8"
|
||||
},
|
||||
{
|
||||
"ordinal": 5,
|
||||
"name": "created_at",
|
||||
"type_info": "Timestamptz"
|
||||
},
|
||||
{
|
||||
"ordinal": 6,
|
||||
"name": "updated_at",
|
||||
"type_info": "Timestamptz"
|
||||
}
|
||||
],
|
||||
"parameters": {
|
||||
"Left": [
|
||||
"Int8",
|
||||
"Int8",
|
||||
"Text"
|
||||
]
|
||||
},
|
||||
"nullable": [
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
false
|
||||
]
|
||||
}
|
||||
},
|
||||
"886dee6a6f1f426f0e891790bbeffbc222fd75d8da0a107e7de673f1cc445f30": {
|
||||
"query": "\n SELECT\n oc.id,\n os.id AS \"oauth2_session_id!\",\n os.client_id AS \"client_id!\",\n os.redirect_uri,\n os.scope AS \"scope!\",\n os.nonce\n FROM oauth2_codes oc\n INNER JOIN oauth2_sessions os\n ON os.id = oc.oauth2_session_id\n WHERE oc.code = $1\n ",
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
"ordinal": 0,
|
||||
"name": "id",
|
||||
"type_info": "Int8"
|
||||
},
|
||||
{
|
||||
"ordinal": 1,
|
||||
"name": "oauth2_session_id!",
|
||||
"type_info": "Int8"
|
||||
},
|
||||
{
|
||||
"ordinal": 2,
|
||||
"name": "client_id!",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 3,
|
||||
"name": "redirect_uri",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 4,
|
||||
"name": "scope!",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 5,
|
||||
"name": "nonce",
|
||||
"type_info": "Text"
|
||||
}
|
||||
],
|
||||
"parameters": {
|
||||
"Left": [
|
||||
"Text"
|
||||
]
|
||||
},
|
||||
"nullable": [
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
true
|
||||
]
|
||||
}
|
||||
},
|
||||
"88ac8783bd5881c42eafd9cf87a16fe6031f3153fd6a8618e689694584aeb2de": {
|
||||
"query": "\n DELETE FROM oauth2_access_tokens\n WHERE id = $1\n ",
|
||||
"describe": {
|
||||
"columns": [],
|
||||
"parameters": {
|
||||
"Left": [
|
||||
"Int8"
|
||||
]
|
||||
},
|
||||
"nullable": []
|
||||
}
|
||||
},
|
||||
"9ba45ab114b656105cc46b0c10fb05769860fcdc05eaf54d6225640fb914dab9": {
|
||||
"query": "\n INSERT INTO user_session_authentications (session_id)\n VALUES ($1)\n RETURNING created_at\n ",
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
"ordinal": 0,
|
||||
"name": "created_at",
|
||||
"type_info": "Timestamptz"
|
||||
}
|
||||
],
|
||||
"parameters": {
|
||||
"Left": [
|
||||
"Int8"
|
||||
]
|
||||
},
|
||||
"nullable": [
|
||||
false
|
||||
]
|
||||
}
|
||||
},
|
||||
"a09dfe1019110f2ec6eba0d35bafa467ab4b7980dd8b556826f03863f8edb0ab": {
|
||||
"query": "UPDATE user_sessions SET active = FALSE WHERE id = $1",
|
||||
"describe": {
|
||||
"columns": [],
|
||||
"parameters": {
|
||||
"Left": [
|
||||
"Int8"
|
||||
]
|
||||
},
|
||||
"nullable": []
|
||||
}
|
||||
},
|
||||
"a552eee8a8e5ffdee4d4789c634851bd64780dfe730807aac20142d7cd643814": {
|
||||
"query": "\n SELECT u.hashed_password\n FROM user_sessions s\n INNER JOIN users u\n ON u.id = s.user_id \n WHERE s.id = $1\n ",
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
"ordinal": 0,
|
||||
"name": "hashed_password",
|
||||
"type_info": "Text"
|
||||
}
|
||||
],
|
||||
"parameters": {
|
||||
"Left": [
|
||||
"Int8"
|
||||
]
|
||||
},
|
||||
"nullable": [
|
||||
false
|
||||
]
|
||||
}
|
||||
},
|
||||
"a6eb935107d060dd01bf9824ceff87b9ff5492b58cefef002a49f444d3a3daa1": {
|
||||
"query": "UPDATE oauth2_sessions SET user_session_id = $1 WHERE id = $2",
|
||||
"describe": {
|
||||
"columns": [],
|
||||
"parameters": {
|
||||
"Left": [
|
||||
"Int8",
|
||||
"Int8"
|
||||
]
|
||||
},
|
||||
"nullable": []
|
||||
}
|
||||
},
|
||||
"b766b2b41d8770b5bef9928bb3b96abbaf8466b473e12b21f145c015b7cf2f05": {
|
||||
"query": "\n INSERT INTO oauth2_access_tokens\n (oauth2_session_id, token, expires_after)\n VALUES\n ($1, $2, $3)\n RETURNING\n id, oauth2_session_id, token, expires_after, created_at\n ",
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
"ordinal": 0,
|
||||
"name": "id",
|
||||
"type_info": "Int8"
|
||||
},
|
||||
{
|
||||
"ordinal": 1,
|
||||
"name": "oauth2_session_id",
|
||||
"type_info": "Int8"
|
||||
},
|
||||
{
|
||||
"ordinal": 2,
|
||||
"name": "token",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 3,
|
||||
"name": "expires_after",
|
||||
"type_info": "Int4"
|
||||
},
|
||||
{
|
||||
"ordinal": 4,
|
||||
"name": "created_at",
|
||||
"type_info": "Timestamptz"
|
||||
}
|
||||
],
|
||||
"parameters": {
|
||||
"Left": [
|
||||
"Int8",
|
||||
"Text",
|
||||
"Int4"
|
||||
]
|
||||
},
|
||||
"nullable": [
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false
|
||||
]
|
||||
}
|
||||
},
|
||||
"c2c402cfe0adcafa615f14a499caba4c96ca71d9ffb163e1feb05e5d85f3462c": {
|
||||
"query": "\n UPDATE oauth2_refresh_tokens\n SET next_token_id = $2\n WHERE id = $1\n ",
|
||||
"describe": {
|
||||
"columns": [],
|
||||
"parameters": {
|
||||
"Left": [
|
||||
"Int8",
|
||||
"Int8"
|
||||
]
|
||||
},
|
||||
"nullable": []
|
||||
}
|
||||
},
|
||||
"cacec823f5d4ed886854fbd62b5f5bb2def792582df58c8a047c769d34d9b190": {
|
||||
"query": "\n INSERT INTO oauth2_sessions\n (user_session_id, client_id, redirect_uri, scope, state, nonce, max_age,\n response_type, response_mode)\n VALUES\n ($1, $2, $3, $4, $5, $6, $7, $8, $9)\n RETURNING\n id, user_session_id, client_id, redirect_uri, scope, state, nonce, max_age,\n response_type, response_mode, created_at, updated_at\n ",
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
"ordinal": 0,
|
||||
"name": "id",
|
||||
"type_info": "Int8"
|
||||
},
|
||||
{
|
||||
"ordinal": 1,
|
||||
"name": "user_session_id",
|
||||
"type_info": "Int8"
|
||||
},
|
||||
{
|
||||
"ordinal": 2,
|
||||
"name": "client_id",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 3,
|
||||
"name": "redirect_uri",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 4,
|
||||
"name": "scope",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 5,
|
||||
"name": "state",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 6,
|
||||
"name": "nonce",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 7,
|
||||
"name": "max_age",
|
||||
"type_info": "Int4"
|
||||
},
|
||||
{
|
||||
"ordinal": 8,
|
||||
"name": "response_type",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 9,
|
||||
"name": "response_mode",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 10,
|
||||
"name": "created_at",
|
||||
"type_info": "Timestamptz"
|
||||
},
|
||||
{
|
||||
"ordinal": 11,
|
||||
"name": "updated_at",
|
||||
"type_info": "Timestamptz"
|
||||
}
|
||||
],
|
||||
"parameters": {
|
||||
"Left": [
|
||||
"Int8",
|
||||
"Text",
|
||||
"Text",
|
||||
"Text",
|
||||
"Text",
|
||||
"Text",
|
||||
"Int4",
|
||||
"Text",
|
||||
"Text"
|
||||
]
|
||||
},
|
||||
"nullable": [
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false
|
||||
]
|
||||
}
|
||||
},
|
||||
"f9a09ff53b6f221649f4f050e3d5ade114f852ddf50a78610a6c0ef0689af681": {
|
||||
"query": "\n INSERT INTO users (username, hashed_password)\n VALUES ($1, $2)\n RETURNING id\n ",
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
"ordinal": 0,
|
||||
"name": "id",
|
||||
"type_info": "Int8"
|
||||
}
|
||||
],
|
||||
"parameters": {
|
||||
"Left": [
|
||||
"Text",
|
||||
"Text"
|
||||
]
|
||||
},
|
||||
"nullable": [
|
||||
false
|
||||
]
|
||||
}
|
||||
},
|
||||
"ff515ebb80ba4af1948472f5c7120a03e25b1ebe42151b8a2036bfbb042f17f6": {
|
||||
"query": "\n SELECT\n id, user_session_id, client_id, redirect_uri, scope, state, nonce,\n max_age, response_type, response_mode, created_at, updated_at\n FROM oauth2_sessions\n WHERE id = $1\n ",
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
"ordinal": 0,
|
||||
"name": "id",
|
||||
"type_info": "Int8"
|
||||
},
|
||||
{
|
||||
"ordinal": 1,
|
||||
"name": "user_session_id",
|
||||
"type_info": "Int8"
|
||||
},
|
||||
{
|
||||
"ordinal": 2,
|
||||
"name": "client_id",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 3,
|
||||
"name": "redirect_uri",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 4,
|
||||
"name": "scope",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 5,
|
||||
"name": "state",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 6,
|
||||
"name": "nonce",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 7,
|
||||
"name": "max_age",
|
||||
"type_info": "Int4"
|
||||
},
|
||||
{
|
||||
"ordinal": 8,
|
||||
"name": "response_type",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 9,
|
||||
"name": "response_mode",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 10,
|
||||
"name": "created_at",
|
||||
"type_info": "Timestamptz"
|
||||
},
|
||||
{
|
||||
"ordinal": 11,
|
||||
"name": "updated_at",
|
||||
"type_info": "Timestamptz"
|
||||
}
|
||||
],
|
||||
"parameters": {
|
||||
"Left": [
|
||||
"Int8"
|
||||
]
|
||||
},
|
||||
"nullable": [
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
128
crates/core/src/errors.rs
Normal file
128
crates/core/src/errors.rs
Normal file
@ -0,0 +1,128 @@
|
||||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::{collections::HashMap, fmt::Debug, hash::Hash};
|
||||
|
||||
use serde::{ser::SerializeMap, Serialize};
|
||||
use warp::{reject::Reject, Rejection};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct WrappedError(anyhow::Error);
|
||||
|
||||
impl warp::reject::Reject for WrappedError {}
|
||||
|
||||
pub trait WrapError<T> {
|
||||
fn wrap_error(self) -> Result<T, Rejection>;
|
||||
}
|
||||
|
||||
impl<T, E> WrapError<T> for Result<T, E>
|
||||
where
|
||||
E: Into<anyhow::Error>,
|
||||
{
|
||||
fn wrap_error(self) -> Result<T, Rejection> {
|
||||
self.map_err(|e| warp::reject::custom(WrappedError(e.into())))
|
||||
}
|
||||
}
|
||||
|
||||
pub trait HtmlError: Debug + Send + Sync + 'static {
|
||||
fn html_display(&self) -> String;
|
||||
}
|
||||
|
||||
pub trait WrapFormError<FieldType> {
|
||||
fn on_form(self) -> ErroredForm<FieldType>;
|
||||
fn on_field(self, field: FieldType) -> ErroredForm<FieldType>;
|
||||
}
|
||||
|
||||
impl<E, FieldType> WrapFormError<FieldType> for E
|
||||
where
|
||||
E: HtmlError,
|
||||
{
|
||||
fn on_form(self) -> ErroredForm<FieldType> {
|
||||
let mut f = ErroredForm::new();
|
||||
f.form.push(FormError {
|
||||
error: Box::new(self),
|
||||
});
|
||||
f
|
||||
}
|
||||
|
||||
fn on_field(self, field: FieldType) -> ErroredForm<FieldType> {
|
||||
let mut f = ErroredForm::new();
|
||||
f.fields.push(FieldError {
|
||||
field,
|
||||
error: Box::new(self),
|
||||
});
|
||||
f
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct FormError {
|
||||
error: Box<dyn HtmlError>,
|
||||
}
|
||||
|
||||
impl Serialize for FormError {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: serde::Serializer,
|
||||
{
|
||||
serializer.serialize_str(&self.error.html_display())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct FieldError<FieldType> {
|
||||
field: FieldType,
|
||||
error: Box<dyn HtmlError>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
pub struct ErroredForm<FieldType> {
|
||||
form: Vec<FormError>,
|
||||
fields: Vec<FieldError<FieldType>>,
|
||||
}
|
||||
|
||||
impl<T> ErroredForm<T> {
|
||||
#[must_use] pub fn new() -> Self {
|
||||
Self {
|
||||
form: Vec::new(),
|
||||
fields: Vec::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Reject for ErroredForm<T> where T: Debug + Send + Sync + 'static {}
|
||||
|
||||
impl<FieldType: Copy + Serialize + Hash + Eq> Serialize for ErroredForm<FieldType> {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: serde::Serializer,
|
||||
{
|
||||
let mut map = serializer.serialize_map(Some(2))?;
|
||||
let has_errors = !self.form.is_empty() || !self.fields.is_empty();
|
||||
map.serialize_entry("has_errors", &has_errors)?;
|
||||
map.serialize_entry("form_errors", &self.form)?;
|
||||
|
||||
let fields: HashMap<FieldType, Vec<String>> =
|
||||
self.fields.iter().fold(HashMap::new(), |mut map, err| {
|
||||
map.entry(err.field)
|
||||
.or_default()
|
||||
.push(err.error.html_display());
|
||||
map
|
||||
});
|
||||
|
||||
map.serialize_entry("fields_errors", &fields)?;
|
||||
|
||||
map.end()
|
||||
}
|
||||
}
|
55
crates/core/src/filters/authenticate.rs
Normal file
55
crates/core/src/filters/authenticate.rs
Normal file
@ -0,0 +1,55 @@
|
||||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use chrono::Utc;
|
||||
use headers::{authorization::Bearer, Authorization};
|
||||
use sqlx::{pool::PoolConnection, PgPool, Postgres};
|
||||
use warp::{Filter, Rejection};
|
||||
|
||||
use super::{database::with_connection, headers::with_typed_header};
|
||||
use crate::{
|
||||
errors::WrapError,
|
||||
storage::oauth2::access_token::{lookup_access_token, OAuth2AccessTokenLookup},
|
||||
tokens,
|
||||
};
|
||||
|
||||
pub fn with_authentication(
|
||||
pool: &PgPool,
|
||||
) -> impl Filter<Extract = (OAuth2AccessTokenLookup,), Error = Rejection> + Clone + Send + Sync + 'static
|
||||
{
|
||||
with_connection(pool)
|
||||
.and(with_typed_header())
|
||||
.and_then(authenticate)
|
||||
}
|
||||
|
||||
async fn authenticate(
|
||||
mut conn: PoolConnection<Postgres>,
|
||||
auth: Authorization<Bearer>,
|
||||
) -> Result<OAuth2AccessTokenLookup, Rejection> {
|
||||
let token = auth.0.token();
|
||||
let token_type = tokens::check(token).wrap_error()?;
|
||||
if token_type != tokens::TokenType::AccessToken {
|
||||
return Err(anyhow::anyhow!("wrong token type")).wrap_error();
|
||||
}
|
||||
|
||||
let token = lookup_access_token(&mut conn, token).await.wrap_error()?;
|
||||
let exp = token.exp();
|
||||
|
||||
// Check it is active and did not expire
|
||||
if !token.active || exp < Utc::now() {
|
||||
return Err(anyhow::anyhow!("token expired")).wrap_error();
|
||||
}
|
||||
|
||||
Ok(token)
|
||||
}
|
213
crates/core/src/filters/client.rs
Normal file
213
crates/core/src/filters/client.rs
Normal file
@ -0,0 +1,213 @@
|
||||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use headers::{authorization::Basic, Authorization};
|
||||
use serde::{de::DeserializeOwned, Deserialize};
|
||||
use thiserror::Error;
|
||||
use warp::{reject::Reject, Filter, Rejection};
|
||||
|
||||
use super::headers::with_typed_header;
|
||||
use crate::config::{OAuth2ClientConfig, OAuth2Config};
|
||||
|
||||
#[derive(Debug, PartialEq, Eq)]
|
||||
pub enum ClientAuthentication {
|
||||
ClientSecretBasic,
|
||||
ClientSecretPost,
|
||||
None,
|
||||
}
|
||||
|
||||
impl ClientAuthentication {
|
||||
#[must_use]
|
||||
pub fn public(&self) -> bool {
|
||||
matches!(self, &Self::None)
|
||||
}
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn with_client_auth<T: DeserializeOwned + Send + 'static>(
|
||||
oauth2_config: &OAuth2Config,
|
||||
) -> impl Filter<Extract = (ClientAuthentication, OAuth2ClientConfig, T), Error = Rejection>
|
||||
+ Clone
|
||||
+ Send
|
||||
+ Sync
|
||||
+ 'static {
|
||||
// First, extract the client credentials
|
||||
let credentials = with_typed_header()
|
||||
.and(warp::body::form())
|
||||
// Either from the "Authorization" header
|
||||
.map(|auth: Authorization<Basic>, body: T| {
|
||||
let client_id = auth.0.username().to_string();
|
||||
let client_secret = Some(auth.0.password().to_string());
|
||||
(
|
||||
ClientAuthentication::ClientSecretBasic,
|
||||
client_id,
|
||||
client_secret,
|
||||
body,
|
||||
)
|
||||
})
|
||||
// Or from the form body
|
||||
.or(warp::body::form().map(|form: ClientAuthForm<T>| {
|
||||
let ClientAuthForm {
|
||||
client_id,
|
||||
client_secret,
|
||||
body,
|
||||
} = form;
|
||||
let auth_type = if client_secret.is_some() {
|
||||
ClientAuthentication::ClientSecretPost
|
||||
} else {
|
||||
ClientAuthentication::None
|
||||
};
|
||||
(auth_type, client_id, client_secret, body)
|
||||
}))
|
||||
.unify()
|
||||
.untuple_one();
|
||||
|
||||
let clients = oauth2_config.clients.clone();
|
||||
warp::any()
|
||||
.map(move || clients.clone())
|
||||
.and(credentials)
|
||||
.and_then(authenticate_client)
|
||||
.untuple_one()
|
||||
}
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
enum ClientAuthenticationError {
|
||||
#[error("no client secret found for client {client_id:?}")]
|
||||
NoClientSecret { client_id: String },
|
||||
|
||||
#[error("wrong client secret for client {client_id:?}")]
|
||||
ClientSecretMismatch { client_id: String },
|
||||
|
||||
#[error("could not find client {client_id:?}")]
|
||||
ClientNotFound { client_id: String },
|
||||
|
||||
#[error("client secret required for client {client_id:?}")]
|
||||
ClientSecretRequired { client_id: String },
|
||||
}
|
||||
|
||||
impl Reject for ClientAuthenticationError {}
|
||||
|
||||
async fn authenticate_client<T>(
|
||||
clients: Vec<OAuth2ClientConfig>,
|
||||
auth_type: ClientAuthentication,
|
||||
client_id: String,
|
||||
client_secret: Option<String>,
|
||||
body: T,
|
||||
) -> Result<(ClientAuthentication, OAuth2ClientConfig, T), Rejection> {
|
||||
let client = clients
|
||||
.iter()
|
||||
.find(|client| client.client_id == client_id)
|
||||
.ok_or_else(|| ClientAuthenticationError::ClientNotFound {
|
||||
client_id: client_id.to_string(),
|
||||
})?;
|
||||
|
||||
let client = match (client_secret, client.client_secret.as_ref()) {
|
||||
(None, None) => Ok(client),
|
||||
(Some(ref given), Some(expected)) if given == expected => Ok(client),
|
||||
(Some(_), Some(_)) => Err(ClientAuthenticationError::ClientSecretMismatch { client_id }),
|
||||
(Some(_), None) => Err(ClientAuthenticationError::NoClientSecret { client_id }),
|
||||
(None, Some(_)) => Err(ClientAuthenticationError::ClientSecretRequired { client_id }),
|
||||
}?;
|
||||
|
||||
Ok((auth_type, client.clone(), body))
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct ClientAuthForm<T> {
|
||||
client_id: String,
|
||||
client_secret: Option<String>,
|
||||
|
||||
#[serde(flatten)]
|
||||
body: T,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use mas_config::ConfigurationSection;
|
||||
|
||||
use super::*;
|
||||
|
||||
fn oauth2_config() -> OAuth2Config {
|
||||
let mut config = OAuth2Config::test();
|
||||
config.clients.push(OAuth2ClientConfig {
|
||||
client_id: "public".to_string(),
|
||||
client_secret: None,
|
||||
redirect_uris: Vec::new(),
|
||||
});
|
||||
config.clients.push(OAuth2ClientConfig {
|
||||
client_id: "confidential".to_string(),
|
||||
client_secret: Some("secret".to_string()),
|
||||
redirect_uris: Vec::new(),
|
||||
});
|
||||
config
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct Form {
|
||||
foo: String,
|
||||
bar: String,
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn client_secret_post() {
|
||||
let filter = with_client_auth::<Form>(&oauth2_config());
|
||||
|
||||
let (auth, client, body) = warp::test::request()
|
||||
.method("POST")
|
||||
.body("client_id=confidential&client_secret=secret&foo=baz&bar=foobar")
|
||||
.filter(&filter)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(auth, ClientAuthentication::ClientSecretPost);
|
||||
assert_eq!(client.client_id, "confidential");
|
||||
assert_eq!(body.foo, "baz");
|
||||
assert_eq!(body.bar, "foobar");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn client_secret_basic() {
|
||||
let filter = with_client_auth::<Form>(&oauth2_config());
|
||||
|
||||
let (auth, client, body) = warp::test::request()
|
||||
.method("POST")
|
||||
.header("Authorization", "Basic Y29uZmlkZW50aWFsOnNlY3JldA==")
|
||||
.body("foo=baz&bar=foobar")
|
||||
.filter(&filter)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(auth, ClientAuthentication::ClientSecretBasic);
|
||||
assert_eq!(client.client_id, "confidential");
|
||||
assert_eq!(body.foo, "baz");
|
||||
assert_eq!(body.bar, "foobar");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn none() {
|
||||
let filter = with_client_auth::<Form>(&oauth2_config());
|
||||
|
||||
let (auth, client, body) = warp::test::request()
|
||||
.method("POST")
|
||||
.body("client_id=public&foo=baz&bar=foobar")
|
||||
.filter(&filter)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(auth, ClientAuthentication::None);
|
||||
assert_eq!(client.client_id, "public");
|
||||
assert_eq!(body.foo, "baz");
|
||||
assert_eq!(body.bar, "foobar");
|
||||
}
|
||||
}
|
134
crates/core/src/filters/cookies.rs
Normal file
134
crates/core/src/filters/cookies.rs
Normal file
@ -0,0 +1,134 @@
|
||||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::convert::Infallible;
|
||||
|
||||
use chacha20poly1305::{
|
||||
aead::{generic_array::GenericArray, Aead, NewAead},
|
||||
ChaCha20Poly1305,
|
||||
};
|
||||
use cookie::Cookie;
|
||||
use data_encoding::BASE64URL_NOPAD;
|
||||
use headers::{Header, HeaderValue, SetCookie};
|
||||
use serde::{de::DeserializeOwned, Deserialize, Serialize};
|
||||
use warp::{Filter, Rejection, Reply};
|
||||
|
||||
use super::headers::{typed_header, WithTypedHeader};
|
||||
use crate::{config::CookiesConfig, errors::WrapError};
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct EncryptedCookie {
|
||||
nonce: [u8; 12],
|
||||
ciphertext: Vec<u8>,
|
||||
}
|
||||
|
||||
impl EncryptedCookie {
|
||||
/// Encrypt from a given key
|
||||
fn encrypt<T: Serialize>(payload: T, key: &[u8; 32]) -> anyhow::Result<Self> {
|
||||
let key = GenericArray::from_slice(key);
|
||||
let aead = ChaCha20Poly1305::new(key);
|
||||
let message = bincode::serialize(&payload)?;
|
||||
let nonce: [u8; 12] = rand::random();
|
||||
let ciphertext = aead.encrypt(GenericArray::from_slice(&nonce[..]), &message[..])?;
|
||||
Ok(Self { nonce, ciphertext })
|
||||
}
|
||||
|
||||
/// Decrypt the content of the cookie from a given key
|
||||
fn decrypt<T: DeserializeOwned>(&self, key: &[u8; 32]) -> anyhow::Result<T> {
|
||||
let key = GenericArray::from_slice(key);
|
||||
let aead = ChaCha20Poly1305::new(key);
|
||||
let message = aead.decrypt(
|
||||
GenericArray::from_slice(&self.nonce[..]),
|
||||
&self.ciphertext[..],
|
||||
)?;
|
||||
let token = bincode::deserialize(&message)?;
|
||||
Ok(token)
|
||||
}
|
||||
|
||||
/// Encode the encrypted cookie to be then saved as a cookie
|
||||
fn to_cookie_value(&self) -> anyhow::Result<String> {
|
||||
let raw = bincode::serialize(self)?;
|
||||
Ok(BASE64URL_NOPAD.encode(&raw))
|
||||
}
|
||||
|
||||
fn from_cookie_value(value: &str) -> anyhow::Result<Self> {
|
||||
let raw = BASE64URL_NOPAD.decode(value.as_bytes())?;
|
||||
let content = bincode::deserialize(&raw)?;
|
||||
Ok(content)
|
||||
}
|
||||
}
|
||||
|
||||
#[must_use] pub fn maybe_encrypted<T>(
|
||||
options: &CookiesConfig,
|
||||
) -> impl Filter<Extract = (Option<T>,), Error = Infallible> + Clone + Send + Sync + 'static
|
||||
where
|
||||
T: DeserializeOwned + EncryptableCookieValue + Send + 'static,
|
||||
{
|
||||
let secret = options.secret;
|
||||
warp::cookie::optional(T::cookie_key()).map(move |maybe_value: Option<String>| {
|
||||
maybe_value
|
||||
.and_then(|value| EncryptedCookie::from_cookie_value(&value).ok())
|
||||
.and_then(|encrypted| encrypted.decrypt(&secret).ok())
|
||||
})
|
||||
}
|
||||
|
||||
#[must_use] pub fn encrypted<T>(
|
||||
options: &CookiesConfig,
|
||||
) -> impl Filter<Extract = (T,), Error = Rejection> + Clone + Send + Sync + 'static
|
||||
where
|
||||
T: DeserializeOwned + EncryptableCookieValue + Send + 'static,
|
||||
{
|
||||
let secret = options.secret;
|
||||
warp::cookie::cookie(T::cookie_key()).and_then(move |value: String| async move {
|
||||
let encrypted = EncryptedCookie::from_cookie_value(&value).wrap_error()?;
|
||||
let decrypted = encrypted.decrypt(&secret).wrap_error()?;
|
||||
Ok::<_, Rejection>(decrypted)
|
||||
})
|
||||
}
|
||||
|
||||
#[must_use] pub fn with_cookie_saver(
|
||||
options: &CookiesConfig,
|
||||
) -> impl Filter<Extract = (EncryptedCookieSaver,), Error = Infallible> + Clone + Send + Sync + 'static
|
||||
{
|
||||
let secret = options.secret;
|
||||
warp::any().map(move || EncryptedCookieSaver { secret })
|
||||
}
|
||||
|
||||
/// A cookie that can be encrypted with a well-known cookie key
|
||||
pub trait EncryptableCookieValue {
|
||||
fn cookie_key() -> &'static str;
|
||||
}
|
||||
|
||||
pub struct EncryptedCookieSaver {
|
||||
secret: [u8; 32],
|
||||
}
|
||||
|
||||
impl EncryptedCookieSaver {
|
||||
pub fn save_encrypted<T: Serialize + EncryptableCookieValue, R: Reply>(
|
||||
&self,
|
||||
cookie: &T,
|
||||
reply: R,
|
||||
) -> Result<WithTypedHeader<R, SetCookie>, Rejection> {
|
||||
let encrypted = EncryptedCookie::encrypt(cookie, &self.secret)
|
||||
.wrap_error()?
|
||||
.to_cookie_value()
|
||||
.wrap_error()?;
|
||||
let value = Cookie::build(T::cookie_key(), encrypted)
|
||||
.finish()
|
||||
.to_string();
|
||||
let header = SetCookie::decode(&mut [HeaderValue::from_str(&value).wrap_error()?].iter())
|
||||
.wrap_error()?;
|
||||
Ok(typed_header(header, reply))
|
||||
}
|
||||
}
|
159
crates/core/src/filters/csrf.rs
Normal file
159
crates/core/src/filters/csrf.rs
Normal file
@ -0,0 +1,159 @@
|
||||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! Stateless CSRF protection middleware based on a chacha20-poly1305 encrypted
|
||||
//! and signed token
|
||||
|
||||
use chrono::{DateTime, Duration, Utc};
|
||||
use data_encoding::{DecodeError, BASE64URL_NOPAD};
|
||||
use serde::{de::DeserializeOwned, Deserialize, Serialize};
|
||||
use serde_with::{serde_as, TimestampSeconds};
|
||||
use thiserror::Error;
|
||||
use warp::{reject::Reject, Filter, Rejection};
|
||||
|
||||
use super::cookies::EncryptableCookieValue;
|
||||
use crate::config::{CookiesConfig, CsrfConfig};
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum CsrfError {
|
||||
#[error("CSRF token mismatch")]
|
||||
Mismatch,
|
||||
|
||||
#[error("CSRF token expired")]
|
||||
Expired,
|
||||
|
||||
#[error("could not decode CSRF token")]
|
||||
Decode(#[from] DecodeError),
|
||||
}
|
||||
|
||||
impl Reject for CsrfError {}
|
||||
|
||||
#[serde_as]
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct CsrfToken {
|
||||
#[serde_as(as = "TimestampSeconds<i64>")]
|
||||
expiration: DateTime<Utc>,
|
||||
token: [u8; 32],
|
||||
}
|
||||
|
||||
impl CsrfToken {
|
||||
/// Create a new token from a defined value valid for a specified duration
|
||||
fn new(token: [u8; 32], ttl: Duration) -> Self {
|
||||
let expiration = Utc::now() + ttl;
|
||||
Self { expiration, token }
|
||||
}
|
||||
|
||||
/// Generate a new random token valid for a specified duration
|
||||
fn generate(ttl: Duration) -> Self {
|
||||
let token = rand::random();
|
||||
Self::new(token, ttl)
|
||||
}
|
||||
|
||||
/// Generate a new token with the same value but an up to date expiration
|
||||
fn refresh(self, ttl: Duration) -> Self {
|
||||
Self::new(self.token, ttl)
|
||||
}
|
||||
|
||||
/// Get the value to include in HTML forms
|
||||
#[must_use] pub fn form_value(&self) -> String {
|
||||
BASE64URL_NOPAD.encode(&self.token[..])
|
||||
}
|
||||
|
||||
/// Verifies that the value got from an HTML form matches this token
|
||||
pub fn verify_form_value(&self, form_value: &str) -> Result<(), CsrfError> {
|
||||
let form_value = BASE64URL_NOPAD.decode(form_value.as_bytes())?;
|
||||
if self.token[..] == form_value {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(CsrfError::Mismatch)
|
||||
}
|
||||
}
|
||||
|
||||
fn verify_expiration(self) -> Result<Self, CsrfError> {
|
||||
if Utc::now() < self.expiration {
|
||||
Ok(self)
|
||||
} else {
|
||||
Err(CsrfError::Expired)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl EncryptableCookieValue for CsrfToken {
|
||||
fn cookie_key() -> &'static str {
|
||||
"csrf"
|
||||
}
|
||||
}
|
||||
|
||||
/// A CSRF-protected form
|
||||
#[derive(Deserialize)]
|
||||
struct CsrfForm<T> {
|
||||
csrf: String,
|
||||
|
||||
#[serde(flatten)]
|
||||
inner: T,
|
||||
}
|
||||
|
||||
impl<T> CsrfForm<T> {
|
||||
fn verify_csrf(self, token: &CsrfToken) -> Result<T, CsrfError> {
|
||||
// Verify CSRF from request
|
||||
token.verify_form_value(&self.csrf)?;
|
||||
Ok(self.inner)
|
||||
}
|
||||
}
|
||||
|
||||
#[must_use] pub fn csrf_token(
|
||||
cookies_config: &CookiesConfig,
|
||||
) -> impl Filter<Extract = (CsrfToken,), Error = Rejection> + Clone + Send + Sync + 'static {
|
||||
super::cookies::encrypted(cookies_config).and_then(move |token: CsrfToken| async move {
|
||||
let verified = token.verify_expiration()?;
|
||||
Ok::<_, Rejection>(verified)
|
||||
})
|
||||
}
|
||||
|
||||
#[must_use] pub fn updated_csrf_token(
|
||||
cookies_config: &CookiesConfig,
|
||||
csrf_config: &CsrfConfig,
|
||||
) -> impl Filter<Extract = (CsrfToken,), Error = Rejection> + Clone + Send + Sync + 'static {
|
||||
let ttl = csrf_config.ttl;
|
||||
super::cookies::maybe_encrypted(cookies_config).and_then(
|
||||
move |maybe_token: Option<CsrfToken>| async move {
|
||||
// Explicitely specify the "Error" type here to have the `?` operation working
|
||||
Ok::<_, Rejection>(
|
||||
maybe_token
|
||||
// Verify its TTL (but do not hard-error if it expired)
|
||||
.and_then(|token| token.verify_expiration().ok())
|
||||
.map_or_else(
|
||||
// Generate a new token if no valid one were found
|
||||
|| CsrfToken::generate(ttl),
|
||||
// Else, refresh the expiration of the token
|
||||
|token| token.refresh(ttl),
|
||||
),
|
||||
)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
#[must_use] pub fn protected_form<T>(
|
||||
cookies_config: &CookiesConfig,
|
||||
) -> impl Filter<Extract = (T,), Error = Rejection> + Clone + Send + Sync + 'static
|
||||
where
|
||||
T: DeserializeOwned + Send + 'static,
|
||||
{
|
||||
csrf_token(cookies_config).and(warp::body::form()).and_then(
|
||||
|csrf_token: CsrfToken, protected_form: CsrfForm<T>| async move {
|
||||
let form = protected_form.verify_csrf(&csrf_token)?;
|
||||
Ok::<_, Rejection>(form)
|
||||
},
|
||||
)
|
||||
}
|
54
crates/core/src/filters/database.rs
Normal file
54
crates/core/src/filters/database.rs
Normal file
@ -0,0 +1,54 @@
|
||||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::convert::Infallible;
|
||||
|
||||
use sqlx::{pool::PoolConnection, PgPool, Postgres, Transaction};
|
||||
use warp::{Filter, Rejection};
|
||||
|
||||
use crate::errors::WrapError;
|
||||
|
||||
fn with_pool(
|
||||
pool: &PgPool,
|
||||
) -> impl Filter<Extract = (PgPool,), Error = Infallible> + Clone + Send + Sync + 'static {
|
||||
let pool = pool.clone();
|
||||
warp::any().map(move || pool.clone())
|
||||
}
|
||||
|
||||
pub fn with_connection(
|
||||
pool: &PgPool,
|
||||
) -> impl Filter<Extract = (PoolConnection<Postgres>,), Error = Rejection> + Clone + Send + Sync + 'static
|
||||
{
|
||||
with_pool(pool).and_then(acquire_connection)
|
||||
}
|
||||
|
||||
async fn acquire_connection(pool: PgPool) -> Result<PoolConnection<Postgres>, Rejection> {
|
||||
let conn = pool.acquire().await.wrap_error()?;
|
||||
Ok(conn)
|
||||
}
|
||||
|
||||
pub fn with_transaction(
|
||||
pool: &PgPool,
|
||||
) -> impl Filter<Extract = (Transaction<'static, Postgres>,), Error = Rejection>
|
||||
+ Clone
|
||||
+ Send
|
||||
+ Sync
|
||||
+ 'static {
|
||||
with_pool(pool).and_then(acquire_transaction)
|
||||
}
|
||||
|
||||
async fn acquire_transaction(pool: PgPool) -> Result<Transaction<'static, Postgres>, Rejection> {
|
||||
let txn = pool.begin().await.wrap_error()?;
|
||||
Ok(txn)
|
||||
}
|
200
crates/core/src/filters/errors.rs
Normal file
200
crates/core/src/filters/errors.rs
Normal file
@ -0,0 +1,200 @@
|
||||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::{cmp::Reverse, future::Future, pin::Pin};
|
||||
|
||||
use mime::{Mime, STAR};
|
||||
use serde::Serialize;
|
||||
use tera::Context;
|
||||
use tide::{
|
||||
http::headers::{ACCEPT, LOCATION},
|
||||
Body, Request, StatusCode,
|
||||
};
|
||||
use tracing::debug;
|
||||
|
||||
use crate::{state::State, templates::common_context};
|
||||
|
||||
/// Get the weight parameter for a mime type from 0 to 1000
|
||||
#[allow(clippy::cast_sign_loss, clippy::cast_possible_truncation)]
|
||||
fn get_weight(mime: &Mime) -> usize {
|
||||
let q = mime
|
||||
.get_param("q")
|
||||
.map_or(1.0_f64, |q| q.as_str().parse().unwrap_or(0.0))
|
||||
.min(1.0)
|
||||
.max(0.0);
|
||||
|
||||
// Weight have a 3 digit precision so we can multiply by 1000 and cast to
|
||||
// int. Sign loss should not happen here because of the min/max up there and
|
||||
// truncation does not matter here.
|
||||
(q * 1000.0) as _
|
||||
}
|
||||
|
||||
/// Find what content type should be used for a given request
|
||||
fn preferred_mime_type<'a>(
|
||||
request: &Request<State>,
|
||||
supported_types: &'a [Mime],
|
||||
) -> Option<&'a Mime> {
|
||||
let accept = request.header(ACCEPT)?;
|
||||
// Parse the Accept header as a list of mime types with their associated
|
||||
// weight
|
||||
let accepted_types: Vec<(Mime, usize)> = {
|
||||
let v: Option<Vec<_>> = accept
|
||||
.into_iter()
|
||||
.flat_map(|value| value.as_str().split(','))
|
||||
.map(|mime| {
|
||||
mime.trim().parse().ok().map(|mime| {
|
||||
let q = get_weight(&mime);
|
||||
(mime, q)
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
let mut v = v?;
|
||||
v.sort_by_key(|(_, weight)| Reverse(*weight));
|
||||
v
|
||||
};
|
||||
|
||||
// For each supported content type, find out if it is accepted with what
|
||||
// weight and specificity
|
||||
let mut types: Vec<_> = supported_types
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter_map(|(index, supported)| {
|
||||
accepted_types.iter().find_map(|(accepted, weight)| {
|
||||
if accepted.type_() == supported.type_()
|
||||
&& accepted.subtype() == supported.subtype()
|
||||
{
|
||||
// Accept: text/html
|
||||
Some((supported, *weight, 2_usize, index))
|
||||
} else if accepted.type_() == supported.type_() && accepted.subtype() == STAR {
|
||||
// Accept: text/*
|
||||
Some((supported, *weight, 1, index))
|
||||
} else if accepted.type_() == STAR && accepted.subtype() == STAR {
|
||||
// Accept: */*
|
||||
Some((supported, *weight, 0, index))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
types.sort_by_key(|(_, weight, specificity, index)| {
|
||||
(Reverse(*weight), Reverse(*specificity), *index)
|
||||
});
|
||||
|
||||
types.first().map(|(mime, _, _, _)| *mime)
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct ErrorContext {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
code: Option<String>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
description: Option<String>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
details: Option<String>,
|
||||
}
|
||||
|
||||
impl ErrorContext {
|
||||
fn should_render(&self) -> bool {
|
||||
self.code.is_some() || self.description.is_some() || self.details.is_some()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn middleware<'a>(
|
||||
request: tide::Request<State>,
|
||||
next: tide::Next<'a, State>,
|
||||
) -> Pin<Box<dyn Future<Output = tide::Result> + Send + 'a>> {
|
||||
Box::pin(async {
|
||||
let content_type = preferred_mime_type(
|
||||
&request,
|
||||
&[mime::TEXT_PLAIN, mime::TEXT_HTML, mime::APPLICATION_JSON],
|
||||
);
|
||||
debug!("Content-Type from Accept: {:?}", content_type);
|
||||
|
||||
// TODO: We should not clone here
|
||||
let templates = request.state().templates().clone();
|
||||
|
||||
// TODO: This context should probably be comptuted somewhere else
|
||||
let pctx = common_context(&request).await?.clone();
|
||||
|
||||
let mut response = next.run(request).await;
|
||||
|
||||
// Find out what message should be displayed from the response status
|
||||
// code
|
||||
let (code, description) = match response.status() {
|
||||
StatusCode::NotFound => (Some("Not found".to_string()), None),
|
||||
StatusCode::MethodNotAllowed => (Some("Method not allowed".to_string()), None),
|
||||
StatusCode::Found
|
||||
| StatusCode::PermanentRedirect
|
||||
| StatusCode::TemporaryRedirect
|
||||
| StatusCode::SeeOther => {
|
||||
let description = response.header(LOCATION).map(|loc| format!("To {}", loc));
|
||||
(Some("Redirecting".to_string()), description)
|
||||
}
|
||||
StatusCode::InternalServerError => (Some("Internal server error".to_string()), None),
|
||||
_ => (None, None),
|
||||
};
|
||||
|
||||
// If there is an error associated to the response, format it in a nice
|
||||
// way with a backtrace if we have one
|
||||
let details = response.take_error().map(|err| {
|
||||
format!(
|
||||
"{:?}{}",
|
||||
err,
|
||||
err.backtrace()
|
||||
.map(|bt| format!("\nBacktrace:\n{}", bt.to_string()))
|
||||
.unwrap_or_default()
|
||||
)
|
||||
});
|
||||
|
||||
let error_context = ErrorContext {
|
||||
code,
|
||||
description,
|
||||
details,
|
||||
};
|
||||
|
||||
// This is the case if one of the code, description or details is not
|
||||
// None
|
||||
if error_context.should_render() {
|
||||
match content_type {
|
||||
Some(c) if c == &mime::APPLICATION_JSON => {
|
||||
response.set_body(Body::from_json(&error_context)?);
|
||||
response.set_content_type("application/json");
|
||||
}
|
||||
Some(c) if c == &mime::TEXT_HTML => {
|
||||
let mut ctx = Context::from_serialize(&error_context)?;
|
||||
ctx.extend(pctx);
|
||||
response.set_body(templates.render("error.html", &ctx)?);
|
||||
response.set_content_type("text/html");
|
||||
}
|
||||
Some(c) if c == &mime::TEXT_PLAIN => {
|
||||
let mut ctx = Context::from_serialize(&error_context)?;
|
||||
ctx.extend(pctx);
|
||||
response.set_body(templates.render("error.txt", &ctx)?);
|
||||
response.set_content_type("text/plain");
|
||||
}
|
||||
_ => {
|
||||
response.set_body("Unsupported Content-Type in Accept header");
|
||||
response.set_content_type("text/plain");
|
||||
response.set_status(StatusCode::NotAcceptable);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(response)
|
||||
})
|
||||
}
|
50
crates/core/src/filters/headers.rs
Normal file
50
crates/core/src/filters/headers.rs
Normal file
@ -0,0 +1,50 @@
|
||||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
use headers::{Header, HeaderMapExt, HeaderValue};
|
||||
use warp::{Filter, Rejection, Reply};
|
||||
|
||||
use crate::errors::WrapError;
|
||||
|
||||
pub fn typed_header<R, H>(header: H, reply: R) -> WithTypedHeader<R, H> {
|
||||
WithTypedHeader { reply, header }
|
||||
}
|
||||
|
||||
pub struct WithTypedHeader<R, H> {
|
||||
reply: R,
|
||||
header: H,
|
||||
}
|
||||
|
||||
impl<R, H> Reply for WithTypedHeader<R, H>
|
||||
where
|
||||
R: Reply,
|
||||
H: Header + Send,
|
||||
{
|
||||
fn into_response(self) -> warp::reply::Response {
|
||||
let mut res = self.reply.into_response();
|
||||
res.headers_mut().typed_insert(self.header);
|
||||
res
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_typed_header<T: Header + Send + 'static>(
|
||||
) -> impl Filter<Extract = (T,), Error = Rejection> + Clone + Send + Sync + 'static {
|
||||
warp::header::value(T::name().as_str()).and_then(decode_typed_header)
|
||||
}
|
||||
|
||||
async fn decode_typed_header<T: Header>(header: HeaderValue) -> Result<T, Rejection> {
|
||||
let mut it = std::iter::once(&header);
|
||||
let decoded = T::decode(&mut it).wrap_error()?;
|
||||
Ok(decoded)
|
||||
}
|
48
crates/core/src/filters/mod.rs
Normal file
48
crates/core/src/filters/mod.rs
Normal file
@ -0,0 +1,48 @@
|
||||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#![allow(clippy::unused_async)] // Some warp filters need that
|
||||
|
||||
pub mod csrf;
|
||||
// mod errors;
|
||||
pub mod authenticate;
|
||||
pub mod client;
|
||||
pub mod cookies;
|
||||
pub mod database;
|
||||
pub mod headers;
|
||||
pub mod session;
|
||||
|
||||
use std::convert::Infallible;
|
||||
|
||||
use warp::Filter;
|
||||
|
||||
pub use self::csrf::CsrfToken;
|
||||
use crate::{
|
||||
config::{KeySet, OAuth2Config},
|
||||
templates::Templates,
|
||||
};
|
||||
|
||||
#[must_use] pub fn with_templates(
|
||||
templates: &Templates,
|
||||
) -> impl Filter<Extract = (Templates,), Error = Infallible> + Clone + Send + Sync + 'static {
|
||||
let templates = templates.clone();
|
||||
warp::any().map(move || templates.clone())
|
||||
}
|
||||
|
||||
#[must_use] pub fn with_keys(
|
||||
oauth2_config: &OAuth2Config,
|
||||
) -> impl Filter<Extract = (KeySet,), Error = Infallible> + Clone + Send + Sync + 'static {
|
||||
let keyset = oauth2_config.keys.clone();
|
||||
warp::any().map(move || keyset.clone())
|
||||
}
|
86
crates/core/src/filters/session.rs
Normal file
86
crates/core/src/filters/session.rs
Normal file
@ -0,0 +1,86 @@
|
||||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sqlx::{pool::PoolConnection, Executor, PgPool, Postgres};
|
||||
use warp::{Filter, Rejection};
|
||||
|
||||
use super::{
|
||||
cookies::{encrypted, maybe_encrypted, EncryptableCookieValue},
|
||||
database::with_connection,
|
||||
};
|
||||
use crate::{
|
||||
config::CookiesConfig,
|
||||
errors::WrapError,
|
||||
storage::{lookup_active_session, SessionInfo},
|
||||
};
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct SessionCookie {
|
||||
current: i64,
|
||||
}
|
||||
|
||||
impl SessionCookie {
|
||||
#[must_use] pub fn from_session_info(info: &SessionInfo) -> Self {
|
||||
Self {
|
||||
current: info.key(),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn load_session_info(
|
||||
&self,
|
||||
executor: impl Executor<'_, Database = Postgres>,
|
||||
) -> anyhow::Result<SessionInfo> {
|
||||
lookup_active_session(executor, self.current).await
|
||||
}
|
||||
}
|
||||
|
||||
impl EncryptableCookieValue for SessionCookie {
|
||||
fn cookie_key() -> &'static str {
|
||||
"session"
|
||||
}
|
||||
}
|
||||
|
||||
#[must_use] pub fn with_optional_session(
|
||||
pool: &PgPool,
|
||||
cookies_config: &CookiesConfig,
|
||||
) -> impl Filter<Extract = (Option<SessionInfo>,), Error = Rejection> + Clone + Send + Sync + 'static
|
||||
{
|
||||
maybe_encrypted(cookies_config)
|
||||
.and(with_connection(pool))
|
||||
.and_then(
|
||||
|maybe_session: Option<SessionCookie>, mut conn: PoolConnection<Postgres>| async move {
|
||||
let maybe_session_info = if let Some(session) = maybe_session {
|
||||
session.load_session_info(&mut conn).await.ok()
|
||||
} else {
|
||||
None
|
||||
};
|
||||
Ok::<_, Rejection>(maybe_session_info)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
#[must_use] pub fn with_session(
|
||||
pool: &PgPool,
|
||||
cookies_config: &CookiesConfig,
|
||||
) -> impl Filter<Extract = (SessionInfo,), Error = Rejection> + Clone + Send + Sync + 'static {
|
||||
encrypted(cookies_config)
|
||||
.and(with_connection(pool))
|
||||
.and_then(
|
||||
|session: SessionCookie, mut conn: PoolConnection<Postgres>| async move {
|
||||
let session_info = session.load_session_info(&mut conn).await.wrap_error()?;
|
||||
Ok::<_, Rejection>(session_info)
|
||||
},
|
||||
)
|
||||
}
|
41
crates/core/src/handlers/health.rs
Normal file
41
crates/core/src/handlers/health.rs
Normal file
@ -0,0 +1,41 @@
|
||||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use hyper::header::CONTENT_TYPE;
|
||||
use mime::TEXT_PLAIN;
|
||||
use sqlx::{pool::PoolConnection, PgPool, Postgres};
|
||||
use tracing::{info_span, Instrument};
|
||||
use warp::{reply::with_header, Filter, Rejection, Reply};
|
||||
|
||||
use crate::{errors::WrapError, filters::database::with_connection};
|
||||
|
||||
pub fn filter(
|
||||
pool: &PgPool,
|
||||
) -> impl Filter<Extract = (impl Reply,), Error = Rejection> + Clone + Send + Sync + 'static {
|
||||
warp::path!("health")
|
||||
.and(warp::get())
|
||||
.and(with_connection(pool))
|
||||
.and_then(get)
|
||||
}
|
||||
|
||||
async fn get(mut conn: PoolConnection<Postgres>) -> Result<impl Reply, Rejection> {
|
||||
sqlx::query("SELECT $1")
|
||||
.bind(1_i64)
|
||||
.execute(&mut conn)
|
||||
.instrument(info_span!("DB health"))
|
||||
.await
|
||||
.wrap_error()?;
|
||||
|
||||
Ok(with_header("ok", CONTENT_TYPE, TEXT_PLAIN.to_string()))
|
||||
}
|
43
crates/core/src/handlers/mod.rs
Normal file
43
crates/core/src/handlers/mod.rs
Normal file
@ -0,0 +1,43 @@
|
||||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#![allow(clippy::unused_async)] // Some warp filters need that
|
||||
|
||||
use sqlx::PgPool;
|
||||
use warp::{Filter, Rejection, Reply};
|
||||
|
||||
use crate::{config::RootConfig, templates::Templates};
|
||||
|
||||
mod health;
|
||||
mod oauth2;
|
||||
mod views;
|
||||
|
||||
use self::{health::filter as health, oauth2::filter as oauth2, views::filter as views};
|
||||
|
||||
#[must_use] pub fn root(
|
||||
pool: &PgPool,
|
||||
templates: &Templates,
|
||||
config: &RootConfig,
|
||||
) -> impl Filter<Extract = (impl Reply,), Error = Rejection> + Clone + Send + Sync + 'static {
|
||||
health(pool)
|
||||
.or(oauth2(pool, templates, &config.oauth2, &config.cookies))
|
||||
.or(views(
|
||||
pool,
|
||||
templates,
|
||||
&config.oauth2,
|
||||
&config.csrf,
|
||||
&config.cookies,
|
||||
))
|
||||
.with(warp::log(module_path!()))
|
||||
}
|
459
crates/core/src/handlers/oauth2/authorization.rs
Normal file
459
crates/core/src/handlers/oauth2/authorization.rs
Normal file
@ -0,0 +1,459 @@
|
||||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::{
|
||||
collections::{HashMap, HashSet},
|
||||
convert::TryFrom,
|
||||
};
|
||||
|
||||
use chrono::Duration;
|
||||
use hyper::{
|
||||
header::LOCATION,
|
||||
http::uri::{Parts, PathAndQuery, Uri},
|
||||
StatusCode,
|
||||
};
|
||||
use itertools::Itertools;
|
||||
use oauth2_types::{
|
||||
errors::{ErrorResponse, InvalidRequest, OAuth2Error},
|
||||
pkce,
|
||||
requests::{
|
||||
AccessTokenResponse, AuthorizationRequest, AuthorizationResponse, ResponseMode,
|
||||
ResponseType,
|
||||
},
|
||||
};
|
||||
use rand::{distributions::Alphanumeric, thread_rng, Rng};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use sqlx::{PgPool, Postgres, Transaction};
|
||||
use url::Url;
|
||||
use warp::{
|
||||
redirect::see_other,
|
||||
reject::InvalidQuery,
|
||||
reply::{html, with_header},
|
||||
Filter, Rejection, Reply,
|
||||
};
|
||||
|
||||
use crate::{
|
||||
config::{CookiesConfig, OAuth2ClientConfig, OAuth2Config},
|
||||
errors::WrapError,
|
||||
filters::{
|
||||
database::with_transaction,
|
||||
session::{with_optional_session, with_session},
|
||||
with_templates,
|
||||
},
|
||||
handlers::views::LoginRequest,
|
||||
storage::{
|
||||
oauth2::{
|
||||
access_token::add_access_token,
|
||||
refresh_token::add_refresh_token,
|
||||
session::{get_session_by_id, start_session},
|
||||
},
|
||||
SessionInfo,
|
||||
},
|
||||
templates::{FormPostContext, Templates},
|
||||
tokens,
|
||||
};
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct PartialParams {
|
||||
client_id: Option<String>,
|
||||
redirect_uri: Option<String>,
|
||||
/*
|
||||
response_type: Option<String>,
|
||||
response_mode: Option<String>,
|
||||
*/
|
||||
}
|
||||
|
||||
enum ReplyOrBackToClient {
|
||||
Reply(Box<dyn Reply>),
|
||||
BackToClient {
|
||||
params: Value,
|
||||
redirect_uri: Url,
|
||||
response_mode: ResponseMode,
|
||||
},
|
||||
Error(Box<dyn OAuth2Error>),
|
||||
}
|
||||
|
||||
fn back_to_client<T>(
|
||||
mut redirect_uri: Url,
|
||||
response_mode: ResponseMode,
|
||||
params: T,
|
||||
templates: &Templates,
|
||||
) -> anyhow::Result<Box<dyn Reply>>
|
||||
where
|
||||
T: Serialize,
|
||||
{
|
||||
#[derive(Serialize)]
|
||||
struct AllParams<'s, T> {
|
||||
#[serde(flatten, skip_serializing_if = "Option::is_none")]
|
||||
existing: Option<HashMap<&'s str, &'s str>>,
|
||||
|
||||
#[serde(flatten)]
|
||||
params: T,
|
||||
}
|
||||
|
||||
match response_mode {
|
||||
ResponseMode::Query => {
|
||||
let existing: Option<HashMap<&str, &str>> = redirect_uri
|
||||
.query()
|
||||
.map(|qs| serde_urlencoded::from_str(qs))
|
||||
.transpose()?;
|
||||
|
||||
let merged = AllParams { existing, params };
|
||||
|
||||
let new_qs = serde_urlencoded::to_string(merged)?;
|
||||
|
||||
redirect_uri.set_query(Some(&new_qs));
|
||||
|
||||
Ok(Box::new(with_header(
|
||||
StatusCode::SEE_OTHER,
|
||||
LOCATION,
|
||||
redirect_uri.as_str(),
|
||||
)))
|
||||
}
|
||||
ResponseMode::Fragment => {
|
||||
let existing: Option<HashMap<&str, &str>> = redirect_uri
|
||||
.fragment()
|
||||
.map(|qs| serde_urlencoded::from_str(qs))
|
||||
.transpose()?;
|
||||
|
||||
let merged = AllParams { existing, params };
|
||||
|
||||
let new_qs = serde_urlencoded::to_string(merged)?;
|
||||
|
||||
redirect_uri.set_fragment(Some(&new_qs));
|
||||
|
||||
Ok(Box::new(with_header(
|
||||
StatusCode::SEE_OTHER,
|
||||
LOCATION,
|
||||
redirect_uri.as_str(),
|
||||
)))
|
||||
}
|
||||
ResponseMode::FormPost => {
|
||||
let ctx = FormPostContext::new(redirect_uri, params);
|
||||
let rendered = templates.render_form_post(&ctx)?;
|
||||
Ok(Box::new(html(rendered)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct Params {
|
||||
#[serde(flatten)]
|
||||
auth: AuthorizationRequest,
|
||||
|
||||
#[serde(flatten)]
|
||||
pkce: Option<pkce::Request>,
|
||||
}
|
||||
|
||||
/// Given a list of response types and an optional user-defined response mode,
|
||||
/// figure out what response mode must be used, and emit an error if the
|
||||
/// suggested response mode isn't allowed for the given response types.
|
||||
fn resolve_response_mode(
|
||||
response_type: &HashSet<ResponseType>,
|
||||
suggested_response_mode: Option<ResponseMode>,
|
||||
) -> anyhow::Result<ResponseMode> {
|
||||
use ResponseMode as M;
|
||||
use ResponseType as T;
|
||||
|
||||
// If the response type includes either "token" or "id_token", the default
|
||||
// response mode is "fragment" and the response mode "query" must not be
|
||||
// used
|
||||
if response_type.contains(&T::Token) || response_type.contains(&T::IdToken) {
|
||||
match suggested_response_mode {
|
||||
None => Ok(M::Fragment),
|
||||
Some(M::Query) => Err(anyhow::anyhow!("invalid response mode")),
|
||||
Some(mode) => Ok(mode),
|
||||
}
|
||||
} else {
|
||||
// In other cases, all response modes are allowed, defaulting to "query"
|
||||
Ok(suggested_response_mode.unwrap_or(M::Query))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn filter(
|
||||
pool: &PgPool,
|
||||
templates: &Templates,
|
||||
oauth2_config: &OAuth2Config,
|
||||
cookies_config: &CookiesConfig,
|
||||
) -> impl Filter<Extract = (impl Reply,), Error = Rejection> + Clone + Send + Sync + 'static {
|
||||
let clients = oauth2_config.clients.clone();
|
||||
let authorize = warp::path!("oauth2" / "authorize")
|
||||
.and(warp::get())
|
||||
.map(move || clients.clone())
|
||||
.and(warp::query())
|
||||
.and(with_optional_session(pool, cookies_config))
|
||||
.and(with_transaction(pool))
|
||||
.and_then(get);
|
||||
|
||||
let step = warp::path!("oauth2" / "authorize" / "step")
|
||||
.and(warp::get())
|
||||
.and(warp::query().map(|s: StepRequest| s.id))
|
||||
.and(with_session(pool, cookies_config))
|
||||
.and(with_transaction(pool))
|
||||
.and_then(step);
|
||||
|
||||
let clients = oauth2_config.clients.clone();
|
||||
authorize
|
||||
.or(step)
|
||||
.unify()
|
||||
.recover(recover)
|
||||
.unify()
|
||||
.and(warp::query())
|
||||
.and(warp::any().map(move || clients.clone()))
|
||||
.and(with_templates(templates))
|
||||
.and_then(actually_reply)
|
||||
}
|
||||
|
||||
async fn recover(rejection: Rejection) -> Result<ReplyOrBackToClient, Rejection> {
|
||||
if rejection.find::<InvalidQuery>().is_some() {
|
||||
Ok(ReplyOrBackToClient::Error(Box::new(InvalidRequest)))
|
||||
} else {
|
||||
Err(rejection)
|
||||
}
|
||||
}
|
||||
|
||||
async fn actually_reply(
|
||||
rep: ReplyOrBackToClient,
|
||||
q: PartialParams,
|
||||
clients: Vec<OAuth2ClientConfig>,
|
||||
templates: Templates,
|
||||
) -> Result<impl Reply, Rejection> {
|
||||
let (redirect_uri, response_mode, params) = match rep {
|
||||
ReplyOrBackToClient::Reply(r) => return Ok(r),
|
||||
ReplyOrBackToClient::BackToClient {
|
||||
redirect_uri,
|
||||
response_mode,
|
||||
params,
|
||||
} => (redirect_uri, response_mode, params),
|
||||
ReplyOrBackToClient::Error(error) => {
|
||||
let PartialParams {
|
||||
client_id,
|
||||
redirect_uri,
|
||||
..
|
||||
} = q;
|
||||
|
||||
// First, disover the client
|
||||
let client = client_id.and_then(|client_id| {
|
||||
clients
|
||||
.into_iter()
|
||||
.find(|client| client.client_id == client_id)
|
||||
});
|
||||
|
||||
let client = match client {
|
||||
Some(client) => client,
|
||||
None => return Ok(Box::new(html(templates.render_error(&error.into())?))),
|
||||
};
|
||||
|
||||
let redirect_uri: Result<Option<Url>, _> = redirect_uri.map(|r| r.parse()).transpose();
|
||||
let redirect_uri = match redirect_uri {
|
||||
Ok(r) => r,
|
||||
Err(_) => return Ok(Box::new(html(templates.render_error(&error.into())?))),
|
||||
};
|
||||
|
||||
let redirect_uri = client.resolve_redirect_uri(&redirect_uri);
|
||||
let redirect_uri = match redirect_uri {
|
||||
Ok(r) => r,
|
||||
Err(_) => return Ok(Box::new(html(templates.render_error(&error.into())?))),
|
||||
};
|
||||
|
||||
let reply: ErrorResponse = error.into();
|
||||
let reply = serde_json::to_value(&reply).wrap_error()?;
|
||||
// TODO: resolve response mode
|
||||
(redirect_uri.clone(), ResponseMode::Query, reply)
|
||||
}
|
||||
};
|
||||
|
||||
// TODO: we should include the state param in errors
|
||||
back_to_client(redirect_uri, response_mode, params, &templates).wrap_error()
|
||||
}
|
||||
|
||||
async fn get(
|
||||
clients: Vec<OAuth2ClientConfig>,
|
||||
params: Params,
|
||||
maybe_session: Option<SessionInfo>,
|
||||
mut txn: Transaction<'_, Postgres>,
|
||||
) -> Result<ReplyOrBackToClient, Rejection> {
|
||||
// First, find out what client it is
|
||||
let client = clients
|
||||
.into_iter()
|
||||
.find(|client| client.client_id == params.auth.client_id)
|
||||
.ok_or_else(|| anyhow::anyhow!("could not find client"))
|
||||
.wrap_error()?;
|
||||
|
||||
let maybe_session_id = maybe_session.as_ref().map(SessionInfo::key);
|
||||
|
||||
let scope: String = {
|
||||
let it = params.auth.scope.iter().map(ToString::to_string);
|
||||
Itertools::intersperse(it, " ".to_string()).collect()
|
||||
};
|
||||
|
||||
let redirect_uri = client
|
||||
.resolve_redirect_uri(¶ms.auth.redirect_uri)
|
||||
.wrap_error()?;
|
||||
let response_type = ¶ms.auth.response_type;
|
||||
let response_mode =
|
||||
resolve_response_mode(response_type, params.auth.response_mode).wrap_error()?;
|
||||
|
||||
let oauth2_session = start_session(
|
||||
&mut txn,
|
||||
maybe_session_id,
|
||||
&client.client_id,
|
||||
redirect_uri,
|
||||
&scope,
|
||||
params.auth.state.as_deref(),
|
||||
params.auth.nonce.as_deref(),
|
||||
params.auth.max_age,
|
||||
response_type,
|
||||
response_mode,
|
||||
)
|
||||
.await
|
||||
.wrap_error()?;
|
||||
|
||||
// Generate the code at this stage, since we have the PKCE params ready
|
||||
if response_type.contains(&ResponseType::Code) {
|
||||
// 32 random alphanumeric characters, about 190bit of entropy
|
||||
let code: String = thread_rng()
|
||||
.sample_iter(&Alphanumeric)
|
||||
.take(32)
|
||||
.map(char::from)
|
||||
.collect();
|
||||
|
||||
oauth2_session
|
||||
.add_code(&mut txn, &code, ¶ms.pkce)
|
||||
.await
|
||||
.wrap_error()?;
|
||||
};
|
||||
|
||||
// Do we already have a user session for this oauth2 session?
|
||||
let user_session = oauth2_session.fetch_session(&mut txn).await.wrap_error()?;
|
||||
|
||||
if let Some(user_session) = user_session {
|
||||
step(oauth2_session.id, user_session, txn).await
|
||||
} else {
|
||||
// If not, redirect the user to the login page
|
||||
txn.commit().await.wrap_error()?;
|
||||
|
||||
let next = StepRequest::new(oauth2_session.id)
|
||||
.build_uri()
|
||||
.wrap_error()?
|
||||
.to_string();
|
||||
|
||||
let destination = LoginRequest::new(Some(next)).build_uri().wrap_error()?;
|
||||
Ok(ReplyOrBackToClient::Reply(Box::new(see_other(destination))))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Serialize)]
|
||||
struct StepRequest {
|
||||
id: i64,
|
||||
}
|
||||
|
||||
impl StepRequest {
|
||||
fn new(id: i64) -> Self {
|
||||
Self { id }
|
||||
}
|
||||
|
||||
fn build_uri(&self) -> anyhow::Result<Uri> {
|
||||
let qs = serde_urlencoded::to_string(self)?;
|
||||
let path_and_query = PathAndQuery::try_from(format!("/oauth2/authorize/step?{}", qs))?;
|
||||
let uri = Uri::from_parts({
|
||||
let mut parts = Parts::default();
|
||||
parts.path_and_query = Some(path_and_query);
|
||||
parts
|
||||
})?;
|
||||
Ok(uri)
|
||||
}
|
||||
}
|
||||
|
||||
async fn step(
|
||||
oauth2_session_id: i64,
|
||||
user_session: SessionInfo,
|
||||
mut txn: Transaction<'_, Postgres>,
|
||||
) -> Result<ReplyOrBackToClient, Rejection> {
|
||||
let mut oauth2_session = get_session_by_id(&mut txn, oauth2_session_id)
|
||||
.await
|
||||
.wrap_error()?;
|
||||
|
||||
let user_session = oauth2_session
|
||||
.match_or_set_session(&mut txn, user_session)
|
||||
.await
|
||||
.wrap_error()?;
|
||||
|
||||
let response_mode = oauth2_session.response_mode().wrap_error()?;
|
||||
let response_type = oauth2_session.response_type().wrap_error()?;
|
||||
let redirect_uri = oauth2_session.redirect_uri().wrap_error()?;
|
||||
|
||||
// Check if the active session is valid
|
||||
let reply = if user_session.active
|
||||
&& user_session.last_authd_at >= oauth2_session.max_auth_time()
|
||||
{
|
||||
// Yep! Let's complete the auth now
|
||||
let mut params = AuthorizationResponse {
|
||||
state: oauth2_session.state.clone(),
|
||||
..AuthorizationResponse::default()
|
||||
};
|
||||
|
||||
// Did they request an auth code?
|
||||
if response_type.contains(&ResponseType::Code) {
|
||||
params.code = Some(oauth2_session.fetch_code(&mut txn).await.wrap_error()?);
|
||||
}
|
||||
|
||||
// Did they request an access token?
|
||||
if response_type.contains(&ResponseType::Token) {
|
||||
let ttl = Duration::minutes(5);
|
||||
let (access_token, refresh_token) = {
|
||||
let mut rng = thread_rng();
|
||||
(
|
||||
tokens::generate(&mut rng, tokens::TokenType::AccessToken),
|
||||
tokens::generate(&mut rng, tokens::TokenType::RefreshToken),
|
||||
)
|
||||
};
|
||||
|
||||
let access_token = add_access_token(&mut txn, oauth2_session_id, &access_token, ttl)
|
||||
.await
|
||||
.wrap_error()?;
|
||||
|
||||
let refresh_token =
|
||||
add_refresh_token(&mut txn, oauth2_session_id, access_token.id, &refresh_token)
|
||||
.await
|
||||
.wrap_error()?;
|
||||
|
||||
params.response = Some(
|
||||
AccessTokenResponse::new(access_token.token)
|
||||
.with_expires_in(ttl)
|
||||
.with_refresh_token(refresh_token.token),
|
||||
);
|
||||
}
|
||||
|
||||
// Did they request an ID token?
|
||||
if response_type.contains(&ResponseType::IdToken) {
|
||||
todo!("id tokens are not implemented yet");
|
||||
}
|
||||
|
||||
let params = serde_json::to_value(¶ms).unwrap();
|
||||
ReplyOrBackToClient::BackToClient {
|
||||
redirect_uri,
|
||||
response_mode,
|
||||
params,
|
||||
}
|
||||
} else {
|
||||
// Ask for a reauth
|
||||
// TODO: have the OAuth2 session ID in there
|
||||
ReplyOrBackToClient::Reply(Box::new(see_other(Uri::from_static("/reauth"))))
|
||||
};
|
||||
|
||||
txn.commit().await.wrap_error()?;
|
||||
Ok(reply)
|
||||
}
|
87
crates/core/src/handlers/oauth2/discovery.rs
Normal file
87
crates/core/src/handlers/oauth2/discovery.rs
Normal file
@ -0,0 +1,87 @@
|
||||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::collections::HashSet;
|
||||
|
||||
use oauth2_types::{
|
||||
oidc::Metadata,
|
||||
requests::{ClientAuthenticationMethod, GrantType, ResponseMode},
|
||||
};
|
||||
use warp::{Filter, Rejection, Reply};
|
||||
|
||||
use crate::config::OAuth2Config;
|
||||
|
||||
pub(super) fn filter(
|
||||
config: &OAuth2Config,
|
||||
) -> impl Filter<Extract = (impl Reply,), Error = Rejection> + Clone + Send + Sync + 'static {
|
||||
let base = config.issuer.clone();
|
||||
|
||||
let response_modes_supported = Some({
|
||||
let mut s = HashSet::new();
|
||||
s.insert(ResponseMode::FormPost);
|
||||
s.insert(ResponseMode::Query);
|
||||
s.insert(ResponseMode::Fragment);
|
||||
s
|
||||
});
|
||||
|
||||
let response_types_supported = Some({
|
||||
let mut s = HashSet::new();
|
||||
s.insert("code".to_string());
|
||||
s.insert("token".to_string());
|
||||
s.insert("id_token".to_string());
|
||||
s.insert("code token".to_string());
|
||||
s.insert("code id_token".to_string());
|
||||
s.insert("token id_token".to_string());
|
||||
s.insert("code token id_token".to_string());
|
||||
s
|
||||
});
|
||||
|
||||
let grant_types_supported = Some({
|
||||
let mut s = HashSet::new();
|
||||
s.insert(GrantType::AuthorizationCode);
|
||||
s.insert(GrantType::RefreshToken);
|
||||
s
|
||||
});
|
||||
|
||||
let token_endpoint_auth_methods_supported = Some({
|
||||
let mut s = HashSet::new();
|
||||
s.insert(ClientAuthenticationMethod::ClientSecretBasic);
|
||||
s.insert(ClientAuthenticationMethod::ClientSecretPost);
|
||||
s.insert(ClientAuthenticationMethod::None);
|
||||
s
|
||||
});
|
||||
|
||||
let metadata = Metadata {
|
||||
authorization_endpoint: base.join("oauth2/authorize").ok(),
|
||||
token_endpoint: base.join("oauth2/token").ok(),
|
||||
jwks_uri: base.join("oauth2/keys.json").ok(),
|
||||
introspection_endpoint: base.join("oauth2/introspect").ok(),
|
||||
userinfo_endpoint: base.join("oauth2/userinfo").ok(),
|
||||
issuer: base,
|
||||
registration_endpoint: None,
|
||||
scopes_supported: None,
|
||||
response_types_supported,
|
||||
response_modes_supported,
|
||||
grant_types_supported,
|
||||
token_endpoint_auth_methods_supported,
|
||||
code_challenge_methods_supported: None,
|
||||
};
|
||||
|
||||
let cors = warp::cors().allow_any_origin();
|
||||
|
||||
warp::path!(".well-known" / "openid-configuration")
|
||||
.and(warp::get())
|
||||
.map(move || warp::reply::json(&metadata))
|
||||
.with(cors)
|
||||
}
|
136
crates/core/src/handlers/oauth2/introspection.rs
Normal file
136
crates/core/src/handlers/oauth2/introspection.rs
Normal file
@ -0,0 +1,136 @@
|
||||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use chrono::Utc;
|
||||
use oauth2_types::requests::{IntrospectionRequest, IntrospectionResponse, TokenTypeHint};
|
||||
use sqlx::{pool::PoolConnection, PgPool, Postgres};
|
||||
use tracing::{info, warn};
|
||||
use warp::{Filter, Rejection, Reply};
|
||||
|
||||
use crate::{
|
||||
config::{OAuth2ClientConfig, OAuth2Config},
|
||||
errors::WrapError,
|
||||
filters::{
|
||||
client::{with_client_auth, ClientAuthentication},
|
||||
database::with_connection,
|
||||
},
|
||||
storage::oauth2::{access_token::lookup_access_token, refresh_token::lookup_refresh_token},
|
||||
tokens,
|
||||
};
|
||||
|
||||
pub fn filter(
|
||||
pool: &PgPool,
|
||||
oauth2_config: &OAuth2Config,
|
||||
) -> impl Filter<Extract = (impl Reply,), Error = Rejection> + Clone + Send + Sync + 'static {
|
||||
warp::path!("oauth2" / "introspect")
|
||||
.and(warp::post())
|
||||
.and(with_connection(pool))
|
||||
.and(with_client_auth(oauth2_config))
|
||||
.and_then(introspect)
|
||||
.recover(recover)
|
||||
}
|
||||
|
||||
const INACTIVE: IntrospectionResponse = IntrospectionResponse {
|
||||
active: false,
|
||||
scope: None,
|
||||
client_id: None,
|
||||
username: None,
|
||||
token_type: None,
|
||||
exp: None,
|
||||
iat: None,
|
||||
nbf: None,
|
||||
sub: None,
|
||||
aud: None,
|
||||
iss: None,
|
||||
jti: None,
|
||||
};
|
||||
|
||||
async fn introspect(
|
||||
mut conn: PoolConnection<Postgres>,
|
||||
auth: ClientAuthentication,
|
||||
client: OAuth2ClientConfig,
|
||||
params: IntrospectionRequest,
|
||||
) -> Result<impl Reply, Rejection> {
|
||||
// Token introspection is only allowed by confidential clients
|
||||
if auth.public() {
|
||||
warn!(?client, "Client tried to introspect");
|
||||
// TODO: have a nice error here
|
||||
return Ok(warp::reply::json(&INACTIVE));
|
||||
}
|
||||
|
||||
let token = ¶ms.token;
|
||||
let token_type = tokens::check(token).wrap_error()?;
|
||||
if let Some(hint) = params.token_type_hint {
|
||||
if token_type != hint {
|
||||
info!("Token type hint did not match");
|
||||
return Ok(warp::reply::json(&INACTIVE));
|
||||
}
|
||||
}
|
||||
|
||||
let reply = match token_type {
|
||||
tokens::TokenType::AccessToken => {
|
||||
let token = lookup_access_token(&mut conn, token).await.wrap_error()?;
|
||||
let exp = token.exp();
|
||||
|
||||
// Check it is active and did not expire
|
||||
if !token.active || exp < Utc::now() {
|
||||
info!(?token, "Access token expired");
|
||||
return Ok(warp::reply::json(&INACTIVE));
|
||||
}
|
||||
|
||||
IntrospectionResponse {
|
||||
active: true,
|
||||
scope: None, // TODO: parse back scopes
|
||||
client_id: Some(token.client_id.clone()),
|
||||
username: Some(token.username.clone()),
|
||||
token_type: Some(TokenTypeHint::AccessToken),
|
||||
exp: Some(exp),
|
||||
iat: Some(token.created_at),
|
||||
nbf: Some(token.created_at),
|
||||
sub: None,
|
||||
aud: None,
|
||||
iss: None,
|
||||
jti: None,
|
||||
}
|
||||
}
|
||||
tokens::TokenType::RefreshToken => {
|
||||
let token = lookup_refresh_token(&mut conn, token).await.wrap_error()?;
|
||||
|
||||
IntrospectionResponse {
|
||||
active: true,
|
||||
scope: None, // TODO: parse back scopes
|
||||
client_id: Some(token.client_id),
|
||||
username: None,
|
||||
token_type: Some(TokenTypeHint::RefreshToken),
|
||||
exp: None,
|
||||
iat: None,
|
||||
nbf: None,
|
||||
sub: None,
|
||||
aud: None,
|
||||
iss: None,
|
||||
jti: None,
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
Ok(warp::reply::json(&reply))
|
||||
}
|
||||
|
||||
async fn recover(rejection: Rejection) -> Result<impl Reply, Rejection> {
|
||||
if rejection.is_not_found() {
|
||||
Err(rejection)
|
||||
} else {
|
||||
Ok(warp::reply::json(&INACTIVE))
|
||||
}
|
||||
}
|
30
crates/core/src/handlers/oauth2/keys.rs
Normal file
30
crates/core/src/handlers/oauth2/keys.rs
Normal file
@ -0,0 +1,30 @@
|
||||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use warp::{Filter, Rejection, Reply};
|
||||
|
||||
use crate::config::OAuth2Config;
|
||||
|
||||
pub(super) fn filter(
|
||||
config: &OAuth2Config,
|
||||
) -> impl Filter<Extract = (impl Reply,), Error = Rejection> + Clone + Send + Sync + 'static {
|
||||
let jwks = config.keys.to_public_jwks();
|
||||
|
||||
let cors = warp::cors().allow_any_origin();
|
||||
|
||||
warp::path!("oauth2" / "keys.json")
|
||||
.and(warp::get())
|
||||
.map(move || warp::reply::json(&jwks))
|
||||
.with(cors)
|
||||
}
|
53
crates/core/src/handlers/oauth2/mod.rs
Normal file
53
crates/core/src/handlers/oauth2/mod.rs
Normal file
@ -0,0 +1,53 @@
|
||||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use sqlx::PgPool;
|
||||
use warp::{Filter, Rejection, Reply};
|
||||
|
||||
use crate::{
|
||||
config::{CookiesConfig, OAuth2Config},
|
||||
templates::Templates,
|
||||
};
|
||||
|
||||
mod authorization;
|
||||
mod discovery;
|
||||
mod introspection;
|
||||
mod keys;
|
||||
mod token;
|
||||
mod userinfo;
|
||||
|
||||
use self::{
|
||||
authorization::filter as authorization, discovery::filter as discovery,
|
||||
introspection::filter as introspection, keys::filter as keys, token::filter as token,
|
||||
userinfo::filter as userinfo,
|
||||
};
|
||||
|
||||
pub fn filter(
|
||||
pool: &PgPool,
|
||||
templates: &Templates,
|
||||
oauth2_config: &OAuth2Config,
|
||||
cookies_config: &CookiesConfig,
|
||||
) -> impl Filter<Extract = (impl Reply,), Error = Rejection> + Clone + Send + Sync + 'static {
|
||||
discovery(oauth2_config)
|
||||
.or(keys(oauth2_config))
|
||||
.or(authorization(
|
||||
pool,
|
||||
templates,
|
||||
oauth2_config,
|
||||
cookies_config,
|
||||
))
|
||||
.or(userinfo(pool, oauth2_config))
|
||||
.or(introspection(pool, oauth2_config))
|
||||
.or(token(pool, oauth2_config))
|
||||
}
|
276
crates/core/src/handlers/oauth2/token.rs
Normal file
276
crates/core/src/handlers/oauth2/token.rs
Normal file
@ -0,0 +1,276 @@
|
||||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use anyhow::Context;
|
||||
use chrono::Duration;
|
||||
use data_encoding::BASE64URL_NOPAD;
|
||||
use headers::{CacheControl, Pragma};
|
||||
use hyper::StatusCode;
|
||||
use jwt_compact::{Claims, Header, TimeOptions};
|
||||
use oauth2_types::{
|
||||
errors::{InvalidGrant, OAuth2Error, OAuth2ErrorCode, UnauthorizedClient},
|
||||
requests::{
|
||||
AccessTokenRequest, AccessTokenResponse, AuthorizationCodeGrant, RefreshTokenGrant,
|
||||
},
|
||||
};
|
||||
use rand::thread_rng;
|
||||
use serde::Serialize;
|
||||
use serde_with::skip_serializing_none;
|
||||
use sha2::{Digest, Sha256};
|
||||
use sqlx::{pool::PoolConnection, Acquire, PgPool, Postgres};
|
||||
use url::Url;
|
||||
use warp::{
|
||||
reject::Reject,
|
||||
reply::{json, with_status},
|
||||
Filter, Rejection, Reply,
|
||||
};
|
||||
|
||||
use crate::{
|
||||
config::{KeySet, OAuth2ClientConfig, OAuth2Config},
|
||||
errors::WrapError,
|
||||
filters::{
|
||||
client::{with_client_auth, ClientAuthentication},
|
||||
database::with_connection,
|
||||
headers::typed_header,
|
||||
with_keys,
|
||||
},
|
||||
storage::oauth2::{
|
||||
access_token::{add_access_token, revoke_access_token},
|
||||
authorization_code::lookup_code,
|
||||
refresh_token::{add_refresh_token, lookup_refresh_token, replace_refresh_token},
|
||||
},
|
||||
tokens,
|
||||
};
|
||||
|
||||
#[skip_serializing_none]
|
||||
#[derive(Serialize, Debug)]
|
||||
struct CustomClaims {
|
||||
#[serde(rename = "iss")]
|
||||
issuer: Url,
|
||||
#[serde(rename = "sub")]
|
||||
subject: String,
|
||||
#[serde(rename = "aud")]
|
||||
audiences: Vec<String>,
|
||||
nonce: Option<String>,
|
||||
at_hash: String,
|
||||
c_hash: String,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Error {
|
||||
json: serde_json::Value,
|
||||
status: StatusCode,
|
||||
}
|
||||
|
||||
impl Reject for Error {}
|
||||
|
||||
fn error<T, E>(e: E) -> Result<T, Rejection>
|
||||
where
|
||||
E: OAuth2ErrorCode + 'static,
|
||||
{
|
||||
let status = e.status();
|
||||
let json = serde_json::to_value(e.into_response()).wrap_error()?;
|
||||
Err(Error { json, status }.into())
|
||||
}
|
||||
|
||||
pub fn filter(
|
||||
pool: &PgPool,
|
||||
oauth2_config: &OAuth2Config,
|
||||
) -> impl Filter<Extract = (impl Reply,), Error = Rejection> + Clone + Send + Sync + 'static {
|
||||
let issuer = oauth2_config.issuer.clone();
|
||||
warp::path!("oauth2" / "token")
|
||||
.and(warp::post())
|
||||
.and(with_client_auth(oauth2_config))
|
||||
.and(with_keys(oauth2_config))
|
||||
.and(warp::any().map(move || issuer.clone()))
|
||||
.and(with_connection(pool))
|
||||
.and_then(token)
|
||||
.recover(recover)
|
||||
}
|
||||
|
||||
async fn recover(rejection: Rejection) -> Result<impl Reply, Rejection> {
|
||||
if let Some(Error { json, status }) = rejection.find::<Error>() {
|
||||
Ok(with_status(warp::reply::json(json), *status))
|
||||
} else {
|
||||
Err(rejection)
|
||||
}
|
||||
}
|
||||
|
||||
async fn token(
|
||||
_auth: ClientAuthentication,
|
||||
client: OAuth2ClientConfig,
|
||||
req: AccessTokenRequest,
|
||||
keys: KeySet,
|
||||
issuer: Url,
|
||||
mut conn: PoolConnection<Postgres>,
|
||||
) -> Result<impl Reply, Rejection> {
|
||||
let reply = match req {
|
||||
AccessTokenRequest::AuthorizationCode(grant) => {
|
||||
let reply = authorization_code_grant(&grant, &client, &keys, issuer, &mut conn).await?;
|
||||
json(&reply)
|
||||
}
|
||||
AccessTokenRequest::RefreshToken(grant) => {
|
||||
let reply = refresh_token_grant(&grant, &client, &mut conn).await?;
|
||||
json(&reply)
|
||||
}
|
||||
_ => {
|
||||
let reply = InvalidGrant.into_response();
|
||||
json(&reply)
|
||||
}
|
||||
};
|
||||
|
||||
Ok(typed_header(
|
||||
Pragma::no_cache(),
|
||||
typed_header(CacheControl::new().with_no_store(), reply),
|
||||
))
|
||||
}
|
||||
|
||||
fn hash<H: Digest>(mut hasher: H, token: &str) -> anyhow::Result<String> {
|
||||
hasher.update(token);
|
||||
let hash = hasher.finalize();
|
||||
// Left-most 128bit
|
||||
let bits = hash
|
||||
.get(..16)
|
||||
.context("failed to get first 128 bits of hash")?;
|
||||
Ok(BASE64URL_NOPAD.encode(bits))
|
||||
}
|
||||
|
||||
async fn authorization_code_grant(
|
||||
grant: &AuthorizationCodeGrant,
|
||||
client: &OAuth2ClientConfig,
|
||||
keys: &KeySet,
|
||||
issuer: Url,
|
||||
conn: &mut PoolConnection<Postgres>,
|
||||
) -> Result<AccessTokenResponse, Rejection> {
|
||||
let mut txn = conn.begin().await.wrap_error()?;
|
||||
let code = lookup_code(&mut txn, &grant.code).await.wrap_error()?;
|
||||
if client.client_id != code.client_id {
|
||||
return error(UnauthorizedClient);
|
||||
}
|
||||
|
||||
// TODO: verify PKCE
|
||||
// TODO: make the code invalid
|
||||
let ttl = Duration::minutes(5);
|
||||
let (access_token, refresh_token) = {
|
||||
let mut rng = thread_rng();
|
||||
(
|
||||
tokens::generate(&mut rng, tokens::TokenType::AccessToken),
|
||||
tokens::generate(&mut rng, tokens::TokenType::RefreshToken),
|
||||
)
|
||||
};
|
||||
|
||||
let access_token = add_access_token(&mut txn, code.oauth2_session_id, &access_token, ttl)
|
||||
.await
|
||||
.wrap_error()?;
|
||||
|
||||
let refresh_token = add_refresh_token(
|
||||
&mut txn,
|
||||
code.oauth2_session_id,
|
||||
access_token.id,
|
||||
&refresh_token,
|
||||
)
|
||||
.await
|
||||
.wrap_error()?;
|
||||
|
||||
// TODO: generate id_token only if the "openid" scope was asked
|
||||
let header = Header::default();
|
||||
let options = TimeOptions::default();
|
||||
let claims = Claims::new(CustomClaims {
|
||||
issuer,
|
||||
// TODO: get that from the session
|
||||
subject: "random-subject".to_string(),
|
||||
audiences: vec![client.client_id.clone()],
|
||||
nonce: code.nonce,
|
||||
at_hash: hash(Sha256::new(), &access_token.token).wrap_error()?,
|
||||
c_hash: hash(Sha256::new(), &grant.code).wrap_error()?,
|
||||
})
|
||||
.set_duration_and_issuance(&options, Duration::minutes(30));
|
||||
let id_token = keys
|
||||
.token(crate::config::Algorithm::Rs256, header, claims)
|
||||
.await
|
||||
.context("could not sign ID token")
|
||||
.wrap_error()?;
|
||||
|
||||
// TODO: have the scopes back here
|
||||
let params = AccessTokenResponse::new(access_token.token)
|
||||
.with_expires_in(ttl)
|
||||
.with_refresh_token(refresh_token.token)
|
||||
.with_id_token(id_token);
|
||||
|
||||
txn.commit().await.wrap_error()?;
|
||||
|
||||
Ok(params)
|
||||
}
|
||||
|
||||
async fn refresh_token_grant(
|
||||
grant: &RefreshTokenGrant,
|
||||
client: &OAuth2ClientConfig,
|
||||
conn: &mut PoolConnection<Postgres>,
|
||||
) -> Result<AccessTokenResponse, Rejection> {
|
||||
let mut txn = conn.begin().await.wrap_error()?;
|
||||
// TODO: scope handling
|
||||
let refresh_token_lookup = lookup_refresh_token(&mut txn, &grant.refresh_token)
|
||||
.await
|
||||
.wrap_error()?;
|
||||
|
||||
if client.client_id != refresh_token_lookup.client_id {
|
||||
// As per https://datatracker.ietf.org/doc/html/rfc6749#section-5.2
|
||||
return error(InvalidGrant);
|
||||
}
|
||||
|
||||
let ttl = Duration::minutes(5);
|
||||
let (access_token, refresh_token) = {
|
||||
let mut rng = thread_rng();
|
||||
(
|
||||
tokens::generate(&mut rng, tokens::TokenType::AccessToken),
|
||||
tokens::generate(&mut rng, tokens::TokenType::RefreshToken),
|
||||
)
|
||||
};
|
||||
|
||||
let access_token = add_access_token(
|
||||
&mut txn,
|
||||
refresh_token_lookup.oauth2_session_id,
|
||||
&access_token,
|
||||
ttl,
|
||||
)
|
||||
.await
|
||||
.wrap_error()?;
|
||||
|
||||
let refresh_token = add_refresh_token(
|
||||
&mut txn,
|
||||
refresh_token_lookup.oauth2_session_id,
|
||||
access_token.id,
|
||||
&refresh_token,
|
||||
)
|
||||
.await
|
||||
.wrap_error()?;
|
||||
|
||||
replace_refresh_token(&mut txn, refresh_token_lookup.id, refresh_token.id)
|
||||
.await
|
||||
.wrap_error()?;
|
||||
|
||||
if let Some(access_token_id) = refresh_token_lookup.oauth2_access_token_id {
|
||||
revoke_access_token(&mut txn, access_token_id)
|
||||
.await
|
||||
.wrap_error()?;
|
||||
}
|
||||
|
||||
let params = AccessTokenResponse::new(access_token.token)
|
||||
.with_expires_in(ttl)
|
||||
.with_refresh_token(refresh_token.token);
|
||||
|
||||
txn.commit().await.wrap_error()?;
|
||||
|
||||
Ok(params)
|
||||
}
|
43
crates/core/src/handlers/oauth2/userinfo.rs
Normal file
43
crates/core/src/handlers/oauth2/userinfo.rs
Normal file
@ -0,0 +1,43 @@
|
||||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use serde::Serialize;
|
||||
use sqlx::PgPool;
|
||||
use warp::{Filter, Rejection, Reply};
|
||||
|
||||
use crate::{
|
||||
config::OAuth2Config, filters::authenticate::with_authentication,
|
||||
storage::oauth2::access_token::OAuth2AccessTokenLookup,
|
||||
};
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct UserInfo {
|
||||
sub: String,
|
||||
}
|
||||
|
||||
pub(super) fn filter(
|
||||
pool: &PgPool,
|
||||
_config: &OAuth2Config,
|
||||
) -> impl Filter<Extract = (impl Reply,), Error = Rejection> + Clone + Send + Sync + 'static {
|
||||
warp::path!("oauth2" / "userinfo")
|
||||
.and(warp::get().or(warp::post()).unify())
|
||||
.and(with_authentication(pool))
|
||||
.and_then(userinfo)
|
||||
}
|
||||
|
||||
async fn userinfo(token: OAuth2AccessTokenLookup) -> Result<impl Reply, Rejection> {
|
||||
Ok(warp::reply::json(&UserInfo {
|
||||
sub: token.username,
|
||||
}))
|
||||
}
|
64
crates/core/src/handlers/views/index.rs
Normal file
64
crates/core/src/handlers/views/index.rs
Normal file
@ -0,0 +1,64 @@
|
||||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use sqlx::PgPool;
|
||||
use url::Url;
|
||||
use warp::{reply::html, Filter, Rejection, Reply};
|
||||
|
||||
use crate::{
|
||||
config::{CookiesConfig, CsrfConfig, OAuth2Config},
|
||||
filters::{
|
||||
cookies::{with_cookie_saver, EncryptedCookieSaver},
|
||||
csrf::updated_csrf_token,
|
||||
session::with_optional_session,
|
||||
with_templates, CsrfToken,
|
||||
},
|
||||
storage::SessionInfo,
|
||||
templates::{IndexContext, TemplateContext, Templates},
|
||||
};
|
||||
|
||||
pub(super) fn filter(
|
||||
pool: &PgPool,
|
||||
templates: &Templates,
|
||||
oauth2_config: &OAuth2Config,
|
||||
csrf_config: &CsrfConfig,
|
||||
cookies_config: &CookiesConfig,
|
||||
) -> impl Filter<Extract = (impl Reply,), Error = Rejection> + Clone + Send + Sync + 'static {
|
||||
let discovery_url = oauth2_config.discovery_url();
|
||||
warp::path::end()
|
||||
.and(warp::get())
|
||||
.map(move || discovery_url.clone())
|
||||
.and(with_templates(templates))
|
||||
.and(with_cookie_saver(cookies_config))
|
||||
.and(updated_csrf_token(cookies_config, csrf_config))
|
||||
.and(with_optional_session(pool, cookies_config))
|
||||
.and_then(get)
|
||||
}
|
||||
|
||||
async fn get(
|
||||
discovery_url: Url,
|
||||
templates: Templates,
|
||||
cookie_saver: EncryptedCookieSaver,
|
||||
csrf_token: CsrfToken,
|
||||
session: Option<SessionInfo>,
|
||||
) -> Result<impl Reply, Rejection> {
|
||||
let ctx = IndexContext::new(discovery_url)
|
||||
.maybe_with_session(session)
|
||||
.with_csrf(&csrf_token);
|
||||
|
||||
let content = templates.render_index(&ctx)?;
|
||||
let reply = html(content);
|
||||
let reply = cookie_saver.save_encrypted(&csrf_token, reply)?;
|
||||
Ok(Box::new(reply))
|
||||
}
|
154
crates/core/src/handlers/views/login.rs
Normal file
154
crates/core/src/handlers/views/login.rs
Normal file
@ -0,0 +1,154 @@
|
||||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::convert::TryFrom;
|
||||
|
||||
use hyper::http::uri::{Parts, PathAndQuery, Uri};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sqlx::{pool::PoolConnection, PgPool, Postgres};
|
||||
use warp::{reply::html, Filter, Rejection, Reply};
|
||||
|
||||
use crate::{
|
||||
config::{CookiesConfig, CsrfConfig},
|
||||
errors::{WrapError, WrapFormError},
|
||||
filters::{
|
||||
cookies::{with_cookie_saver, EncryptedCookieSaver},
|
||||
csrf::{protected_form, updated_csrf_token},
|
||||
database::with_connection,
|
||||
session::{with_optional_session, SessionCookie},
|
||||
with_templates, CsrfToken,
|
||||
},
|
||||
storage::{login, SessionInfo},
|
||||
templates::{LoginContext, LoginFormField, TemplateContext, Templates},
|
||||
};
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct LoginRequest {
|
||||
next: Option<String>,
|
||||
}
|
||||
|
||||
impl LoginRequest {
|
||||
pub fn new(next: Option<String>) -> Self {
|
||||
Self { next }
|
||||
}
|
||||
|
||||
pub fn build_uri(&self) -> anyhow::Result<Uri> {
|
||||
let qs = serde_urlencoded::to_string(self)?;
|
||||
let path_and_query = PathAndQuery::try_from(format!("/login?{}", qs))?;
|
||||
let uri = Uri::from_parts({
|
||||
let mut parts = Parts::default();
|
||||
parts.path_and_query = Some(path_and_query);
|
||||
parts
|
||||
})?;
|
||||
Ok(uri)
|
||||
}
|
||||
|
||||
fn redirect(self) -> Result<impl Reply, Rejection> {
|
||||
let uri: Uri = Uri::from_parts({
|
||||
let mut parts = Parts::default();
|
||||
parts.path_and_query = Some(
|
||||
self.next
|
||||
.map(warp::http::uri::PathAndQuery::try_from)
|
||||
.transpose()
|
||||
.wrap_error()?
|
||||
.unwrap_or_else(|| PathAndQuery::from_static("/")),
|
||||
);
|
||||
parts
|
||||
})
|
||||
.wrap_error()?;
|
||||
Ok(warp::redirect::see_other(uri))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct LoginForm {
|
||||
username: String,
|
||||
password: String,
|
||||
}
|
||||
|
||||
pub(super) fn filter(
|
||||
pool: &PgPool,
|
||||
templates: &Templates,
|
||||
csrf_config: &CsrfConfig,
|
||||
cookies_config: &CookiesConfig,
|
||||
) -> impl Filter<Extract = (impl Reply,), Error = Rejection> + Clone + Send + Sync + 'static {
|
||||
let get = warp::get()
|
||||
.and(with_templates(templates))
|
||||
.and(with_cookie_saver(cookies_config))
|
||||
.and(updated_csrf_token(cookies_config, csrf_config))
|
||||
.and(warp::query())
|
||||
.and(with_optional_session(pool, cookies_config))
|
||||
.and_then(get);
|
||||
|
||||
let post = warp::post()
|
||||
.and(with_templates(templates))
|
||||
.and(with_connection(pool))
|
||||
.and(with_cookie_saver(cookies_config))
|
||||
.and(updated_csrf_token(cookies_config, csrf_config))
|
||||
.and(protected_form(cookies_config))
|
||||
.and(warp::query())
|
||||
.and_then(post);
|
||||
|
||||
warp::path!("login").and(get.or(post))
|
||||
}
|
||||
|
||||
async fn get(
|
||||
templates: Templates,
|
||||
cookie_saver: EncryptedCookieSaver,
|
||||
csrf_token: CsrfToken,
|
||||
query: LoginRequest,
|
||||
maybe_session: Option<SessionInfo>,
|
||||
) -> Result<Box<dyn Reply>, Rejection> {
|
||||
if maybe_session.is_some() {
|
||||
Ok(Box::new(query.redirect()?))
|
||||
} else {
|
||||
let ctx = LoginContext::default().with_csrf(&csrf_token);
|
||||
let content = templates.render_login(&ctx)?;
|
||||
let reply = html(content);
|
||||
let reply = cookie_saver.save_encrypted(&csrf_token, reply)?;
|
||||
Ok(Box::new(reply))
|
||||
}
|
||||
}
|
||||
|
||||
async fn post(
|
||||
templates: Templates,
|
||||
mut conn: PoolConnection<Postgres>,
|
||||
cookie_saver: EncryptedCookieSaver,
|
||||
csrf_token: CsrfToken,
|
||||
form: LoginForm,
|
||||
query: LoginRequest,
|
||||
) -> Result<Box<dyn Reply>, Rejection> {
|
||||
use crate::storage::user::LoginError;
|
||||
// TODO: recover
|
||||
match login(&mut conn, &form.username, form.password).await {
|
||||
Ok(session_info) => {
|
||||
let session_cookie = SessionCookie::from_session_info(&session_info);
|
||||
let reply = query.redirect()?;
|
||||
let reply = cookie_saver.save_encrypted(&session_cookie, reply)?;
|
||||
Ok(Box::new(reply))
|
||||
}
|
||||
Err(e) => {
|
||||
let errored_form = match e {
|
||||
LoginError::NotFound { .. } => e.on_field(LoginFormField::Username),
|
||||
LoginError::Authentication { .. } => e.on_field(LoginFormField::Password),
|
||||
LoginError::Other(_) => e.on_form(),
|
||||
};
|
||||
let ctx = LoginContext::with_form_error(errored_form).with_csrf(&csrf_token);
|
||||
let content = templates.render_login(&ctx)?;
|
||||
let reply = html(content);
|
||||
let reply = cookie_saver.save_encrypted(&csrf_token, reply)?;
|
||||
Ok(Box::new(reply))
|
||||
}
|
||||
}
|
||||
}
|
44
crates/core/src/handlers/views/logout.rs
Normal file
44
crates/core/src/handlers/views/logout.rs
Normal file
@ -0,0 +1,44 @@
|
||||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use sqlx::{pool::PoolConnection, PgPool, Postgres};
|
||||
use warp::{hyper::Uri, Filter, Rejection, Reply};
|
||||
|
||||
use crate::{
|
||||
config::CookiesConfig,
|
||||
errors::WrapError,
|
||||
filters::{csrf::protected_form, database::with_connection, session::with_session},
|
||||
storage::SessionInfo,
|
||||
};
|
||||
|
||||
pub(super) fn filter(
|
||||
pool: &PgPool,
|
||||
cookies_config: &CookiesConfig,
|
||||
) -> impl Filter<Extract = (impl Reply,), Error = Rejection> + Clone + Send + Sync + 'static {
|
||||
warp::path!("logout")
|
||||
.and(warp::post())
|
||||
.and(with_session(pool, cookies_config))
|
||||
.and(with_connection(pool))
|
||||
.and(protected_form(cookies_config))
|
||||
.and_then(post)
|
||||
}
|
||||
|
||||
async fn post(
|
||||
session: SessionInfo,
|
||||
mut conn: PoolConnection<Postgres>,
|
||||
_form: (),
|
||||
) -> Result<impl Reply, Rejection> {
|
||||
session.end(&mut conn).await.wrap_error()?;
|
||||
Ok(warp::redirect(Uri::from_static("/login")))
|
||||
}
|
48
crates/core/src/handlers/views/mod.rs
Normal file
48
crates/core/src/handlers/views/mod.rs
Normal file
@ -0,0 +1,48 @@
|
||||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use sqlx::PgPool;
|
||||
use warp::{Filter, Rejection, Reply};
|
||||
|
||||
use crate::{
|
||||
config::{CookiesConfig, CsrfConfig, OAuth2Config},
|
||||
templates::Templates,
|
||||
};
|
||||
|
||||
mod index;
|
||||
mod login;
|
||||
mod logout;
|
||||
mod reauth;
|
||||
mod register;
|
||||
|
||||
pub use self::login::LoginRequest;
|
||||
use self::{
|
||||
index::filter as index, login::filter as login, logout::filter as logout,
|
||||
reauth::filter as reauth, register::filter as register,
|
||||
};
|
||||
|
||||
pub(super) fn filter(
|
||||
pool: &PgPool,
|
||||
templates: &Templates,
|
||||
oauth2_config: &OAuth2Config,
|
||||
csrf_config: &CsrfConfig,
|
||||
cookies_config: &CookiesConfig,
|
||||
) -> impl Filter<Extract = (impl Reply,), Error = Rejection> + Clone + Send + Sync + 'static {
|
||||
index(pool, templates, oauth2_config, csrf_config, cookies_config)
|
||||
.or(login(pool, templates, csrf_config, cookies_config))
|
||||
.or(register(pool, templates, csrf_config, cookies_config))
|
||||
.or(logout(pool, cookies_config))
|
||||
.or(reauth(pool, templates, csrf_config, cookies_config))
|
||||
.boxed()
|
||||
}
|
85
crates/core/src/handlers/views/reauth.rs
Normal file
85
crates/core/src/handlers/views/reauth.rs
Normal file
@ -0,0 +1,85 @@
|
||||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use serde::Deserialize;
|
||||
use sqlx::{pool::PoolConnection, PgPool, Postgres};
|
||||
use warp::{hyper::Uri, reply::html, Filter, Rejection, Reply};
|
||||
|
||||
use crate::{
|
||||
config::{CookiesConfig, CsrfConfig},
|
||||
errors::WrapError,
|
||||
filters::{
|
||||
cookies::{with_cookie_saver, EncryptedCookieSaver},
|
||||
csrf::{protected_form, updated_csrf_token},
|
||||
database::with_connection,
|
||||
session::with_session,
|
||||
with_templates, CsrfToken,
|
||||
},
|
||||
storage::SessionInfo,
|
||||
templates::{TemplateContext, Templates},
|
||||
};
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
struct ReauthForm {
|
||||
password: String,
|
||||
}
|
||||
|
||||
pub(super) fn filter(
|
||||
pool: &PgPool,
|
||||
templates: &Templates,
|
||||
csrf_config: &CsrfConfig,
|
||||
cookies_config: &CookiesConfig,
|
||||
) -> impl Filter<Extract = (impl Reply,), Error = Rejection> + Clone + Send + Sync + 'static {
|
||||
let get = warp::get()
|
||||
.and(with_templates(templates))
|
||||
.and(with_cookie_saver(cookies_config))
|
||||
.and(updated_csrf_token(cookies_config, csrf_config))
|
||||
.and(with_session(pool, cookies_config))
|
||||
.and_then(get);
|
||||
|
||||
let post = warp::post()
|
||||
.and(with_session(pool, cookies_config))
|
||||
.and(with_connection(pool))
|
||||
.and(protected_form(cookies_config))
|
||||
.and_then(post);
|
||||
|
||||
warp::path!("reauth").and(get.or(post))
|
||||
}
|
||||
|
||||
async fn get(
|
||||
templates: Templates,
|
||||
cookie_saver: EncryptedCookieSaver,
|
||||
csrf_token: CsrfToken,
|
||||
session: SessionInfo,
|
||||
) -> Result<impl Reply, Rejection> {
|
||||
let ctx = ().with_session(session).with_csrf(&csrf_token);
|
||||
|
||||
let content = templates.render_reauth(&ctx)?;
|
||||
let reply = html(content);
|
||||
let reply = cookie_saver.save_encrypted(&csrf_token, reply)?;
|
||||
Ok(reply)
|
||||
}
|
||||
|
||||
async fn post(
|
||||
session: SessionInfo,
|
||||
mut conn: PoolConnection<Postgres>,
|
||||
form: ReauthForm,
|
||||
) -> Result<impl Reply, Rejection> {
|
||||
let _session = session
|
||||
.reauth(&mut conn, form.password)
|
||||
.await
|
||||
.wrap_error()?;
|
||||
|
||||
Ok(warp::redirect(Uri::from_static("/")))
|
||||
}
|
147
crates/core/src/handlers/views/register.rs
Normal file
147
crates/core/src/handlers/views/register.rs
Normal file
@ -0,0 +1,147 @@
|
||||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::convert::TryFrom;
|
||||
|
||||
use argon2::Argon2;
|
||||
use hyper::http::uri::{Parts, PathAndQuery, Uri};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sqlx::{pool::PoolConnection, PgPool, Postgres};
|
||||
use warp::{reply::html, Filter, Rejection, Reply};
|
||||
|
||||
use crate::{
|
||||
config::{CookiesConfig, CsrfConfig},
|
||||
errors::WrapError,
|
||||
filters::{
|
||||
cookies::{with_cookie_saver, EncryptedCookieSaver},
|
||||
csrf::{protected_form, updated_csrf_token},
|
||||
database::with_connection,
|
||||
session::{with_optional_session, SessionCookie},
|
||||
with_templates, CsrfToken,
|
||||
},
|
||||
storage::{register_user, user::start_session, SessionInfo},
|
||||
templates::{TemplateContext, Templates},
|
||||
};
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct RegisterRequest {
|
||||
next: Option<String>,
|
||||
}
|
||||
|
||||
impl RegisterRequest {
|
||||
#[allow(dead_code)]
|
||||
pub fn new(next: Option<String>) -> Self {
|
||||
Self { next }
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn build_uri(&self) -> anyhow::Result<Uri> {
|
||||
let qs = serde_urlencoded::to_string(self)?;
|
||||
let path_and_query = PathAndQuery::try_from(format!("/register?{}", qs))?;
|
||||
let uri = Uri::from_parts({
|
||||
let mut parts = Parts::default();
|
||||
parts.path_and_query = Some(path_and_query);
|
||||
parts
|
||||
})?;
|
||||
Ok(uri)
|
||||
}
|
||||
|
||||
fn redirect(self) -> Result<impl Reply, Rejection> {
|
||||
let uri: Uri = Uri::from_parts({
|
||||
let mut parts = Parts::default();
|
||||
parts.path_and_query = Some(
|
||||
self.next
|
||||
.map(warp::http::uri::PathAndQuery::try_from)
|
||||
.transpose()
|
||||
.wrap_error()?
|
||||
.unwrap_or_else(|| PathAndQuery::from_static("/")),
|
||||
);
|
||||
parts
|
||||
})
|
||||
.wrap_error()?;
|
||||
Ok(warp::redirect::see_other(uri))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct RegisterForm {
|
||||
username: String,
|
||||
password: String,
|
||||
password_confirm: String,
|
||||
}
|
||||
|
||||
pub(super) fn filter(
|
||||
pool: &PgPool,
|
||||
templates: &Templates,
|
||||
csrf_config: &CsrfConfig,
|
||||
cookies_config: &CookiesConfig,
|
||||
) -> impl Filter<Extract = (impl Reply,), Error = Rejection> + Clone + Send + Sync + 'static {
|
||||
let get = warp::get()
|
||||
.and(with_templates(templates))
|
||||
.and(with_cookie_saver(cookies_config))
|
||||
.and(updated_csrf_token(cookies_config, csrf_config))
|
||||
.and(warp::query())
|
||||
.and(with_optional_session(pool, cookies_config))
|
||||
.and_then(get);
|
||||
|
||||
let post = warp::post()
|
||||
.and(with_connection(pool))
|
||||
.and(with_cookie_saver(cookies_config))
|
||||
.and(protected_form(cookies_config))
|
||||
.and(warp::query())
|
||||
.and_then(post);
|
||||
|
||||
warp::path!("register").and(get.or(post))
|
||||
}
|
||||
|
||||
async fn get(
|
||||
templates: Templates,
|
||||
cookie_saver: EncryptedCookieSaver,
|
||||
csrf_token: CsrfToken,
|
||||
query: RegisterRequest,
|
||||
maybe_session: Option<SessionInfo>,
|
||||
) -> Result<Box<dyn Reply>, Rejection> {
|
||||
if maybe_session.is_some() {
|
||||
Ok(Box::new(query.redirect()?))
|
||||
} else {
|
||||
let ctx = ().with_csrf(&csrf_token);
|
||||
let content = templates.render_register(&ctx)?;
|
||||
let reply = html(content);
|
||||
let reply = cookie_saver.save_encrypted(&csrf_token, reply)?;
|
||||
Ok(Box::new(reply))
|
||||
}
|
||||
}
|
||||
|
||||
async fn post(
|
||||
mut conn: PoolConnection<Postgres>,
|
||||
cookie_saver: EncryptedCookieSaver,
|
||||
form: RegisterForm,
|
||||
query: RegisterRequest,
|
||||
) -> Result<impl Reply, Rejection> {
|
||||
if form.password != form.password_confirm {
|
||||
return Err(anyhow::anyhow!("password mismatch")).wrap_error();
|
||||
}
|
||||
|
||||
let pfh = Argon2::default();
|
||||
let user = register_user(&mut conn, pfh, &form.username, &form.password)
|
||||
.await
|
||||
.wrap_error()?;
|
||||
|
||||
let session_info = start_session(&mut conn, user).await.wrap_error()?;
|
||||
|
||||
let session_cookie = SessionCookie::from_session_info(&session_info);
|
||||
let reply = query.redirect()?;
|
||||
let reply = cookie_saver.save_encrypted(&session_cookie, reply)?;
|
||||
Ok(reply)
|
||||
}
|
31
crates/core/src/lib.rs
Normal file
31
crates/core/src/lib.rs
Normal file
@ -0,0 +1,31 @@
|
||||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#![forbid(unsafe_code)]
|
||||
#![deny(clippy::all)]
|
||||
#![warn(clippy::pedantic)]
|
||||
#![allow(clippy::module_name_repetitions)]
|
||||
#![allow(clippy::missing_panics_doc)]
|
||||
#![allow(clippy::missing_errors_doc)]
|
||||
#![allow(clippy::implicit_hasher)]
|
||||
|
||||
pub(crate) use mas_config as config;
|
||||
|
||||
pub mod errors;
|
||||
pub mod filters;
|
||||
pub mod handlers;
|
||||
pub mod storage;
|
||||
pub mod tasks;
|
||||
pub mod templates;
|
||||
pub mod tokens;
|
24
crates/core/src/storage/mod.rs
Normal file
24
crates/core/src/storage/mod.rs
Normal file
@ -0,0 +1,24 @@
|
||||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#![allow(clippy::used_underscore_binding)] // This is needed by sqlx macros
|
||||
|
||||
use sqlx::migrate::Migrator;
|
||||
|
||||
pub mod oauth2;
|
||||
pub mod user;
|
||||
|
||||
pub use self::user::{login, lookup_active_session, register_user, SessionInfo, User};
|
||||
|
||||
pub static MIGRATOR: Migrator = sqlx::migrate!();
|
141
crates/core/src/storage/oauth2/access_token.rs
Normal file
141
crates/core/src/storage/oauth2/access_token.rs
Normal file
@ -0,0 +1,141 @@
|
||||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::convert::TryFrom;
|
||||
|
||||
use anyhow::Context;
|
||||
use chrono::{DateTime, Duration, Utc};
|
||||
use serde::Serialize;
|
||||
use sqlx::{Executor, FromRow, Postgres};
|
||||
|
||||
#[derive(FromRow, Serialize)]
|
||||
pub struct OAuth2AccessToken {
|
||||
pub id: i64,
|
||||
pub oauth2_session_id: i64,
|
||||
pub token: String,
|
||||
expires_after: i32,
|
||||
created_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
pub async fn add_access_token(
|
||||
executor: impl Executor<'_, Database = Postgres>,
|
||||
oauth2_session_id: i64,
|
||||
token: &str,
|
||||
expires_after: Duration,
|
||||
) -> anyhow::Result<OAuth2AccessToken> {
|
||||
// Checked convertion of duration to i32, maxing at i32::MAX
|
||||
let expires_after = i32::try_from(expires_after.num_seconds()).unwrap_or(i32::MAX);
|
||||
|
||||
sqlx::query_as!(
|
||||
OAuth2AccessToken,
|
||||
r#"
|
||||
INSERT INTO oauth2_access_tokens
|
||||
(oauth2_session_id, token, expires_after)
|
||||
VALUES
|
||||
($1, $2, $3)
|
||||
RETURNING
|
||||
id, oauth2_session_id, token, expires_after, created_at
|
||||
"#,
|
||||
oauth2_session_id,
|
||||
token,
|
||||
expires_after,
|
||||
)
|
||||
.fetch_one(executor)
|
||||
.await
|
||||
.context("could not insert oauth2 access token")
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct OAuth2AccessTokenLookup {
|
||||
pub active: bool,
|
||||
pub username: String,
|
||||
pub client_id: String,
|
||||
pub scope: String,
|
||||
pub created_at: DateTime<Utc>,
|
||||
expires_after: i32,
|
||||
}
|
||||
|
||||
impl OAuth2AccessTokenLookup {
|
||||
#[must_use] pub fn exp(&self) -> DateTime<Utc> {
|
||||
self.created_at + Duration::seconds(i64::from(self.expires_after))
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn lookup_access_token(
|
||||
executor: impl Executor<'_, Database = Postgres>,
|
||||
token: &str,
|
||||
) -> anyhow::Result<OAuth2AccessTokenLookup> {
|
||||
sqlx::query_as!(
|
||||
OAuth2AccessTokenLookup,
|
||||
r#"
|
||||
SELECT
|
||||
u.username AS "username!",
|
||||
us.active AS "active!",
|
||||
os.client_id AS "client_id!",
|
||||
os.scope AS "scope!",
|
||||
at.created_at AS "created_at!",
|
||||
at.expires_after AS "expires_after!"
|
||||
FROM oauth2_access_tokens at
|
||||
INNER JOIN oauth2_sessions os
|
||||
ON os.id = at.oauth2_session_id
|
||||
INNER JOIN user_sessions us
|
||||
ON us.id = os.user_session_id
|
||||
INNER JOIN users u
|
||||
ON u.id = us.user_id
|
||||
WHERE at.token = $1
|
||||
"#,
|
||||
token,
|
||||
)
|
||||
.fetch_one(executor)
|
||||
.await
|
||||
.context("could not introspect oauth2 access token")
|
||||
}
|
||||
|
||||
pub async fn revoke_access_token(
|
||||
executor: impl Executor<'_, Database = Postgres>,
|
||||
id: i64,
|
||||
) -> anyhow::Result<()> {
|
||||
let res = sqlx::query!(
|
||||
r#"
|
||||
DELETE FROM oauth2_access_tokens
|
||||
WHERE id = $1
|
||||
"#,
|
||||
id,
|
||||
)
|
||||
.execute(executor)
|
||||
.await
|
||||
.context("could not revoke access tokens")?;
|
||||
|
||||
if res.rows_affected() == 1 {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(anyhow::anyhow!("no row were affected when revoking token"))
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn cleanup_expired(
|
||||
executor: impl Executor<'_, Database = Postgres>,
|
||||
) -> anyhow::Result<u64> {
|
||||
let res = sqlx::query!(
|
||||
r#"
|
||||
DELETE FROM oauth2_access_tokens
|
||||
WHERE created_at + (expires_after * INTERVAL '1 second') + INTERVAL '15 minutes' < now()
|
||||
"#,
|
||||
)
|
||||
.execute(executor)
|
||||
.await
|
||||
.context("could not cleanup expired access tokens")?;
|
||||
|
||||
Ok(res.rows_affected())
|
||||
}
|
92
crates/core/src/storage/oauth2/authorization_code.rs
Normal file
92
crates/core/src/storage/oauth2/authorization_code.rs
Normal file
@ -0,0 +1,92 @@
|
||||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use anyhow::Context;
|
||||
use oauth2_types::pkce;
|
||||
use serde::Serialize;
|
||||
use sqlx::{Executor, FromRow, Postgres};
|
||||
|
||||
#[derive(FromRow, Serialize)]
|
||||
pub struct OAuth2Code {
|
||||
id: i64,
|
||||
oauth2_session_id: i64,
|
||||
pub code: String,
|
||||
code_challenge: Option<String>,
|
||||
code_challenge_method: Option<i16>,
|
||||
}
|
||||
|
||||
pub async fn add_code(
|
||||
executor: impl Executor<'_, Database = Postgres>,
|
||||
oauth2_session_id: i64,
|
||||
code: &str,
|
||||
code_challenge: &Option<pkce::Request>,
|
||||
) -> anyhow::Result<OAuth2Code> {
|
||||
let code_challenge_method = code_challenge
|
||||
.as_ref()
|
||||
.map(|c| c.code_challenge_method as i16);
|
||||
let code_challenge = code_challenge.as_ref().map(|c| &c.code_challenge);
|
||||
sqlx::query_as!(
|
||||
OAuth2Code,
|
||||
r#"
|
||||
INSERT INTO oauth2_codes
|
||||
(oauth2_session_id, code, code_challenge_method, code_challenge)
|
||||
VALUES
|
||||
($1, $2, $3, $4)
|
||||
RETURNING
|
||||
id, oauth2_session_id, code, code_challenge_method, code_challenge
|
||||
"#,
|
||||
oauth2_session_id,
|
||||
code,
|
||||
code_challenge_method,
|
||||
code_challenge,
|
||||
)
|
||||
.fetch_one(executor)
|
||||
.await
|
||||
.context("could not insert oauth2 authorization code")
|
||||
}
|
||||
|
||||
pub struct OAuth2CodeLookup {
|
||||
pub id: i64,
|
||||
pub oauth2_session_id: i64,
|
||||
pub client_id: String,
|
||||
pub redirect_uri: String,
|
||||
pub scope: String,
|
||||
pub nonce: Option<String>,
|
||||
}
|
||||
|
||||
pub async fn lookup_code(
|
||||
executor: impl Executor<'_, Database = Postgres>,
|
||||
code: &str,
|
||||
) -> anyhow::Result<OAuth2CodeLookup> {
|
||||
sqlx::query_as!(
|
||||
OAuth2CodeLookup,
|
||||
r#"
|
||||
SELECT
|
||||
oc.id,
|
||||
os.id AS "oauth2_session_id!",
|
||||
os.client_id AS "client_id!",
|
||||
os.redirect_uri,
|
||||
os.scope AS "scope!",
|
||||
os.nonce
|
||||
FROM oauth2_codes oc
|
||||
INNER JOIN oauth2_sessions os
|
||||
ON os.id = oc.oauth2_session_id
|
||||
WHERE oc.code = $1
|
||||
"#,
|
||||
code,
|
||||
)
|
||||
.fetch_one(executor)
|
||||
.await
|
||||
.context("could not lookup oauth2 code")
|
||||
}
|
18
crates/core/src/storage/oauth2/mod.rs
Normal file
18
crates/core/src/storage/oauth2/mod.rs
Normal file
@ -0,0 +1,18 @@
|
||||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
pub mod access_token;
|
||||
pub mod authorization_code;
|
||||
pub mod refresh_token;
|
||||
pub mod session;
|
114
crates/core/src/storage/oauth2/refresh_token.rs
Normal file
114
crates/core/src/storage/oauth2/refresh_token.rs
Normal file
@ -0,0 +1,114 @@
|
||||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use anyhow::Context;
|
||||
use chrono::{DateTime, Utc};
|
||||
use sqlx::{Executor, Postgres};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct OAuth2RefreshToken {
|
||||
pub id: i64,
|
||||
oauth2_session_id: i64,
|
||||
oauth2_access_token_id: Option<i64>,
|
||||
pub token: String,
|
||||
next_token_id: Option<i64>,
|
||||
created_at: DateTime<Utc>,
|
||||
updated_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
pub async fn add_refresh_token(
|
||||
executor: impl Executor<'_, Database = Postgres>,
|
||||
oauth2_session_id: i64,
|
||||
oauth2_access_token_id: i64,
|
||||
token: &str,
|
||||
) -> anyhow::Result<OAuth2RefreshToken> {
|
||||
sqlx::query_as!(
|
||||
OAuth2RefreshToken,
|
||||
r#"
|
||||
INSERT INTO oauth2_refresh_tokens
|
||||
(oauth2_session_id, oauth2_access_token_id, token)
|
||||
VALUES
|
||||
($1, $2, $3)
|
||||
RETURNING
|
||||
id, oauth2_session_id, oauth2_access_token_id, token, next_token_id,
|
||||
created_at, updated_at
|
||||
"#,
|
||||
oauth2_session_id,
|
||||
oauth2_access_token_id,
|
||||
token,
|
||||
)
|
||||
.fetch_one(executor)
|
||||
.await
|
||||
.context("could not insert oauth2 refresh token")
|
||||
}
|
||||
|
||||
pub struct OAuth2RefreshTokenLookup {
|
||||
pub id: i64,
|
||||
pub oauth2_session_id: i64,
|
||||
pub oauth2_access_token_id: Option<i64>,
|
||||
pub client_id: String,
|
||||
pub scope: String,
|
||||
}
|
||||
|
||||
pub async fn lookup_refresh_token(
|
||||
executor: impl Executor<'_, Database = Postgres>,
|
||||
token: &str,
|
||||
) -> anyhow::Result<OAuth2RefreshTokenLookup> {
|
||||
sqlx::query_as!(
|
||||
OAuth2RefreshTokenLookup,
|
||||
r#"
|
||||
SELECT
|
||||
rt.id,
|
||||
rt.oauth2_session_id,
|
||||
rt.oauth2_access_token_id,
|
||||
os.client_id AS "client_id!",
|
||||
os.scope AS "scope!"
|
||||
FROM oauth2_refresh_tokens rt
|
||||
INNER JOIN oauth2_sessions os
|
||||
ON os.id = rt.oauth2_session_id
|
||||
WHERE rt.token = $1 AND rt.next_token_id IS NULL
|
||||
"#,
|
||||
token,
|
||||
)
|
||||
.fetch_one(executor)
|
||||
.await
|
||||
.context("failed to fetch oauth2 refresh token")
|
||||
}
|
||||
|
||||
pub async fn replace_refresh_token(
|
||||
executor: impl Executor<'_, Database = Postgres>,
|
||||
refresh_token_id: i64,
|
||||
next_refresh_token_id: i64,
|
||||
) -> anyhow::Result<()> {
|
||||
let res = sqlx::query!(
|
||||
r#"
|
||||
UPDATE oauth2_refresh_tokens
|
||||
SET next_token_id = $2
|
||||
WHERE id = $1
|
||||
"#,
|
||||
refresh_token_id,
|
||||
next_refresh_token_id
|
||||
)
|
||||
.execute(executor)
|
||||
.await
|
||||
.context("failed to update oauth2 refresh token")?;
|
||||
|
||||
if res.rows_affected() == 1 {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(anyhow::anyhow!(
|
||||
"no row were affected when updating refresh token"
|
||||
))
|
||||
}
|
||||
}
|
214
crates/core/src/storage/oauth2/session.rs
Normal file
214
crates/core/src/storage/oauth2/session.rs
Normal file
@ -0,0 +1,214 @@
|
||||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::{collections::HashSet, convert::TryFrom, str::FromStr, string::ToString};
|
||||
|
||||
use anyhow::Context;
|
||||
use chrono::{DateTime, Duration, Utc};
|
||||
use itertools::Itertools;
|
||||
use oauth2_types::{
|
||||
pkce,
|
||||
requests::{ResponseMode, ResponseType},
|
||||
};
|
||||
use serde::Serialize;
|
||||
use sqlx::{Executor, FromRow, Postgres};
|
||||
use url::Url;
|
||||
|
||||
use super::{
|
||||
super::{user::lookup_session, SessionInfo},
|
||||
authorization_code::{add_code, OAuth2Code},
|
||||
};
|
||||
|
||||
#[derive(FromRow, Serialize)]
|
||||
pub struct OAuth2Session {
|
||||
pub id: i64,
|
||||
user_session_id: Option<i64>,
|
||||
pub client_id: String,
|
||||
redirect_uri: String,
|
||||
scope: String,
|
||||
pub state: Option<String>,
|
||||
nonce: Option<String>,
|
||||
max_age: Option<i32>,
|
||||
response_type: String,
|
||||
response_mode: String,
|
||||
|
||||
created_at: DateTime<Utc>,
|
||||
updated_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
impl OAuth2Session {
|
||||
pub async fn add_code<'e>(
|
||||
&self,
|
||||
executor: impl Executor<'e, Database = Postgres>,
|
||||
code: &str,
|
||||
code_challenge: &Option<pkce::Request>,
|
||||
) -> anyhow::Result<OAuth2Code> {
|
||||
add_code(executor, self.id, code, code_challenge).await
|
||||
}
|
||||
|
||||
pub async fn fetch_session(
|
||||
&self,
|
||||
executor: impl Executor<'_, Database = Postgres>,
|
||||
) -> anyhow::Result<Option<SessionInfo>> {
|
||||
match self.user_session_id {
|
||||
Some(id) => {
|
||||
let info = lookup_session(executor, id).await?;
|
||||
Ok(Some(info))
|
||||
}
|
||||
None => Ok(None),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn fetch_code(
|
||||
&self,
|
||||
executor: impl Executor<'_, Database = Postgres>,
|
||||
) -> anyhow::Result<String> {
|
||||
get_code_for_session(executor, self.id).await
|
||||
}
|
||||
|
||||
pub async fn match_or_set_session(
|
||||
&mut self,
|
||||
executor: impl Executor<'_, Database = Postgres>,
|
||||
session: SessionInfo,
|
||||
) -> anyhow::Result<SessionInfo> {
|
||||
match self.user_session_id {
|
||||
Some(id) if id == session.key() => Ok(session),
|
||||
Some(id) => Err(anyhow::anyhow!(
|
||||
"session mismatch, expected {}, got {}",
|
||||
id,
|
||||
session.key()
|
||||
)),
|
||||
None => {
|
||||
sqlx::query!(
|
||||
"UPDATE oauth2_sessions SET user_session_id = $1 WHERE id = $2",
|
||||
session.key(),
|
||||
self.id,
|
||||
)
|
||||
.execute(executor)
|
||||
.await
|
||||
.context("could not update oauth2 session")?;
|
||||
Ok(session)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[must_use] pub fn max_auth_time(&self) -> Option<DateTime<Utc>> {
|
||||
self.max_age
|
||||
.map(|d| Duration::seconds(i64::from(d)))
|
||||
.map(|d| self.created_at - d)
|
||||
}
|
||||
|
||||
pub fn response_type(&self) -> anyhow::Result<HashSet<ResponseType>> {
|
||||
self.response_type
|
||||
.split(' ')
|
||||
.map(|s| {
|
||||
ResponseType::from_str(s).with_context(|| format!("invalid response type {}", s))
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn response_mode(&self) -> anyhow::Result<ResponseMode> {
|
||||
self.response_mode.parse().context("invalid response mode")
|
||||
}
|
||||
|
||||
pub fn redirect_uri(&self) -> anyhow::Result<Url> {
|
||||
self.redirect_uri.parse().context("invalid redirect uri")
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub async fn start_session(
|
||||
executor: impl Executor<'_, Database = Postgres>,
|
||||
optional_session_id: Option<i64>,
|
||||
client_id: &str,
|
||||
redirect_uri: &Url,
|
||||
scope: &str,
|
||||
state: Option<&str>,
|
||||
nonce: Option<&str>,
|
||||
max_age: Option<Duration>,
|
||||
response_type: &HashSet<ResponseType>,
|
||||
response_mode: ResponseMode,
|
||||
) -> anyhow::Result<OAuth2Session> {
|
||||
// Checked convertion of duration to i32, maxing at i32::MAX
|
||||
let max_age = max_age.map(|d| i32::try_from(d.num_seconds()).unwrap_or(i32::MAX));
|
||||
let response_mode = response_mode.to_string();
|
||||
let redirect_uri = redirect_uri.to_string();
|
||||
let response_type: String = {
|
||||
let it = response_type.iter().map(ToString::to_string);
|
||||
Itertools::intersperse(it, " ".to_string()).collect()
|
||||
};
|
||||
|
||||
sqlx::query_as!(
|
||||
OAuth2Session,
|
||||
r#"
|
||||
INSERT INTO oauth2_sessions
|
||||
(user_session_id, client_id, redirect_uri, scope, state, nonce, max_age,
|
||||
response_type, response_mode)
|
||||
VALUES
|
||||
($1, $2, $3, $4, $5, $6, $7, $8, $9)
|
||||
RETURNING
|
||||
id, user_session_id, client_id, redirect_uri, scope, state, nonce, max_age,
|
||||
response_type, response_mode, created_at, updated_at
|
||||
"#,
|
||||
optional_session_id,
|
||||
client_id,
|
||||
redirect_uri,
|
||||
scope,
|
||||
state,
|
||||
nonce,
|
||||
max_age,
|
||||
response_type,
|
||||
response_mode,
|
||||
)
|
||||
.fetch_one(executor)
|
||||
.await
|
||||
.context("could not insert oauth2 session")
|
||||
}
|
||||
|
||||
pub async fn get_session_by_id(
|
||||
executor: impl Executor<'_, Database = Postgres>,
|
||||
oauth2_session_id: i64,
|
||||
) -> anyhow::Result<OAuth2Session> {
|
||||
sqlx::query_as!(
|
||||
OAuth2Session,
|
||||
r#"
|
||||
SELECT
|
||||
id, user_session_id, client_id, redirect_uri, scope, state, nonce,
|
||||
max_age, response_type, response_mode, created_at, updated_at
|
||||
FROM oauth2_sessions
|
||||
WHERE id = $1
|
||||
"#,
|
||||
oauth2_session_id
|
||||
)
|
||||
.fetch_one(executor)
|
||||
.await
|
||||
.context("could not fetch oauth2 session")
|
||||
}
|
||||
|
||||
pub async fn get_code_for_session(
|
||||
executor: impl Executor<'_, Database = Postgres>,
|
||||
oauth2_session_id: i64,
|
||||
) -> anyhow::Result<String> {
|
||||
sqlx::query_scalar!(
|
||||
r#"
|
||||
SELECT code
|
||||
FROM oauth2_codes
|
||||
WHERE oauth2_session_id = $1
|
||||
"#,
|
||||
oauth2_session_id
|
||||
)
|
||||
.fetch_one(executor)
|
||||
.await
|
||||
.context("could not fetch oauth2 code")
|
||||
}
|
370
crates/core/src/storage/user.rs
Normal file
370
crates/core/src/storage/user.rs
Normal file
@ -0,0 +1,370 @@
|
||||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::borrow::BorrowMut;
|
||||
|
||||
use anyhow::Context;
|
||||
use argon2::Argon2;
|
||||
use chrono::{DateTime, Utc};
|
||||
use password_hash::{PasswordHash, PasswordHasher, SaltString};
|
||||
use rand::rngs::OsRng;
|
||||
use serde::Serialize;
|
||||
use sqlx::{Acquire, Executor, FromRow, Postgres, Transaction};
|
||||
use thiserror::Error;
|
||||
use tokio::task;
|
||||
use tracing::{info_span, Instrument};
|
||||
|
||||
use crate::errors::HtmlError;
|
||||
|
||||
#[derive(Serialize, Debug, Clone, FromRow)]
|
||||
pub struct User {
|
||||
pub id: i64,
|
||||
pub username: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Debug, Clone, FromRow)]
|
||||
pub struct SessionInfo {
|
||||
id: i64,
|
||||
user_id: i64,
|
||||
username: String,
|
||||
pub active: bool,
|
||||
created_at: DateTime<Utc>,
|
||||
pub last_authd_at: Option<DateTime<Utc>>,
|
||||
}
|
||||
|
||||
impl SessionInfo {
|
||||
#[must_use] pub fn key(&self) -> i64 {
|
||||
self.id
|
||||
}
|
||||
|
||||
pub async fn reauth(
|
||||
mut self,
|
||||
conn: impl Acquire<'_, Database = Postgres>,
|
||||
password: String,
|
||||
) -> anyhow::Result<Self> {
|
||||
let mut txn = conn.begin().await?;
|
||||
self.last_authd_at = Some(authenticate_session(&mut txn, self.id, password).await?);
|
||||
txn.commit().await?;
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
pub async fn end(
|
||||
mut self,
|
||||
executor: impl Executor<'_, Database = Postgres>,
|
||||
) -> anyhow::Result<Self> {
|
||||
end_session(executor, self.id).await?;
|
||||
self.active = false;
|
||||
Ok(self)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum LoginError {
|
||||
#[error("could not find user {username:?}")]
|
||||
NotFound {
|
||||
username: String,
|
||||
#[source]
|
||||
source: sqlx::Error,
|
||||
},
|
||||
|
||||
#[error("authentication failed for {username:?}")]
|
||||
Authentication {
|
||||
username: String,
|
||||
#[source]
|
||||
source: AuthenticationError,
|
||||
},
|
||||
|
||||
#[error("failed to login")]
|
||||
Other(#[from] anyhow::Error),
|
||||
}
|
||||
|
||||
impl HtmlError for LoginError {
|
||||
fn html_display(&self) -> String {
|
||||
match self {
|
||||
LoginError::NotFound { .. } => "Could not find user".to_string(),
|
||||
LoginError::Authentication { .. } => "Failed to authenticate user".to_string(),
|
||||
LoginError::Other(e) => format!("Internal error: <pre>{}</pre>", e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn login(
|
||||
conn: impl Acquire<'_, Database = Postgres>,
|
||||
username: &str,
|
||||
password: String,
|
||||
) -> Result<SessionInfo, LoginError> {
|
||||
let mut txn = conn.begin().await.context("could not start transaction")?;
|
||||
let user = lookup_user_by_username(&mut txn, username)
|
||||
.await
|
||||
.map_err(|source| {
|
||||
if matches!(source, sqlx::Error::RowNotFound) {
|
||||
LoginError::NotFound {
|
||||
username: username.to_string(),
|
||||
source,
|
||||
}
|
||||
} else {
|
||||
LoginError::Other(source.into())
|
||||
}
|
||||
})?;
|
||||
|
||||
let mut session = start_session(&mut txn, user).await?;
|
||||
session.last_authd_at = Some(
|
||||
authenticate_session(&mut txn, session.id, password)
|
||||
.await
|
||||
.map_err(|source| {
|
||||
if matches!(source, AuthenticationError::Password { .. }) {
|
||||
LoginError::Authentication {
|
||||
username: username.to_string(),
|
||||
source,
|
||||
}
|
||||
} else {
|
||||
LoginError::Other(source.into())
|
||||
}
|
||||
})?,
|
||||
);
|
||||
txn.commit().await.context("could not commit transaction")?;
|
||||
Ok(session)
|
||||
}
|
||||
|
||||
pub async fn lookup_active_session(
|
||||
executor: impl Executor<'_, Database = Postgres>,
|
||||
id: i64,
|
||||
) -> anyhow::Result<SessionInfo> {
|
||||
sqlx::query_as!(
|
||||
SessionInfo,
|
||||
r#"
|
||||
SELECT
|
||||
s.id,
|
||||
u.id as user_id,
|
||||
u.username,
|
||||
s.active,
|
||||
s.created_at,
|
||||
a.created_at as "last_authd_at?"
|
||||
FROM user_sessions s
|
||||
INNER JOIN users u
|
||||
ON s.user_id = u.id
|
||||
LEFT JOIN user_session_authentications a
|
||||
ON a.session_id = s.id
|
||||
WHERE s.id = $1 AND s.active
|
||||
ORDER BY a.created_at DESC
|
||||
LIMIT 1
|
||||
"#,
|
||||
id,
|
||||
)
|
||||
.fetch_one(executor)
|
||||
.await
|
||||
.context("could not fetch session")
|
||||
}
|
||||
|
||||
pub async fn lookup_session(
|
||||
executor: impl Executor<'_, Database = Postgres>,
|
||||
id: i64,
|
||||
) -> anyhow::Result<SessionInfo> {
|
||||
sqlx::query_as!(
|
||||
SessionInfo,
|
||||
r#"
|
||||
SELECT
|
||||
s.id,
|
||||
u.id as user_id,
|
||||
u.username,
|
||||
s.active,
|
||||
s.created_at,
|
||||
a.created_at as "last_authd_at?"
|
||||
FROM user_sessions s
|
||||
INNER JOIN users u
|
||||
ON s.user_id = u.id
|
||||
LEFT JOIN user_session_authentications a
|
||||
ON a.session_id = s.id
|
||||
WHERE s.id = $1
|
||||
ORDER BY a.created_at DESC
|
||||
LIMIT 1
|
||||
"#,
|
||||
id,
|
||||
)
|
||||
.fetch_one(executor)
|
||||
.await
|
||||
.context("could not fetch session")
|
||||
}
|
||||
|
||||
pub async fn start_session(
|
||||
executor: impl Executor<'_, Database = Postgres>,
|
||||
user: User,
|
||||
) -> anyhow::Result<SessionInfo> {
|
||||
let (id, created_at): (i64, DateTime<Utc>) = sqlx::query_as(
|
||||
r#"
|
||||
INSERT INTO user_sessions (user_id)
|
||||
VALUES ($1)
|
||||
RETURNING id, created_at
|
||||
"#,
|
||||
)
|
||||
.bind(user.id)
|
||||
.fetch_one(executor)
|
||||
.await
|
||||
.context("could not create session")?;
|
||||
|
||||
Ok(SessionInfo {
|
||||
id,
|
||||
user_id: user.id,
|
||||
username: user.username,
|
||||
active: true,
|
||||
created_at,
|
||||
last_authd_at: None,
|
||||
})
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum AuthenticationError {
|
||||
#[error("could not verify password")]
|
||||
Password(#[from] password_hash::Error),
|
||||
|
||||
#[error("could not fetch user password hash")]
|
||||
Fetch(sqlx::Error),
|
||||
|
||||
#[error("could not save session auth")]
|
||||
Save(sqlx::Error),
|
||||
|
||||
#[error("runtime error")]
|
||||
Internal(#[from] tokio::task::JoinError),
|
||||
}
|
||||
|
||||
pub async fn authenticate_session(
|
||||
txn: &mut Transaction<'_, Postgres>,
|
||||
session_id: i64,
|
||||
password: String,
|
||||
) -> Result<DateTime<Utc>, AuthenticationError> {
|
||||
// First, fetch the hashed password from the user associated with that session
|
||||
let hashed_password: String = sqlx::query_scalar!(
|
||||
r#"
|
||||
SELECT u.hashed_password
|
||||
FROM user_sessions s
|
||||
INNER JOIN users u
|
||||
ON u.id = s.user_id
|
||||
WHERE s.id = $1
|
||||
"#,
|
||||
session_id,
|
||||
)
|
||||
.fetch_one(txn.borrow_mut())
|
||||
.await
|
||||
.map_err(AuthenticationError::Fetch)?;
|
||||
|
||||
// TODO: pass verifiers list as parameter
|
||||
// Verify the password in a blocking thread to avoid blocking the async executor
|
||||
task::spawn_blocking(move || {
|
||||
let context = Argon2::default();
|
||||
let hasher = PasswordHash::new(&hashed_password).map_err(AuthenticationError::Password)?;
|
||||
hasher
|
||||
.verify_password(&[&context], &password)
|
||||
.map_err(AuthenticationError::Password)
|
||||
})
|
||||
.await??;
|
||||
|
||||
// That went well, let's insert the auth info
|
||||
let created_at: DateTime<Utc> = sqlx::query_scalar!(
|
||||
r#"
|
||||
INSERT INTO user_session_authentications (session_id)
|
||||
VALUES ($1)
|
||||
RETURNING created_at
|
||||
"#,
|
||||
session_id,
|
||||
)
|
||||
.fetch_one(txn.borrow_mut())
|
||||
.await
|
||||
.map_err(AuthenticationError::Save)?;
|
||||
|
||||
Ok(created_at)
|
||||
}
|
||||
|
||||
pub async fn register_user(
|
||||
executor: impl Executor<'_, Database = Postgres>,
|
||||
phf: impl PasswordHasher,
|
||||
username: &str,
|
||||
password: &str,
|
||||
) -> anyhow::Result<User> {
|
||||
let salt = SaltString::generate(&mut OsRng);
|
||||
let hashed_password = PasswordHash::generate(phf, password, salt.as_str())?;
|
||||
|
||||
let id: i64 = sqlx::query_scalar!(
|
||||
r#"
|
||||
INSERT INTO users (username, hashed_password)
|
||||
VALUES ($1, $2)
|
||||
RETURNING id
|
||||
"#,
|
||||
username,
|
||||
hashed_password.to_string(),
|
||||
)
|
||||
.fetch_one(executor)
|
||||
.instrument(info_span!("Register user"))
|
||||
.await
|
||||
.context("could not insert user")?;
|
||||
|
||||
Ok(User {
|
||||
id,
|
||||
username: username.to_string(),
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn end_session(
|
||||
executor: impl Executor<'_, Database = Postgres>,
|
||||
id: i64,
|
||||
) -> anyhow::Result<()> {
|
||||
let res = sqlx::query!("UPDATE user_sessions SET active = FALSE WHERE id = $1", id)
|
||||
.execute(executor)
|
||||
.instrument(info_span!("End session"))
|
||||
.await
|
||||
.context("could not end session")?;
|
||||
|
||||
match res.rows_affected() {
|
||||
1 => Ok(()),
|
||||
0 => Err(anyhow::anyhow!("no row affected")),
|
||||
_ => Err(anyhow::anyhow!("too many row affected")),
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub async fn lookup_user_by_id(
|
||||
executor: impl Executor<'_, Database = Postgres>,
|
||||
id: i64,
|
||||
) -> anyhow::Result<User> {
|
||||
sqlx::query_as!(
|
||||
User,
|
||||
r#"
|
||||
SELECT id, username
|
||||
FROM users
|
||||
WHERE id = $1
|
||||
"#,
|
||||
id
|
||||
)
|
||||
.fetch_one(executor)
|
||||
.instrument(info_span!("Fetch user"))
|
||||
.await
|
||||
.context("could not fetch user")
|
||||
}
|
||||
|
||||
pub async fn lookup_user_by_username(
|
||||
executor: impl Executor<'_, Database = Postgres>,
|
||||
username: &str,
|
||||
) -> Result<User, sqlx::Error> {
|
||||
sqlx::query_as!(
|
||||
User,
|
||||
r#"
|
||||
SELECT id, username
|
||||
FROM users
|
||||
WHERE username = $1
|
||||
"#,
|
||||
username,
|
||||
)
|
||||
.fetch_one(executor)
|
||||
.instrument(info_span!("Fetch user"))
|
||||
.await
|
||||
}
|
43
crates/core/src/tasks/database.rs
Normal file
43
crates/core/src/tasks/database.rs
Normal file
@ -0,0 +1,43 @@
|
||||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use sqlx::{Pool, Postgres};
|
||||
use tracing::{debug, error, info};
|
||||
|
||||
use super::Task;
|
||||
|
||||
#[derive(Clone)]
|
||||
struct CleanupExpired(Pool<Postgres>);
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl Task for CleanupExpired {
|
||||
async fn run(&self) {
|
||||
let res = crate::storage::oauth2::access_token::cleanup_expired(&self.0).await;
|
||||
match res {
|
||||
Ok(0) => {
|
||||
debug!("no token to clean up");
|
||||
}
|
||||
Ok(count) => {
|
||||
info!(count, "cleaned up expired tokens");
|
||||
}
|
||||
Err(error) => {
|
||||
error!(?error, "failed to cleanup expired tokens");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[must_use] pub fn cleanup_expired(pool: &Pool<Postgres>) -> impl Task + Clone {
|
||||
CleanupExpired(pool.clone())
|
||||
}
|
102
crates/core/src/tasks/mod.rs
Normal file
102
crates/core/src/tasks/mod.rs
Normal file
@ -0,0 +1,102 @@
|
||||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::{collections::VecDeque, sync::Arc, time::Duration};
|
||||
|
||||
use futures_util::StreamExt;
|
||||
use tokio::{
|
||||
sync::{Mutex, Notify},
|
||||
time::Interval,
|
||||
};
|
||||
use tokio_stream::wrappers::IntervalStream;
|
||||
|
||||
mod database;
|
||||
|
||||
pub use self::database::cleanup_expired;
|
||||
|
||||
#[async_trait::async_trait]
|
||||
pub trait Task: Send + Sync + 'static {
|
||||
async fn run(&self);
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
struct TaskQueueInner {
|
||||
pending_tasks: Mutex<VecDeque<Box<dyn Task>>>,
|
||||
notifier: Notify,
|
||||
}
|
||||
|
||||
impl TaskQueueInner {
|
||||
async fn recuring<T: Task + Clone>(&self, interval: Interval, task: T) {
|
||||
let mut stream = IntervalStream::new(interval);
|
||||
|
||||
while (stream.next()).await.is_some() {
|
||||
self.schedule(task.clone()).await;
|
||||
}
|
||||
}
|
||||
|
||||
async fn schedule<T: Task>(&self, task: T) {
|
||||
let task = Box::new(task);
|
||||
self.pending_tasks.lock().await.push_back(task);
|
||||
self.notifier.notify_one();
|
||||
}
|
||||
|
||||
async fn tick(&self) {
|
||||
loop {
|
||||
let pending = {
|
||||
let mut tasks = self.pending_tasks.lock().await;
|
||||
tasks.pop_front()
|
||||
};
|
||||
|
||||
if let Some(pending) = pending {
|
||||
pending.run().await;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn run_forever(&self) {
|
||||
loop {
|
||||
self.notifier.notified().await;
|
||||
self.tick().await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct TaskQueue {
|
||||
inner: Arc<TaskQueueInner>,
|
||||
}
|
||||
|
||||
impl TaskQueue {
|
||||
pub fn start(&self) {
|
||||
let queue = self.inner.clone();
|
||||
tokio::task::spawn(async move {
|
||||
queue.run_forever().await;
|
||||
});
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
async fn schedule<T: Task>(&self, task: T) {
|
||||
let queue = self.inner.clone();
|
||||
queue.schedule(task).await;
|
||||
}
|
||||
|
||||
pub fn recuring(&self, every: Duration, task: impl Task + Clone) {
|
||||
let queue = self.inner.clone();
|
||||
tokio::task::spawn(async move {
|
||||
queue.recuring(tokio::time::interval(every), task).await;
|
||||
});
|
||||
}
|
||||
}
|
299
crates/core/src/templates.rs
Normal file
299
crates/core/src/templates.rs
Normal file
@ -0,0 +1,299 @@
|
||||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::{collections::HashSet, string::ToString, sync::Arc};
|
||||
|
||||
use oauth2_types::errors::OAuth2Error;
|
||||
use serde::Serialize;
|
||||
use tera::{Context, Error as TeraError, Tera};
|
||||
use thiserror::Error;
|
||||
use tracing::{debug, info};
|
||||
use url::Url;
|
||||
use warp::reject::Reject;
|
||||
|
||||
use crate::{errors::ErroredForm, filters::CsrfToken, storage::SessionInfo};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Templates(Arc<Tera>);
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum TemplateLoadingError {
|
||||
#[error("could not load and compile some templates")]
|
||||
Compile(#[from] TeraError),
|
||||
|
||||
#[error("missing templates {missing:?}")]
|
||||
MissingTemplates {
|
||||
missing: HashSet<String>,
|
||||
loaded: HashSet<String>,
|
||||
},
|
||||
}
|
||||
|
||||
impl Templates {
|
||||
/// Load the templates and check all needed templates are properly loaded
|
||||
pub fn load() -> Result<Self, TemplateLoadingError> {
|
||||
let path = format!("{}/templates/**/*.{{html,txt}}", env!("CARGO_MANIFEST_DIR"));
|
||||
info!(%path, "Loading templates");
|
||||
let tera = Tera::new(&path)?;
|
||||
|
||||
let loaded: HashSet<_> = tera.get_template_names().collect();
|
||||
let needed: HashSet<_> = std::array::IntoIter::new(TEMPLATES).collect();
|
||||
debug!(?loaded, ?needed, "Templates loaded");
|
||||
let missing: HashSet<_> = needed.difference(&loaded).collect();
|
||||
|
||||
if missing.is_empty() {
|
||||
Ok(Self(Arc::new(tera)))
|
||||
} else {
|
||||
let missing = missing.into_iter().map(ToString::to_string).collect();
|
||||
let loaded = loaded.into_iter().map(ToString::to_string).collect();
|
||||
Err(TemplateLoadingError::MissingTemplates { missing, loaded })
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum TemplateError {
|
||||
#[error("could not prepare context for template {template:?}")]
|
||||
Context {
|
||||
template: &'static str,
|
||||
#[source]
|
||||
source: TeraError,
|
||||
},
|
||||
|
||||
#[error("could not render template {template:?}")]
|
||||
Render {
|
||||
template: &'static str,
|
||||
#[source]
|
||||
source: TeraError,
|
||||
},
|
||||
}
|
||||
|
||||
impl Reject for TemplateError {}
|
||||
|
||||
/// Count the number of tokens. Used to have a fixed-sized array for the
|
||||
/// templates list.
|
||||
macro_rules! count {
|
||||
() => (0_usize);
|
||||
( $x:tt $($xs:tt)* ) => (1_usize + count!($($xs)*));
|
||||
}
|
||||
|
||||
/// Macro that helps generating helper function that renders a specific template
|
||||
/// with a strongly-typed context. It also register the template in a static
|
||||
/// array to help detecting missing templates at startup time.
|
||||
///
|
||||
/// The syntax looks almost like a function to confuse syntax highlighter as
|
||||
/// little as possible.
|
||||
macro_rules! register_templates {
|
||||
{
|
||||
$(
|
||||
// Match any attribute on the function, such as #[doc], #[allow(dead_code)], etc.
|
||||
$( #[ $attr:meta ] )*
|
||||
// The function name
|
||||
pub fn $name:ident
|
||||
// Optional list of generics. Taken from
|
||||
// https://newbedev.com/rust-macro-accepting-type-with-generic-parameters
|
||||
$(< $( $lt:tt $( : $clt:tt $(+ $dlt:tt )* )? ),+ >)?
|
||||
// Type of context taken by the template
|
||||
( $param:ty )
|
||||
{
|
||||
// The name of the template file
|
||||
$template:expr
|
||||
}
|
||||
)*
|
||||
} => {
|
||||
/// List of registered templates
|
||||
static TEMPLATES: [&'static str; count!( $( $template )* )] = [ $( $template ),* ];
|
||||
|
||||
impl Templates {
|
||||
$(
|
||||
$(#[$attr])?
|
||||
pub fn $name
|
||||
$(< $( $lt $( : $clt $(+ $dlt )* )? ),+ >)?
|
||||
(&self, context: &$param)
|
||||
-> Result<String, TemplateError> {
|
||||
let ctx = Context::from_serialize(context)
|
||||
.map_err(|source| TemplateError::Context { template: $template, source })?;
|
||||
|
||||
self.0.render($template, &ctx)
|
||||
.map_err(|source| TemplateError::Render { template: $template, source })
|
||||
}
|
||||
)*
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
register_templates! {
|
||||
/// Render the login page
|
||||
pub fn render_login(WithCsrf<LoginContext>) { "login.html" }
|
||||
|
||||
/// Render the registration page
|
||||
pub fn render_register(WithCsrf<()>) { "register.html" }
|
||||
|
||||
/// Render the home page
|
||||
pub fn render_index(WithCsrf<WithOptionalSession<IndexContext>>) { "index.html" }
|
||||
|
||||
/// Render the re-authentication form
|
||||
pub fn render_reauth(WithCsrf<WithSession<()>>) { "reauth.html" }
|
||||
|
||||
/// Render the form used by the form_post response mode
|
||||
pub fn render_form_post<T: Serialize>(FormPostContext<T>) { "form_post.html" }
|
||||
|
||||
/// Render the HTML error page
|
||||
pub fn render_error(ErrorContext) { "error.html" }
|
||||
}
|
||||
|
||||
/// Helper trait to construct context wrappers
|
||||
pub trait TemplateContext: Sized {
|
||||
fn with_session(self, current_session: SessionInfo) -> WithSession<Self> {
|
||||
WithSession {
|
||||
current_session,
|
||||
inner: self,
|
||||
}
|
||||
}
|
||||
|
||||
fn maybe_with_session(self, current_session: Option<SessionInfo>) -> WithOptionalSession<Self> {
|
||||
WithOptionalSession {
|
||||
current_session,
|
||||
inner: self,
|
||||
}
|
||||
}
|
||||
|
||||
fn with_csrf(self, token: &CsrfToken) -> WithCsrf<Self> {
|
||||
WithCsrf {
|
||||
csrf_token: token.form_value(),
|
||||
inner: self,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Sized> TemplateContext for T {}
|
||||
|
||||
/// Context with a CSRF token in it
|
||||
#[derive(Serialize)]
|
||||
pub struct WithCsrf<T> {
|
||||
csrf_token: String,
|
||||
|
||||
#[serde(flatten)]
|
||||
inner: T,
|
||||
}
|
||||
|
||||
/// Context with a user session in it
|
||||
#[derive(Serialize)]
|
||||
pub struct WithSession<T> {
|
||||
current_session: SessionInfo,
|
||||
|
||||
#[serde(flatten)]
|
||||
inner: T,
|
||||
}
|
||||
|
||||
/// Context with an optional user session in it
|
||||
#[derive(Serialize)]
|
||||
pub struct WithOptionalSession<T> {
|
||||
current_session: Option<SessionInfo>,
|
||||
|
||||
#[serde(flatten)]
|
||||
inner: T,
|
||||
}
|
||||
|
||||
// Context used by the `index.html` template
|
||||
#[derive(Serialize)]
|
||||
pub struct IndexContext {
|
||||
discovery_url: Url,
|
||||
}
|
||||
|
||||
impl IndexContext {
|
||||
#[must_use] pub fn new(discovery_url: Url) -> Self {
|
||||
Self { discovery_url }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Debug, Clone, Copy, Hash, PartialEq, Eq)]
|
||||
#[serde(rename_all = "kebab-case")]
|
||||
pub enum LoginFormField {
|
||||
Username,
|
||||
Password,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct LoginContext {
|
||||
form: ErroredForm<LoginFormField>,
|
||||
}
|
||||
|
||||
impl LoginContext {
|
||||
#[must_use] pub fn with_form_error(form: ErroredForm<LoginFormField>) -> Self {
|
||||
Self { form }
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for LoginContext {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
form: ErroredForm::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Context used by the `form_post.html` template
|
||||
#[derive(Serialize)]
|
||||
pub struct FormPostContext<T> {
|
||||
redirect_uri: Url,
|
||||
params: T,
|
||||
}
|
||||
|
||||
impl<T> FormPostContext<T> {
|
||||
pub fn new(redirect_uri: Url, params: T) -> Self {
|
||||
Self {
|
||||
redirect_uri,
|
||||
params,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default, Serialize)]
|
||||
pub struct ErrorContext {
|
||||
code: Option<&'static str>,
|
||||
description: Option<String>,
|
||||
details: Option<String>,
|
||||
}
|
||||
|
||||
impl ErrorContext {
|
||||
#[must_use] pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
#[must_use] pub fn with_code(mut self, code: &'static str) -> Self {
|
||||
self.code = Some(code);
|
||||
self
|
||||
}
|
||||
|
||||
#[must_use] pub fn with_description(mut self, description: String) -> Self {
|
||||
self.description = Some(description);
|
||||
self
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
#[must_use] pub fn with_details(mut self, details: String) -> Self {
|
||||
self.details = Some(details);
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Box<dyn OAuth2Error>> for ErrorContext {
|
||||
fn from(err: Box<dyn OAuth2Error>) -> Self {
|
||||
let mut ctx = ErrorContext::new().with_code(err.error());
|
||||
if let Some(desc) = err.description() {
|
||||
ctx = ctx.with_description(desc);
|
||||
}
|
||||
ctx
|
||||
}
|
||||
}
|
176
crates/core/src/tokens.rs
Normal file
176
crates/core/src/tokens.rs
Normal file
@ -0,0 +1,176 @@
|
||||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::convert::TryInto;
|
||||
|
||||
use crc::{Crc, CRC_32_ISO_HDLC};
|
||||
use oauth2_types::requests::TokenTypeHint;
|
||||
use rand::{distributions::Alphanumeric, Rng};
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum TokenType {
|
||||
AccessToken,
|
||||
RefreshToken,
|
||||
}
|
||||
|
||||
impl TokenType {
|
||||
fn prefix(self) -> &'static str {
|
||||
match self {
|
||||
TokenType::AccessToken => "mat",
|
||||
TokenType::RefreshToken => "mar",
|
||||
}
|
||||
}
|
||||
|
||||
fn match_prefix(prefix: &str) -> Option<Self> {
|
||||
match prefix {
|
||||
"mat" => Some(TokenType::AccessToken),
|
||||
"mar" => Some(TokenType::RefreshToken),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialEq<TokenTypeHint> for TokenType {
|
||||
fn eq(&self, other: &TokenTypeHint) -> bool {
|
||||
matches!(
|
||||
(self, other),
|
||||
(TokenType::AccessToken, TokenTypeHint::AccessToken)
|
||||
| (TokenType::RefreshToken, TokenTypeHint::RefreshToken)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
const NUM: [u8; 62] = *b"0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
|
||||
|
||||
fn base62_encode(mut num: u32) -> String {
|
||||
let mut res = String::with_capacity(6);
|
||||
while num > 0 {
|
||||
res.push(NUM[(num % 62) as usize] as char);
|
||||
num /= 62;
|
||||
}
|
||||
|
||||
format!("{:0>6}", res)
|
||||
}
|
||||
|
||||
const CRC: Crc<u32> = Crc::<u32>::new(&CRC_32_ISO_HDLC);
|
||||
|
||||
pub fn generate(rng: impl Rng, token_type: TokenType) -> String {
|
||||
let random_part: String = rng
|
||||
.sample_iter(&Alphanumeric)
|
||||
.take(30)
|
||||
.map(char::from)
|
||||
.collect();
|
||||
|
||||
let base = format!("{}_{}", token_type.prefix(), random_part);
|
||||
let crc = CRC.checksum(base.as_bytes());
|
||||
let crc = base62_encode(crc);
|
||||
format!("{}_{}", base, crc)
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum TokenFormatError {
|
||||
#[error("invalid token format")]
|
||||
InvalidFormat,
|
||||
|
||||
#[error("unknown token prefix {prefix:?}")]
|
||||
UnknownPrefix { prefix: String },
|
||||
|
||||
#[error("invalid crc {got:?}, expected {expected:?}")]
|
||||
InvalidCrc { expected: String, got: String },
|
||||
}
|
||||
|
||||
pub fn check(token: &str) -> Result<TokenType, TokenFormatError> {
|
||||
let split: Vec<&str> = token.split('_').collect();
|
||||
let [prefix, random_part, crc]: [&str; 3] = split
|
||||
.try_into()
|
||||
.map_err(|_| TokenFormatError::InvalidFormat)?;
|
||||
|
||||
if prefix.len() != 3 || random_part.len() != 30 || crc.len() != 6 {
|
||||
return Err(TokenFormatError::InvalidFormat);
|
||||
}
|
||||
|
||||
let token_type =
|
||||
TokenType::match_prefix(prefix).ok_or_else(|| TokenFormatError::UnknownPrefix {
|
||||
prefix: prefix.to_string(),
|
||||
})?;
|
||||
|
||||
let base = format!("{}_{}", token_type.prefix(), random_part);
|
||||
let expected_crc = CRC.checksum(base.as_bytes());
|
||||
let expected_crc = base62_encode(expected_crc);
|
||||
if crc != expected_crc {
|
||||
return Err(TokenFormatError::InvalidCrc {
|
||||
expected: expected_crc,
|
||||
got: crc.to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
Ok(token_type)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::collections::HashSet;
|
||||
|
||||
use rand::thread_rng;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_prefix_match() {
|
||||
use TokenType::{AccessToken, RefreshToken};
|
||||
assert_eq!(TokenType::match_prefix("mat"), Some(AccessToken));
|
||||
assert_eq!(TokenType::match_prefix("mar"), Some(RefreshToken));
|
||||
assert_eq!(TokenType::match_prefix("matt"), None);
|
||||
assert_eq!(TokenType::match_prefix("marr"), None);
|
||||
assert_eq!(TokenType::match_prefix("ma"), None);
|
||||
assert_eq!(
|
||||
TokenType::match_prefix(TokenType::AccessToken.prefix()),
|
||||
Some(TokenType::AccessToken)
|
||||
);
|
||||
assert_eq!(
|
||||
TokenType::match_prefix(TokenType::RefreshToken.prefix()),
|
||||
Some(TokenType::RefreshToken)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_generate_and_check() {
|
||||
const COUNT: usize = 500; // Generate 500 of each token type
|
||||
let mut rng = thread_rng();
|
||||
// Generate many access tokens
|
||||
let tokens: HashSet<String> = (0..COUNT)
|
||||
.map(|_| generate(&mut rng, TokenType::AccessToken))
|
||||
.collect();
|
||||
|
||||
// Check that they are all different
|
||||
assert_eq!(tokens.len(), COUNT, "All tokens are unique");
|
||||
|
||||
// Check that they are all valid and detected as access tokens
|
||||
for token in tokens {
|
||||
assert_eq!(check(&token).unwrap(), TokenType::AccessToken);
|
||||
}
|
||||
|
||||
// Same, but for refresh tokens
|
||||
let tokens: HashSet<String> = (0..COUNT)
|
||||
.map(|_| generate(&mut rng, TokenType::RefreshToken))
|
||||
.collect();
|
||||
|
||||
assert_eq!(tokens.len(), COUNT, "All tokens are unique");
|
||||
|
||||
for token in tokens {
|
||||
assert_eq!(check(&token).unwrap(), TokenType::RefreshToken);
|
||||
}
|
||||
}
|
||||
}
|
66
crates/core/templates/base.html
Normal file
66
crates/core/templates/base.html
Normal file
@ -0,0 +1,66 @@
|
||||
{#
|
||||
Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
#}
|
||||
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="utf-8">
|
||||
<title>{% block title %}matrix-authentication-service{% endblock title %}</title>
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1">
|
||||
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/bulma@0.9.3/css/bulma.min.css">
|
||||
</head>
|
||||
<body>
|
||||
<nav class="navbar is-dark" role="navigation" aria-label="main navigation">
|
||||
<div class="container">
|
||||
<div class="navbar-brand">
|
||||
<a class="navbar-item" href="/">
|
||||
matrix-authentication-service
|
||||
</a>
|
||||
</div>
|
||||
|
||||
<div class="navbar-end">
|
||||
{% if current_session %}
|
||||
<div class="navbar-item">
|
||||
Howdy {{ current_session.username }}!
|
||||
</div>
|
||||
<div class="navbar-item">
|
||||
<form method="POST" action="/logout">
|
||||
<input type="hidden" name="csrf" value="{{ csrf_token }}" />
|
||||
<button class="button is-light" action="submit">
|
||||
Log out
|
||||
</button>
|
||||
</form>
|
||||
</div>
|
||||
{% else %}
|
||||
<div class="navbar-item">
|
||||
<a class="button is-light" href="/login">
|
||||
Log in
|
||||
</a>
|
||||
</div>
|
||||
|
||||
<div class="navbar-item">
|
||||
<a class="button is-light" href="/register">
|
||||
Register
|
||||
</a>
|
||||
</div>
|
||||
{% endif %}
|
||||
</div>
|
||||
</div>
|
||||
</nav>
|
||||
|
||||
{% block content %}{% endblock content %}
|
||||
</body>
|
||||
</html>
|
39
crates/core/templates/error.html
Normal file
39
crates/core/templates/error.html
Normal file
@ -0,0 +1,39 @@
|
||||
{#
|
||||
Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
#}
|
||||
|
||||
{% extends "base.html" %}
|
||||
|
||||
{% block content %}
|
||||
<section class="hero is-danger">
|
||||
<div class="hero-body">
|
||||
<div class="container">
|
||||
{% if code %}
|
||||
<p class="title">
|
||||
{{ code }}
|
||||
</p>
|
||||
{% endif %}
|
||||
{% if description %}
|
||||
<p class="subtitle">
|
||||
{{ description }}
|
||||
</p>
|
||||
{% endif %}
|
||||
{% if details %}
|
||||
<pre><code>{{ details }}</code></pre>
|
||||
{% endif %}
|
||||
</div>
|
||||
</div>
|
||||
</section>
|
||||
{% endblock %}
|
27
crates/core/templates/error.txt
Normal file
27
crates/core/templates/error.txt
Normal file
@ -0,0 +1,27 @@
|
||||
{#
|
||||
Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
#}
|
||||
|
||||
{% if code %}
|
||||
{{- code }}
|
||||
{% endif %}
|
||||
|
||||
{%- if description %}
|
||||
{{ description }}
|
||||
{% endif %}
|
||||
|
||||
{%- if details %}
|
||||
{{ details }}
|
||||
{% endif %}
|
31
crates/core/templates/form_post.html
Normal file
31
crates/core/templates/form_post.html
Normal file
@ -0,0 +1,31 @@
|
||||
{#
|
||||
Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
#}
|
||||
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="utf-8">
|
||||
<title>Redirecting to client</title>
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1">
|
||||
</head>
|
||||
<body onload="javascript:document.forms[0].submit()">
|
||||
<form method="post" action="{{ redirect_uri }}">
|
||||
{% for key, value in params %}
|
||||
<input type="hidden" name="{{ key }}" value="{{ value }}" />
|
||||
{% endfor %}
|
||||
</form>
|
||||
</body>
|
||||
</html>
|
28
crates/core/templates/index.html
Normal file
28
crates/core/templates/index.html
Normal file
@ -0,0 +1,28 @@
|
||||
{#
|
||||
Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
#}
|
||||
|
||||
{% extends "base.html" %}
|
||||
|
||||
{% block content %}
|
||||
<section class="section">
|
||||
<div class="container content">
|
||||
<h1>Matrix Authentication Service</h1>
|
||||
<p>
|
||||
OpenID Connect discovery document: <a href="{{ discovery_url }}">{{ discovery_url }}</a>
|
||||
</p>
|
||||
</div>
|
||||
</section>
|
||||
{% endblock content %}
|
70
crates/core/templates/login.html
Normal file
70
crates/core/templates/login.html
Normal file
@ -0,0 +1,70 @@
|
||||
{#
|
||||
Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
#}
|
||||
|
||||
{% extends "base.html" %}
|
||||
|
||||
{% block content %}
|
||||
<section class="section">
|
||||
<div class="container is-max-desktop">
|
||||
<div class="columns">
|
||||
<div class="column is-half is-offset-one-quarter">
|
||||
{% if form.has_errors %}
|
||||
<article class="message is-danger">
|
||||
<div class="message-body">
|
||||
{% for message in form.form_errors %}
|
||||
<p>{{ message | safe }}</p>
|
||||
{% else %}
|
||||
<p>Login failed, check the fields below for more details.</p>
|
||||
{% endfor %}
|
||||
</div>
|
||||
</article>
|
||||
{% endif %}
|
||||
|
||||
<form method="POST">
|
||||
<input type="hidden" name="csrf" value="{{ csrf_token }}" />
|
||||
<div class="field">
|
||||
<label class="label" for="login-username">Username</label>
|
||||
<div class="control">
|
||||
<input class="input{% if 'username' in form.fields_errors %} is-danger{% endif %}" name="username" id="login-username" type="text">
|
||||
</div>
|
||||
{% if 'username' in form.fields_errors %}
|
||||
{% for message in form.fields_errors.username %}
|
||||
<p class="help is-danger">{{ message | safe }}</p>
|
||||
{% endfor %}
|
||||
{% endif %}
|
||||
</div>
|
||||
|
||||
<div class="field">
|
||||
<label class="label" for="login-password">Password</label>
|
||||
<div class="control">
|
||||
<input class="input{% if 'password' in form.fields_errors %} is-danger{% endif %}" name="password" id="login-password" type="password">
|
||||
</div>
|
||||
{% if 'password' in form.fields_errors %}
|
||||
{% for message in form.fields_errors.password %}
|
||||
<p class="help is-danger">{{ message | safe }}</p>
|
||||
{% endfor %}
|
||||
{% endif %}
|
||||
</div>
|
||||
|
||||
<div class="control">
|
||||
<button type="submit" class="button is-link">Login</button>
|
||||
</div>
|
||||
</form>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</section>
|
||||
{% endblock content %}
|
45
crates/core/templates/reauth.html
Normal file
45
crates/core/templates/reauth.html
Normal file
@ -0,0 +1,45 @@
|
||||
{#
|
||||
Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
#}
|
||||
|
||||
{% extends "base.html" %}
|
||||
|
||||
{% block content %}
|
||||
<section class="section">
|
||||
<div class="container is-max-desktop">
|
||||
<div class="columns">
|
||||
<div class="column is-one-third">
|
||||
<form method="POST">
|
||||
<input type="hidden" name="csrf" value="{{ csrf_token }}" />
|
||||
<div class="field">
|
||||
<label class="label" for="login-password">Password</label>
|
||||
<div class="control">
|
||||
<input class="input" name="password" id="login-password" type="password">
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="control">
|
||||
<button type="submit" class="button is-link">Submit</button>
|
||||
</div>
|
||||
</form>
|
||||
</div>
|
||||
<div class="column is-two-third">
|
||||
<pre><code>{{ current_session | json_encode(pretty=True) | safe }}</code></pre>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</section>
|
||||
{% endblock content %}
|
||||
|
55
crates/core/templates/register.html
Normal file
55
crates/core/templates/register.html
Normal file
@ -0,0 +1,55 @@
|
||||
{#
|
||||
Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
#}
|
||||
|
||||
{% extends "base.html" %}
|
||||
|
||||
{% block content %}
|
||||
<section class="section">
|
||||
<div class="container is-max-desktop">
|
||||
<div class="columns">
|
||||
<div class="column is-one-third">
|
||||
<form method="POST">
|
||||
<input type="hidden" name="csrf" value="{{ csrf_token }}" />
|
||||
<div class="field">
|
||||
<label class="label" for="register-username">Username</label>
|
||||
<div class="control">
|
||||
<input class="input" name="username" id="register-username" type="text">
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="field">
|
||||
<label class="label" for="register-password">Password</label>
|
||||
<div class="control">
|
||||
<input class="input" name="password" id="register-password" type="password">
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="field">
|
||||
<label class="label" for="register-password">Confirm password</label>
|
||||
<div class="control">
|
||||
<input class="input" name="password_confirm" id="register-password-confirm" type="password">
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="control">
|
||||
<button type="submit" class="button is-link">Register</button>
|
||||
</div>
|
||||
</form>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</section>
|
||||
{% endblock content %}
|
21
crates/oauth2-types/Cargo.toml
Normal file
21
crates/oauth2-types/Cargo.toml
Normal file
@ -0,0 +1,21 @@
|
||||
[package]
|
||||
name = "oauth2-types"
|
||||
version = "0.1.0"
|
||||
authors = ["Quentin Gliech <quenting@element.io>"]
|
||||
edition = "2018"
|
||||
license = "Apache-2.0"
|
||||
|
||||
[dependencies]
|
||||
http = "0.2.4"
|
||||
serde = "1.0.130"
|
||||
serde_json = "1.0.68"
|
||||
language-tags = { version = "0.3.2", features = ["serde"] }
|
||||
url = { version = "2.2.2", features = ["serde"] }
|
||||
parse-display = "0.5.1"
|
||||
indoc = "1.0.3"
|
||||
serde_with = { version = "1.10.0", features = ["chrono"] }
|
||||
sqlx = { version = "0.5.7", default-features = false, optional = true }
|
||||
chrono = "0.4.19"
|
||||
|
||||
[features]
|
||||
sqlx_type = ["sqlx"]
|
268
crates/oauth2-types/src/errors.rs
Normal file
268
crates/oauth2-types/src/errors.rs
Normal file
@ -0,0 +1,268 @@
|
||||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use http::status::StatusCode;
|
||||
use serde::ser::{Serialize, SerializeMap};
|
||||
use url::Url;
|
||||
|
||||
pub trait OAuth2Error: std::fmt::Debug + Send + Sync {
|
||||
/// A single ASCII error code.
|
||||
///
|
||||
/// Maps to the required "error" field.
|
||||
fn error(&self) -> &'static str;
|
||||
|
||||
/// Human-readable ASCII text providing additional information, used to
|
||||
/// assist the client developer in understanding the error that
|
||||
/// occurred.
|
||||
///
|
||||
/// Maps to the optional `error_description` field.
|
||||
fn description(&self) -> Option<String> {
|
||||
None
|
||||
}
|
||||
|
||||
/// A URI identifying a human-readable web page with information about the
|
||||
/// error, used to provide the client developer with additional
|
||||
/// information about the error.
|
||||
///
|
||||
/// Maps to the optional `error_uri` field.
|
||||
fn uri(&self) -> Option<Url> {
|
||||
None
|
||||
}
|
||||
|
||||
/// Wraps the error with an `ErrorResponse` to help serializing.
|
||||
fn into_response(self) -> ErrorResponse
|
||||
where
|
||||
Self: Sized + 'static,
|
||||
{
|
||||
ErrorResponse(Box::new(self))
|
||||
}
|
||||
}
|
||||
|
||||
pub trait OAuth2ErrorCode: OAuth2Error + 'static {
|
||||
/// The HTTP status code that must be returned by this error
|
||||
fn status(&self) -> StatusCode;
|
||||
}
|
||||
|
||||
impl OAuth2Error for &Box<dyn OAuth2ErrorCode> {
|
||||
fn error(&self) -> &'static str {
|
||||
self.as_ref().error()
|
||||
}
|
||||
fn description(&self) -> Option<String> {
|
||||
self.as_ref().description()
|
||||
}
|
||||
|
||||
fn uri(&self) -> Option<Url> {
|
||||
self.as_ref().uri()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ErrorResponse(Box<dyn OAuth2Error>);
|
||||
|
||||
impl From<Box<dyn OAuth2Error>> for ErrorResponse {
|
||||
fn from(b: Box<dyn OAuth2Error>) -> Self {
|
||||
Self(b)
|
||||
}
|
||||
}
|
||||
|
||||
impl OAuth2Error for ErrorResponse {
|
||||
fn error(&self) -> &'static str {
|
||||
self.0.error()
|
||||
}
|
||||
|
||||
fn description(&self) -> Option<String> {
|
||||
self.0.description()
|
||||
}
|
||||
|
||||
fn uri(&self) -> Option<Url> {
|
||||
self.0.uri()
|
||||
}
|
||||
}
|
||||
|
||||
impl Serialize for ErrorResponse {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: serde::Serializer,
|
||||
{
|
||||
let error = self.0.error();
|
||||
let description = self.0.description();
|
||||
let uri = self.0.uri();
|
||||
|
||||
// Count the number of fields to serialize
|
||||
let len = {
|
||||
let mut x = 1;
|
||||
if description.is_some() {
|
||||
x += 1;
|
||||
}
|
||||
if uri.is_some() {
|
||||
x += 1;
|
||||
}
|
||||
x
|
||||
};
|
||||
|
||||
let mut map = serializer.serialize_map(Some(len))?;
|
||||
map.serialize_entry("error", error)?;
|
||||
if let Some(ref description) = description {
|
||||
map.serialize_entry("error_description", description)?;
|
||||
}
|
||||
if let Some(ref uri) = uri {
|
||||
map.serialize_entry("error_uri", uri)?;
|
||||
}
|
||||
map.end()
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! oauth2_error_def {
|
||||
($name:ident) => {
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct $name;
|
||||
};
|
||||
}
|
||||
|
||||
macro_rules! oauth2_error_status {
|
||||
($name:ident, $code:ident) => {
|
||||
impl $crate::errors::OAuth2ErrorCode for $name {
|
||||
fn status(&self) -> ::http::status::StatusCode {
|
||||
::http::status::StatusCode::$code
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
macro_rules! oauth2_error_error {
|
||||
($err:literal) => {
|
||||
fn error(&self) -> &'static str {
|
||||
$err
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
macro_rules! oauth2_error_description {
|
||||
($description:expr) => {
|
||||
fn description(&self) -> Option<String> {
|
||||
Some(($description).to_string())
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
macro_rules! oauth2_error {
|
||||
($name:ident, $err:literal => $description:expr) => {
|
||||
oauth2_error_def!($name);
|
||||
impl $crate::errors::OAuth2Error for $name {
|
||||
oauth2_error_error!($err);
|
||||
oauth2_error_description!(indoc::indoc! {$description});
|
||||
}
|
||||
};
|
||||
($name:ident, $err:literal) => {
|
||||
oauth2_error_def!($name);
|
||||
impl $crate::errors::OAuth2Error for $name {
|
||||
oauth2_error_error!($err);
|
||||
}
|
||||
};
|
||||
($name:ident, code: $code:ident, $err:literal => $description:expr) => {
|
||||
oauth2_error!($name, $err => $description);
|
||||
oauth2_error_status!($name, $code);
|
||||
};
|
||||
($name:ident, code: $code:ident, $err:literal) => {
|
||||
oauth2_error!($name, $err);
|
||||
oauth2_error_status!($name, $code);
|
||||
};
|
||||
}
|
||||
|
||||
pub mod rfc6749 {
|
||||
oauth2_error! {
|
||||
InvalidRequest,
|
||||
code: BAD_REQUEST,
|
||||
"invalid_request" =>
|
||||
"The request is missing a required parameter, includes an invalid parameter value, \
|
||||
includes a parameter more than once, or is otherwise malformed."
|
||||
}
|
||||
|
||||
oauth2_error! {
|
||||
InvalidClient,
|
||||
code: BAD_REQUEST,
|
||||
"invalid_client" =>
|
||||
"Client authentication failed."
|
||||
}
|
||||
|
||||
oauth2_error! {
|
||||
InvalidGrant,
|
||||
code: BAD_REQUEST,
|
||||
"invalid_grant"
|
||||
}
|
||||
|
||||
oauth2_error! {
|
||||
UnauthorizedClient,
|
||||
code: BAD_REQUEST,
|
||||
"unauthorized_client" =>
|
||||
"The client is not authorized to request an access token using this method."
|
||||
}
|
||||
|
||||
oauth2_error! {
|
||||
UnsupportedGrantType,
|
||||
code: BAD_REQUEST,
|
||||
"unsupported_grant_type" =>
|
||||
"The authorization grant type is not supported by the authorization server."
|
||||
}
|
||||
|
||||
oauth2_error! {
|
||||
AccessDenied,
|
||||
"access_denied" =>
|
||||
"The resource owner or authorization server denied the request."
|
||||
}
|
||||
|
||||
oauth2_error! {
|
||||
UnsupportedResponseType,
|
||||
"unsupported_response_type" =>
|
||||
"The authorization server does not support obtaining an access token using this method."
|
||||
}
|
||||
|
||||
oauth2_error! {
|
||||
InvalidScope,
|
||||
code: BAD_REQUEST,
|
||||
"invalid_scope" =>
|
||||
"The requested scope is invalid, unknown, or malformed."
|
||||
}
|
||||
|
||||
oauth2_error! {
|
||||
ServerError,
|
||||
"server_error" =>
|
||||
"The authorization server encountered an unexpected \
|
||||
condition that prevented it from fulfilling the request."
|
||||
}
|
||||
|
||||
oauth2_error! {
|
||||
TemporarilyUnavailable,
|
||||
"temporarily_unavailable" =>
|
||||
"The authorization server is currently unable to handle \
|
||||
the request due to a temporary overloading or maintenance \
|
||||
of the server."
|
||||
}
|
||||
}
|
||||
|
||||
pub use rfc6749::*;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use serde_json::json;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn serialize_error() {
|
||||
let expected = json!({"error": "invalid_grant"});
|
||||
let actual = serde_json::to_value(InvalidGrant.into_response()).unwrap();
|
||||
assert_eq!(expected, actual);
|
||||
}
|
||||
}
|
25
crates/oauth2-types/src/lib.rs
Normal file
25
crates/oauth2-types/src/lib.rs
Normal file
@ -0,0 +1,25 @@
|
||||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#![forbid(unsafe_code)]
|
||||
#![deny(clippy::all)]
|
||||
#![warn(clippy::pedantic)]
|
||||
|
||||
pub mod errors;
|
||||
pub mod oidc;
|
||||
pub mod pkce;
|
||||
pub mod requests;
|
||||
|
||||
#[cfg(test)]
|
||||
mod test_utils;
|
75
crates/oauth2-types/src/oidc.rs
Normal file
75
crates/oauth2-types/src/oidc.rs
Normal file
@ -0,0 +1,75 @@
|
||||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::collections::HashSet;
|
||||
|
||||
use serde::Serialize;
|
||||
use serde_with::skip_serializing_none;
|
||||
use url::Url;
|
||||
|
||||
use crate::{
|
||||
pkce::CodeChallengeMethod,
|
||||
requests::{ClientAuthenticationMethod, GrantType, ResponseMode},
|
||||
};
|
||||
|
||||
// TODO: https://datatracker.ietf.org/doc/html/rfc8414#section-2
|
||||
#[skip_serializing_none]
|
||||
#[derive(Serialize, Clone)]
|
||||
pub struct Metadata {
|
||||
/// The authorization server's issuer identifier, which is a URL that uses
|
||||
/// the "https" scheme and has no query or fragment components.
|
||||
pub issuer: Url,
|
||||
|
||||
/// URL of the authorization server's authorization endpoint.
|
||||
pub authorization_endpoint: Option<Url>,
|
||||
|
||||
/// URL of the authorization server's token endpoint.
|
||||
pub token_endpoint: Option<Url>,
|
||||
|
||||
/// URL of the authorization server's JWK Set document.
|
||||
pub jwks_uri: Option<Url>,
|
||||
|
||||
/// URL of the authorization server's OAuth 2.0 Dynamic Client Registration
|
||||
/// endpoint.
|
||||
pub registration_endpoint: Option<Url>,
|
||||
|
||||
/// JSON array containing a list of the OAuth 2.0 "scope" values that this
|
||||
/// authorization server supports.
|
||||
pub scopes_supported: Option<HashSet<String>>,
|
||||
|
||||
/// JSON array containing a list of the OAuth 2.0 "response_type" values
|
||||
/// that this authorization server supports.
|
||||
pub response_types_supported: Option<HashSet<String>>,
|
||||
|
||||
/// JSON array containing a list of the OAuth 2.0 "response_mode" values
|
||||
/// that this authorization server supports, as specified in "OAuth 2.0
|
||||
/// Multiple Response Type Encoding Practices".
|
||||
pub response_modes_supported: Option<HashSet<ResponseMode>>,
|
||||
|
||||
/// JSON array containing a list of the OAuth 2.0 grant type values that
|
||||
/// this authorization server supports.
|
||||
pub grant_types_supported: Option<HashSet<GrantType>>,
|
||||
|
||||
/// JSON array containing a list of client authentication methods supported
|
||||
/// by this token endpoint.
|
||||
pub token_endpoint_auth_methods_supported: Option<HashSet<ClientAuthenticationMethod>>,
|
||||
|
||||
/// PKCE code challenge methods supported by this authorization server
|
||||
pub code_challenge_methods_supported: Option<HashSet<CodeChallengeMethod>>,
|
||||
|
||||
/// URL of the authorization server's OAuth 2.0 introspection endpoint.
|
||||
pub introspection_endpoint: Option<Url>,
|
||||
|
||||
pub userinfo_endpoint: Option<Url>,
|
||||
}
|
48
crates/oauth2-types/src/pkce.rs
Normal file
48
crates/oauth2-types/src/pkce.rs
Normal file
@ -0,0 +1,48 @@
|
||||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use parse_display::{Display, FromStr};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(
|
||||
Debug,
|
||||
Hash,
|
||||
PartialEq,
|
||||
Eq,
|
||||
PartialOrd,
|
||||
Ord,
|
||||
Clone,
|
||||
Copy,
|
||||
Display,
|
||||
FromStr,
|
||||
Serialize,
|
||||
Deserialize,
|
||||
)]
|
||||
#[cfg_attr(feature = "sqlx_type", derive(sqlx::Type))]
|
||||
#[repr(i8)]
|
||||
pub enum CodeChallengeMethod {
|
||||
#[serde(rename = "plain")]
|
||||
#[display("plain")]
|
||||
Plain = 0,
|
||||
|
||||
#[serde(rename = "S256")]
|
||||
#[display("S256")]
|
||||
S256 = 1,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct Request {
|
||||
pub code_challenge_method: CodeChallengeMethod,
|
||||
pub code_challenge: String,
|
||||
}
|
414
crates/oauth2-types/src/requests.rs
Normal file
414
crates/oauth2-types/src/requests.rs
Normal file
@ -0,0 +1,414 @@
|
||||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::{collections::HashSet, hash::Hash};
|
||||
|
||||
use chrono::{DateTime, Duration, Utc};
|
||||
use language_tags::LanguageTag;
|
||||
use parse_display::{Display, FromStr};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_with::{
|
||||
rust::StringWithSeparator, serde_as, skip_serializing_none, DurationSeconds, SpaceSeparator,
|
||||
TimestampSeconds,
|
||||
};
|
||||
use url::Url;
|
||||
|
||||
// ref: https://www.iana.org/assignments/oauth-parameters/oauth-parameters.xhtml
|
||||
|
||||
#[derive(
|
||||
Debug,
|
||||
Hash,
|
||||
PartialEq,
|
||||
Eq,
|
||||
PartialOrd,
|
||||
Ord,
|
||||
Clone,
|
||||
Copy,
|
||||
Display,
|
||||
FromStr,
|
||||
Serialize,
|
||||
Deserialize,
|
||||
)]
|
||||
#[display(style = "snake_case")]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum ResponseType {
|
||||
Code,
|
||||
IdToken,
|
||||
Token,
|
||||
None,
|
||||
}
|
||||
|
||||
#[derive(
|
||||
Debug,
|
||||
Hash,
|
||||
PartialEq,
|
||||
Eq,
|
||||
PartialOrd,
|
||||
Ord,
|
||||
Clone,
|
||||
Copy,
|
||||
Display,
|
||||
FromStr,
|
||||
Serialize,
|
||||
Deserialize,
|
||||
)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum ResponseMode {
|
||||
Query,
|
||||
Fragment,
|
||||
FormPost,
|
||||
}
|
||||
|
||||
#[derive(
|
||||
Debug,
|
||||
Hash,
|
||||
PartialEq,
|
||||
Eq,
|
||||
PartialOrd,
|
||||
Ord,
|
||||
Clone,
|
||||
Copy,
|
||||
Display,
|
||||
FromStr,
|
||||
Serialize,
|
||||
Deserialize,
|
||||
)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum ClientAuthenticationMethod {
|
||||
None,
|
||||
ClientSecretPost,
|
||||
ClientSecretBasic,
|
||||
}
|
||||
|
||||
#[derive(
|
||||
Debug,
|
||||
Hash,
|
||||
PartialEq,
|
||||
Eq,
|
||||
PartialOrd,
|
||||
Ord,
|
||||
Clone,
|
||||
Copy,
|
||||
Display,
|
||||
FromStr,
|
||||
Serialize,
|
||||
Deserialize,
|
||||
)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum Display {
|
||||
Page,
|
||||
Popup,
|
||||
Touch,
|
||||
Wap,
|
||||
}
|
||||
|
||||
#[derive(
|
||||
Debug,
|
||||
Hash,
|
||||
PartialEq,
|
||||
Eq,
|
||||
PartialOrd,
|
||||
Ord,
|
||||
Clone,
|
||||
Copy,
|
||||
Display,
|
||||
FromStr,
|
||||
Serialize,
|
||||
Deserialize,
|
||||
)]
|
||||
#[display(style = "snake_case")]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum Prompt {
|
||||
None,
|
||||
Login,
|
||||
Consent,
|
||||
SelectAccount,
|
||||
}
|
||||
|
||||
#[serde_as]
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct AuthorizationRequest {
|
||||
#[serde_as(as = "StringWithSeparator::<SpaceSeparator, ResponseType>")]
|
||||
pub response_type: HashSet<ResponseType>,
|
||||
|
||||
pub client_id: String,
|
||||
|
||||
pub redirect_uri: Option<Url>,
|
||||
|
||||
#[serde_as(as = "StringWithSeparator::<SpaceSeparator, String>")]
|
||||
pub scope: HashSet<String>,
|
||||
|
||||
pub state: Option<String>,
|
||||
|
||||
pub response_mode: Option<ResponseMode>,
|
||||
|
||||
pub nonce: Option<String>,
|
||||
|
||||
display: Option<Display>,
|
||||
|
||||
#[serde_as(as = "Option<DurationSeconds<i64>>")]
|
||||
#[serde(default)]
|
||||
pub max_age: Option<Duration>,
|
||||
|
||||
#[serde_as(as = "Option<StringWithSeparator::<SpaceSeparator, LanguageTag>>")]
|
||||
#[serde(default)]
|
||||
ui_locales: Option<Vec<LanguageTag>>,
|
||||
|
||||
id_token_hint: Option<String>,
|
||||
|
||||
login_hint: Option<String>,
|
||||
|
||||
#[serde_as(as = "Option<StringWithSeparator::<SpaceSeparator, String>>")]
|
||||
#[serde(default)]
|
||||
acr_values: Option<HashSet<String>>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Default)]
|
||||
pub struct AuthorizationResponse<R> {
|
||||
pub code: Option<String>,
|
||||
pub state: Option<String>,
|
||||
#[serde(flatten)]
|
||||
pub response: R,
|
||||
}
|
||||
|
||||
#[derive(
|
||||
Debug,
|
||||
Hash,
|
||||
PartialEq,
|
||||
Eq,
|
||||
PartialOrd,
|
||||
Ord,
|
||||
Clone,
|
||||
Copy,
|
||||
Display,
|
||||
FromStr,
|
||||
Serialize,
|
||||
Deserialize,
|
||||
)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum TokenType {
|
||||
Bearer,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, PartialEq)]
|
||||
pub struct AuthorizationCodeGrant {
|
||||
pub code: String,
|
||||
#[serde(default)]
|
||||
pub redirect_uri: Option<Url>,
|
||||
}
|
||||
|
||||
#[serde_as]
|
||||
#[derive(Serialize, Deserialize, Debug, PartialEq)]
|
||||
pub struct RefreshTokenGrant {
|
||||
pub refresh_token: String,
|
||||
|
||||
#[serde(default)]
|
||||
#[serde_as(as = "Option<StringWithSeparator::<SpaceSeparator, String>>")]
|
||||
scope: Option<HashSet<String>>,
|
||||
}
|
||||
|
||||
#[serde_as]
|
||||
#[derive(Serialize, Deserialize, Debug, PartialEq)]
|
||||
pub struct ClientCredentialsGrant {
|
||||
#[serde(default)]
|
||||
#[serde_as(as = "Option<StringWithSeparator::<SpaceSeparator, String>>")]
|
||||
scope: Option<HashSet<String>>,
|
||||
}
|
||||
|
||||
#[derive(
|
||||
Debug,
|
||||
Hash,
|
||||
PartialEq,
|
||||
Eq,
|
||||
PartialOrd,
|
||||
Ord,
|
||||
Clone,
|
||||
Copy,
|
||||
Display,
|
||||
FromStr,
|
||||
Serialize,
|
||||
Deserialize,
|
||||
)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum GrantType {
|
||||
AuthorizationCode,
|
||||
RefreshToken,
|
||||
ClientCredentials,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, PartialEq)]
|
||||
#[serde(tag = "grant_type", rename_all = "snake_case")]
|
||||
pub enum AccessTokenRequest {
|
||||
AuthorizationCode(AuthorizationCodeGrant),
|
||||
RefreshToken(RefreshTokenGrant),
|
||||
ClientCredentials(ClientCredentialsGrant),
|
||||
#[serde(skip_deserializing, other)]
|
||||
Unsupported,
|
||||
}
|
||||
|
||||
#[serde_as]
|
||||
#[skip_serializing_none]
|
||||
#[derive(Serialize, Deserialize, Debug, PartialEq)]
|
||||
pub struct AccessTokenResponse {
|
||||
access_token: String,
|
||||
refresh_token: Option<String>,
|
||||
// TODO: this should be somewhere else
|
||||
id_token: Option<String>,
|
||||
|
||||
token_type: TokenType,
|
||||
|
||||
#[serde_as(as = "Option<DurationSeconds<i64>>")]
|
||||
expires_in: Option<Duration>,
|
||||
|
||||
#[serde_as(as = "Option<StringWithSeparator::<SpaceSeparator, String>>")]
|
||||
scope: Option<HashSet<String>>,
|
||||
}
|
||||
|
||||
impl AccessTokenResponse {
|
||||
#[must_use]
|
||||
pub fn new(access_token: String) -> AccessTokenResponse {
|
||||
AccessTokenResponse {
|
||||
access_token,
|
||||
refresh_token: None,
|
||||
id_token: None,
|
||||
token_type: TokenType::Bearer,
|
||||
expires_in: None,
|
||||
scope: None,
|
||||
}
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn with_refresh_token(mut self, refresh_token: String) -> Self {
|
||||
self.refresh_token = Some(refresh_token);
|
||||
self
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn with_id_token(mut self, id_token: String) -> Self {
|
||||
self.id_token = Some(id_token);
|
||||
self
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn with_scopes(mut self, scope: HashSet<String>) -> Self {
|
||||
self.scope = Some(scope);
|
||||
self
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn with_expires_in(mut self, expires_in: Duration) -> Self {
|
||||
self.expires_in = Some(expires_in);
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, PartialEq)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum TokenTypeHint {
|
||||
AccessToken,
|
||||
RefreshToken,
|
||||
}
|
||||
|
||||
#[skip_serializing_none]
|
||||
#[derive(Serialize, Deserialize, Debug, PartialEq)]
|
||||
pub struct IntrospectionRequest {
|
||||
pub token: String,
|
||||
|
||||
#[serde(default)]
|
||||
pub token_type_hint: Option<TokenTypeHint>,
|
||||
}
|
||||
|
||||
#[serde_as]
|
||||
#[skip_serializing_none]
|
||||
#[derive(Serialize, Deserialize, Debug, PartialEq, Default)]
|
||||
pub struct IntrospectionResponse {
|
||||
pub active: bool,
|
||||
|
||||
#[serde_as(as = "Option<StringWithSeparator::<SpaceSeparator, String>>")]
|
||||
pub scope: Option<HashSet<String>>,
|
||||
|
||||
pub client_id: Option<String>,
|
||||
|
||||
pub username: Option<String>,
|
||||
|
||||
pub token_type: Option<TokenTypeHint>,
|
||||
|
||||
#[serde_as(as = "Option<TimestampSeconds>")]
|
||||
pub exp: Option<DateTime<Utc>>,
|
||||
|
||||
#[serde_as(as = "Option<TimestampSeconds>")]
|
||||
pub iat: Option<DateTime<Utc>>,
|
||||
|
||||
#[serde_as(as = "Option<TimestampSeconds>")]
|
||||
pub nbf: Option<DateTime<Utc>>,
|
||||
|
||||
pub sub: Option<String>,
|
||||
|
||||
pub aud: Option<String>,
|
||||
|
||||
pub iss: Option<String>,
|
||||
|
||||
pub jti: Option<String>,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::collections::HashSet;
|
||||
|
||||
use serde_json::json;
|
||||
|
||||
use super::*;
|
||||
use crate::test_utils::assert_serde_json;
|
||||
|
||||
#[test]
|
||||
fn serde_refresh_token_grant() {
|
||||
let expected = json!({
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": "abcd",
|
||||
"scope": "openid",
|
||||
});
|
||||
|
||||
let scope = {
|
||||
let mut s = HashSet::new();
|
||||
// TODO: insert multiple scopes and test it. It's a bit tricky to test since
|
||||
// HashSet have no guarantees regarding the ordering of items, so right
|
||||
// now the output is unstable.
|
||||
s.insert("openid".to_string());
|
||||
Some(s)
|
||||
};
|
||||
|
||||
let req = AccessTokenRequest::RefreshToken(RefreshTokenGrant {
|
||||
refresh_token: "abcd".into(),
|
||||
scope,
|
||||
});
|
||||
|
||||
assert_serde_json(&req, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn serde_authorization_code_grant() {
|
||||
let expected = json!({
|
||||
"grant_type": "authorization_code",
|
||||
"code": "abcd",
|
||||
"redirect_uri": "https://example.com/redirect",
|
||||
});
|
||||
|
||||
let req = AccessTokenRequest::AuthorizationCode(AuthorizationCodeGrant {
|
||||
code: "abcd".into(),
|
||||
redirect_uri: Some("https://example.com/redirect".parse().unwrap()),
|
||||
});
|
||||
|
||||
assert_serde_json(&req, expected);
|
||||
}
|
||||
}
|
30
crates/oauth2-types/src/test_utils.rs
Normal file
30
crates/oauth2-types/src/test_utils.rs
Normal file
@ -0,0 +1,30 @@
|
||||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::fmt::Debug;
|
||||
|
||||
use serde::{de::DeserializeOwned, Serialize};
|
||||
|
||||
#[track_caller]
|
||||
pub(crate) fn assert_serde_json<T: Serialize + DeserializeOwned + PartialEq + Debug>(
|
||||
got: &T,
|
||||
expected_value: serde_json::Value,
|
||||
) {
|
||||
let got_value = serde_json::to_value(&got).expect("could not serialize object as JSON value");
|
||||
assert_eq!(got_value, expected_value);
|
||||
|
||||
let expected: T =
|
||||
serde_json::from_value(expected_value).expect("could not serialize object as JSON value");
|
||||
assert_eq!(got, &expected);
|
||||
}
|
Reference in New Issue
Block a user