1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-29 22:01:14 +03:00

Split the service in multiple crates

This commit is contained in:
Quentin Gliech
2021-09-16 14:43:56 +02:00
parent da91564bf9
commit a44e33931c
83 changed files with 311 additions and 174 deletions

27
crates/cli/Cargo.toml Normal file
View File

@ -0,0 +1,27 @@
[package]
name = "mas-cli"
version = "0.1.0"
authors = ["Quentin Gliech <quenting@element.io>"]
edition = "2018"
license = "Apache-2.0"
[dependencies]
tokio = { version = "1.11.0", features = ["full"] }
anyhow = "1.0.44"
clap = "3.0.0-beta.4"
tracing = "0.1.27"
tracing-subscriber = "0.2.22"
dotenv = "0.15.0"
schemars = { version = "0.8.3", features = ["url", "chrono"] }
tower = { version = "0.4.8", features = ["full"] }
tower-http = { version = "0.1.1", features = ["full"] }
hyper = { version = "0.14.12", features = ["full"] }
serde_yaml = "0.8.21"
warp = "0.3.1"
argon2 = { version = "0.3.1", features = ["password-hash"] }
mas-config = { path = "../config" }
mas-core = { path = "../core" }
[dev-dependencies]
indoc = "1.0.3"

75
crates/cli/src/config.rs Normal file
View File

@ -0,0 +1,75 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use clap::Clap;
use mas_config::{ConfigurationSection, RootConfig};
use schemars::schema_for;
use tracing::info;
use super::RootCommand;
#[derive(Clap, Debug)]
pub(super) struct ConfigCommand {
#[clap(subcommand)]
subcommand: ConfigSubcommand,
}
#[derive(Clap, Debug)]
enum ConfigSubcommand {
/// Dump the current config as YAML
Dump,
/// Print the JSON Schema that validates configuration files
Schema,
/// Check a config file
Check,
/// Generate a new config file
Generate,
}
impl ConfigCommand {
pub async fn run(&self, root: &RootCommand) -> anyhow::Result<()> {
use ConfigSubcommand as SC;
match &self.subcommand {
SC::Dump => {
let config: RootConfig = root.load_config()?;
serde_yaml::to_writer(std::io::stdout(), &config)?;
Ok(())
}
SC::Schema => {
let schema = schema_for!(RootConfig);
serde_yaml::to_writer(std::io::stdout(), &schema)?;
Ok(())
}
SC::Check => {
let _config: RootConfig = root.load_config()?;
info!(path = ?root.config, "Configuration file looks good");
Ok(())
}
SC::Generate => {
let config = RootConfig::load_and_generate().await?;
serde_yaml::to_writer(std::io::stdout(), &config)?;
Ok(())
}
}
}
}

View File

@ -0,0 +1,47 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use anyhow::Context;
use clap::Clap;
use mas_config::DatabaseConfig;
use mas_core::storage::MIGRATOR;
use super::RootCommand;
#[derive(Clap, Debug)]
pub(super) struct DatabaseCommand {
#[clap(subcommand)]
subcommand: DatabaseSubcommand,
}
#[derive(Clap, Debug)]
enum DatabaseSubcommand {
/// Run database migrations
Migrate,
}
impl DatabaseCommand {
pub async fn run(&self, root: &RootCommand) -> anyhow::Result<()> {
let config: DatabaseConfig = root.load_config()?;
let pool = config.connect().await?;
// Run pending migrations
MIGRATOR
.run(&pool)
.await
.context("could not run migrations")?;
Ok(())
}
}

103
crates/cli/src/main.rs Normal file
View File

@ -0,0 +1,103 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#![forbid(unsafe_code)]
#![deny(clippy::all)]
#![warn(clippy::pedantic)]
#![allow(clippy::module_name_repetitions)]
#![allow(clippy::suspicious_else_formatting)]
use std::path::PathBuf;
use anyhow::Context;
use clap::Clap;
use mas_config::ConfigurationSection;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilter, Registry};
use self::{
config::ConfigCommand, database::DatabaseCommand, manage::ManageCommand, server::ServerCommand,
};
mod config;
mod database;
mod manage;
mod server;
#[derive(Clap, Debug)]
enum Subcommand {
/// Configuration-related commands
Config(ConfigCommand),
/// Manage the database
Database(DatabaseCommand),
/// Runs the web server
Server(ServerCommand),
/// Manage the instance
Manage(ManageCommand),
}
#[derive(Clap, Debug)]
struct RootCommand {
/// Path to the configuration file
#[clap(short, long, global = true, default_value = "config.yaml")]
config: PathBuf,
#[clap(subcommand)]
subcommand: Option<Subcommand>,
}
impl RootCommand {
async fn run(&self) -> anyhow::Result<()> {
use Subcommand as S;
match &self.subcommand {
Some(S::Config(c)) => c.run(self).await,
Some(S::Database(c)) => c.run(self).await,
Some(S::Server(c)) => c.run(self).await,
Some(S::Manage(c)) => c.run(self).await,
None => ServerCommand::default().run(self).await,
}
}
fn load_config<'de, T: ConfigurationSection<'de>>(&self) -> anyhow::Result<T> {
T::load_from_file(&self.config).context("could not load configuration")
}
}
#[tokio::main]
async fn main() -> anyhow::Result<()> {
// Load environment variables from .env files
if let Err(e) = dotenv::dotenv() {
// Display the error if it is something other than the .env file not existing
if !e.not_found() {
return Err(e).context("could not load .env file");
}
}
// Setup logging & tracing
let fmt_layer = tracing_subscriber::fmt::layer().with_writer(std::io::stderr);
let filter_layer = EnvFilter::try_from_default_env().or_else(|_| EnvFilter::try_new("info"))?;
let subscriber = Registry::default().with(filter_layer).with(fmt_layer);
subscriber
.try_init()
.context("could not initialize logging")?;
// Parse the CLI arguments
let opts = RootCommand::parse();
// And run the command
opts.run().await
}

59
crates/cli/src/manage.rs Normal file
View File

@ -0,0 +1,59 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use argon2::Argon2;
use clap::Clap;
use mas_config::DatabaseConfig;
use mas_core::storage::register_user;
use tracing::{info, warn};
use super::RootCommand;
#[derive(Clap, Debug)]
pub(super) struct ManageCommand {
#[clap(subcommand)]
subcommand: ManageSubcommand,
}
#[derive(Clap, Debug)]
enum ManageSubcommand {
/// Register a new user
Register { username: String, password: String },
/// List active users
Users,
}
impl ManageCommand {
pub async fn run(&self, root: &RootCommand) -> anyhow::Result<()> {
use ManageSubcommand as SC;
match &self.subcommand {
SC::Register { username, password } => {
let config: DatabaseConfig = root.load_config()?;
let pool = config.connect().await?;
let hasher = Argon2::default();
let user = register_user(&pool, hasher, username, password).await?;
info!(?user, "User registered");
Ok(())
}
SC::Users => {
warn!("Not implemented yet");
Ok(())
}
}
}
}

93
crates/cli/src/server.rs Normal file
View File

@ -0,0 +1,93 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::{
net::{SocketAddr, TcpListener},
time::Duration,
};
use anyhow::Context;
use clap::Clap;
use hyper::{header, Server};
use mas_config::RootConfig;
use mas_core::{
tasks::{self, TaskQueue},
templates::Templates,
};
use tower::{make::Shared, ServiceBuilder};
use tower_http::{
compression::CompressionLayer,
sensitive_headers::SetSensitiveHeadersLayer,
trace::{DefaultMakeSpan, DefaultOnResponse, TraceLayer},
LatencyUnit,
};
use super::RootCommand;
#[derive(Clap, Debug, Default)]
pub(super) struct ServerCommand;
impl ServerCommand {
pub async fn run(&self, root: &RootCommand) -> anyhow::Result<()> {
let config: RootConfig = root.load_config()?;
let addr: SocketAddr = config.http.address.parse()?;
let listener = TcpListener::bind(addr)?;
// Connect to the database
let pool = config.database.connect().await?;
// Load and compile the templates
let templates = Templates::load().context("could not load templates")?;
// Start the server
let root = mas_core::handlers::root(&pool, &templates, &config);
let queue = TaskQueue::default();
queue.recuring(Duration::from_secs(15), tasks::cleanup_expired(&pool));
queue.start();
let warp_service = warp::service(root);
let service = ServiceBuilder::new()
// Add high level tracing/logging to all requests
.layer(
TraceLayer::new_for_http()
.make_span_with(DefaultMakeSpan::new().include_headers(true))
.on_response(
DefaultOnResponse::new()
.include_headers(true)
.latency_unit(LatencyUnit::Micros),
),
)
// Set a timeout
.timeout(Duration::from_secs(10))
// Compress responses
.layer(CompressionLayer::new())
// Mark the `Authorization` and `Cookie` headers as sensitive so it doesn't show in logs
.layer(SetSensitiveHeadersLayer::new(vec![
header::AUTHORIZATION,
header::COOKIE,
]))
.service(warp_service);
tracing::info!("Listening on http://{}", listener.local_addr().unwrap());
Server::from_tcp(listener)?
.serve(Shared::new(service))
.await?;
Ok(())
}
}

38
crates/config/Cargo.toml Normal file
View File

@ -0,0 +1,38 @@
[package]
name = "mas-config"
version = "0.1.0"
authors = ["Quentin Gliech <quenting@element.io>"]
edition = "2018"
license = "Apache-2.0"
[dependencies]
tokio = { version = "1.11.0", features = [] }
tracing = "0.1.27"
async-trait = "0.1.51"
thiserror = "1.0.29"
anyhow = "1.0.44"
schemars = { version = "0.8.3", features = ["url", "chrono"] }
figment = { version = "0.10.6", features = ["env", "yaml", "test"] }
chrono = { version = "0.4.19", features = ["serde"] }
url = { version = "2.2.2", features = ["serde"] }
serde = { version = "1.0.130", features = ["derive"] }
serde_with = { version = "1.10.0", features = ["hex", "chrono"] }
serde_json = "1.0.68"
sqlx = { version = "0.5.7", features = ["runtime-tokio-rustls", "postgres"] }
rand = "0.8.4"
rsa = "0.5.0"
k256 = "0.9.6"
pkcs8 = { version = "0.7.6", features = ["pem"] }
elliptic-curve = { version = "0.10.6", features = ["pem"] }
indoc = "1.0.3"
[dependencies.jwt-compact]
# Waiting on the next release because of the bump of the `rsa` dependency
git = "https://github.com/slowli/jwt-compact.git"
rev = "7a6dee6824c1d4e7c7f81019c9a968e5c9e44923"
features = ["rsa", "k256"]

View File

@ -0,0 +1,49 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use async_trait::async_trait;
use schemars::{gen::SchemaGenerator, schema::Schema, JsonSchema};
use serde::{Deserialize, Serialize};
use serde_with::serde_as;
use super::ConfigurationSection;
fn secret_schema(gen: &mut SchemaGenerator) -> Schema {
String::json_schema(gen)
}
#[serde_as]
#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
pub struct CookiesConfig {
#[schemars(schema_with = "secret_schema")]
#[serde_as(as = "serde_with::hex::Hex")]
pub secret: [u8; 32],
}
#[async_trait]
impl ConfigurationSection<'_> for CookiesConfig {
fn path() -> &'static str {
"cookies"
}
async fn generate() -> anyhow::Result<Self> {
Ok(Self {
secret: rand::random(),
})
}
fn test() -> Self {
Self { secret: [0xEA; 32] }
}
}

85
crates/config/src/csrf.rs Normal file
View File

@ -0,0 +1,85 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use async_trait::async_trait;
use chrono::Duration;
use schemars::{gen::SchemaGenerator, schema::Schema, JsonSchema};
use serde::{Deserialize, Serialize};
use serde_with::serde_as;
use super::ConfigurationSection;
fn default_ttl() -> Duration {
Duration::hours(1)
}
fn ttl_schema(gen: &mut SchemaGenerator) -> Schema {
u64::json_schema(gen)
}
#[serde_as]
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
pub struct CsrfConfig {
#[schemars(schema_with = "ttl_schema")]
#[serde(default = "default_ttl")]
#[serde_as(as = "serde_with::DurationSeconds<i64>")]
pub ttl: Duration,
}
impl Default for CsrfConfig {
fn default() -> Self {
Self { ttl: default_ttl() }
}
}
#[async_trait]
impl ConfigurationSection<'_> for CsrfConfig {
fn path() -> &'static str {
"csrf"
}
async fn generate() -> anyhow::Result<Self> {
Ok(Self::default())
}
fn test() -> Self {
Self::default()
}
}
#[cfg(test)]
mod tests {
use figment::Jail;
use super::*;
#[test]
fn load_config() {
Jail::expect_with(|jail| {
jail.create_file(
"config.yaml",
r#"
csrf:
ttl: 1800
"#,
)?;
let config = CsrfConfig::load_from_file("config.yaml")?;
assert_eq!(config.ttl, Duration::minutes(30));
Ok(())
});
}
}

View File

@ -0,0 +1,166 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::time::Duration;
use anyhow::Context;
use async_trait::async_trait;
use schemars::{gen::SchemaGenerator, schema::Schema, JsonSchema};
use serde::{Deserialize, Serialize};
use serde_with::{serde_as, skip_serializing_none};
use sqlx::postgres::{PgConnectOptions, PgPool, PgPoolOptions};
// FIXME
// use sqlx::ConnectOptions
// use tracing::log::LevelFilter;
use super::ConfigurationSection;
fn default_uri() -> String {
"postgresql://".to_string()
}
fn default_max_connections() -> u32 {
10
}
fn default_connect_timeout() -> Duration {
Duration::from_secs(30)
}
#[allow(clippy::unnecessary_wraps)]
fn default_idle_timeout() -> Option<Duration> {
Some(Duration::from_secs(10 * 60))
}
#[allow(clippy::unnecessary_wraps)]
fn default_max_lifetime() -> Option<Duration> {
Some(Duration::from_secs(30 * 60))
}
impl Default for DatabaseConfig {
fn default() -> Self {
Self {
uri: default_uri(),
max_connections: default_max_connections(),
min_connections: Default::default(),
connect_timeout: default_connect_timeout(),
idle_timeout: default_idle_timeout(),
max_lifetime: default_max_lifetime(),
}
}
}
fn duration_schema(gen: &mut SchemaGenerator) -> Schema {
Option::<u64>::json_schema(gen)
}
fn optional_duration_schema(gen: &mut SchemaGenerator) -> Schema {
u64::json_schema(gen)
}
#[serde_as]
#[skip_serializing_none]
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
pub struct DatabaseConfig {
#[serde(default = "default_uri")]
uri: String,
#[serde(default = "default_max_connections")]
max_connections: u32,
#[serde(default)]
min_connections: u32,
#[schemars(schema_with = "duration_schema")]
#[serde(default = "default_connect_timeout")]
#[serde_as(as = "serde_with::DurationSeconds<u64>")]
connect_timeout: Duration,
#[schemars(schema_with = "optional_duration_schema")]
#[serde(default = "default_idle_timeout")]
#[serde_as(as = "Option<serde_with::DurationSeconds<u64>>")]
idle_timeout: Option<Duration>,
#[schemars(schema_with = "optional_duration_schema")]
#[serde(default = "default_max_lifetime")]
#[serde_as(as = "Option<serde_with::DurationSeconds<u64>>")]
max_lifetime: Option<Duration>,
}
impl DatabaseConfig {
#[tracing::instrument(err)]
pub async fn connect(&self) -> anyhow::Result<PgPool> {
let options = self
.uri
.parse::<PgConnectOptions>()
.context("invalid database URL")?
.application_name("matrix-authentication-service");
// FIXME
// options
// .log_statements(LevelFilter::Debug)
// .log_slow_statements(LevelFilter::Warn, Duration::from_millis(100));
PgPoolOptions::new()
.max_connections(self.max_connections)
.min_connections(self.min_connections)
.connect_timeout(self.connect_timeout)
.idle_timeout(self.idle_timeout)
.max_lifetime(self.max_lifetime)
.connect_with(options)
.await
.context("could not connect to the database")
}
}
#[async_trait]
impl ConfigurationSection<'_> for DatabaseConfig {
fn path() -> &'static str {
"database"
}
async fn generate() -> anyhow::Result<Self> {
Ok(Self::default())
}
fn test() -> Self {
Self::default()
}
}
#[cfg(test)]
mod tests {
use figment::Jail;
use super::*;
#[test]
fn load_config() {
Jail::expect_with(|jail| {
jail.create_file(
"config.yaml",
r#"
database:
uri: postgresql://user:password@host/database
"#,
)?;
let config = DatabaseConfig::load_from_file("config.yaml")?;
assert_eq!(config.uri, "postgresql://user:password@host/database");
Ok(())
});
}
}

52
crates/config/src/http.rs Normal file
View File

@ -0,0 +1,52 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use async_trait::async_trait;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use super::ConfigurationSection;
fn default_http_address() -> String {
"[::]:8080".into()
}
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
pub struct HttpConfig {
#[serde(default = "default_http_address")]
pub address: String,
}
impl Default for HttpConfig {
fn default() -> Self {
Self {
address: default_http_address(),
}
}
}
#[async_trait]
impl ConfigurationSection<'_> for HttpConfig {
fn path() -> &'static str {
"http"
}
async fn generate() -> anyhow::Result<Self> {
Ok(Self::default())
}
fn test() -> Self {
Self::default()
}
}

76
crates/config/src/lib.rs Normal file
View File

@ -0,0 +1,76 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use async_trait::async_trait;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
mod cookies;
mod csrf;
mod database;
mod http;
mod oauth2;
mod util;
pub use self::{
cookies::CookiesConfig,
csrf::CsrfConfig,
database::DatabaseConfig,
http::HttpConfig,
oauth2::{Algorithm, KeySet, OAuth2ClientConfig, OAuth2Config},
util::ConfigurationSection,
};
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
pub struct RootConfig {
pub oauth2: OAuth2Config,
#[serde(default)]
pub http: HttpConfig,
#[serde(default)]
pub database: DatabaseConfig,
pub cookies: CookiesConfig,
#[serde(default)]
pub csrf: CsrfConfig,
}
#[async_trait]
impl ConfigurationSection<'_> for RootConfig {
fn path() -> &'static str {
""
}
async fn generate() -> anyhow::Result<Self> {
Ok(Self {
oauth2: OAuth2Config::generate().await?,
http: HttpConfig::generate().await?,
database: DatabaseConfig::generate().await?,
cookies: CookiesConfig::generate().await?,
csrf: CsrfConfig::generate().await?,
})
}
fn test() -> Self {
Self {
oauth2: OAuth2Config::test(),
http: HttpConfig::test(),
database: DatabaseConfig::test(),
cookies: CookiesConfig::test(),
csrf: CsrfConfig::test(),
}
}
}

476
crates/config/src/oauth2.rs Normal file
View File

@ -0,0 +1,476 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::convert::TryFrom;
use anyhow::Context;
use async_trait::async_trait;
use jwt_compact::{
alg::{self, StrongAlg, StrongKey},
jwk::JsonWebKey,
AlgorithmExt, Claims, Header,
};
use pkcs8::{FromPrivateKey, ToPrivateKey};
use rsa::RsaPrivateKey;
use schemars::JsonSchema;
use serde::{
de::{MapAccess, Visitor},
ser::SerializeStruct,
Deserialize, Serialize,
};
use serde_with::skip_serializing_none;
use thiserror::Error;
use tokio::task;
use tracing::info;
use url::Url;
use super::ConfigurationSection;
// TODO: a lot of the signing logic should go out somewhere else
const RS256: StrongAlg<alg::Rsa> = StrongAlg(alg::Rsa::rs256());
#[derive(Serialize, Deserialize, Clone, Copy, Debug)]
#[serde(rename_all = "UPPERCASE")]
pub enum Algorithm {
Rs256,
Es256k,
}
#[derive(Serialize, Clone)]
pub struct Jwk {
kid: String,
alg: Algorithm,
#[serde(flatten)]
inner: serde_json::Value,
}
#[derive(Serialize, Clone)]
pub struct Jwks {
keys: Vec<Jwk>,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(transparent)]
pub struct KeySet(Vec<Key>);
impl KeySet {
pub fn to_public_jwks(&self) -> Jwks {
let keys = self.0.iter().map(Key::to_public_jwk).collect();
Jwks { keys }
}
#[tracing::instrument(err)]
pub async fn token<T>(
&self,
alg: Algorithm,
header: Header,
claims: Claims<T>,
) -> anyhow::Result<String>
where
T: std::fmt::Debug + Serialize + Send + Sync + 'static,
{
match alg {
Algorithm::Rs256 => {
let (kid, key) = self
.0
.iter()
.find_map(Key::rsa)
.context("could not find RSA key")?;
let header = header.with_key_id(kid);
// TODO: store them as strong keys
let key = StrongKey::try_from(key.clone())?;
task::spawn_blocking(move || {
RS256
.token(header, &claims, &key)
.context("failed to sign token")
})
.await?
}
Algorithm::Es256k => {
// TODO: make this const with lazy_static?
let es256k: alg::Es256k = alg::Es256k::default();
let (kid, key) = self
.0
.iter()
.find_map(Key::ecdsa)
.context("could not find ECDSA key")?;
let key = k256::ecdsa::SigningKey::from(key);
let header = header.with_key_id(kid);
// TODO: use StrongAlg
task::spawn_blocking(move || {
es256k
.token(header, &claims, &key)
.context("failed to sign token")
})
.await?
}
}
}
}
#[derive(Debug, Clone)]
#[non_exhaustive]
pub enum Key {
Rsa { key: RsaPrivateKey, kid: String },
Ecdsa { key: k256::SecretKey, kid: String },
}
impl Key {
fn from_ecdsa(key: k256::SecretKey) -> Self {
// TODO: hash the key and use as KID
let kid = String::from("ecdsa-kid");
Self::Ecdsa { kid, key }
}
fn from_ecdsa_pem(key: &str) -> anyhow::Result<Self> {
let key = k256::SecretKey::from_pkcs8_pem(key)?;
Ok(Self::from_ecdsa(key))
}
fn from_rsa(key: RsaPrivateKey) -> Self {
// TODO: hash the key and use as KID
let kid = String::from("rsa-kid");
Self::Rsa { kid, key }
}
fn from_rsa_pem(key: &str) -> anyhow::Result<Self> {
let key = RsaPrivateKey::from_pkcs8_pem(key)?;
Ok(Self::from_rsa(key))
}
fn to_public_jwk(&self) -> Jwk {
match self {
Key::Rsa { key, kid } => {
let pubkey = key.to_public_key();
let inner = JsonWebKey::from(&pubkey);
let inner = serde_json::to_value(&inner).unwrap();
let kid = kid.to_string();
let alg = Algorithm::Rs256;
Jwk { kid, alg, inner }
}
Key::Ecdsa { key, kid } => {
let pubkey = k256::ecdsa::VerifyingKey::from(key.public_key());
let inner = JsonWebKey::from(&pubkey);
let inner = serde_json::to_value(&inner).unwrap();
let kid = kid.to_string();
let alg = Algorithm::Es256k;
Jwk { kid, alg, inner }
}
}
}
fn rsa(&self) -> Option<(&str, &RsaPrivateKey)> {
match self {
Key::Rsa { key, kid } => Some((kid, key)),
_ => None,
}
}
fn ecdsa(&self) -> Option<(&str, &k256::SecretKey)> {
match self {
Key::Ecdsa { key, kid } => Some((kid, key)),
_ => None,
}
}
}
impl Serialize for Key {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
let mut map = serializer.serialize_struct("Key", 2)?;
match self {
Key::Rsa { key, kid: _ } => {
map.serialize_field("type", "rsa")?;
let pem = key.to_pkcs8_pem().map_err(serde::ser::Error::custom)?;
map.serialize_field("key", pem.as_str())?;
}
Key::Ecdsa { key, kid: _ } => {
map.serialize_field("type", "ecdsa")?;
let pem = key.to_pkcs8_pem().map_err(serde::ser::Error::custom)?;
map.serialize_field("key", pem.as_str())?;
}
}
map.end()
}
}
impl<'de> Deserialize<'de> for Key {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
#[derive(Deserialize, Debug)]
#[serde(field_identifier, rename_all = "lowercase")]
enum Field {
Type,
Key,
}
#[derive(Deserialize)]
#[serde(rename_all = "lowercase")]
enum KeyType {
Rsa,
Ecdsa,
}
struct KeyVisitor;
impl<'de> Visitor<'de> for KeyVisitor {
type Value = Key;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str("struct Key")
}
fn visit_map<V>(self, mut map: V) -> Result<Key, V::Error>
where
V: MapAccess<'de>,
{
let mut key_type = None;
let mut key_key = None;
while let Some(key) = map.next_key()? {
match key {
Field::Type => {
if key_type.is_some() {
return Err(serde::de::Error::duplicate_field("type"));
}
key_type = Some(map.next_value()?);
}
Field::Key => {
if key_key.is_some() {
return Err(serde::de::Error::duplicate_field("key"));
}
key_key = Some(map.next_value()?);
}
}
}
let key_type: KeyType =
key_type.ok_or_else(|| serde::de::Error::missing_field("type"))?;
let key_key: String =
key_key.ok_or_else(|| serde::de::Error::missing_field("key"))?;
match key_type {
KeyType::Rsa => Key::from_rsa_pem(&key_key).map_err(serde::de::Error::custom),
KeyType::Ecdsa => {
Key::from_ecdsa_pem(&key_key).map_err(serde::de::Error::custom)
}
}
}
}
deserializer.deserialize_struct("Key", &["type", "key"], KeyVisitor)
}
}
#[skip_serializing_none]
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
pub struct OAuth2ClientConfig {
pub client_id: String,
#[serde(default)]
pub client_secret: Option<String>,
#[serde(default)]
pub redirect_uris: Vec<Url>,
}
#[derive(Debug, Error)]
#[error("Invalid redirect URI")]
pub struct InvalidRedirectUriError;
impl OAuth2ClientConfig {
pub fn resolve_redirect_uri<'a>(
&'a self,
suggested_uri: &'a Option<Url>,
) -> Result<&'a Url, InvalidRedirectUriError> {
suggested_uri.as_ref().map_or_else(
|| self.redirect_uris.get(0).ok_or(InvalidRedirectUriError),
|suggested_uri| self.check_redirect_uri(suggested_uri),
)
}
pub fn check_redirect_uri<'a>(
&self,
redirect_uri: &'a Url,
) -> Result<&'a Url, InvalidRedirectUriError> {
if self.redirect_uris.contains(redirect_uri) {
Ok(redirect_uri)
} else {
Err(InvalidRedirectUriError)
}
}
}
fn default_oauth2_issuer() -> Url {
"http://[::]:8080".parse().unwrap()
}
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
pub struct OAuth2Config {
#[serde(default = "default_oauth2_issuer")]
pub issuer: Url,
#[serde(default)]
pub clients: Vec<OAuth2ClientConfig>,
#[schemars(with = "Vec<String>")] // TODO: this is a lie
pub keys: KeySet,
}
impl OAuth2Config {
pub fn discovery_url(&self) -> Url {
self.issuer
.join(".well-known/openid-configuration")
.expect("could not build discovery url")
}
}
#[async_trait]
impl ConfigurationSection<'_> for OAuth2Config {
fn path() -> &'static str {
"oauth2"
}
#[tracing::instrument]
async fn generate() -> anyhow::Result<Self> {
info!("Generating keys...");
let span = tracing::info_span!("rsa");
let rsa_key = task::spawn_blocking(move || {
let _entered = span.enter();
let mut rng = rand::thread_rng();
let ret =
RsaPrivateKey::new(&mut rng, 2048).context("could not generate RSA private key");
info!("Done generating RSA key");
ret
})
.await
.context("could not join blocking task")??;
let span = tracing::info_span!("ecdsa");
let ecdsa_key = task::spawn_blocking(move || {
let _entered = span.enter();
let rng = rand::thread_rng();
let ret = k256::SecretKey::random(rng);
info!("Done generating ECDSA key");
ret
})
.await
.context("could not join blocking task")?;
Ok(Self {
issuer: default_oauth2_issuer(),
clients: Vec::new(),
keys: KeySet(vec![Key::from_rsa(rsa_key), Key::from_ecdsa(ecdsa_key)]),
})
}
fn test() -> Self {
let rsa_key = Key::from_rsa_pem(indoc::indoc! {r#"
-----BEGIN PRIVATE KEY-----
MIIBVQIBADANBgkqhkiG9w0BAQEFAASCAT8wggE7AgEAAkEAymS2RkeIZo7pUeEN
QUGCG4GLJru5jzxomO9jiNr5D/oRcerhpQVc9aCpBfAAg4l4a1SmYdBzWqX0X5pU
scgTtQIDAQABAkEArNIMlrxUK4bSklkCcXtXdtdKE9vuWfGyOw0GyAB69fkEUBxh
3j65u+u3ZmW+bpMWHgp1FtdobE9nGwb2VBTWAQIhAOyU1jiUEkrwKK004+6b5QRE
vC9UI2vDWy5vioMNx5Y1AiEA2wGAJ6ETF8FF2Vd+kZlkKK7J0em9cl0gbJDsWIEw
N4ECIEyWYkMurD1WQdTQqnk0Po+DMOihdFYOiBYgRdbnPxWBAiEAmtd0xJAd7622
tPQniMnrBtiN2NxqFXHCev/8Gpc8gAECIBcaPcF59qVeRmYrfqzKBxFm7LmTwlAl
Gh7BNzCeN+D6
-----END PRIVATE KEY-----
"#})
.unwrap();
let ecdsa_key = Key::from_ecdsa_pem(indoc::indoc! {r#"
-----BEGIN PRIVATE KEY-----
MIGEAgEAMBAGByqGSM49AgEGBSuBBAAKBG0wawIBAQQgqfn5mYO/5Qq/wOOiWgHA
NaiDiepgUJ2GI5eq2V8D8nahRANCAARMK9aKUd/H28qaU+0qvS6bSJItzAge1VHn
OhBAAUVci1RpmUA+KdCL5sw9nadAEiONeiGr+28RYHZmlB9qXnjC
-----END PRIVATE KEY-----
"#})
.unwrap();
Self {
issuer: default_oauth2_issuer(),
clients: Vec::new(),
keys: KeySet(vec![rsa_key, ecdsa_key]),
}
}
}
#[cfg(test)]
mod tests {
use figment::Jail;
use super::*;
#[test]
fn load_config() {
Jail::expect_with(|jail| {
jail.create_file(
"config.yaml",
r#"
oauth2:
keys:
- type: rsa
key: |
-----BEGIN PRIVATE KEY-----
MIIBVQIBADANBgkqhkiG9w0BAQEFAASCAT8wggE7AgEAAkEAymS2RkeIZo7pUeEN
QUGCG4GLJru5jzxomO9jiNr5D/oRcerhpQVc9aCpBfAAg4l4a1SmYdBzWqX0X5pU
scgTtQIDAQABAkEArNIMlrxUK4bSklkCcXtXdtdKE9vuWfGyOw0GyAB69fkEUBxh
3j65u+u3ZmW+bpMWHgp1FtdobE9nGwb2VBTWAQIhAOyU1jiUEkrwKK004+6b5QRE
vC9UI2vDWy5vioMNx5Y1AiEA2wGAJ6ETF8FF2Vd+kZlkKK7J0em9cl0gbJDsWIEw
N4ECIEyWYkMurD1WQdTQqnk0Po+DMOihdFYOiBYgRdbnPxWBAiEAmtd0xJAd7622
tPQniMnrBtiN2NxqFXHCev/8Gpc8gAECIBcaPcF59qVeRmYrfqzKBxFm7LmTwlAl
Gh7BNzCeN+D6
-----END PRIVATE KEY-----
- type: ecdsa
key: |
-----BEGIN PRIVATE KEY-----
MIGEAgEAMBAGByqGSM49AgEGBSuBBAAKBG0wawIBAQQgqfn5mYO/5Qq/wOOiWgHA
NaiDiepgUJ2GI5eq2V8D8nahRANCAARMK9aKUd/H28qaU+0qvS6bSJItzAge1VHn
OhBAAUVci1RpmUA+KdCL5sw9nadAEiONeiGr+28RYHZmlB9qXnjC
-----END PRIVATE KEY-----
issuer: https://example.com
clients:
- client_id: hello
redirect_uris:
- https://exemple.fr/callback
- client_id: world
"#,
)?;
let config = OAuth2Config::load_from_file("config.yaml")?;
assert_eq!(config.issuer, "https://example.com".parse().unwrap());
assert_eq!(config.clients.len(), 2);
assert_eq!(config.clients[0].client_id, "hello");
assert_eq!(
config.clients[0].redirect_uris,
vec!["https://exemple.fr/callback".parse().unwrap()]
);
assert_eq!(config.clients[1].client_id, "world");
assert_eq!(config.clients[1].redirect_uris, Vec::new());
Ok(())
});
}
}

72
crates/config/src/util.rs Normal file
View File

@ -0,0 +1,72 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::path::Path;
use anyhow::Context;
use async_trait::async_trait;
use figment::{
error::Error as FigmentError,
providers::{Env, Format, Serialized, Yaml},
Figment, Profile,
};
use serde::{Deserialize, Serialize};
#[async_trait]
/// Trait implemented by all configuration section to help loading specific part
/// of the config and generate the sample config.
pub trait ConfigurationSection<'a>: Sized + Deserialize<'a> + Serialize {
/// Specify where this section should live relative to the root.
fn path() -> &'static str;
/// Generate a sample configuration for this section.
async fn generate() -> anyhow::Result<Self>;
/// Generate a sample configuration and override it with environment
/// variables.
///
/// This is what backs the `config generate` subcommand, allowing to
/// programatically generate a configuration file, e.g.
///
/// ```sh
/// export MAS_OAUTH2_ISSUER=https://example.com/
/// export MAS_HTTP_ADDRESS=127.0.0.1:1234
/// matrix-authentication-service config generate
/// ```
async fn load_and_generate() -> anyhow::Result<Self> {
let base = Self::generate()
.await
.context("could not generate configuration")?;
Figment::new()
.merge(Serialized::from(&base, Profile::Default))
.merge(Env::prefixed("MAS_").split("_"))
.extract_inner(Self::path())
.context("could not load configuration")
}
/// Load configuration from a file and environment variables.
fn load_from_file<P>(path: P) -> Result<Self, FigmentError>
where
P: AsRef<Path>,
{
Figment::new()
.merge(Env::prefixed("MAS_").split("_"))
.merge(Yaml::file(path))
.extract_inner(Self::path())
}
/// Generate config used in unit tests
fn test() -> Self;
}

77
crates/core/Cargo.toml Normal file
View File

@ -0,0 +1,77 @@
[package]
name = "mas-core"
version = "0.1.0"
authors = ["Quentin Gliech <quenting@element.io>"]
edition = "2018"
license = "Apache-2.0"
[dependencies]
# Async runtime
tokio = { version = "1.11.0", features = ["full"] }
async-trait = "0.1.51"
tokio-stream = "0.1.7"
futures-util = "0.3.17"
# Logging and tracing
tracing = "0.1.27"
# Error management
thiserror = "1.0.29"
anyhow = "1.0.44"
# Web server
warp = "0.3.1"
hyper = { version = "0.14.12", features = ["full"] }
# Template engine
tera = "1.12.1"
# Database access
sqlx = { version = "0.5.7", features = ["runtime-tokio-rustls", "postgres", "migrate", "chrono", "offline"] }
# Various structure (de)serialization
serde = { version = "1.0.130", features = ["derive"] }
serde_yaml = "0.8.21"
serde_with = { version = "1.10.0", features = ["hex", "chrono"] }
serde_json = "1.0.68"
serde_urlencoded = "0.7.0"
# Argument & config parsing
figment = { version = "0.10.6", features = ["env", "yaml", "test"] }
schemars = { version = "0.8.3", features = ["url", "chrono"] }
# Password hashing
argon2 = { version = "0.3.1", features = ["password-hash"] }
password-hash = { version = "0.3.2", features = ["std"] }
# Crypto, hashing and signing stuff
rsa = "0.5.0"
k256 = "0.9.6"
pkcs8 = { version = "0.7.6", features = ["pem"] }
elliptic-curve = { version = "0.10.6", features = ["pem"] }
chacha20poly1305 = { version = "0.9.0", features = ["std"] }
sha2 = "0.9.8"
crc = "2.0.0"
# Various data types and utilities
data-encoding = "2.3.2"
chrono = { version = "0.4.19", features = ["serde"] }
url = { version = "2.2.2", features = ["serde"] }
itertools = "0.10.1"
mime = "0.3.16"
rand = "0.8.4"
bincode = "1.3.3"
headers = "0.3.4"
cookie = "0.15.1"
oauth2-types = { path = "../oauth2-types", features = ["sqlx_type"] }
mas-config = { path = "../config" }
[dependencies.jwt-compact]
# Waiting on the next release because of the bump of the `rsa` dependency
git = "https://github.com/slowli/jwt-compact.git"
rev = "7a6dee6824c1d4e7c7f81019c9a968e5c9e44923"
features = ["rsa", "k256"]
[dev-dependencies]
indoc = "1.0.3"

View File

@ -0,0 +1,15 @@
-- Copyright 2021 The Matrix.org Foundation C.I.C.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
DROP FUNCTION IF EXISTS trigger_set_timestamp();

View File

@ -0,0 +1,21 @@
-- Copyright 2021 The Matrix.org Foundation C.I.C.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
CREATE OR REPLACE FUNCTION trigger_set_timestamp()
RETURNS TRIGGER AS $$
BEGIN
NEW.updated_at = NOW();
RETURN NEW;
END;
$$ LANGUAGE plpgsql;

View File

@ -0,0 +1,16 @@
-- Copyright 2021 The Matrix.org Foundation C.I.C.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
DROP TRIGGER set_timestamp ON users;
DROP TABLE users;

View File

@ -0,0 +1,26 @@
-- Copyright 2021 The Matrix.org Foundation C.I.C.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
CREATE TABLE users (
"id" BIGSERIAL PRIMARY KEY,
"username" TEXT NOT NULL UNIQUE,
"hashed_password" TEXT NOT NULL,
"created_at" TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT now(),
"updated_at" TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT now()
);
CREATE TRIGGER set_timestamp
BEFORE UPDATE ON users
FOR EACH ROW
EXECUTE PROCEDURE trigger_set_timestamp();

View File

@ -0,0 +1,17 @@
-- Copyright 2021 The Matrix.org Foundation C.I.C.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
DROP TRIGGER set_timestamp ON user_sessions;
DROP TABLE user_session_authentications;
DROP TABLE user_sessions;

View File

@ -0,0 +1,35 @@
-- Copyright 2021 The Matrix.org Foundation C.I.C.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
-- A logged in session
CREATE TABLE user_sessions (
"id" BIGSERIAL PRIMARY KEY,
"user_id" BIGINT NOT NULL REFERENCES users (id) ON DELETE CASCADE,
"active" BOOLEAN NOT NULL DEFAULT TRUE,
"created_at" TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT now(),
"updated_at" TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT now()
);
CREATE TRIGGER set_timestamp
BEFORE UPDATE ON user_sessions
FOR EACH ROW
EXECUTE PROCEDURE trigger_set_timestamp();
-- An authentication within a session
CREATE TABLE user_session_authentications (
"id" BIGSERIAL PRIMARY KEY,
"session_id" BIGINT NOT NULL REFERENCES user_sessions (id) ON DELETE CASCADE,
"created_at" TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT now()
);

View File

@ -0,0 +1,17 @@
-- Copyright 2021 The Matrix.org Foundation C.I.C.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
DROP TRIGGER set_timestamp ON oauth2_sessions;
DROP TABLE oauth2_codes;
DROP TABLE oauth2_sessions;

View File

@ -0,0 +1,45 @@
-- Copyright 2021 The Matrix.org Foundation C.I.C.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
CREATE TABLE oauth2_sessions (
"id" BIGSERIAL PRIMARY KEY,
"user_session_id" BIGINT REFERENCES user_sessions (id) ON DELETE CASCADE,
"client_id" TEXT NOT NULL,
"redirect_uri" TEXT NOT NULL,
"scope" TEXT NOT NULL,
"state" TEXT,
"nonce" TEXT,
"max_age" INT,
"response_type" TEXT NOT NULL,
"response_mode" TEXT NOT NULL,
"created_at" TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT now(),
"updated_at" TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT now()
);
CREATE TRIGGER set_timestamp
BEFORE UPDATE ON oauth2_sessions
FOR EACH ROW
EXECUTE PROCEDURE trigger_set_timestamp();
CREATE TABLE oauth2_codes (
"id" BIGSERIAL PRIMARY KEY,
"oauth2_session_id" BIGINT NOT NULL REFERENCES oauth2_sessions (id) ON DELETE CASCADE,
"code" TEXT UNIQUE NOT NULL,
"code_challenge_method" SMALLINT,
"code_challenge" TEXT,
CHECK (("code_challenge" IS NULL AND "code_challenge_method" IS NULL)
OR ("code_challenge" IS NOT NULL AND "code_challenge_method" IS NOT NULL))
);

View File

@ -0,0 +1,15 @@
-- Copyright 2021 The Matrix.org Foundation C.I.C.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
DROP TABLE oauth2_access_tokens;

View File

@ -0,0 +1,23 @@
-- Copyright 2021 The Matrix.org Foundation C.I.C.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
CREATE TABLE oauth2_access_tokens (
"id" BIGSERIAL PRIMARY KEY,
"oauth2_session_id" BIGINT NOT NULL REFERENCES oauth2_sessions (id) ON DELETE CASCADE,
"token" TEXT UNIQUE NOT NULL,
"expires_after" INT NOT NULL,
"created_at" TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT now()
);

View File

@ -0,0 +1,16 @@
-- Copyright 2021 The Matrix.org Foundation C.I.C.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
DROP TRIGGER set_timestamp ON oauth2_refresh_tokens;
DROP TABLE oauth2_refresh_tokens;

View File

@ -0,0 +1,30 @@
-- Copyright 2021 The Matrix.org Foundation C.I.C.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
CREATE TABLE oauth2_refresh_tokens (
"id" BIGSERIAL PRIMARY KEY,
"oauth2_session_id" BIGINT NOT NULL REFERENCES oauth2_sessions (id) ON DELETE CASCADE,
"oauth2_access_token_id" BIGINT REFERENCES oauth2_access_tokens (id) ON DELETE SET NULL,
"token" TEXT UNIQUE NOT NULL,
"next_token_id" BIGINT REFERENCES oauth2_refresh_tokens (id),
"created_at" TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT now(),
"updated_at" TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT now()
);
CREATE TRIGGER set_timestamp
BEFORE UPDATE ON oauth2_refresh_tokens
FOR EACH ROW
EXECUTE PROCEDURE trigger_set_timestamp();

771
crates/core/sqlx-data.json Normal file
View File

@ -0,0 +1,771 @@
{
"db": "PostgreSQL",
"037ba804eabd0b4290d87d1de37054f358eb11397d3a8e4b69a81cdce0a178e0": {
"query": "\n SELECT id, username\n FROM users\n WHERE username = $1\n ",
"describe": {
"columns": [
{
"ordinal": 0,
"name": "id",
"type_info": "Int8"
},
{
"ordinal": 1,
"name": "username",
"type_info": "Text"
}
],
"parameters": {
"Left": [
"Text"
]
},
"nullable": [
false,
false
]
}
},
"138c3297a66107d8428ca10d04f9a4dd75faf9c1d3f84bcedd3b09f55dd84206": {
"query": "\n INSERT INTO oauth2_codes\n (oauth2_session_id, code, code_challenge_method, code_challenge)\n VALUES\n ($1, $2, $3, $4)\n RETURNING\n id, oauth2_session_id, code, code_challenge_method, code_challenge\n ",
"describe": {
"columns": [
{
"ordinal": 0,
"name": "id",
"type_info": "Int8"
},
{
"ordinal": 1,
"name": "oauth2_session_id",
"type_info": "Int8"
},
{
"ordinal": 2,
"name": "code",
"type_info": "Text"
},
{
"ordinal": 3,
"name": "code_challenge_method",
"type_info": "Int2"
},
{
"ordinal": 4,
"name": "code_challenge",
"type_info": "Text"
}
],
"parameters": {
"Left": [
"Int8",
"Text",
"Int2",
"Text"
]
},
"nullable": [
false,
false,
false,
true,
true
]
}
},
"17729fd0354a84e04bfcd525db6575ed2ba75dd730bea3f2be964f4b347dd484": {
"query": "\n SELECT code\n FROM oauth2_codes\n WHERE oauth2_session_id = $1\n ",
"describe": {
"columns": [
{
"ordinal": 0,
"name": "code",
"type_info": "Text"
}
],
"parameters": {
"Left": [
"Int8"
]
},
"nullable": [
false
]
}
},
"35bedaa6fdf7ac91d54b458b4637f2182c2f82be3e2f80cd2db934ee279a7f2a": {
"query": "\n SELECT id, username\n FROM users\n WHERE id = $1\n ",
"describe": {
"columns": [
{
"ordinal": 0,
"name": "id",
"type_info": "Int8"
},
{
"ordinal": 1,
"name": "username",
"type_info": "Text"
}
],
"parameters": {
"Left": [
"Int8"
]
},
"nullable": [
false,
false
]
}
},
"49888f812910633b87ce65c277f8969377fe264be154d8aa6b33d861d26d2b3b": {
"query": "\n SELECT\n u.username AS \"username!\",\n us.active AS \"active!\",\n os.client_id AS \"client_id!\",\n os.scope AS \"scope!\",\n at.created_at AS \"created_at!\",\n at.expires_after AS \"expires_after!\"\n FROM oauth2_access_tokens at\n INNER JOIN oauth2_sessions os\n ON os.id = at.oauth2_session_id\n INNER JOIN user_sessions us\n ON us.id = os.user_session_id\n INNER JOIN users u\n ON u.id = us.user_id\n WHERE at.token = $1\n ",
"describe": {
"columns": [
{
"ordinal": 0,
"name": "username!",
"type_info": "Text"
},
{
"ordinal": 1,
"name": "active!",
"type_info": "Bool"
},
{
"ordinal": 2,
"name": "client_id!",
"type_info": "Text"
},
{
"ordinal": 3,
"name": "scope!",
"type_info": "Text"
},
{
"ordinal": 4,
"name": "created_at!",
"type_info": "Timestamptz"
},
{
"ordinal": 5,
"name": "expires_after!",
"type_info": "Int4"
}
],
"parameters": {
"Left": [
"Text"
]
},
"nullable": [
false,
false,
false,
false,
false,
false
]
}
},
"4f925a277d73df779360f81e0cf5d7983b50ebe744f461559dd561b7e36c20d4": {
"query": "\n SELECT\n s.id,\n u.id as user_id,\n u.username,\n s.active,\n s.created_at,\n a.created_at as \"last_authd_at?\"\n FROM user_sessions s\n INNER JOIN users u \n ON s.user_id = u.id\n LEFT JOIN user_session_authentications a\n ON a.session_id = s.id\n WHERE s.id = $1 AND s.active\n ORDER BY a.created_at DESC\n LIMIT 1\n ",
"describe": {
"columns": [
{
"ordinal": 0,
"name": "id",
"type_info": "Int8"
},
{
"ordinal": 1,
"name": "user_id",
"type_info": "Int8"
},
{
"ordinal": 2,
"name": "username",
"type_info": "Text"
},
{
"ordinal": 3,
"name": "active",
"type_info": "Bool"
},
{
"ordinal": 4,
"name": "created_at",
"type_info": "Timestamptz"
},
{
"ordinal": 5,
"name": "last_authd_at?",
"type_info": "Timestamptz"
}
],
"parameters": {
"Left": [
"Int8"
]
},
"nullable": [
false,
false,
false,
false,
false,
false
]
}
},
"562b0d4dcf857e99c20e9288e9c8bd46232290715c0d2459b0398a1c746cf65d": {
"query": "\n SELECT\n rt.id,\n rt.oauth2_session_id,\n rt.oauth2_access_token_id,\n os.client_id AS \"client_id!\",\n os.scope AS \"scope!\"\n FROM oauth2_refresh_tokens rt\n INNER JOIN oauth2_sessions os\n ON os.id = rt.oauth2_session_id\n WHERE rt.token = $1 AND rt.next_token_id IS NULL\n ",
"describe": {
"columns": [
{
"ordinal": 0,
"name": "id",
"type_info": "Int8"
},
{
"ordinal": 1,
"name": "oauth2_session_id",
"type_info": "Int8"
},
{
"ordinal": 2,
"name": "oauth2_access_token_id",
"type_info": "Int8"
},
{
"ordinal": 3,
"name": "client_id!",
"type_info": "Text"
},
{
"ordinal": 4,
"name": "scope!",
"type_info": "Text"
}
],
"parameters": {
"Left": [
"Text"
]
},
"nullable": [
false,
false,
true,
false,
false
]
}
},
"5d1a17b2ad6153217551ae31549ad9d62cc39d2f9a4e62a7ccb60fd91e0ac685": {
"query": "\n DELETE FROM oauth2_access_tokens\n WHERE created_at + (expires_after * INTERVAL '1 second') + INTERVAL '15 minutes' < now()\n ",
"describe": {
"columns": [],
"parameters": {
"Left": []
},
"nullable": []
}
},
"62986972431bfc4649e3d8c8c7648f9049c4197773e53496422ad8b8aa15b459": {
"query": "\n SELECT\n s.id,\n u.id as user_id,\n u.username,\n s.active,\n s.created_at,\n a.created_at as \"last_authd_at?\"\n FROM user_sessions s\n INNER JOIN users u \n ON s.user_id = u.id\n LEFT JOIN user_session_authentications a\n ON a.session_id = s.id\n WHERE s.id = $1\n ORDER BY a.created_at DESC\n LIMIT 1\n ",
"describe": {
"columns": [
{
"ordinal": 0,
"name": "id",
"type_info": "Int8"
},
{
"ordinal": 1,
"name": "user_id",
"type_info": "Int8"
},
{
"ordinal": 2,
"name": "username",
"type_info": "Text"
},
{
"ordinal": 3,
"name": "active",
"type_info": "Bool"
},
{
"ordinal": 4,
"name": "created_at",
"type_info": "Timestamptz"
},
{
"ordinal": 5,
"name": "last_authd_at?",
"type_info": "Timestamptz"
}
],
"parameters": {
"Left": [
"Int8"
]
},
"nullable": [
false,
false,
false,
false,
false,
false
]
}
},
"73f2d928f7bf88af79a3685bd6346652b4e4454b0ce75e38343840c9765e3f27": {
"query": "\n INSERT INTO oauth2_refresh_tokens\n (oauth2_session_id, oauth2_access_token_id, token)\n VALUES\n ($1, $2, $3)\n RETURNING\n id, oauth2_session_id, oauth2_access_token_id, token, next_token_id, \n created_at, updated_at\n ",
"describe": {
"columns": [
{
"ordinal": 0,
"name": "id",
"type_info": "Int8"
},
{
"ordinal": 1,
"name": "oauth2_session_id",
"type_info": "Int8"
},
{
"ordinal": 2,
"name": "oauth2_access_token_id",
"type_info": "Int8"
},
{
"ordinal": 3,
"name": "token",
"type_info": "Text"
},
{
"ordinal": 4,
"name": "next_token_id",
"type_info": "Int8"
},
{
"ordinal": 5,
"name": "created_at",
"type_info": "Timestamptz"
},
{
"ordinal": 6,
"name": "updated_at",
"type_info": "Timestamptz"
}
],
"parameters": {
"Left": [
"Int8",
"Int8",
"Text"
]
},
"nullable": [
false,
false,
true,
false,
true,
false,
false
]
}
},
"886dee6a6f1f426f0e891790bbeffbc222fd75d8da0a107e7de673f1cc445f30": {
"query": "\n SELECT\n oc.id,\n os.id AS \"oauth2_session_id!\",\n os.client_id AS \"client_id!\",\n os.redirect_uri,\n os.scope AS \"scope!\",\n os.nonce\n FROM oauth2_codes oc\n INNER JOIN oauth2_sessions os\n ON os.id = oc.oauth2_session_id\n WHERE oc.code = $1\n ",
"describe": {
"columns": [
{
"ordinal": 0,
"name": "id",
"type_info": "Int8"
},
{
"ordinal": 1,
"name": "oauth2_session_id!",
"type_info": "Int8"
},
{
"ordinal": 2,
"name": "client_id!",
"type_info": "Text"
},
{
"ordinal": 3,
"name": "redirect_uri",
"type_info": "Text"
},
{
"ordinal": 4,
"name": "scope!",
"type_info": "Text"
},
{
"ordinal": 5,
"name": "nonce",
"type_info": "Text"
}
],
"parameters": {
"Left": [
"Text"
]
},
"nullable": [
false,
false,
false,
false,
false,
true
]
}
},
"88ac8783bd5881c42eafd9cf87a16fe6031f3153fd6a8618e689694584aeb2de": {
"query": "\n DELETE FROM oauth2_access_tokens\n WHERE id = $1\n ",
"describe": {
"columns": [],
"parameters": {
"Left": [
"Int8"
]
},
"nullable": []
}
},
"9ba45ab114b656105cc46b0c10fb05769860fcdc05eaf54d6225640fb914dab9": {
"query": "\n INSERT INTO user_session_authentications (session_id)\n VALUES ($1)\n RETURNING created_at\n ",
"describe": {
"columns": [
{
"ordinal": 0,
"name": "created_at",
"type_info": "Timestamptz"
}
],
"parameters": {
"Left": [
"Int8"
]
},
"nullable": [
false
]
}
},
"a09dfe1019110f2ec6eba0d35bafa467ab4b7980dd8b556826f03863f8edb0ab": {
"query": "UPDATE user_sessions SET active = FALSE WHERE id = $1",
"describe": {
"columns": [],
"parameters": {
"Left": [
"Int8"
]
},
"nullable": []
}
},
"a552eee8a8e5ffdee4d4789c634851bd64780dfe730807aac20142d7cd643814": {
"query": "\n SELECT u.hashed_password\n FROM user_sessions s\n INNER JOIN users u\n ON u.id = s.user_id \n WHERE s.id = $1\n ",
"describe": {
"columns": [
{
"ordinal": 0,
"name": "hashed_password",
"type_info": "Text"
}
],
"parameters": {
"Left": [
"Int8"
]
},
"nullable": [
false
]
}
},
"a6eb935107d060dd01bf9824ceff87b9ff5492b58cefef002a49f444d3a3daa1": {
"query": "UPDATE oauth2_sessions SET user_session_id = $1 WHERE id = $2",
"describe": {
"columns": [],
"parameters": {
"Left": [
"Int8",
"Int8"
]
},
"nullable": []
}
},
"b766b2b41d8770b5bef9928bb3b96abbaf8466b473e12b21f145c015b7cf2f05": {
"query": "\n INSERT INTO oauth2_access_tokens\n (oauth2_session_id, token, expires_after)\n VALUES\n ($1, $2, $3)\n RETURNING\n id, oauth2_session_id, token, expires_after, created_at\n ",
"describe": {
"columns": [
{
"ordinal": 0,
"name": "id",
"type_info": "Int8"
},
{
"ordinal": 1,
"name": "oauth2_session_id",
"type_info": "Int8"
},
{
"ordinal": 2,
"name": "token",
"type_info": "Text"
},
{
"ordinal": 3,
"name": "expires_after",
"type_info": "Int4"
},
{
"ordinal": 4,
"name": "created_at",
"type_info": "Timestamptz"
}
],
"parameters": {
"Left": [
"Int8",
"Text",
"Int4"
]
},
"nullable": [
false,
false,
false,
false,
false
]
}
},
"c2c402cfe0adcafa615f14a499caba4c96ca71d9ffb163e1feb05e5d85f3462c": {
"query": "\n UPDATE oauth2_refresh_tokens\n SET next_token_id = $2\n WHERE id = $1\n ",
"describe": {
"columns": [],
"parameters": {
"Left": [
"Int8",
"Int8"
]
},
"nullable": []
}
},
"cacec823f5d4ed886854fbd62b5f5bb2def792582df58c8a047c769d34d9b190": {
"query": "\n INSERT INTO oauth2_sessions\n (user_session_id, client_id, redirect_uri, scope, state, nonce, max_age,\n response_type, response_mode)\n VALUES\n ($1, $2, $3, $4, $5, $6, $7, $8, $9)\n RETURNING\n id, user_session_id, client_id, redirect_uri, scope, state, nonce, max_age,\n response_type, response_mode, created_at, updated_at\n ",
"describe": {
"columns": [
{
"ordinal": 0,
"name": "id",
"type_info": "Int8"
},
{
"ordinal": 1,
"name": "user_session_id",
"type_info": "Int8"
},
{
"ordinal": 2,
"name": "client_id",
"type_info": "Text"
},
{
"ordinal": 3,
"name": "redirect_uri",
"type_info": "Text"
},
{
"ordinal": 4,
"name": "scope",
"type_info": "Text"
},
{
"ordinal": 5,
"name": "state",
"type_info": "Text"
},
{
"ordinal": 6,
"name": "nonce",
"type_info": "Text"
},
{
"ordinal": 7,
"name": "max_age",
"type_info": "Int4"
},
{
"ordinal": 8,
"name": "response_type",
"type_info": "Text"
},
{
"ordinal": 9,
"name": "response_mode",
"type_info": "Text"
},
{
"ordinal": 10,
"name": "created_at",
"type_info": "Timestamptz"
},
{
"ordinal": 11,
"name": "updated_at",
"type_info": "Timestamptz"
}
],
"parameters": {
"Left": [
"Int8",
"Text",
"Text",
"Text",
"Text",
"Text",
"Int4",
"Text",
"Text"
]
},
"nullable": [
false,
true,
false,
false,
false,
true,
true,
true,
false,
false,
false,
false
]
}
},
"f9a09ff53b6f221649f4f050e3d5ade114f852ddf50a78610a6c0ef0689af681": {
"query": "\n INSERT INTO users (username, hashed_password)\n VALUES ($1, $2)\n RETURNING id\n ",
"describe": {
"columns": [
{
"ordinal": 0,
"name": "id",
"type_info": "Int8"
}
],
"parameters": {
"Left": [
"Text",
"Text"
]
},
"nullable": [
false
]
}
},
"ff515ebb80ba4af1948472f5c7120a03e25b1ebe42151b8a2036bfbb042f17f6": {
"query": "\n SELECT\n id, user_session_id, client_id, redirect_uri, scope, state, nonce,\n max_age, response_type, response_mode, created_at, updated_at\n FROM oauth2_sessions\n WHERE id = $1\n ",
"describe": {
"columns": [
{
"ordinal": 0,
"name": "id",
"type_info": "Int8"
},
{
"ordinal": 1,
"name": "user_session_id",
"type_info": "Int8"
},
{
"ordinal": 2,
"name": "client_id",
"type_info": "Text"
},
{
"ordinal": 3,
"name": "redirect_uri",
"type_info": "Text"
},
{
"ordinal": 4,
"name": "scope",
"type_info": "Text"
},
{
"ordinal": 5,
"name": "state",
"type_info": "Text"
},
{
"ordinal": 6,
"name": "nonce",
"type_info": "Text"
},
{
"ordinal": 7,
"name": "max_age",
"type_info": "Int4"
},
{
"ordinal": 8,
"name": "response_type",
"type_info": "Text"
},
{
"ordinal": 9,
"name": "response_mode",
"type_info": "Text"
},
{
"ordinal": 10,
"name": "created_at",
"type_info": "Timestamptz"
},
{
"ordinal": 11,
"name": "updated_at",
"type_info": "Timestamptz"
}
],
"parameters": {
"Left": [
"Int8"
]
},
"nullable": [
false,
true,
false,
false,
false,
true,
true,
true,
false,
false,
false,
false
]
}
}
}

128
crates/core/src/errors.rs Normal file
View File

@ -0,0 +1,128 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::{collections::HashMap, fmt::Debug, hash::Hash};
use serde::{ser::SerializeMap, Serialize};
use warp::{reject::Reject, Rejection};
#[derive(Debug)]
pub struct WrappedError(anyhow::Error);
impl warp::reject::Reject for WrappedError {}
pub trait WrapError<T> {
fn wrap_error(self) -> Result<T, Rejection>;
}
impl<T, E> WrapError<T> for Result<T, E>
where
E: Into<anyhow::Error>,
{
fn wrap_error(self) -> Result<T, Rejection> {
self.map_err(|e| warp::reject::custom(WrappedError(e.into())))
}
}
pub trait HtmlError: Debug + Send + Sync + 'static {
fn html_display(&self) -> String;
}
pub trait WrapFormError<FieldType> {
fn on_form(self) -> ErroredForm<FieldType>;
fn on_field(self, field: FieldType) -> ErroredForm<FieldType>;
}
impl<E, FieldType> WrapFormError<FieldType> for E
where
E: HtmlError,
{
fn on_form(self) -> ErroredForm<FieldType> {
let mut f = ErroredForm::new();
f.form.push(FormError {
error: Box::new(self),
});
f
}
fn on_field(self, field: FieldType) -> ErroredForm<FieldType> {
let mut f = ErroredForm::new();
f.fields.push(FieldError {
field,
error: Box::new(self),
});
f
}
}
#[derive(Debug)]
struct FormError {
error: Box<dyn HtmlError>,
}
impl Serialize for FormError {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
serializer.serialize_str(&self.error.html_display())
}
}
#[derive(Debug)]
struct FieldError<FieldType> {
field: FieldType,
error: Box<dyn HtmlError>,
}
#[derive(Debug, Default)]
pub struct ErroredForm<FieldType> {
form: Vec<FormError>,
fields: Vec<FieldError<FieldType>>,
}
impl<T> ErroredForm<T> {
#[must_use] pub fn new() -> Self {
Self {
form: Vec::new(),
fields: Vec::new(),
}
}
}
impl<T> Reject for ErroredForm<T> where T: Debug + Send + Sync + 'static {}
impl<FieldType: Copy + Serialize + Hash + Eq> Serialize for ErroredForm<FieldType> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
let mut map = serializer.serialize_map(Some(2))?;
let has_errors = !self.form.is_empty() || !self.fields.is_empty();
map.serialize_entry("has_errors", &has_errors)?;
map.serialize_entry("form_errors", &self.form)?;
let fields: HashMap<FieldType, Vec<String>> =
self.fields.iter().fold(HashMap::new(), |mut map, err| {
map.entry(err.field)
.or_default()
.push(err.error.html_display());
map
});
map.serialize_entry("fields_errors", &fields)?;
map.end()
}
}

View File

@ -0,0 +1,55 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use chrono::Utc;
use headers::{authorization::Bearer, Authorization};
use sqlx::{pool::PoolConnection, PgPool, Postgres};
use warp::{Filter, Rejection};
use super::{database::with_connection, headers::with_typed_header};
use crate::{
errors::WrapError,
storage::oauth2::access_token::{lookup_access_token, OAuth2AccessTokenLookup},
tokens,
};
pub fn with_authentication(
pool: &PgPool,
) -> impl Filter<Extract = (OAuth2AccessTokenLookup,), Error = Rejection> + Clone + Send + Sync + 'static
{
with_connection(pool)
.and(with_typed_header())
.and_then(authenticate)
}
async fn authenticate(
mut conn: PoolConnection<Postgres>,
auth: Authorization<Bearer>,
) -> Result<OAuth2AccessTokenLookup, Rejection> {
let token = auth.0.token();
let token_type = tokens::check(token).wrap_error()?;
if token_type != tokens::TokenType::AccessToken {
return Err(anyhow::anyhow!("wrong token type")).wrap_error();
}
let token = lookup_access_token(&mut conn, token).await.wrap_error()?;
let exp = token.exp();
// Check it is active and did not expire
if !token.active || exp < Utc::now() {
return Err(anyhow::anyhow!("token expired")).wrap_error();
}
Ok(token)
}

View File

@ -0,0 +1,213 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use headers::{authorization::Basic, Authorization};
use serde::{de::DeserializeOwned, Deserialize};
use thiserror::Error;
use warp::{reject::Reject, Filter, Rejection};
use super::headers::with_typed_header;
use crate::config::{OAuth2ClientConfig, OAuth2Config};
#[derive(Debug, PartialEq, Eq)]
pub enum ClientAuthentication {
ClientSecretBasic,
ClientSecretPost,
None,
}
impl ClientAuthentication {
#[must_use]
pub fn public(&self) -> bool {
matches!(self, &Self::None)
}
}
#[must_use]
pub fn with_client_auth<T: DeserializeOwned + Send + 'static>(
oauth2_config: &OAuth2Config,
) -> impl Filter<Extract = (ClientAuthentication, OAuth2ClientConfig, T), Error = Rejection>
+ Clone
+ Send
+ Sync
+ 'static {
// First, extract the client credentials
let credentials = with_typed_header()
.and(warp::body::form())
// Either from the "Authorization" header
.map(|auth: Authorization<Basic>, body: T| {
let client_id = auth.0.username().to_string();
let client_secret = Some(auth.0.password().to_string());
(
ClientAuthentication::ClientSecretBasic,
client_id,
client_secret,
body,
)
})
// Or from the form body
.or(warp::body::form().map(|form: ClientAuthForm<T>| {
let ClientAuthForm {
client_id,
client_secret,
body,
} = form;
let auth_type = if client_secret.is_some() {
ClientAuthentication::ClientSecretPost
} else {
ClientAuthentication::None
};
(auth_type, client_id, client_secret, body)
}))
.unify()
.untuple_one();
let clients = oauth2_config.clients.clone();
warp::any()
.map(move || clients.clone())
.and(credentials)
.and_then(authenticate_client)
.untuple_one()
}
#[derive(Error, Debug)]
enum ClientAuthenticationError {
#[error("no client secret found for client {client_id:?}")]
NoClientSecret { client_id: String },
#[error("wrong client secret for client {client_id:?}")]
ClientSecretMismatch { client_id: String },
#[error("could not find client {client_id:?}")]
ClientNotFound { client_id: String },
#[error("client secret required for client {client_id:?}")]
ClientSecretRequired { client_id: String },
}
impl Reject for ClientAuthenticationError {}
async fn authenticate_client<T>(
clients: Vec<OAuth2ClientConfig>,
auth_type: ClientAuthentication,
client_id: String,
client_secret: Option<String>,
body: T,
) -> Result<(ClientAuthentication, OAuth2ClientConfig, T), Rejection> {
let client = clients
.iter()
.find(|client| client.client_id == client_id)
.ok_or_else(|| ClientAuthenticationError::ClientNotFound {
client_id: client_id.to_string(),
})?;
let client = match (client_secret, client.client_secret.as_ref()) {
(None, None) => Ok(client),
(Some(ref given), Some(expected)) if given == expected => Ok(client),
(Some(_), Some(_)) => Err(ClientAuthenticationError::ClientSecretMismatch { client_id }),
(Some(_), None) => Err(ClientAuthenticationError::NoClientSecret { client_id }),
(None, Some(_)) => Err(ClientAuthenticationError::ClientSecretRequired { client_id }),
}?;
Ok((auth_type, client.clone(), body))
}
#[derive(Deserialize)]
struct ClientAuthForm<T> {
client_id: String,
client_secret: Option<String>,
#[serde(flatten)]
body: T,
}
#[cfg(test)]
mod tests {
use mas_config::ConfigurationSection;
use super::*;
fn oauth2_config() -> OAuth2Config {
let mut config = OAuth2Config::test();
config.clients.push(OAuth2ClientConfig {
client_id: "public".to_string(),
client_secret: None,
redirect_uris: Vec::new(),
});
config.clients.push(OAuth2ClientConfig {
client_id: "confidential".to_string(),
client_secret: Some("secret".to_string()),
redirect_uris: Vec::new(),
});
config
}
#[derive(Deserialize)]
struct Form {
foo: String,
bar: String,
}
#[tokio::test]
async fn client_secret_post() {
let filter = with_client_auth::<Form>(&oauth2_config());
let (auth, client, body) = warp::test::request()
.method("POST")
.body("client_id=confidential&client_secret=secret&foo=baz&bar=foobar")
.filter(&filter)
.await
.unwrap();
assert_eq!(auth, ClientAuthentication::ClientSecretPost);
assert_eq!(client.client_id, "confidential");
assert_eq!(body.foo, "baz");
assert_eq!(body.bar, "foobar");
}
#[tokio::test]
async fn client_secret_basic() {
let filter = with_client_auth::<Form>(&oauth2_config());
let (auth, client, body) = warp::test::request()
.method("POST")
.header("Authorization", "Basic Y29uZmlkZW50aWFsOnNlY3JldA==")
.body("foo=baz&bar=foobar")
.filter(&filter)
.await
.unwrap();
assert_eq!(auth, ClientAuthentication::ClientSecretBasic);
assert_eq!(client.client_id, "confidential");
assert_eq!(body.foo, "baz");
assert_eq!(body.bar, "foobar");
}
#[tokio::test]
async fn none() {
let filter = with_client_auth::<Form>(&oauth2_config());
let (auth, client, body) = warp::test::request()
.method("POST")
.body("client_id=public&foo=baz&bar=foobar")
.filter(&filter)
.await
.unwrap();
assert_eq!(auth, ClientAuthentication::None);
assert_eq!(client.client_id, "public");
assert_eq!(body.foo, "baz");
assert_eq!(body.bar, "foobar");
}
}

View File

@ -0,0 +1,134 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::convert::Infallible;
use chacha20poly1305::{
aead::{generic_array::GenericArray, Aead, NewAead},
ChaCha20Poly1305,
};
use cookie::Cookie;
use data_encoding::BASE64URL_NOPAD;
use headers::{Header, HeaderValue, SetCookie};
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use warp::{Filter, Rejection, Reply};
use super::headers::{typed_header, WithTypedHeader};
use crate::{config::CookiesConfig, errors::WrapError};
#[derive(Serialize, Deserialize)]
struct EncryptedCookie {
nonce: [u8; 12],
ciphertext: Vec<u8>,
}
impl EncryptedCookie {
/// Encrypt from a given key
fn encrypt<T: Serialize>(payload: T, key: &[u8; 32]) -> anyhow::Result<Self> {
let key = GenericArray::from_slice(key);
let aead = ChaCha20Poly1305::new(key);
let message = bincode::serialize(&payload)?;
let nonce: [u8; 12] = rand::random();
let ciphertext = aead.encrypt(GenericArray::from_slice(&nonce[..]), &message[..])?;
Ok(Self { nonce, ciphertext })
}
/// Decrypt the content of the cookie from a given key
fn decrypt<T: DeserializeOwned>(&self, key: &[u8; 32]) -> anyhow::Result<T> {
let key = GenericArray::from_slice(key);
let aead = ChaCha20Poly1305::new(key);
let message = aead.decrypt(
GenericArray::from_slice(&self.nonce[..]),
&self.ciphertext[..],
)?;
let token = bincode::deserialize(&message)?;
Ok(token)
}
/// Encode the encrypted cookie to be then saved as a cookie
fn to_cookie_value(&self) -> anyhow::Result<String> {
let raw = bincode::serialize(self)?;
Ok(BASE64URL_NOPAD.encode(&raw))
}
fn from_cookie_value(value: &str) -> anyhow::Result<Self> {
let raw = BASE64URL_NOPAD.decode(value.as_bytes())?;
let content = bincode::deserialize(&raw)?;
Ok(content)
}
}
#[must_use] pub fn maybe_encrypted<T>(
options: &CookiesConfig,
) -> impl Filter<Extract = (Option<T>,), Error = Infallible> + Clone + Send + Sync + 'static
where
T: DeserializeOwned + EncryptableCookieValue + Send + 'static,
{
let secret = options.secret;
warp::cookie::optional(T::cookie_key()).map(move |maybe_value: Option<String>| {
maybe_value
.and_then(|value| EncryptedCookie::from_cookie_value(&value).ok())
.and_then(|encrypted| encrypted.decrypt(&secret).ok())
})
}
#[must_use] pub fn encrypted<T>(
options: &CookiesConfig,
) -> impl Filter<Extract = (T,), Error = Rejection> + Clone + Send + Sync + 'static
where
T: DeserializeOwned + EncryptableCookieValue + Send + 'static,
{
let secret = options.secret;
warp::cookie::cookie(T::cookie_key()).and_then(move |value: String| async move {
let encrypted = EncryptedCookie::from_cookie_value(&value).wrap_error()?;
let decrypted = encrypted.decrypt(&secret).wrap_error()?;
Ok::<_, Rejection>(decrypted)
})
}
#[must_use] pub fn with_cookie_saver(
options: &CookiesConfig,
) -> impl Filter<Extract = (EncryptedCookieSaver,), Error = Infallible> + Clone + Send + Sync + 'static
{
let secret = options.secret;
warp::any().map(move || EncryptedCookieSaver { secret })
}
/// A cookie that can be encrypted with a well-known cookie key
pub trait EncryptableCookieValue {
fn cookie_key() -> &'static str;
}
pub struct EncryptedCookieSaver {
secret: [u8; 32],
}
impl EncryptedCookieSaver {
pub fn save_encrypted<T: Serialize + EncryptableCookieValue, R: Reply>(
&self,
cookie: &T,
reply: R,
) -> Result<WithTypedHeader<R, SetCookie>, Rejection> {
let encrypted = EncryptedCookie::encrypt(cookie, &self.secret)
.wrap_error()?
.to_cookie_value()
.wrap_error()?;
let value = Cookie::build(T::cookie_key(), encrypted)
.finish()
.to_string();
let header = SetCookie::decode(&mut [HeaderValue::from_str(&value).wrap_error()?].iter())
.wrap_error()?;
Ok(typed_header(header, reply))
}
}

View File

@ -0,0 +1,159 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Stateless CSRF protection middleware based on a chacha20-poly1305 encrypted
//! and signed token
use chrono::{DateTime, Duration, Utc};
use data_encoding::{DecodeError, BASE64URL_NOPAD};
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use serde_with::{serde_as, TimestampSeconds};
use thiserror::Error;
use warp::{reject::Reject, Filter, Rejection};
use super::cookies::EncryptableCookieValue;
use crate::config::{CookiesConfig, CsrfConfig};
#[derive(Debug, Error)]
pub enum CsrfError {
#[error("CSRF token mismatch")]
Mismatch,
#[error("CSRF token expired")]
Expired,
#[error("could not decode CSRF token")]
Decode(#[from] DecodeError),
}
impl Reject for CsrfError {}
#[serde_as]
#[derive(Serialize, Deserialize)]
pub struct CsrfToken {
#[serde_as(as = "TimestampSeconds<i64>")]
expiration: DateTime<Utc>,
token: [u8; 32],
}
impl CsrfToken {
/// Create a new token from a defined value valid for a specified duration
fn new(token: [u8; 32], ttl: Duration) -> Self {
let expiration = Utc::now() + ttl;
Self { expiration, token }
}
/// Generate a new random token valid for a specified duration
fn generate(ttl: Duration) -> Self {
let token = rand::random();
Self::new(token, ttl)
}
/// Generate a new token with the same value but an up to date expiration
fn refresh(self, ttl: Duration) -> Self {
Self::new(self.token, ttl)
}
/// Get the value to include in HTML forms
#[must_use] pub fn form_value(&self) -> String {
BASE64URL_NOPAD.encode(&self.token[..])
}
/// Verifies that the value got from an HTML form matches this token
pub fn verify_form_value(&self, form_value: &str) -> Result<(), CsrfError> {
let form_value = BASE64URL_NOPAD.decode(form_value.as_bytes())?;
if self.token[..] == form_value {
Ok(())
} else {
Err(CsrfError::Mismatch)
}
}
fn verify_expiration(self) -> Result<Self, CsrfError> {
if Utc::now() < self.expiration {
Ok(self)
} else {
Err(CsrfError::Expired)
}
}
}
impl EncryptableCookieValue for CsrfToken {
fn cookie_key() -> &'static str {
"csrf"
}
}
/// A CSRF-protected form
#[derive(Deserialize)]
struct CsrfForm<T> {
csrf: String,
#[serde(flatten)]
inner: T,
}
impl<T> CsrfForm<T> {
fn verify_csrf(self, token: &CsrfToken) -> Result<T, CsrfError> {
// Verify CSRF from request
token.verify_form_value(&self.csrf)?;
Ok(self.inner)
}
}
#[must_use] pub fn csrf_token(
cookies_config: &CookiesConfig,
) -> impl Filter<Extract = (CsrfToken,), Error = Rejection> + Clone + Send + Sync + 'static {
super::cookies::encrypted(cookies_config).and_then(move |token: CsrfToken| async move {
let verified = token.verify_expiration()?;
Ok::<_, Rejection>(verified)
})
}
#[must_use] pub fn updated_csrf_token(
cookies_config: &CookiesConfig,
csrf_config: &CsrfConfig,
) -> impl Filter<Extract = (CsrfToken,), Error = Rejection> + Clone + Send + Sync + 'static {
let ttl = csrf_config.ttl;
super::cookies::maybe_encrypted(cookies_config).and_then(
move |maybe_token: Option<CsrfToken>| async move {
// Explicitely specify the "Error" type here to have the `?` operation working
Ok::<_, Rejection>(
maybe_token
// Verify its TTL (but do not hard-error if it expired)
.and_then(|token| token.verify_expiration().ok())
.map_or_else(
// Generate a new token if no valid one were found
|| CsrfToken::generate(ttl),
// Else, refresh the expiration of the token
|token| token.refresh(ttl),
),
)
},
)
}
#[must_use] pub fn protected_form<T>(
cookies_config: &CookiesConfig,
) -> impl Filter<Extract = (T,), Error = Rejection> + Clone + Send + Sync + 'static
where
T: DeserializeOwned + Send + 'static,
{
csrf_token(cookies_config).and(warp::body::form()).and_then(
|csrf_token: CsrfToken, protected_form: CsrfForm<T>| async move {
let form = protected_form.verify_csrf(&csrf_token)?;
Ok::<_, Rejection>(form)
},
)
}

View File

@ -0,0 +1,54 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::convert::Infallible;
use sqlx::{pool::PoolConnection, PgPool, Postgres, Transaction};
use warp::{Filter, Rejection};
use crate::errors::WrapError;
fn with_pool(
pool: &PgPool,
) -> impl Filter<Extract = (PgPool,), Error = Infallible> + Clone + Send + Sync + 'static {
let pool = pool.clone();
warp::any().map(move || pool.clone())
}
pub fn with_connection(
pool: &PgPool,
) -> impl Filter<Extract = (PoolConnection<Postgres>,), Error = Rejection> + Clone + Send + Sync + 'static
{
with_pool(pool).and_then(acquire_connection)
}
async fn acquire_connection(pool: PgPool) -> Result<PoolConnection<Postgres>, Rejection> {
let conn = pool.acquire().await.wrap_error()?;
Ok(conn)
}
pub fn with_transaction(
pool: &PgPool,
) -> impl Filter<Extract = (Transaction<'static, Postgres>,), Error = Rejection>
+ Clone
+ Send
+ Sync
+ 'static {
with_pool(pool).and_then(acquire_transaction)
}
async fn acquire_transaction(pool: PgPool) -> Result<Transaction<'static, Postgres>, Rejection> {
let txn = pool.begin().await.wrap_error()?;
Ok(txn)
}

View File

@ -0,0 +1,200 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::{cmp::Reverse, future::Future, pin::Pin};
use mime::{Mime, STAR};
use serde::Serialize;
use tera::Context;
use tide::{
http::headers::{ACCEPT, LOCATION},
Body, Request, StatusCode,
};
use tracing::debug;
use crate::{state::State, templates::common_context};
/// Get the weight parameter for a mime type from 0 to 1000
#[allow(clippy::cast_sign_loss, clippy::cast_possible_truncation)]
fn get_weight(mime: &Mime) -> usize {
let q = mime
.get_param("q")
.map_or(1.0_f64, |q| q.as_str().parse().unwrap_or(0.0))
.min(1.0)
.max(0.0);
// Weight have a 3 digit precision so we can multiply by 1000 and cast to
// int. Sign loss should not happen here because of the min/max up there and
// truncation does not matter here.
(q * 1000.0) as _
}
/// Find what content type should be used for a given request
fn preferred_mime_type<'a>(
request: &Request<State>,
supported_types: &'a [Mime],
) -> Option<&'a Mime> {
let accept = request.header(ACCEPT)?;
// Parse the Accept header as a list of mime types with their associated
// weight
let accepted_types: Vec<(Mime, usize)> = {
let v: Option<Vec<_>> = accept
.into_iter()
.flat_map(|value| value.as_str().split(','))
.map(|mime| {
mime.trim().parse().ok().map(|mime| {
let q = get_weight(&mime);
(mime, q)
})
})
.collect();
let mut v = v?;
v.sort_by_key(|(_, weight)| Reverse(*weight));
v
};
// For each supported content type, find out if it is accepted with what
// weight and specificity
let mut types: Vec<_> = supported_types
.iter()
.enumerate()
.filter_map(|(index, supported)| {
accepted_types.iter().find_map(|(accepted, weight)| {
if accepted.type_() == supported.type_()
&& accepted.subtype() == supported.subtype()
{
// Accept: text/html
Some((supported, *weight, 2_usize, index))
} else if accepted.type_() == supported.type_() && accepted.subtype() == STAR {
// Accept: text/*
Some((supported, *weight, 1, index))
} else if accepted.type_() == STAR && accepted.subtype() == STAR {
// Accept: */*
Some((supported, *weight, 0, index))
} else {
None
}
})
})
.collect();
types.sort_by_key(|(_, weight, specificity, index)| {
(Reverse(*weight), Reverse(*specificity), *index)
});
types.first().map(|(mime, _, _, _)| *mime)
}
#[derive(Serialize)]
struct ErrorContext {
#[serde(skip_serializing_if = "Option::is_none")]
code: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
description: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
details: Option<String>,
}
impl ErrorContext {
fn should_render(&self) -> bool {
self.code.is_some() || self.description.is_some() || self.details.is_some()
}
}
pub fn middleware<'a>(
request: tide::Request<State>,
next: tide::Next<'a, State>,
) -> Pin<Box<dyn Future<Output = tide::Result> + Send + 'a>> {
Box::pin(async {
let content_type = preferred_mime_type(
&request,
&[mime::TEXT_PLAIN, mime::TEXT_HTML, mime::APPLICATION_JSON],
);
debug!("Content-Type from Accept: {:?}", content_type);
// TODO: We should not clone here
let templates = request.state().templates().clone();
// TODO: This context should probably be comptuted somewhere else
let pctx = common_context(&request).await?.clone();
let mut response = next.run(request).await;
// Find out what message should be displayed from the response status
// code
let (code, description) = match response.status() {
StatusCode::NotFound => (Some("Not found".to_string()), None),
StatusCode::MethodNotAllowed => (Some("Method not allowed".to_string()), None),
StatusCode::Found
| StatusCode::PermanentRedirect
| StatusCode::TemporaryRedirect
| StatusCode::SeeOther => {
let description = response.header(LOCATION).map(|loc| format!("To {}", loc));
(Some("Redirecting".to_string()), description)
}
StatusCode::InternalServerError => (Some("Internal server error".to_string()), None),
_ => (None, None),
};
// If there is an error associated to the response, format it in a nice
// way with a backtrace if we have one
let details = response.take_error().map(|err| {
format!(
"{:?}{}",
err,
err.backtrace()
.map(|bt| format!("\nBacktrace:\n{}", bt.to_string()))
.unwrap_or_default()
)
});
let error_context = ErrorContext {
code,
description,
details,
};
// This is the case if one of the code, description or details is not
// None
if error_context.should_render() {
match content_type {
Some(c) if c == &mime::APPLICATION_JSON => {
response.set_body(Body::from_json(&error_context)?);
response.set_content_type("application/json");
}
Some(c) if c == &mime::TEXT_HTML => {
let mut ctx = Context::from_serialize(&error_context)?;
ctx.extend(pctx);
response.set_body(templates.render("error.html", &ctx)?);
response.set_content_type("text/html");
}
Some(c) if c == &mime::TEXT_PLAIN => {
let mut ctx = Context::from_serialize(&error_context)?;
ctx.extend(pctx);
response.set_body(templates.render("error.txt", &ctx)?);
response.set_content_type("text/plain");
}
_ => {
response.set_body("Unsupported Content-Type in Accept header");
response.set_content_type("text/plain");
response.set_status(StatusCode::NotAcceptable);
}
}
}
Ok(response)
})
}

View File

@ -0,0 +1,50 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
use headers::{Header, HeaderMapExt, HeaderValue};
use warp::{Filter, Rejection, Reply};
use crate::errors::WrapError;
pub fn typed_header<R, H>(header: H, reply: R) -> WithTypedHeader<R, H> {
WithTypedHeader { reply, header }
}
pub struct WithTypedHeader<R, H> {
reply: R,
header: H,
}
impl<R, H> Reply for WithTypedHeader<R, H>
where
R: Reply,
H: Header + Send,
{
fn into_response(self) -> warp::reply::Response {
let mut res = self.reply.into_response();
res.headers_mut().typed_insert(self.header);
res
}
}
pub fn with_typed_header<T: Header + Send + 'static>(
) -> impl Filter<Extract = (T,), Error = Rejection> + Clone + Send + Sync + 'static {
warp::header::value(T::name().as_str()).and_then(decode_typed_header)
}
async fn decode_typed_header<T: Header>(header: HeaderValue) -> Result<T, Rejection> {
let mut it = std::iter::once(&header);
let decoded = T::decode(&mut it).wrap_error()?;
Ok(decoded)
}

View File

@ -0,0 +1,48 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#![allow(clippy::unused_async)] // Some warp filters need that
pub mod csrf;
// mod errors;
pub mod authenticate;
pub mod client;
pub mod cookies;
pub mod database;
pub mod headers;
pub mod session;
use std::convert::Infallible;
use warp::Filter;
pub use self::csrf::CsrfToken;
use crate::{
config::{KeySet, OAuth2Config},
templates::Templates,
};
#[must_use] pub fn with_templates(
templates: &Templates,
) -> impl Filter<Extract = (Templates,), Error = Infallible> + Clone + Send + Sync + 'static {
let templates = templates.clone();
warp::any().map(move || templates.clone())
}
#[must_use] pub fn with_keys(
oauth2_config: &OAuth2Config,
) -> impl Filter<Extract = (KeySet,), Error = Infallible> + Clone + Send + Sync + 'static {
let keyset = oauth2_config.keys.clone();
warp::any().map(move || keyset.clone())
}

View File

@ -0,0 +1,86 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use serde::{Deserialize, Serialize};
use sqlx::{pool::PoolConnection, Executor, PgPool, Postgres};
use warp::{Filter, Rejection};
use super::{
cookies::{encrypted, maybe_encrypted, EncryptableCookieValue},
database::with_connection,
};
use crate::{
config::CookiesConfig,
errors::WrapError,
storage::{lookup_active_session, SessionInfo},
};
#[derive(Serialize, Deserialize)]
pub struct SessionCookie {
current: i64,
}
impl SessionCookie {
#[must_use] pub fn from_session_info(info: &SessionInfo) -> Self {
Self {
current: info.key(),
}
}
pub async fn load_session_info(
&self,
executor: impl Executor<'_, Database = Postgres>,
) -> anyhow::Result<SessionInfo> {
lookup_active_session(executor, self.current).await
}
}
impl EncryptableCookieValue for SessionCookie {
fn cookie_key() -> &'static str {
"session"
}
}
#[must_use] pub fn with_optional_session(
pool: &PgPool,
cookies_config: &CookiesConfig,
) -> impl Filter<Extract = (Option<SessionInfo>,), Error = Rejection> + Clone + Send + Sync + 'static
{
maybe_encrypted(cookies_config)
.and(with_connection(pool))
.and_then(
|maybe_session: Option<SessionCookie>, mut conn: PoolConnection<Postgres>| async move {
let maybe_session_info = if let Some(session) = maybe_session {
session.load_session_info(&mut conn).await.ok()
} else {
None
};
Ok::<_, Rejection>(maybe_session_info)
},
)
}
#[must_use] pub fn with_session(
pool: &PgPool,
cookies_config: &CookiesConfig,
) -> impl Filter<Extract = (SessionInfo,), Error = Rejection> + Clone + Send + Sync + 'static {
encrypted(cookies_config)
.and(with_connection(pool))
.and_then(
|session: SessionCookie, mut conn: PoolConnection<Postgres>| async move {
let session_info = session.load_session_info(&mut conn).await.wrap_error()?;
Ok::<_, Rejection>(session_info)
},
)
}

View File

@ -0,0 +1,41 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use hyper::header::CONTENT_TYPE;
use mime::TEXT_PLAIN;
use sqlx::{pool::PoolConnection, PgPool, Postgres};
use tracing::{info_span, Instrument};
use warp::{reply::with_header, Filter, Rejection, Reply};
use crate::{errors::WrapError, filters::database::with_connection};
pub fn filter(
pool: &PgPool,
) -> impl Filter<Extract = (impl Reply,), Error = Rejection> + Clone + Send + Sync + 'static {
warp::path!("health")
.and(warp::get())
.and(with_connection(pool))
.and_then(get)
}
async fn get(mut conn: PoolConnection<Postgres>) -> Result<impl Reply, Rejection> {
sqlx::query("SELECT $1")
.bind(1_i64)
.execute(&mut conn)
.instrument(info_span!("DB health"))
.await
.wrap_error()?;
Ok(with_header("ok", CONTENT_TYPE, TEXT_PLAIN.to_string()))
}

View File

@ -0,0 +1,43 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#![allow(clippy::unused_async)] // Some warp filters need that
use sqlx::PgPool;
use warp::{Filter, Rejection, Reply};
use crate::{config::RootConfig, templates::Templates};
mod health;
mod oauth2;
mod views;
use self::{health::filter as health, oauth2::filter as oauth2, views::filter as views};
#[must_use] pub fn root(
pool: &PgPool,
templates: &Templates,
config: &RootConfig,
) -> impl Filter<Extract = (impl Reply,), Error = Rejection> + Clone + Send + Sync + 'static {
health(pool)
.or(oauth2(pool, templates, &config.oauth2, &config.cookies))
.or(views(
pool,
templates,
&config.oauth2,
&config.csrf,
&config.cookies,
))
.with(warp::log(module_path!()))
}

View File

@ -0,0 +1,459 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::{
collections::{HashMap, HashSet},
convert::TryFrom,
};
use chrono::Duration;
use hyper::{
header::LOCATION,
http::uri::{Parts, PathAndQuery, Uri},
StatusCode,
};
use itertools::Itertools;
use oauth2_types::{
errors::{ErrorResponse, InvalidRequest, OAuth2Error},
pkce,
requests::{
AccessTokenResponse, AuthorizationRequest, AuthorizationResponse, ResponseMode,
ResponseType,
},
};
use rand::{distributions::Alphanumeric, thread_rng, Rng};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use sqlx::{PgPool, Postgres, Transaction};
use url::Url;
use warp::{
redirect::see_other,
reject::InvalidQuery,
reply::{html, with_header},
Filter, Rejection, Reply,
};
use crate::{
config::{CookiesConfig, OAuth2ClientConfig, OAuth2Config},
errors::WrapError,
filters::{
database::with_transaction,
session::{with_optional_session, with_session},
with_templates,
},
handlers::views::LoginRequest,
storage::{
oauth2::{
access_token::add_access_token,
refresh_token::add_refresh_token,
session::{get_session_by_id, start_session},
},
SessionInfo,
},
templates::{FormPostContext, Templates},
tokens,
};
#[derive(Deserialize)]
struct PartialParams {
client_id: Option<String>,
redirect_uri: Option<String>,
/*
response_type: Option<String>,
response_mode: Option<String>,
*/
}
enum ReplyOrBackToClient {
Reply(Box<dyn Reply>),
BackToClient {
params: Value,
redirect_uri: Url,
response_mode: ResponseMode,
},
Error(Box<dyn OAuth2Error>),
}
fn back_to_client<T>(
mut redirect_uri: Url,
response_mode: ResponseMode,
params: T,
templates: &Templates,
) -> anyhow::Result<Box<dyn Reply>>
where
T: Serialize,
{
#[derive(Serialize)]
struct AllParams<'s, T> {
#[serde(flatten, skip_serializing_if = "Option::is_none")]
existing: Option<HashMap<&'s str, &'s str>>,
#[serde(flatten)]
params: T,
}
match response_mode {
ResponseMode::Query => {
let existing: Option<HashMap<&str, &str>> = redirect_uri
.query()
.map(|qs| serde_urlencoded::from_str(qs))
.transpose()?;
let merged = AllParams { existing, params };
let new_qs = serde_urlencoded::to_string(merged)?;
redirect_uri.set_query(Some(&new_qs));
Ok(Box::new(with_header(
StatusCode::SEE_OTHER,
LOCATION,
redirect_uri.as_str(),
)))
}
ResponseMode::Fragment => {
let existing: Option<HashMap<&str, &str>> = redirect_uri
.fragment()
.map(|qs| serde_urlencoded::from_str(qs))
.transpose()?;
let merged = AllParams { existing, params };
let new_qs = serde_urlencoded::to_string(merged)?;
redirect_uri.set_fragment(Some(&new_qs));
Ok(Box::new(with_header(
StatusCode::SEE_OTHER,
LOCATION,
redirect_uri.as_str(),
)))
}
ResponseMode::FormPost => {
let ctx = FormPostContext::new(redirect_uri, params);
let rendered = templates.render_form_post(&ctx)?;
Ok(Box::new(html(rendered)))
}
}
}
#[derive(Deserialize)]
struct Params {
#[serde(flatten)]
auth: AuthorizationRequest,
#[serde(flatten)]
pkce: Option<pkce::Request>,
}
/// Given a list of response types and an optional user-defined response mode,
/// figure out what response mode must be used, and emit an error if the
/// suggested response mode isn't allowed for the given response types.
fn resolve_response_mode(
response_type: &HashSet<ResponseType>,
suggested_response_mode: Option<ResponseMode>,
) -> anyhow::Result<ResponseMode> {
use ResponseMode as M;
use ResponseType as T;
// If the response type includes either "token" or "id_token", the default
// response mode is "fragment" and the response mode "query" must not be
// used
if response_type.contains(&T::Token) || response_type.contains(&T::IdToken) {
match suggested_response_mode {
None => Ok(M::Fragment),
Some(M::Query) => Err(anyhow::anyhow!("invalid response mode")),
Some(mode) => Ok(mode),
}
} else {
// In other cases, all response modes are allowed, defaulting to "query"
Ok(suggested_response_mode.unwrap_or(M::Query))
}
}
pub fn filter(
pool: &PgPool,
templates: &Templates,
oauth2_config: &OAuth2Config,
cookies_config: &CookiesConfig,
) -> impl Filter<Extract = (impl Reply,), Error = Rejection> + Clone + Send + Sync + 'static {
let clients = oauth2_config.clients.clone();
let authorize = warp::path!("oauth2" / "authorize")
.and(warp::get())
.map(move || clients.clone())
.and(warp::query())
.and(with_optional_session(pool, cookies_config))
.and(with_transaction(pool))
.and_then(get);
let step = warp::path!("oauth2" / "authorize" / "step")
.and(warp::get())
.and(warp::query().map(|s: StepRequest| s.id))
.and(with_session(pool, cookies_config))
.and(with_transaction(pool))
.and_then(step);
let clients = oauth2_config.clients.clone();
authorize
.or(step)
.unify()
.recover(recover)
.unify()
.and(warp::query())
.and(warp::any().map(move || clients.clone()))
.and(with_templates(templates))
.and_then(actually_reply)
}
async fn recover(rejection: Rejection) -> Result<ReplyOrBackToClient, Rejection> {
if rejection.find::<InvalidQuery>().is_some() {
Ok(ReplyOrBackToClient::Error(Box::new(InvalidRequest)))
} else {
Err(rejection)
}
}
async fn actually_reply(
rep: ReplyOrBackToClient,
q: PartialParams,
clients: Vec<OAuth2ClientConfig>,
templates: Templates,
) -> Result<impl Reply, Rejection> {
let (redirect_uri, response_mode, params) = match rep {
ReplyOrBackToClient::Reply(r) => return Ok(r),
ReplyOrBackToClient::BackToClient {
redirect_uri,
response_mode,
params,
} => (redirect_uri, response_mode, params),
ReplyOrBackToClient::Error(error) => {
let PartialParams {
client_id,
redirect_uri,
..
} = q;
// First, disover the client
let client = client_id.and_then(|client_id| {
clients
.into_iter()
.find(|client| client.client_id == client_id)
});
let client = match client {
Some(client) => client,
None => return Ok(Box::new(html(templates.render_error(&error.into())?))),
};
let redirect_uri: Result<Option<Url>, _> = redirect_uri.map(|r| r.parse()).transpose();
let redirect_uri = match redirect_uri {
Ok(r) => r,
Err(_) => return Ok(Box::new(html(templates.render_error(&error.into())?))),
};
let redirect_uri = client.resolve_redirect_uri(&redirect_uri);
let redirect_uri = match redirect_uri {
Ok(r) => r,
Err(_) => return Ok(Box::new(html(templates.render_error(&error.into())?))),
};
let reply: ErrorResponse = error.into();
let reply = serde_json::to_value(&reply).wrap_error()?;
// TODO: resolve response mode
(redirect_uri.clone(), ResponseMode::Query, reply)
}
};
// TODO: we should include the state param in errors
back_to_client(redirect_uri, response_mode, params, &templates).wrap_error()
}
async fn get(
clients: Vec<OAuth2ClientConfig>,
params: Params,
maybe_session: Option<SessionInfo>,
mut txn: Transaction<'_, Postgres>,
) -> Result<ReplyOrBackToClient, Rejection> {
// First, find out what client it is
let client = clients
.into_iter()
.find(|client| client.client_id == params.auth.client_id)
.ok_or_else(|| anyhow::anyhow!("could not find client"))
.wrap_error()?;
let maybe_session_id = maybe_session.as_ref().map(SessionInfo::key);
let scope: String = {
let it = params.auth.scope.iter().map(ToString::to_string);
Itertools::intersperse(it, " ".to_string()).collect()
};
let redirect_uri = client
.resolve_redirect_uri(&params.auth.redirect_uri)
.wrap_error()?;
let response_type = &params.auth.response_type;
let response_mode =
resolve_response_mode(response_type, params.auth.response_mode).wrap_error()?;
let oauth2_session = start_session(
&mut txn,
maybe_session_id,
&client.client_id,
redirect_uri,
&scope,
params.auth.state.as_deref(),
params.auth.nonce.as_deref(),
params.auth.max_age,
response_type,
response_mode,
)
.await
.wrap_error()?;
// Generate the code at this stage, since we have the PKCE params ready
if response_type.contains(&ResponseType::Code) {
// 32 random alphanumeric characters, about 190bit of entropy
let code: String = thread_rng()
.sample_iter(&Alphanumeric)
.take(32)
.map(char::from)
.collect();
oauth2_session
.add_code(&mut txn, &code, &params.pkce)
.await
.wrap_error()?;
};
// Do we already have a user session for this oauth2 session?
let user_session = oauth2_session.fetch_session(&mut txn).await.wrap_error()?;
if let Some(user_session) = user_session {
step(oauth2_session.id, user_session, txn).await
} else {
// If not, redirect the user to the login page
txn.commit().await.wrap_error()?;
let next = StepRequest::new(oauth2_session.id)
.build_uri()
.wrap_error()?
.to_string();
let destination = LoginRequest::new(Some(next)).build_uri().wrap_error()?;
Ok(ReplyOrBackToClient::Reply(Box::new(see_other(destination))))
}
}
#[derive(Deserialize, Serialize)]
struct StepRequest {
id: i64,
}
impl StepRequest {
fn new(id: i64) -> Self {
Self { id }
}
fn build_uri(&self) -> anyhow::Result<Uri> {
let qs = serde_urlencoded::to_string(self)?;
let path_and_query = PathAndQuery::try_from(format!("/oauth2/authorize/step?{}", qs))?;
let uri = Uri::from_parts({
let mut parts = Parts::default();
parts.path_and_query = Some(path_and_query);
parts
})?;
Ok(uri)
}
}
async fn step(
oauth2_session_id: i64,
user_session: SessionInfo,
mut txn: Transaction<'_, Postgres>,
) -> Result<ReplyOrBackToClient, Rejection> {
let mut oauth2_session = get_session_by_id(&mut txn, oauth2_session_id)
.await
.wrap_error()?;
let user_session = oauth2_session
.match_or_set_session(&mut txn, user_session)
.await
.wrap_error()?;
let response_mode = oauth2_session.response_mode().wrap_error()?;
let response_type = oauth2_session.response_type().wrap_error()?;
let redirect_uri = oauth2_session.redirect_uri().wrap_error()?;
// Check if the active session is valid
let reply = if user_session.active
&& user_session.last_authd_at >= oauth2_session.max_auth_time()
{
// Yep! Let's complete the auth now
let mut params = AuthorizationResponse {
state: oauth2_session.state.clone(),
..AuthorizationResponse::default()
};
// Did they request an auth code?
if response_type.contains(&ResponseType::Code) {
params.code = Some(oauth2_session.fetch_code(&mut txn).await.wrap_error()?);
}
// Did they request an access token?
if response_type.contains(&ResponseType::Token) {
let ttl = Duration::minutes(5);
let (access_token, refresh_token) = {
let mut rng = thread_rng();
(
tokens::generate(&mut rng, tokens::TokenType::AccessToken),
tokens::generate(&mut rng, tokens::TokenType::RefreshToken),
)
};
let access_token = add_access_token(&mut txn, oauth2_session_id, &access_token, ttl)
.await
.wrap_error()?;
let refresh_token =
add_refresh_token(&mut txn, oauth2_session_id, access_token.id, &refresh_token)
.await
.wrap_error()?;
params.response = Some(
AccessTokenResponse::new(access_token.token)
.with_expires_in(ttl)
.with_refresh_token(refresh_token.token),
);
}
// Did they request an ID token?
if response_type.contains(&ResponseType::IdToken) {
todo!("id tokens are not implemented yet");
}
let params = serde_json::to_value(&params).unwrap();
ReplyOrBackToClient::BackToClient {
redirect_uri,
response_mode,
params,
}
} else {
// Ask for a reauth
// TODO: have the OAuth2 session ID in there
ReplyOrBackToClient::Reply(Box::new(see_other(Uri::from_static("/reauth"))))
};
txn.commit().await.wrap_error()?;
Ok(reply)
}

View File

@ -0,0 +1,87 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::collections::HashSet;
use oauth2_types::{
oidc::Metadata,
requests::{ClientAuthenticationMethod, GrantType, ResponseMode},
};
use warp::{Filter, Rejection, Reply};
use crate::config::OAuth2Config;
pub(super) fn filter(
config: &OAuth2Config,
) -> impl Filter<Extract = (impl Reply,), Error = Rejection> + Clone + Send + Sync + 'static {
let base = config.issuer.clone();
let response_modes_supported = Some({
let mut s = HashSet::new();
s.insert(ResponseMode::FormPost);
s.insert(ResponseMode::Query);
s.insert(ResponseMode::Fragment);
s
});
let response_types_supported = Some({
let mut s = HashSet::new();
s.insert("code".to_string());
s.insert("token".to_string());
s.insert("id_token".to_string());
s.insert("code token".to_string());
s.insert("code id_token".to_string());
s.insert("token id_token".to_string());
s.insert("code token id_token".to_string());
s
});
let grant_types_supported = Some({
let mut s = HashSet::new();
s.insert(GrantType::AuthorizationCode);
s.insert(GrantType::RefreshToken);
s
});
let token_endpoint_auth_methods_supported = Some({
let mut s = HashSet::new();
s.insert(ClientAuthenticationMethod::ClientSecretBasic);
s.insert(ClientAuthenticationMethod::ClientSecretPost);
s.insert(ClientAuthenticationMethod::None);
s
});
let metadata = Metadata {
authorization_endpoint: base.join("oauth2/authorize").ok(),
token_endpoint: base.join("oauth2/token").ok(),
jwks_uri: base.join("oauth2/keys.json").ok(),
introspection_endpoint: base.join("oauth2/introspect").ok(),
userinfo_endpoint: base.join("oauth2/userinfo").ok(),
issuer: base,
registration_endpoint: None,
scopes_supported: None,
response_types_supported,
response_modes_supported,
grant_types_supported,
token_endpoint_auth_methods_supported,
code_challenge_methods_supported: None,
};
let cors = warp::cors().allow_any_origin();
warp::path!(".well-known" / "openid-configuration")
.and(warp::get())
.map(move || warp::reply::json(&metadata))
.with(cors)
}

View File

@ -0,0 +1,136 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use chrono::Utc;
use oauth2_types::requests::{IntrospectionRequest, IntrospectionResponse, TokenTypeHint};
use sqlx::{pool::PoolConnection, PgPool, Postgres};
use tracing::{info, warn};
use warp::{Filter, Rejection, Reply};
use crate::{
config::{OAuth2ClientConfig, OAuth2Config},
errors::WrapError,
filters::{
client::{with_client_auth, ClientAuthentication},
database::with_connection,
},
storage::oauth2::{access_token::lookup_access_token, refresh_token::lookup_refresh_token},
tokens,
};
pub fn filter(
pool: &PgPool,
oauth2_config: &OAuth2Config,
) -> impl Filter<Extract = (impl Reply,), Error = Rejection> + Clone + Send + Sync + 'static {
warp::path!("oauth2" / "introspect")
.and(warp::post())
.and(with_connection(pool))
.and(with_client_auth(oauth2_config))
.and_then(introspect)
.recover(recover)
}
const INACTIVE: IntrospectionResponse = IntrospectionResponse {
active: false,
scope: None,
client_id: None,
username: None,
token_type: None,
exp: None,
iat: None,
nbf: None,
sub: None,
aud: None,
iss: None,
jti: None,
};
async fn introspect(
mut conn: PoolConnection<Postgres>,
auth: ClientAuthentication,
client: OAuth2ClientConfig,
params: IntrospectionRequest,
) -> Result<impl Reply, Rejection> {
// Token introspection is only allowed by confidential clients
if auth.public() {
warn!(?client, "Client tried to introspect");
// TODO: have a nice error here
return Ok(warp::reply::json(&INACTIVE));
}
let token = &params.token;
let token_type = tokens::check(token).wrap_error()?;
if let Some(hint) = params.token_type_hint {
if token_type != hint {
info!("Token type hint did not match");
return Ok(warp::reply::json(&INACTIVE));
}
}
let reply = match token_type {
tokens::TokenType::AccessToken => {
let token = lookup_access_token(&mut conn, token).await.wrap_error()?;
let exp = token.exp();
// Check it is active and did not expire
if !token.active || exp < Utc::now() {
info!(?token, "Access token expired");
return Ok(warp::reply::json(&INACTIVE));
}
IntrospectionResponse {
active: true,
scope: None, // TODO: parse back scopes
client_id: Some(token.client_id.clone()),
username: Some(token.username.clone()),
token_type: Some(TokenTypeHint::AccessToken),
exp: Some(exp),
iat: Some(token.created_at),
nbf: Some(token.created_at),
sub: None,
aud: None,
iss: None,
jti: None,
}
}
tokens::TokenType::RefreshToken => {
let token = lookup_refresh_token(&mut conn, token).await.wrap_error()?;
IntrospectionResponse {
active: true,
scope: None, // TODO: parse back scopes
client_id: Some(token.client_id),
username: None,
token_type: Some(TokenTypeHint::RefreshToken),
exp: None,
iat: None,
nbf: None,
sub: None,
aud: None,
iss: None,
jti: None,
}
}
};
Ok(warp::reply::json(&reply))
}
async fn recover(rejection: Rejection) -> Result<impl Reply, Rejection> {
if rejection.is_not_found() {
Err(rejection)
} else {
Ok(warp::reply::json(&INACTIVE))
}
}

View File

@ -0,0 +1,30 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use warp::{Filter, Rejection, Reply};
use crate::config::OAuth2Config;
pub(super) fn filter(
config: &OAuth2Config,
) -> impl Filter<Extract = (impl Reply,), Error = Rejection> + Clone + Send + Sync + 'static {
let jwks = config.keys.to_public_jwks();
let cors = warp::cors().allow_any_origin();
warp::path!("oauth2" / "keys.json")
.and(warp::get())
.map(move || warp::reply::json(&jwks))
.with(cors)
}

View File

@ -0,0 +1,53 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use sqlx::PgPool;
use warp::{Filter, Rejection, Reply};
use crate::{
config::{CookiesConfig, OAuth2Config},
templates::Templates,
};
mod authorization;
mod discovery;
mod introspection;
mod keys;
mod token;
mod userinfo;
use self::{
authorization::filter as authorization, discovery::filter as discovery,
introspection::filter as introspection, keys::filter as keys, token::filter as token,
userinfo::filter as userinfo,
};
pub fn filter(
pool: &PgPool,
templates: &Templates,
oauth2_config: &OAuth2Config,
cookies_config: &CookiesConfig,
) -> impl Filter<Extract = (impl Reply,), Error = Rejection> + Clone + Send + Sync + 'static {
discovery(oauth2_config)
.or(keys(oauth2_config))
.or(authorization(
pool,
templates,
oauth2_config,
cookies_config,
))
.or(userinfo(pool, oauth2_config))
.or(introspection(pool, oauth2_config))
.or(token(pool, oauth2_config))
}

View File

@ -0,0 +1,276 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use anyhow::Context;
use chrono::Duration;
use data_encoding::BASE64URL_NOPAD;
use headers::{CacheControl, Pragma};
use hyper::StatusCode;
use jwt_compact::{Claims, Header, TimeOptions};
use oauth2_types::{
errors::{InvalidGrant, OAuth2Error, OAuth2ErrorCode, UnauthorizedClient},
requests::{
AccessTokenRequest, AccessTokenResponse, AuthorizationCodeGrant, RefreshTokenGrant,
},
};
use rand::thread_rng;
use serde::Serialize;
use serde_with::skip_serializing_none;
use sha2::{Digest, Sha256};
use sqlx::{pool::PoolConnection, Acquire, PgPool, Postgres};
use url::Url;
use warp::{
reject::Reject,
reply::{json, with_status},
Filter, Rejection, Reply,
};
use crate::{
config::{KeySet, OAuth2ClientConfig, OAuth2Config},
errors::WrapError,
filters::{
client::{with_client_auth, ClientAuthentication},
database::with_connection,
headers::typed_header,
with_keys,
},
storage::oauth2::{
access_token::{add_access_token, revoke_access_token},
authorization_code::lookup_code,
refresh_token::{add_refresh_token, lookup_refresh_token, replace_refresh_token},
},
tokens,
};
#[skip_serializing_none]
#[derive(Serialize, Debug)]
struct CustomClaims {
#[serde(rename = "iss")]
issuer: Url,
#[serde(rename = "sub")]
subject: String,
#[serde(rename = "aud")]
audiences: Vec<String>,
nonce: Option<String>,
at_hash: String,
c_hash: String,
}
#[derive(Debug)]
struct Error {
json: serde_json::Value,
status: StatusCode,
}
impl Reject for Error {}
fn error<T, E>(e: E) -> Result<T, Rejection>
where
E: OAuth2ErrorCode + 'static,
{
let status = e.status();
let json = serde_json::to_value(e.into_response()).wrap_error()?;
Err(Error { json, status }.into())
}
pub fn filter(
pool: &PgPool,
oauth2_config: &OAuth2Config,
) -> impl Filter<Extract = (impl Reply,), Error = Rejection> + Clone + Send + Sync + 'static {
let issuer = oauth2_config.issuer.clone();
warp::path!("oauth2" / "token")
.and(warp::post())
.and(with_client_auth(oauth2_config))
.and(with_keys(oauth2_config))
.and(warp::any().map(move || issuer.clone()))
.and(with_connection(pool))
.and_then(token)
.recover(recover)
}
async fn recover(rejection: Rejection) -> Result<impl Reply, Rejection> {
if let Some(Error { json, status }) = rejection.find::<Error>() {
Ok(with_status(warp::reply::json(json), *status))
} else {
Err(rejection)
}
}
async fn token(
_auth: ClientAuthentication,
client: OAuth2ClientConfig,
req: AccessTokenRequest,
keys: KeySet,
issuer: Url,
mut conn: PoolConnection<Postgres>,
) -> Result<impl Reply, Rejection> {
let reply = match req {
AccessTokenRequest::AuthorizationCode(grant) => {
let reply = authorization_code_grant(&grant, &client, &keys, issuer, &mut conn).await?;
json(&reply)
}
AccessTokenRequest::RefreshToken(grant) => {
let reply = refresh_token_grant(&grant, &client, &mut conn).await?;
json(&reply)
}
_ => {
let reply = InvalidGrant.into_response();
json(&reply)
}
};
Ok(typed_header(
Pragma::no_cache(),
typed_header(CacheControl::new().with_no_store(), reply),
))
}
fn hash<H: Digest>(mut hasher: H, token: &str) -> anyhow::Result<String> {
hasher.update(token);
let hash = hasher.finalize();
// Left-most 128bit
let bits = hash
.get(..16)
.context("failed to get first 128 bits of hash")?;
Ok(BASE64URL_NOPAD.encode(bits))
}
async fn authorization_code_grant(
grant: &AuthorizationCodeGrant,
client: &OAuth2ClientConfig,
keys: &KeySet,
issuer: Url,
conn: &mut PoolConnection<Postgres>,
) -> Result<AccessTokenResponse, Rejection> {
let mut txn = conn.begin().await.wrap_error()?;
let code = lookup_code(&mut txn, &grant.code).await.wrap_error()?;
if client.client_id != code.client_id {
return error(UnauthorizedClient);
}
// TODO: verify PKCE
// TODO: make the code invalid
let ttl = Duration::minutes(5);
let (access_token, refresh_token) = {
let mut rng = thread_rng();
(
tokens::generate(&mut rng, tokens::TokenType::AccessToken),
tokens::generate(&mut rng, tokens::TokenType::RefreshToken),
)
};
let access_token = add_access_token(&mut txn, code.oauth2_session_id, &access_token, ttl)
.await
.wrap_error()?;
let refresh_token = add_refresh_token(
&mut txn,
code.oauth2_session_id,
access_token.id,
&refresh_token,
)
.await
.wrap_error()?;
// TODO: generate id_token only if the "openid" scope was asked
let header = Header::default();
let options = TimeOptions::default();
let claims = Claims::new(CustomClaims {
issuer,
// TODO: get that from the session
subject: "random-subject".to_string(),
audiences: vec![client.client_id.clone()],
nonce: code.nonce,
at_hash: hash(Sha256::new(), &access_token.token).wrap_error()?,
c_hash: hash(Sha256::new(), &grant.code).wrap_error()?,
})
.set_duration_and_issuance(&options, Duration::minutes(30));
let id_token = keys
.token(crate::config::Algorithm::Rs256, header, claims)
.await
.context("could not sign ID token")
.wrap_error()?;
// TODO: have the scopes back here
let params = AccessTokenResponse::new(access_token.token)
.with_expires_in(ttl)
.with_refresh_token(refresh_token.token)
.with_id_token(id_token);
txn.commit().await.wrap_error()?;
Ok(params)
}
async fn refresh_token_grant(
grant: &RefreshTokenGrant,
client: &OAuth2ClientConfig,
conn: &mut PoolConnection<Postgres>,
) -> Result<AccessTokenResponse, Rejection> {
let mut txn = conn.begin().await.wrap_error()?;
// TODO: scope handling
let refresh_token_lookup = lookup_refresh_token(&mut txn, &grant.refresh_token)
.await
.wrap_error()?;
if client.client_id != refresh_token_lookup.client_id {
// As per https://datatracker.ietf.org/doc/html/rfc6749#section-5.2
return error(InvalidGrant);
}
let ttl = Duration::minutes(5);
let (access_token, refresh_token) = {
let mut rng = thread_rng();
(
tokens::generate(&mut rng, tokens::TokenType::AccessToken),
tokens::generate(&mut rng, tokens::TokenType::RefreshToken),
)
};
let access_token = add_access_token(
&mut txn,
refresh_token_lookup.oauth2_session_id,
&access_token,
ttl,
)
.await
.wrap_error()?;
let refresh_token = add_refresh_token(
&mut txn,
refresh_token_lookup.oauth2_session_id,
access_token.id,
&refresh_token,
)
.await
.wrap_error()?;
replace_refresh_token(&mut txn, refresh_token_lookup.id, refresh_token.id)
.await
.wrap_error()?;
if let Some(access_token_id) = refresh_token_lookup.oauth2_access_token_id {
revoke_access_token(&mut txn, access_token_id)
.await
.wrap_error()?;
}
let params = AccessTokenResponse::new(access_token.token)
.with_expires_in(ttl)
.with_refresh_token(refresh_token.token);
txn.commit().await.wrap_error()?;
Ok(params)
}

View File

@ -0,0 +1,43 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use serde::Serialize;
use sqlx::PgPool;
use warp::{Filter, Rejection, Reply};
use crate::{
config::OAuth2Config, filters::authenticate::with_authentication,
storage::oauth2::access_token::OAuth2AccessTokenLookup,
};
#[derive(Serialize)]
struct UserInfo {
sub: String,
}
pub(super) fn filter(
pool: &PgPool,
_config: &OAuth2Config,
) -> impl Filter<Extract = (impl Reply,), Error = Rejection> + Clone + Send + Sync + 'static {
warp::path!("oauth2" / "userinfo")
.and(warp::get().or(warp::post()).unify())
.and(with_authentication(pool))
.and_then(userinfo)
}
async fn userinfo(token: OAuth2AccessTokenLookup) -> Result<impl Reply, Rejection> {
Ok(warp::reply::json(&UserInfo {
sub: token.username,
}))
}

View File

@ -0,0 +1,64 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use sqlx::PgPool;
use url::Url;
use warp::{reply::html, Filter, Rejection, Reply};
use crate::{
config::{CookiesConfig, CsrfConfig, OAuth2Config},
filters::{
cookies::{with_cookie_saver, EncryptedCookieSaver},
csrf::updated_csrf_token,
session::with_optional_session,
with_templates, CsrfToken,
},
storage::SessionInfo,
templates::{IndexContext, TemplateContext, Templates},
};
pub(super) fn filter(
pool: &PgPool,
templates: &Templates,
oauth2_config: &OAuth2Config,
csrf_config: &CsrfConfig,
cookies_config: &CookiesConfig,
) -> impl Filter<Extract = (impl Reply,), Error = Rejection> + Clone + Send + Sync + 'static {
let discovery_url = oauth2_config.discovery_url();
warp::path::end()
.and(warp::get())
.map(move || discovery_url.clone())
.and(with_templates(templates))
.and(with_cookie_saver(cookies_config))
.and(updated_csrf_token(cookies_config, csrf_config))
.and(with_optional_session(pool, cookies_config))
.and_then(get)
}
async fn get(
discovery_url: Url,
templates: Templates,
cookie_saver: EncryptedCookieSaver,
csrf_token: CsrfToken,
session: Option<SessionInfo>,
) -> Result<impl Reply, Rejection> {
let ctx = IndexContext::new(discovery_url)
.maybe_with_session(session)
.with_csrf(&csrf_token);
let content = templates.render_index(&ctx)?;
let reply = html(content);
let reply = cookie_saver.save_encrypted(&csrf_token, reply)?;
Ok(Box::new(reply))
}

View File

@ -0,0 +1,154 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::convert::TryFrom;
use hyper::http::uri::{Parts, PathAndQuery, Uri};
use serde::{Deserialize, Serialize};
use sqlx::{pool::PoolConnection, PgPool, Postgres};
use warp::{reply::html, Filter, Rejection, Reply};
use crate::{
config::{CookiesConfig, CsrfConfig},
errors::{WrapError, WrapFormError},
filters::{
cookies::{with_cookie_saver, EncryptedCookieSaver},
csrf::{protected_form, updated_csrf_token},
database::with_connection,
session::{with_optional_session, SessionCookie},
with_templates, CsrfToken,
},
storage::{login, SessionInfo},
templates::{LoginContext, LoginFormField, TemplateContext, Templates},
};
#[derive(Serialize, Deserialize)]
pub struct LoginRequest {
next: Option<String>,
}
impl LoginRequest {
pub fn new(next: Option<String>) -> Self {
Self { next }
}
pub fn build_uri(&self) -> anyhow::Result<Uri> {
let qs = serde_urlencoded::to_string(self)?;
let path_and_query = PathAndQuery::try_from(format!("/login?{}", qs))?;
let uri = Uri::from_parts({
let mut parts = Parts::default();
parts.path_and_query = Some(path_and_query);
parts
})?;
Ok(uri)
}
fn redirect(self) -> Result<impl Reply, Rejection> {
let uri: Uri = Uri::from_parts({
let mut parts = Parts::default();
parts.path_and_query = Some(
self.next
.map(warp::http::uri::PathAndQuery::try_from)
.transpose()
.wrap_error()?
.unwrap_or_else(|| PathAndQuery::from_static("/")),
);
parts
})
.wrap_error()?;
Ok(warp::redirect::see_other(uri))
}
}
#[derive(Deserialize)]
struct LoginForm {
username: String,
password: String,
}
pub(super) fn filter(
pool: &PgPool,
templates: &Templates,
csrf_config: &CsrfConfig,
cookies_config: &CookiesConfig,
) -> impl Filter<Extract = (impl Reply,), Error = Rejection> + Clone + Send + Sync + 'static {
let get = warp::get()
.and(with_templates(templates))
.and(with_cookie_saver(cookies_config))
.and(updated_csrf_token(cookies_config, csrf_config))
.and(warp::query())
.and(with_optional_session(pool, cookies_config))
.and_then(get);
let post = warp::post()
.and(with_templates(templates))
.and(with_connection(pool))
.and(with_cookie_saver(cookies_config))
.and(updated_csrf_token(cookies_config, csrf_config))
.and(protected_form(cookies_config))
.and(warp::query())
.and_then(post);
warp::path!("login").and(get.or(post))
}
async fn get(
templates: Templates,
cookie_saver: EncryptedCookieSaver,
csrf_token: CsrfToken,
query: LoginRequest,
maybe_session: Option<SessionInfo>,
) -> Result<Box<dyn Reply>, Rejection> {
if maybe_session.is_some() {
Ok(Box::new(query.redirect()?))
} else {
let ctx = LoginContext::default().with_csrf(&csrf_token);
let content = templates.render_login(&ctx)?;
let reply = html(content);
let reply = cookie_saver.save_encrypted(&csrf_token, reply)?;
Ok(Box::new(reply))
}
}
async fn post(
templates: Templates,
mut conn: PoolConnection<Postgres>,
cookie_saver: EncryptedCookieSaver,
csrf_token: CsrfToken,
form: LoginForm,
query: LoginRequest,
) -> Result<Box<dyn Reply>, Rejection> {
use crate::storage::user::LoginError;
// TODO: recover
match login(&mut conn, &form.username, form.password).await {
Ok(session_info) => {
let session_cookie = SessionCookie::from_session_info(&session_info);
let reply = query.redirect()?;
let reply = cookie_saver.save_encrypted(&session_cookie, reply)?;
Ok(Box::new(reply))
}
Err(e) => {
let errored_form = match e {
LoginError::NotFound { .. } => e.on_field(LoginFormField::Username),
LoginError::Authentication { .. } => e.on_field(LoginFormField::Password),
LoginError::Other(_) => e.on_form(),
};
let ctx = LoginContext::with_form_error(errored_form).with_csrf(&csrf_token);
let content = templates.render_login(&ctx)?;
let reply = html(content);
let reply = cookie_saver.save_encrypted(&csrf_token, reply)?;
Ok(Box::new(reply))
}
}
}

View File

@ -0,0 +1,44 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use sqlx::{pool::PoolConnection, PgPool, Postgres};
use warp::{hyper::Uri, Filter, Rejection, Reply};
use crate::{
config::CookiesConfig,
errors::WrapError,
filters::{csrf::protected_form, database::with_connection, session::with_session},
storage::SessionInfo,
};
pub(super) fn filter(
pool: &PgPool,
cookies_config: &CookiesConfig,
) -> impl Filter<Extract = (impl Reply,), Error = Rejection> + Clone + Send + Sync + 'static {
warp::path!("logout")
.and(warp::post())
.and(with_session(pool, cookies_config))
.and(with_connection(pool))
.and(protected_form(cookies_config))
.and_then(post)
}
async fn post(
session: SessionInfo,
mut conn: PoolConnection<Postgres>,
_form: (),
) -> Result<impl Reply, Rejection> {
session.end(&mut conn).await.wrap_error()?;
Ok(warp::redirect(Uri::from_static("/login")))
}

View File

@ -0,0 +1,48 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use sqlx::PgPool;
use warp::{Filter, Rejection, Reply};
use crate::{
config::{CookiesConfig, CsrfConfig, OAuth2Config},
templates::Templates,
};
mod index;
mod login;
mod logout;
mod reauth;
mod register;
pub use self::login::LoginRequest;
use self::{
index::filter as index, login::filter as login, logout::filter as logout,
reauth::filter as reauth, register::filter as register,
};
pub(super) fn filter(
pool: &PgPool,
templates: &Templates,
oauth2_config: &OAuth2Config,
csrf_config: &CsrfConfig,
cookies_config: &CookiesConfig,
) -> impl Filter<Extract = (impl Reply,), Error = Rejection> + Clone + Send + Sync + 'static {
index(pool, templates, oauth2_config, csrf_config, cookies_config)
.or(login(pool, templates, csrf_config, cookies_config))
.or(register(pool, templates, csrf_config, cookies_config))
.or(logout(pool, cookies_config))
.or(reauth(pool, templates, csrf_config, cookies_config))
.boxed()
}

View File

@ -0,0 +1,85 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use serde::Deserialize;
use sqlx::{pool::PoolConnection, PgPool, Postgres};
use warp::{hyper::Uri, reply::html, Filter, Rejection, Reply};
use crate::{
config::{CookiesConfig, CsrfConfig},
errors::WrapError,
filters::{
cookies::{with_cookie_saver, EncryptedCookieSaver},
csrf::{protected_form, updated_csrf_token},
database::with_connection,
session::with_session,
with_templates, CsrfToken,
},
storage::SessionInfo,
templates::{TemplateContext, Templates},
};
#[derive(Deserialize, Debug)]
struct ReauthForm {
password: String,
}
pub(super) fn filter(
pool: &PgPool,
templates: &Templates,
csrf_config: &CsrfConfig,
cookies_config: &CookiesConfig,
) -> impl Filter<Extract = (impl Reply,), Error = Rejection> + Clone + Send + Sync + 'static {
let get = warp::get()
.and(with_templates(templates))
.and(with_cookie_saver(cookies_config))
.and(updated_csrf_token(cookies_config, csrf_config))
.and(with_session(pool, cookies_config))
.and_then(get);
let post = warp::post()
.and(with_session(pool, cookies_config))
.and(with_connection(pool))
.and(protected_form(cookies_config))
.and_then(post);
warp::path!("reauth").and(get.or(post))
}
async fn get(
templates: Templates,
cookie_saver: EncryptedCookieSaver,
csrf_token: CsrfToken,
session: SessionInfo,
) -> Result<impl Reply, Rejection> {
let ctx = ().with_session(session).with_csrf(&csrf_token);
let content = templates.render_reauth(&ctx)?;
let reply = html(content);
let reply = cookie_saver.save_encrypted(&csrf_token, reply)?;
Ok(reply)
}
async fn post(
session: SessionInfo,
mut conn: PoolConnection<Postgres>,
form: ReauthForm,
) -> Result<impl Reply, Rejection> {
let _session = session
.reauth(&mut conn, form.password)
.await
.wrap_error()?;
Ok(warp::redirect(Uri::from_static("/")))
}

View File

@ -0,0 +1,147 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::convert::TryFrom;
use argon2::Argon2;
use hyper::http::uri::{Parts, PathAndQuery, Uri};
use serde::{Deserialize, Serialize};
use sqlx::{pool::PoolConnection, PgPool, Postgres};
use warp::{reply::html, Filter, Rejection, Reply};
use crate::{
config::{CookiesConfig, CsrfConfig},
errors::WrapError,
filters::{
cookies::{with_cookie_saver, EncryptedCookieSaver},
csrf::{protected_form, updated_csrf_token},
database::with_connection,
session::{with_optional_session, SessionCookie},
with_templates, CsrfToken,
},
storage::{register_user, user::start_session, SessionInfo},
templates::{TemplateContext, Templates},
};
#[derive(Serialize, Deserialize)]
pub struct RegisterRequest {
next: Option<String>,
}
impl RegisterRequest {
#[allow(dead_code)]
pub fn new(next: Option<String>) -> Self {
Self { next }
}
#[allow(dead_code)]
pub fn build_uri(&self) -> anyhow::Result<Uri> {
let qs = serde_urlencoded::to_string(self)?;
let path_and_query = PathAndQuery::try_from(format!("/register?{}", qs))?;
let uri = Uri::from_parts({
let mut parts = Parts::default();
parts.path_and_query = Some(path_and_query);
parts
})?;
Ok(uri)
}
fn redirect(self) -> Result<impl Reply, Rejection> {
let uri: Uri = Uri::from_parts({
let mut parts = Parts::default();
parts.path_and_query = Some(
self.next
.map(warp::http::uri::PathAndQuery::try_from)
.transpose()
.wrap_error()?
.unwrap_or_else(|| PathAndQuery::from_static("/")),
);
parts
})
.wrap_error()?;
Ok(warp::redirect::see_other(uri))
}
}
#[derive(Deserialize)]
struct RegisterForm {
username: String,
password: String,
password_confirm: String,
}
pub(super) fn filter(
pool: &PgPool,
templates: &Templates,
csrf_config: &CsrfConfig,
cookies_config: &CookiesConfig,
) -> impl Filter<Extract = (impl Reply,), Error = Rejection> + Clone + Send + Sync + 'static {
let get = warp::get()
.and(with_templates(templates))
.and(with_cookie_saver(cookies_config))
.and(updated_csrf_token(cookies_config, csrf_config))
.and(warp::query())
.and(with_optional_session(pool, cookies_config))
.and_then(get);
let post = warp::post()
.and(with_connection(pool))
.and(with_cookie_saver(cookies_config))
.and(protected_form(cookies_config))
.and(warp::query())
.and_then(post);
warp::path!("register").and(get.or(post))
}
async fn get(
templates: Templates,
cookie_saver: EncryptedCookieSaver,
csrf_token: CsrfToken,
query: RegisterRequest,
maybe_session: Option<SessionInfo>,
) -> Result<Box<dyn Reply>, Rejection> {
if maybe_session.is_some() {
Ok(Box::new(query.redirect()?))
} else {
let ctx = ().with_csrf(&csrf_token);
let content = templates.render_register(&ctx)?;
let reply = html(content);
let reply = cookie_saver.save_encrypted(&csrf_token, reply)?;
Ok(Box::new(reply))
}
}
async fn post(
mut conn: PoolConnection<Postgres>,
cookie_saver: EncryptedCookieSaver,
form: RegisterForm,
query: RegisterRequest,
) -> Result<impl Reply, Rejection> {
if form.password != form.password_confirm {
return Err(anyhow::anyhow!("password mismatch")).wrap_error();
}
let pfh = Argon2::default();
let user = register_user(&mut conn, pfh, &form.username, &form.password)
.await
.wrap_error()?;
let session_info = start_session(&mut conn, user).await.wrap_error()?;
let session_cookie = SessionCookie::from_session_info(&session_info);
let reply = query.redirect()?;
let reply = cookie_saver.save_encrypted(&session_cookie, reply)?;
Ok(reply)
}

31
crates/core/src/lib.rs Normal file
View File

@ -0,0 +1,31 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#![forbid(unsafe_code)]
#![deny(clippy::all)]
#![warn(clippy::pedantic)]
#![allow(clippy::module_name_repetitions)]
#![allow(clippy::missing_panics_doc)]
#![allow(clippy::missing_errors_doc)]
#![allow(clippy::implicit_hasher)]
pub(crate) use mas_config as config;
pub mod errors;
pub mod filters;
pub mod handlers;
pub mod storage;
pub mod tasks;
pub mod templates;
pub mod tokens;

View File

@ -0,0 +1,24 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#![allow(clippy::used_underscore_binding)] // This is needed by sqlx macros
use sqlx::migrate::Migrator;
pub mod oauth2;
pub mod user;
pub use self::user::{login, lookup_active_session, register_user, SessionInfo, User};
pub static MIGRATOR: Migrator = sqlx::migrate!();

View File

@ -0,0 +1,141 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::convert::TryFrom;
use anyhow::Context;
use chrono::{DateTime, Duration, Utc};
use serde::Serialize;
use sqlx::{Executor, FromRow, Postgres};
#[derive(FromRow, Serialize)]
pub struct OAuth2AccessToken {
pub id: i64,
pub oauth2_session_id: i64,
pub token: String,
expires_after: i32,
created_at: DateTime<Utc>,
}
pub async fn add_access_token(
executor: impl Executor<'_, Database = Postgres>,
oauth2_session_id: i64,
token: &str,
expires_after: Duration,
) -> anyhow::Result<OAuth2AccessToken> {
// Checked convertion of duration to i32, maxing at i32::MAX
let expires_after = i32::try_from(expires_after.num_seconds()).unwrap_or(i32::MAX);
sqlx::query_as!(
OAuth2AccessToken,
r#"
INSERT INTO oauth2_access_tokens
(oauth2_session_id, token, expires_after)
VALUES
($1, $2, $3)
RETURNING
id, oauth2_session_id, token, expires_after, created_at
"#,
oauth2_session_id,
token,
expires_after,
)
.fetch_one(executor)
.await
.context("could not insert oauth2 access token")
}
#[derive(Debug)]
pub struct OAuth2AccessTokenLookup {
pub active: bool,
pub username: String,
pub client_id: String,
pub scope: String,
pub created_at: DateTime<Utc>,
expires_after: i32,
}
impl OAuth2AccessTokenLookup {
#[must_use] pub fn exp(&self) -> DateTime<Utc> {
self.created_at + Duration::seconds(i64::from(self.expires_after))
}
}
pub async fn lookup_access_token(
executor: impl Executor<'_, Database = Postgres>,
token: &str,
) -> anyhow::Result<OAuth2AccessTokenLookup> {
sqlx::query_as!(
OAuth2AccessTokenLookup,
r#"
SELECT
u.username AS "username!",
us.active AS "active!",
os.client_id AS "client_id!",
os.scope AS "scope!",
at.created_at AS "created_at!",
at.expires_after AS "expires_after!"
FROM oauth2_access_tokens at
INNER JOIN oauth2_sessions os
ON os.id = at.oauth2_session_id
INNER JOIN user_sessions us
ON us.id = os.user_session_id
INNER JOIN users u
ON u.id = us.user_id
WHERE at.token = $1
"#,
token,
)
.fetch_one(executor)
.await
.context("could not introspect oauth2 access token")
}
pub async fn revoke_access_token(
executor: impl Executor<'_, Database = Postgres>,
id: i64,
) -> anyhow::Result<()> {
let res = sqlx::query!(
r#"
DELETE FROM oauth2_access_tokens
WHERE id = $1
"#,
id,
)
.execute(executor)
.await
.context("could not revoke access tokens")?;
if res.rows_affected() == 1 {
Ok(())
} else {
Err(anyhow::anyhow!("no row were affected when revoking token"))
}
}
pub async fn cleanup_expired(
executor: impl Executor<'_, Database = Postgres>,
) -> anyhow::Result<u64> {
let res = sqlx::query!(
r#"
DELETE FROM oauth2_access_tokens
WHERE created_at + (expires_after * INTERVAL '1 second') + INTERVAL '15 minutes' < now()
"#,
)
.execute(executor)
.await
.context("could not cleanup expired access tokens")?;
Ok(res.rows_affected())
}

View File

@ -0,0 +1,92 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use anyhow::Context;
use oauth2_types::pkce;
use serde::Serialize;
use sqlx::{Executor, FromRow, Postgres};
#[derive(FromRow, Serialize)]
pub struct OAuth2Code {
id: i64,
oauth2_session_id: i64,
pub code: String,
code_challenge: Option<String>,
code_challenge_method: Option<i16>,
}
pub async fn add_code(
executor: impl Executor<'_, Database = Postgres>,
oauth2_session_id: i64,
code: &str,
code_challenge: &Option<pkce::Request>,
) -> anyhow::Result<OAuth2Code> {
let code_challenge_method = code_challenge
.as_ref()
.map(|c| c.code_challenge_method as i16);
let code_challenge = code_challenge.as_ref().map(|c| &c.code_challenge);
sqlx::query_as!(
OAuth2Code,
r#"
INSERT INTO oauth2_codes
(oauth2_session_id, code, code_challenge_method, code_challenge)
VALUES
($1, $2, $3, $4)
RETURNING
id, oauth2_session_id, code, code_challenge_method, code_challenge
"#,
oauth2_session_id,
code,
code_challenge_method,
code_challenge,
)
.fetch_one(executor)
.await
.context("could not insert oauth2 authorization code")
}
pub struct OAuth2CodeLookup {
pub id: i64,
pub oauth2_session_id: i64,
pub client_id: String,
pub redirect_uri: String,
pub scope: String,
pub nonce: Option<String>,
}
pub async fn lookup_code(
executor: impl Executor<'_, Database = Postgres>,
code: &str,
) -> anyhow::Result<OAuth2CodeLookup> {
sqlx::query_as!(
OAuth2CodeLookup,
r#"
SELECT
oc.id,
os.id AS "oauth2_session_id!",
os.client_id AS "client_id!",
os.redirect_uri,
os.scope AS "scope!",
os.nonce
FROM oauth2_codes oc
INNER JOIN oauth2_sessions os
ON os.id = oc.oauth2_session_id
WHERE oc.code = $1
"#,
code,
)
.fetch_one(executor)
.await
.context("could not lookup oauth2 code")
}

View File

@ -0,0 +1,18 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
pub mod access_token;
pub mod authorization_code;
pub mod refresh_token;
pub mod session;

View File

@ -0,0 +1,114 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use anyhow::Context;
use chrono::{DateTime, Utc};
use sqlx::{Executor, Postgres};
#[derive(Debug)]
pub struct OAuth2RefreshToken {
pub id: i64,
oauth2_session_id: i64,
oauth2_access_token_id: Option<i64>,
pub token: String,
next_token_id: Option<i64>,
created_at: DateTime<Utc>,
updated_at: DateTime<Utc>,
}
pub async fn add_refresh_token(
executor: impl Executor<'_, Database = Postgres>,
oauth2_session_id: i64,
oauth2_access_token_id: i64,
token: &str,
) -> anyhow::Result<OAuth2RefreshToken> {
sqlx::query_as!(
OAuth2RefreshToken,
r#"
INSERT INTO oauth2_refresh_tokens
(oauth2_session_id, oauth2_access_token_id, token)
VALUES
($1, $2, $3)
RETURNING
id, oauth2_session_id, oauth2_access_token_id, token, next_token_id,
created_at, updated_at
"#,
oauth2_session_id,
oauth2_access_token_id,
token,
)
.fetch_one(executor)
.await
.context("could not insert oauth2 refresh token")
}
pub struct OAuth2RefreshTokenLookup {
pub id: i64,
pub oauth2_session_id: i64,
pub oauth2_access_token_id: Option<i64>,
pub client_id: String,
pub scope: String,
}
pub async fn lookup_refresh_token(
executor: impl Executor<'_, Database = Postgres>,
token: &str,
) -> anyhow::Result<OAuth2RefreshTokenLookup> {
sqlx::query_as!(
OAuth2RefreshTokenLookup,
r#"
SELECT
rt.id,
rt.oauth2_session_id,
rt.oauth2_access_token_id,
os.client_id AS "client_id!",
os.scope AS "scope!"
FROM oauth2_refresh_tokens rt
INNER JOIN oauth2_sessions os
ON os.id = rt.oauth2_session_id
WHERE rt.token = $1 AND rt.next_token_id IS NULL
"#,
token,
)
.fetch_one(executor)
.await
.context("failed to fetch oauth2 refresh token")
}
pub async fn replace_refresh_token(
executor: impl Executor<'_, Database = Postgres>,
refresh_token_id: i64,
next_refresh_token_id: i64,
) -> anyhow::Result<()> {
let res = sqlx::query!(
r#"
UPDATE oauth2_refresh_tokens
SET next_token_id = $2
WHERE id = $1
"#,
refresh_token_id,
next_refresh_token_id
)
.execute(executor)
.await
.context("failed to update oauth2 refresh token")?;
if res.rows_affected() == 1 {
Ok(())
} else {
Err(anyhow::anyhow!(
"no row were affected when updating refresh token"
))
}
}

View File

@ -0,0 +1,214 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::{collections::HashSet, convert::TryFrom, str::FromStr, string::ToString};
use anyhow::Context;
use chrono::{DateTime, Duration, Utc};
use itertools::Itertools;
use oauth2_types::{
pkce,
requests::{ResponseMode, ResponseType},
};
use serde::Serialize;
use sqlx::{Executor, FromRow, Postgres};
use url::Url;
use super::{
super::{user::lookup_session, SessionInfo},
authorization_code::{add_code, OAuth2Code},
};
#[derive(FromRow, Serialize)]
pub struct OAuth2Session {
pub id: i64,
user_session_id: Option<i64>,
pub client_id: String,
redirect_uri: String,
scope: String,
pub state: Option<String>,
nonce: Option<String>,
max_age: Option<i32>,
response_type: String,
response_mode: String,
created_at: DateTime<Utc>,
updated_at: DateTime<Utc>,
}
impl OAuth2Session {
pub async fn add_code<'e>(
&self,
executor: impl Executor<'e, Database = Postgres>,
code: &str,
code_challenge: &Option<pkce::Request>,
) -> anyhow::Result<OAuth2Code> {
add_code(executor, self.id, code, code_challenge).await
}
pub async fn fetch_session(
&self,
executor: impl Executor<'_, Database = Postgres>,
) -> anyhow::Result<Option<SessionInfo>> {
match self.user_session_id {
Some(id) => {
let info = lookup_session(executor, id).await?;
Ok(Some(info))
}
None => Ok(None),
}
}
pub async fn fetch_code(
&self,
executor: impl Executor<'_, Database = Postgres>,
) -> anyhow::Result<String> {
get_code_for_session(executor, self.id).await
}
pub async fn match_or_set_session(
&mut self,
executor: impl Executor<'_, Database = Postgres>,
session: SessionInfo,
) -> anyhow::Result<SessionInfo> {
match self.user_session_id {
Some(id) if id == session.key() => Ok(session),
Some(id) => Err(anyhow::anyhow!(
"session mismatch, expected {}, got {}",
id,
session.key()
)),
None => {
sqlx::query!(
"UPDATE oauth2_sessions SET user_session_id = $1 WHERE id = $2",
session.key(),
self.id,
)
.execute(executor)
.await
.context("could not update oauth2 session")?;
Ok(session)
}
}
}
#[must_use] pub fn max_auth_time(&self) -> Option<DateTime<Utc>> {
self.max_age
.map(|d| Duration::seconds(i64::from(d)))
.map(|d| self.created_at - d)
}
pub fn response_type(&self) -> anyhow::Result<HashSet<ResponseType>> {
self.response_type
.split(' ')
.map(|s| {
ResponseType::from_str(s).with_context(|| format!("invalid response type {}", s))
})
.collect()
}
pub fn response_mode(&self) -> anyhow::Result<ResponseMode> {
self.response_mode.parse().context("invalid response mode")
}
pub fn redirect_uri(&self) -> anyhow::Result<Url> {
self.redirect_uri.parse().context("invalid redirect uri")
}
}
#[allow(clippy::too_many_arguments)]
pub async fn start_session(
executor: impl Executor<'_, Database = Postgres>,
optional_session_id: Option<i64>,
client_id: &str,
redirect_uri: &Url,
scope: &str,
state: Option<&str>,
nonce: Option<&str>,
max_age: Option<Duration>,
response_type: &HashSet<ResponseType>,
response_mode: ResponseMode,
) -> anyhow::Result<OAuth2Session> {
// Checked convertion of duration to i32, maxing at i32::MAX
let max_age = max_age.map(|d| i32::try_from(d.num_seconds()).unwrap_or(i32::MAX));
let response_mode = response_mode.to_string();
let redirect_uri = redirect_uri.to_string();
let response_type: String = {
let it = response_type.iter().map(ToString::to_string);
Itertools::intersperse(it, " ".to_string()).collect()
};
sqlx::query_as!(
OAuth2Session,
r#"
INSERT INTO oauth2_sessions
(user_session_id, client_id, redirect_uri, scope, state, nonce, max_age,
response_type, response_mode)
VALUES
($1, $2, $3, $4, $5, $6, $7, $8, $9)
RETURNING
id, user_session_id, client_id, redirect_uri, scope, state, nonce, max_age,
response_type, response_mode, created_at, updated_at
"#,
optional_session_id,
client_id,
redirect_uri,
scope,
state,
nonce,
max_age,
response_type,
response_mode,
)
.fetch_one(executor)
.await
.context("could not insert oauth2 session")
}
pub async fn get_session_by_id(
executor: impl Executor<'_, Database = Postgres>,
oauth2_session_id: i64,
) -> anyhow::Result<OAuth2Session> {
sqlx::query_as!(
OAuth2Session,
r#"
SELECT
id, user_session_id, client_id, redirect_uri, scope, state, nonce,
max_age, response_type, response_mode, created_at, updated_at
FROM oauth2_sessions
WHERE id = $1
"#,
oauth2_session_id
)
.fetch_one(executor)
.await
.context("could not fetch oauth2 session")
}
pub async fn get_code_for_session(
executor: impl Executor<'_, Database = Postgres>,
oauth2_session_id: i64,
) -> anyhow::Result<String> {
sqlx::query_scalar!(
r#"
SELECT code
FROM oauth2_codes
WHERE oauth2_session_id = $1
"#,
oauth2_session_id
)
.fetch_one(executor)
.await
.context("could not fetch oauth2 code")
}

View File

@ -0,0 +1,370 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::borrow::BorrowMut;
use anyhow::Context;
use argon2::Argon2;
use chrono::{DateTime, Utc};
use password_hash::{PasswordHash, PasswordHasher, SaltString};
use rand::rngs::OsRng;
use serde::Serialize;
use sqlx::{Acquire, Executor, FromRow, Postgres, Transaction};
use thiserror::Error;
use tokio::task;
use tracing::{info_span, Instrument};
use crate::errors::HtmlError;
#[derive(Serialize, Debug, Clone, FromRow)]
pub struct User {
pub id: i64,
pub username: String,
}
#[derive(Serialize, Debug, Clone, FromRow)]
pub struct SessionInfo {
id: i64,
user_id: i64,
username: String,
pub active: bool,
created_at: DateTime<Utc>,
pub last_authd_at: Option<DateTime<Utc>>,
}
impl SessionInfo {
#[must_use] pub fn key(&self) -> i64 {
self.id
}
pub async fn reauth(
mut self,
conn: impl Acquire<'_, Database = Postgres>,
password: String,
) -> anyhow::Result<Self> {
let mut txn = conn.begin().await?;
self.last_authd_at = Some(authenticate_session(&mut txn, self.id, password).await?);
txn.commit().await?;
Ok(self)
}
pub async fn end(
mut self,
executor: impl Executor<'_, Database = Postgres>,
) -> anyhow::Result<Self> {
end_session(executor, self.id).await?;
self.active = false;
Ok(self)
}
}
#[derive(Debug, Error)]
pub enum LoginError {
#[error("could not find user {username:?}")]
NotFound {
username: String,
#[source]
source: sqlx::Error,
},
#[error("authentication failed for {username:?}")]
Authentication {
username: String,
#[source]
source: AuthenticationError,
},
#[error("failed to login")]
Other(#[from] anyhow::Error),
}
impl HtmlError for LoginError {
fn html_display(&self) -> String {
match self {
LoginError::NotFound { .. } => "Could not find user".to_string(),
LoginError::Authentication { .. } => "Failed to authenticate user".to_string(),
LoginError::Other(e) => format!("Internal error: <pre>{}</pre>", e),
}
}
}
pub async fn login(
conn: impl Acquire<'_, Database = Postgres>,
username: &str,
password: String,
) -> Result<SessionInfo, LoginError> {
let mut txn = conn.begin().await.context("could not start transaction")?;
let user = lookup_user_by_username(&mut txn, username)
.await
.map_err(|source| {
if matches!(source, sqlx::Error::RowNotFound) {
LoginError::NotFound {
username: username.to_string(),
source,
}
} else {
LoginError::Other(source.into())
}
})?;
let mut session = start_session(&mut txn, user).await?;
session.last_authd_at = Some(
authenticate_session(&mut txn, session.id, password)
.await
.map_err(|source| {
if matches!(source, AuthenticationError::Password { .. }) {
LoginError::Authentication {
username: username.to_string(),
source,
}
} else {
LoginError::Other(source.into())
}
})?,
);
txn.commit().await.context("could not commit transaction")?;
Ok(session)
}
pub async fn lookup_active_session(
executor: impl Executor<'_, Database = Postgres>,
id: i64,
) -> anyhow::Result<SessionInfo> {
sqlx::query_as!(
SessionInfo,
r#"
SELECT
s.id,
u.id as user_id,
u.username,
s.active,
s.created_at,
a.created_at as "last_authd_at?"
FROM user_sessions s
INNER JOIN users u
ON s.user_id = u.id
LEFT JOIN user_session_authentications a
ON a.session_id = s.id
WHERE s.id = $1 AND s.active
ORDER BY a.created_at DESC
LIMIT 1
"#,
id,
)
.fetch_one(executor)
.await
.context("could not fetch session")
}
pub async fn lookup_session(
executor: impl Executor<'_, Database = Postgres>,
id: i64,
) -> anyhow::Result<SessionInfo> {
sqlx::query_as!(
SessionInfo,
r#"
SELECT
s.id,
u.id as user_id,
u.username,
s.active,
s.created_at,
a.created_at as "last_authd_at?"
FROM user_sessions s
INNER JOIN users u
ON s.user_id = u.id
LEFT JOIN user_session_authentications a
ON a.session_id = s.id
WHERE s.id = $1
ORDER BY a.created_at DESC
LIMIT 1
"#,
id,
)
.fetch_one(executor)
.await
.context("could not fetch session")
}
pub async fn start_session(
executor: impl Executor<'_, Database = Postgres>,
user: User,
) -> anyhow::Result<SessionInfo> {
let (id, created_at): (i64, DateTime<Utc>) = sqlx::query_as(
r#"
INSERT INTO user_sessions (user_id)
VALUES ($1)
RETURNING id, created_at
"#,
)
.bind(user.id)
.fetch_one(executor)
.await
.context("could not create session")?;
Ok(SessionInfo {
id,
user_id: user.id,
username: user.username,
active: true,
created_at,
last_authd_at: None,
})
}
#[derive(Debug, Error)]
pub enum AuthenticationError {
#[error("could not verify password")]
Password(#[from] password_hash::Error),
#[error("could not fetch user password hash")]
Fetch(sqlx::Error),
#[error("could not save session auth")]
Save(sqlx::Error),
#[error("runtime error")]
Internal(#[from] tokio::task::JoinError),
}
pub async fn authenticate_session(
txn: &mut Transaction<'_, Postgres>,
session_id: i64,
password: String,
) -> Result<DateTime<Utc>, AuthenticationError> {
// First, fetch the hashed password from the user associated with that session
let hashed_password: String = sqlx::query_scalar!(
r#"
SELECT u.hashed_password
FROM user_sessions s
INNER JOIN users u
ON u.id = s.user_id
WHERE s.id = $1
"#,
session_id,
)
.fetch_one(txn.borrow_mut())
.await
.map_err(AuthenticationError::Fetch)?;
// TODO: pass verifiers list as parameter
// Verify the password in a blocking thread to avoid blocking the async executor
task::spawn_blocking(move || {
let context = Argon2::default();
let hasher = PasswordHash::new(&hashed_password).map_err(AuthenticationError::Password)?;
hasher
.verify_password(&[&context], &password)
.map_err(AuthenticationError::Password)
})
.await??;
// That went well, let's insert the auth info
let created_at: DateTime<Utc> = sqlx::query_scalar!(
r#"
INSERT INTO user_session_authentications (session_id)
VALUES ($1)
RETURNING created_at
"#,
session_id,
)
.fetch_one(txn.borrow_mut())
.await
.map_err(AuthenticationError::Save)?;
Ok(created_at)
}
pub async fn register_user(
executor: impl Executor<'_, Database = Postgres>,
phf: impl PasswordHasher,
username: &str,
password: &str,
) -> anyhow::Result<User> {
let salt = SaltString::generate(&mut OsRng);
let hashed_password = PasswordHash::generate(phf, password, salt.as_str())?;
let id: i64 = sqlx::query_scalar!(
r#"
INSERT INTO users (username, hashed_password)
VALUES ($1, $2)
RETURNING id
"#,
username,
hashed_password.to_string(),
)
.fetch_one(executor)
.instrument(info_span!("Register user"))
.await
.context("could not insert user")?;
Ok(User {
id,
username: username.to_string(),
})
}
pub async fn end_session(
executor: impl Executor<'_, Database = Postgres>,
id: i64,
) -> anyhow::Result<()> {
let res = sqlx::query!("UPDATE user_sessions SET active = FALSE WHERE id = $1", id)
.execute(executor)
.instrument(info_span!("End session"))
.await
.context("could not end session")?;
match res.rows_affected() {
1 => Ok(()),
0 => Err(anyhow::anyhow!("no row affected")),
_ => Err(anyhow::anyhow!("too many row affected")),
}
}
#[allow(dead_code)]
pub async fn lookup_user_by_id(
executor: impl Executor<'_, Database = Postgres>,
id: i64,
) -> anyhow::Result<User> {
sqlx::query_as!(
User,
r#"
SELECT id, username
FROM users
WHERE id = $1
"#,
id
)
.fetch_one(executor)
.instrument(info_span!("Fetch user"))
.await
.context("could not fetch user")
}
pub async fn lookup_user_by_username(
executor: impl Executor<'_, Database = Postgres>,
username: &str,
) -> Result<User, sqlx::Error> {
sqlx::query_as!(
User,
r#"
SELECT id, username
FROM users
WHERE username = $1
"#,
username,
)
.fetch_one(executor)
.instrument(info_span!("Fetch user"))
.await
}

View File

@ -0,0 +1,43 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use sqlx::{Pool, Postgres};
use tracing::{debug, error, info};
use super::Task;
#[derive(Clone)]
struct CleanupExpired(Pool<Postgres>);
#[async_trait::async_trait]
impl Task for CleanupExpired {
async fn run(&self) {
let res = crate::storage::oauth2::access_token::cleanup_expired(&self.0).await;
match res {
Ok(0) => {
debug!("no token to clean up");
}
Ok(count) => {
info!(count, "cleaned up expired tokens");
}
Err(error) => {
error!(?error, "failed to cleanup expired tokens");
}
}
}
}
#[must_use] pub fn cleanup_expired(pool: &Pool<Postgres>) -> impl Task + Clone {
CleanupExpired(pool.clone())
}

View File

@ -0,0 +1,102 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::{collections::VecDeque, sync::Arc, time::Duration};
use futures_util::StreamExt;
use tokio::{
sync::{Mutex, Notify},
time::Interval,
};
use tokio_stream::wrappers::IntervalStream;
mod database;
pub use self::database::cleanup_expired;
#[async_trait::async_trait]
pub trait Task: Send + Sync + 'static {
async fn run(&self);
}
#[derive(Default)]
struct TaskQueueInner {
pending_tasks: Mutex<VecDeque<Box<dyn Task>>>,
notifier: Notify,
}
impl TaskQueueInner {
async fn recuring<T: Task + Clone>(&self, interval: Interval, task: T) {
let mut stream = IntervalStream::new(interval);
while (stream.next()).await.is_some() {
self.schedule(task.clone()).await;
}
}
async fn schedule<T: Task>(&self, task: T) {
let task = Box::new(task);
self.pending_tasks.lock().await.push_back(task);
self.notifier.notify_one();
}
async fn tick(&self) {
loop {
let pending = {
let mut tasks = self.pending_tasks.lock().await;
tasks.pop_front()
};
if let Some(pending) = pending {
pending.run().await;
} else {
break;
}
}
}
async fn run_forever(&self) {
loop {
self.notifier.notified().await;
self.tick().await;
}
}
}
#[derive(Default)]
pub struct TaskQueue {
inner: Arc<TaskQueueInner>,
}
impl TaskQueue {
pub fn start(&self) {
let queue = self.inner.clone();
tokio::task::spawn(async move {
queue.run_forever().await;
});
}
#[allow(dead_code)]
async fn schedule<T: Task>(&self, task: T) {
let queue = self.inner.clone();
queue.schedule(task).await;
}
pub fn recuring(&self, every: Duration, task: impl Task + Clone) {
let queue = self.inner.clone();
tokio::task::spawn(async move {
queue.recuring(tokio::time::interval(every), task).await;
});
}
}

View File

@ -0,0 +1,299 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::{collections::HashSet, string::ToString, sync::Arc};
use oauth2_types::errors::OAuth2Error;
use serde::Serialize;
use tera::{Context, Error as TeraError, Tera};
use thiserror::Error;
use tracing::{debug, info};
use url::Url;
use warp::reject::Reject;
use crate::{errors::ErroredForm, filters::CsrfToken, storage::SessionInfo};
#[derive(Clone)]
pub struct Templates(Arc<Tera>);
#[derive(Error, Debug)]
pub enum TemplateLoadingError {
#[error("could not load and compile some templates")]
Compile(#[from] TeraError),
#[error("missing templates {missing:?}")]
MissingTemplates {
missing: HashSet<String>,
loaded: HashSet<String>,
},
}
impl Templates {
/// Load the templates and check all needed templates are properly loaded
pub fn load() -> Result<Self, TemplateLoadingError> {
let path = format!("{}/templates/**/*.{{html,txt}}", env!("CARGO_MANIFEST_DIR"));
info!(%path, "Loading templates");
let tera = Tera::new(&path)?;
let loaded: HashSet<_> = tera.get_template_names().collect();
let needed: HashSet<_> = std::array::IntoIter::new(TEMPLATES).collect();
debug!(?loaded, ?needed, "Templates loaded");
let missing: HashSet<_> = needed.difference(&loaded).collect();
if missing.is_empty() {
Ok(Self(Arc::new(tera)))
} else {
let missing = missing.into_iter().map(ToString::to_string).collect();
let loaded = loaded.into_iter().map(ToString::to_string).collect();
Err(TemplateLoadingError::MissingTemplates { missing, loaded })
}
}
}
#[derive(Error, Debug)]
pub enum TemplateError {
#[error("could not prepare context for template {template:?}")]
Context {
template: &'static str,
#[source]
source: TeraError,
},
#[error("could not render template {template:?}")]
Render {
template: &'static str,
#[source]
source: TeraError,
},
}
impl Reject for TemplateError {}
/// Count the number of tokens. Used to have a fixed-sized array for the
/// templates list.
macro_rules! count {
() => (0_usize);
( $x:tt $($xs:tt)* ) => (1_usize + count!($($xs)*));
}
/// Macro that helps generating helper function that renders a specific template
/// with a strongly-typed context. It also register the template in a static
/// array to help detecting missing templates at startup time.
///
/// The syntax looks almost like a function to confuse syntax highlighter as
/// little as possible.
macro_rules! register_templates {
{
$(
// Match any attribute on the function, such as #[doc], #[allow(dead_code)], etc.
$( #[ $attr:meta ] )*
// The function name
pub fn $name:ident
// Optional list of generics. Taken from
// https://newbedev.com/rust-macro-accepting-type-with-generic-parameters
$(< $( $lt:tt $( : $clt:tt $(+ $dlt:tt )* )? ),+ >)?
// Type of context taken by the template
( $param:ty )
{
// The name of the template file
$template:expr
}
)*
} => {
/// List of registered templates
static TEMPLATES: [&'static str; count!( $( $template )* )] = [ $( $template ),* ];
impl Templates {
$(
$(#[$attr])?
pub fn $name
$(< $( $lt $( : $clt $(+ $dlt )* )? ),+ >)?
(&self, context: &$param)
-> Result<String, TemplateError> {
let ctx = Context::from_serialize(context)
.map_err(|source| TemplateError::Context { template: $template, source })?;
self.0.render($template, &ctx)
.map_err(|source| TemplateError::Render { template: $template, source })
}
)*
}
};
}
register_templates! {
/// Render the login page
pub fn render_login(WithCsrf<LoginContext>) { "login.html" }
/// Render the registration page
pub fn render_register(WithCsrf<()>) { "register.html" }
/// Render the home page
pub fn render_index(WithCsrf<WithOptionalSession<IndexContext>>) { "index.html" }
/// Render the re-authentication form
pub fn render_reauth(WithCsrf<WithSession<()>>) { "reauth.html" }
/// Render the form used by the form_post response mode
pub fn render_form_post<T: Serialize>(FormPostContext<T>) { "form_post.html" }
/// Render the HTML error page
pub fn render_error(ErrorContext) { "error.html" }
}
/// Helper trait to construct context wrappers
pub trait TemplateContext: Sized {
fn with_session(self, current_session: SessionInfo) -> WithSession<Self> {
WithSession {
current_session,
inner: self,
}
}
fn maybe_with_session(self, current_session: Option<SessionInfo>) -> WithOptionalSession<Self> {
WithOptionalSession {
current_session,
inner: self,
}
}
fn with_csrf(self, token: &CsrfToken) -> WithCsrf<Self> {
WithCsrf {
csrf_token: token.form_value(),
inner: self,
}
}
}
impl<T: Sized> TemplateContext for T {}
/// Context with a CSRF token in it
#[derive(Serialize)]
pub struct WithCsrf<T> {
csrf_token: String,
#[serde(flatten)]
inner: T,
}
/// Context with a user session in it
#[derive(Serialize)]
pub struct WithSession<T> {
current_session: SessionInfo,
#[serde(flatten)]
inner: T,
}
/// Context with an optional user session in it
#[derive(Serialize)]
pub struct WithOptionalSession<T> {
current_session: Option<SessionInfo>,
#[serde(flatten)]
inner: T,
}
// Context used by the `index.html` template
#[derive(Serialize)]
pub struct IndexContext {
discovery_url: Url,
}
impl IndexContext {
#[must_use] pub fn new(discovery_url: Url) -> Self {
Self { discovery_url }
}
}
#[derive(Serialize, Debug, Clone, Copy, Hash, PartialEq, Eq)]
#[serde(rename_all = "kebab-case")]
pub enum LoginFormField {
Username,
Password,
}
#[derive(Serialize)]
pub struct LoginContext {
form: ErroredForm<LoginFormField>,
}
impl LoginContext {
#[must_use] pub fn with_form_error(form: ErroredForm<LoginFormField>) -> Self {
Self { form }
}
}
impl Default for LoginContext {
fn default() -> Self {
Self {
form: ErroredForm::new(),
}
}
}
/// Context used by the `form_post.html` template
#[derive(Serialize)]
pub struct FormPostContext<T> {
redirect_uri: Url,
params: T,
}
impl<T> FormPostContext<T> {
pub fn new(redirect_uri: Url, params: T) -> Self {
Self {
redirect_uri,
params,
}
}
}
#[derive(Default, Serialize)]
pub struct ErrorContext {
code: Option<&'static str>,
description: Option<String>,
details: Option<String>,
}
impl ErrorContext {
#[must_use] pub fn new() -> Self {
Self::default()
}
#[must_use] pub fn with_code(mut self, code: &'static str) -> Self {
self.code = Some(code);
self
}
#[must_use] pub fn with_description(mut self, description: String) -> Self {
self.description = Some(description);
self
}
#[allow(dead_code)]
#[must_use] pub fn with_details(mut self, details: String) -> Self {
self.details = Some(details);
self
}
}
impl From<Box<dyn OAuth2Error>> for ErrorContext {
fn from(err: Box<dyn OAuth2Error>) -> Self {
let mut ctx = ErrorContext::new().with_code(err.error());
if let Some(desc) = err.description() {
ctx = ctx.with_description(desc);
}
ctx
}
}

176
crates/core/src/tokens.rs Normal file
View File

@ -0,0 +1,176 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::convert::TryInto;
use crc::{Crc, CRC_32_ISO_HDLC};
use oauth2_types::requests::TokenTypeHint;
use rand::{distributions::Alphanumeric, Rng};
use thiserror::Error;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TokenType {
AccessToken,
RefreshToken,
}
impl TokenType {
fn prefix(self) -> &'static str {
match self {
TokenType::AccessToken => "mat",
TokenType::RefreshToken => "mar",
}
}
fn match_prefix(prefix: &str) -> Option<Self> {
match prefix {
"mat" => Some(TokenType::AccessToken),
"mar" => Some(TokenType::RefreshToken),
_ => None,
}
}
}
impl PartialEq<TokenTypeHint> for TokenType {
fn eq(&self, other: &TokenTypeHint) -> bool {
matches!(
(self, other),
(TokenType::AccessToken, TokenTypeHint::AccessToken)
| (TokenType::RefreshToken, TokenTypeHint::RefreshToken)
)
}
}
const NUM: [u8; 62] = *b"0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
fn base62_encode(mut num: u32) -> String {
let mut res = String::with_capacity(6);
while num > 0 {
res.push(NUM[(num % 62) as usize] as char);
num /= 62;
}
format!("{:0>6}", res)
}
const CRC: Crc<u32> = Crc::<u32>::new(&CRC_32_ISO_HDLC);
pub fn generate(rng: impl Rng, token_type: TokenType) -> String {
let random_part: String = rng
.sample_iter(&Alphanumeric)
.take(30)
.map(char::from)
.collect();
let base = format!("{}_{}", token_type.prefix(), random_part);
let crc = CRC.checksum(base.as_bytes());
let crc = base62_encode(crc);
format!("{}_{}", base, crc)
}
#[derive(Debug, Error)]
pub enum TokenFormatError {
#[error("invalid token format")]
InvalidFormat,
#[error("unknown token prefix {prefix:?}")]
UnknownPrefix { prefix: String },
#[error("invalid crc {got:?}, expected {expected:?}")]
InvalidCrc { expected: String, got: String },
}
pub fn check(token: &str) -> Result<TokenType, TokenFormatError> {
let split: Vec<&str> = token.split('_').collect();
let [prefix, random_part, crc]: [&str; 3] = split
.try_into()
.map_err(|_| TokenFormatError::InvalidFormat)?;
if prefix.len() != 3 || random_part.len() != 30 || crc.len() != 6 {
return Err(TokenFormatError::InvalidFormat);
}
let token_type =
TokenType::match_prefix(prefix).ok_or_else(|| TokenFormatError::UnknownPrefix {
prefix: prefix.to_string(),
})?;
let base = format!("{}_{}", token_type.prefix(), random_part);
let expected_crc = CRC.checksum(base.as_bytes());
let expected_crc = base62_encode(expected_crc);
if crc != expected_crc {
return Err(TokenFormatError::InvalidCrc {
expected: expected_crc,
got: crc.to_string(),
});
}
Ok(token_type)
}
#[cfg(test)]
mod tests {
use std::collections::HashSet;
use rand::thread_rng;
use super::*;
#[test]
fn test_prefix_match() {
use TokenType::{AccessToken, RefreshToken};
assert_eq!(TokenType::match_prefix("mat"), Some(AccessToken));
assert_eq!(TokenType::match_prefix("mar"), Some(RefreshToken));
assert_eq!(TokenType::match_prefix("matt"), None);
assert_eq!(TokenType::match_prefix("marr"), None);
assert_eq!(TokenType::match_prefix("ma"), None);
assert_eq!(
TokenType::match_prefix(TokenType::AccessToken.prefix()),
Some(TokenType::AccessToken)
);
assert_eq!(
TokenType::match_prefix(TokenType::RefreshToken.prefix()),
Some(TokenType::RefreshToken)
);
}
#[test]
fn test_generate_and_check() {
const COUNT: usize = 500; // Generate 500 of each token type
let mut rng = thread_rng();
// Generate many access tokens
let tokens: HashSet<String> = (0..COUNT)
.map(|_| generate(&mut rng, TokenType::AccessToken))
.collect();
// Check that they are all different
assert_eq!(tokens.len(), COUNT, "All tokens are unique");
// Check that they are all valid and detected as access tokens
for token in tokens {
assert_eq!(check(&token).unwrap(), TokenType::AccessToken);
}
// Same, but for refresh tokens
let tokens: HashSet<String> = (0..COUNT)
.map(|_| generate(&mut rng, TokenType::RefreshToken))
.collect();
assert_eq!(tokens.len(), COUNT, "All tokens are unique");
for token in tokens {
assert_eq!(check(&token).unwrap(), TokenType::RefreshToken);
}
}
}

View File

@ -0,0 +1,66 @@
{#
Copyright 2021 The Matrix.org Foundation C.I.C.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
#}
<!DOCTYPE html>
<html>
<head>
<meta charset="utf-8">
<title>{% block title %}matrix-authentication-service{% endblock title %}</title>
<meta name="viewport" content="width=device-width, initial-scale=1">
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/bulma@0.9.3/css/bulma.min.css">
</head>
<body>
<nav class="navbar is-dark" role="navigation" aria-label="main navigation">
<div class="container">
<div class="navbar-brand">
<a class="navbar-item" href="/">
matrix-authentication-service
</a>
</div>
<div class="navbar-end">
{% if current_session %}
<div class="navbar-item">
Howdy {{ current_session.username }}!
</div>
<div class="navbar-item">
<form method="POST" action="/logout">
<input type="hidden" name="csrf" value="{{ csrf_token }}" />
<button class="button is-light" action="submit">
Log out
</button>
</form>
</div>
{% else %}
<div class="navbar-item">
<a class="button is-light" href="/login">
Log in
</a>
</div>
<div class="navbar-item">
<a class="button is-light" href="/register">
Register
</a>
</div>
{% endif %}
</div>
</div>
</nav>
{% block content %}{% endblock content %}
</body>
</html>

View File

@ -0,0 +1,39 @@
{#
Copyright 2021 The Matrix.org Foundation C.I.C.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
#}
{% extends "base.html" %}
{% block content %}
<section class="hero is-danger">
<div class="hero-body">
<div class="container">
{% if code %}
<p class="title">
{{ code }}
</p>
{% endif %}
{% if description %}
<p class="subtitle">
{{ description }}
</p>
{% endif %}
{% if details %}
<pre><code>{{ details }}</code></pre>
{% endif %}
</div>
</div>
</section>
{% endblock %}

View File

@ -0,0 +1,27 @@
{#
Copyright 2021 The Matrix.org Foundation C.I.C.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
#}
{% if code %}
{{- code }}
{% endif %}
{%- if description %}
{{ description }}
{% endif %}
{%- if details %}
{{ details }}
{% endif %}

View File

@ -0,0 +1,31 @@
{#
Copyright 2021 The Matrix.org Foundation C.I.C.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
#}
<!DOCTYPE html>
<html>
<head>
<meta charset="utf-8">
<title>Redirecting to client</title>
<meta name="viewport" content="width=device-width, initial-scale=1">
</head>
<body onload="javascript:document.forms[0].submit()">
<form method="post" action="{{ redirect_uri }}">
{% for key, value in params %}
<input type="hidden" name="{{ key }}" value="{{ value }}" />
{% endfor %}
</form>
</body>
</html>

View File

@ -0,0 +1,28 @@
{#
Copyright 2021 The Matrix.org Foundation C.I.C.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
#}
{% extends "base.html" %}
{% block content %}
<section class="section">
<div class="container content">
<h1>Matrix Authentication Service</h1>
<p>
OpenID Connect discovery document: <a href="{{ discovery_url }}">{{ discovery_url }}</a>
</p>
</div>
</section>
{% endblock content %}

View File

@ -0,0 +1,70 @@
{#
Copyright 2021 The Matrix.org Foundation C.I.C.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
#}
{% extends "base.html" %}
{% block content %}
<section class="section">
<div class="container is-max-desktop">
<div class="columns">
<div class="column is-half is-offset-one-quarter">
{% if form.has_errors %}
<article class="message is-danger">
<div class="message-body">
{% for message in form.form_errors %}
<p>{{ message | safe }}</p>
{% else %}
<p>Login failed, check the fields below for more details.</p>
{% endfor %}
</div>
</article>
{% endif %}
<form method="POST">
<input type="hidden" name="csrf" value="{{ csrf_token }}" />
<div class="field">
<label class="label" for="login-username">Username</label>
<div class="control">
<input class="input{% if 'username' in form.fields_errors %} is-danger{% endif %}" name="username" id="login-username" type="text">
</div>
{% if 'username' in form.fields_errors %}
{% for message in form.fields_errors.username %}
<p class="help is-danger">{{ message | safe }}</p>
{% endfor %}
{% endif %}
</div>
<div class="field">
<label class="label" for="login-password">Password</label>
<div class="control">
<input class="input{% if 'password' in form.fields_errors %} is-danger{% endif %}" name="password" id="login-password" type="password">
</div>
{% if 'password' in form.fields_errors %}
{% for message in form.fields_errors.password %}
<p class="help is-danger">{{ message | safe }}</p>
{% endfor %}
{% endif %}
</div>
<div class="control">
<button type="submit" class="button is-link">Login</button>
</div>
</form>
</div>
</div>
</div>
</section>
{% endblock content %}

View File

@ -0,0 +1,45 @@
{#
Copyright 2021 The Matrix.org Foundation C.I.C.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
#}
{% extends "base.html" %}
{% block content %}
<section class="section">
<div class="container is-max-desktop">
<div class="columns">
<div class="column is-one-third">
<form method="POST">
<input type="hidden" name="csrf" value="{{ csrf_token }}" />
<div class="field">
<label class="label" for="login-password">Password</label>
<div class="control">
<input class="input" name="password" id="login-password" type="password">
</div>
</div>
<div class="control">
<button type="submit" class="button is-link">Submit</button>
</div>
</form>
</div>
<div class="column is-two-third">
<pre><code>{{ current_session | json_encode(pretty=True) | safe }}</code></pre>
</div>
</div>
</div>
</section>
{% endblock content %}

View File

@ -0,0 +1,55 @@
{#
Copyright 2021 The Matrix.org Foundation C.I.C.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
#}
{% extends "base.html" %}
{% block content %}
<section class="section">
<div class="container is-max-desktop">
<div class="columns">
<div class="column is-one-third">
<form method="POST">
<input type="hidden" name="csrf" value="{{ csrf_token }}" />
<div class="field">
<label class="label" for="register-username">Username</label>
<div class="control">
<input class="input" name="username" id="register-username" type="text">
</div>
</div>
<div class="field">
<label class="label" for="register-password">Password</label>
<div class="control">
<input class="input" name="password" id="register-password" type="password">
</div>
</div>
<div class="field">
<label class="label" for="register-password">Confirm password</label>
<div class="control">
<input class="input" name="password_confirm" id="register-password-confirm" type="password">
</div>
</div>
<div class="control">
<button type="submit" class="button is-link">Register</button>
</div>
</form>
</div>
</div>
</div>
</section>
{% endblock content %}

View File

@ -0,0 +1,21 @@
[package]
name = "oauth2-types"
version = "0.1.0"
authors = ["Quentin Gliech <quenting@element.io>"]
edition = "2018"
license = "Apache-2.0"
[dependencies]
http = "0.2.4"
serde = "1.0.130"
serde_json = "1.0.68"
language-tags = { version = "0.3.2", features = ["serde"] }
url = { version = "2.2.2", features = ["serde"] }
parse-display = "0.5.1"
indoc = "1.0.3"
serde_with = { version = "1.10.0", features = ["chrono"] }
sqlx = { version = "0.5.7", default-features = false, optional = true }
chrono = "0.4.19"
[features]
sqlx_type = ["sqlx"]

View File

@ -0,0 +1,268 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use http::status::StatusCode;
use serde::ser::{Serialize, SerializeMap};
use url::Url;
pub trait OAuth2Error: std::fmt::Debug + Send + Sync {
/// A single ASCII error code.
///
/// Maps to the required "error" field.
fn error(&self) -> &'static str;
/// Human-readable ASCII text providing additional information, used to
/// assist the client developer in understanding the error that
/// occurred.
///
/// Maps to the optional `error_description` field.
fn description(&self) -> Option<String> {
None
}
/// A URI identifying a human-readable web page with information about the
/// error, used to provide the client developer with additional
/// information about the error.
///
/// Maps to the optional `error_uri` field.
fn uri(&self) -> Option<Url> {
None
}
/// Wraps the error with an `ErrorResponse` to help serializing.
fn into_response(self) -> ErrorResponse
where
Self: Sized + 'static,
{
ErrorResponse(Box::new(self))
}
}
pub trait OAuth2ErrorCode: OAuth2Error + 'static {
/// The HTTP status code that must be returned by this error
fn status(&self) -> StatusCode;
}
impl OAuth2Error for &Box<dyn OAuth2ErrorCode> {
fn error(&self) -> &'static str {
self.as_ref().error()
}
fn description(&self) -> Option<String> {
self.as_ref().description()
}
fn uri(&self) -> Option<Url> {
self.as_ref().uri()
}
}
#[derive(Debug)]
pub struct ErrorResponse(Box<dyn OAuth2Error>);
impl From<Box<dyn OAuth2Error>> for ErrorResponse {
fn from(b: Box<dyn OAuth2Error>) -> Self {
Self(b)
}
}
impl OAuth2Error for ErrorResponse {
fn error(&self) -> &'static str {
self.0.error()
}
fn description(&self) -> Option<String> {
self.0.description()
}
fn uri(&self) -> Option<Url> {
self.0.uri()
}
}
impl Serialize for ErrorResponse {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
let error = self.0.error();
let description = self.0.description();
let uri = self.0.uri();
// Count the number of fields to serialize
let len = {
let mut x = 1;
if description.is_some() {
x += 1;
}
if uri.is_some() {
x += 1;
}
x
};
let mut map = serializer.serialize_map(Some(len))?;
map.serialize_entry("error", error)?;
if let Some(ref description) = description {
map.serialize_entry("error_description", description)?;
}
if let Some(ref uri) = uri {
map.serialize_entry("error_uri", uri)?;
}
map.end()
}
}
macro_rules! oauth2_error_def {
($name:ident) => {
#[derive(Debug, Clone)]
pub struct $name;
};
}
macro_rules! oauth2_error_status {
($name:ident, $code:ident) => {
impl $crate::errors::OAuth2ErrorCode for $name {
fn status(&self) -> ::http::status::StatusCode {
::http::status::StatusCode::$code
}
}
};
}
macro_rules! oauth2_error_error {
($err:literal) => {
fn error(&self) -> &'static str {
$err
}
};
}
macro_rules! oauth2_error_description {
($description:expr) => {
fn description(&self) -> Option<String> {
Some(($description).to_string())
}
};
}
macro_rules! oauth2_error {
($name:ident, $err:literal => $description:expr) => {
oauth2_error_def!($name);
impl $crate::errors::OAuth2Error for $name {
oauth2_error_error!($err);
oauth2_error_description!(indoc::indoc! {$description});
}
};
($name:ident, $err:literal) => {
oauth2_error_def!($name);
impl $crate::errors::OAuth2Error for $name {
oauth2_error_error!($err);
}
};
($name:ident, code: $code:ident, $err:literal => $description:expr) => {
oauth2_error!($name, $err => $description);
oauth2_error_status!($name, $code);
};
($name:ident, code: $code:ident, $err:literal) => {
oauth2_error!($name, $err);
oauth2_error_status!($name, $code);
};
}
pub mod rfc6749 {
oauth2_error! {
InvalidRequest,
code: BAD_REQUEST,
"invalid_request" =>
"The request is missing a required parameter, includes an invalid parameter value, \
includes a parameter more than once, or is otherwise malformed."
}
oauth2_error! {
InvalidClient,
code: BAD_REQUEST,
"invalid_client" =>
"Client authentication failed."
}
oauth2_error! {
InvalidGrant,
code: BAD_REQUEST,
"invalid_grant"
}
oauth2_error! {
UnauthorizedClient,
code: BAD_REQUEST,
"unauthorized_client" =>
"The client is not authorized to request an access token using this method."
}
oauth2_error! {
UnsupportedGrantType,
code: BAD_REQUEST,
"unsupported_grant_type" =>
"The authorization grant type is not supported by the authorization server."
}
oauth2_error! {
AccessDenied,
"access_denied" =>
"The resource owner or authorization server denied the request."
}
oauth2_error! {
UnsupportedResponseType,
"unsupported_response_type" =>
"The authorization server does not support obtaining an access token using this method."
}
oauth2_error! {
InvalidScope,
code: BAD_REQUEST,
"invalid_scope" =>
"The requested scope is invalid, unknown, or malformed."
}
oauth2_error! {
ServerError,
"server_error" =>
"The authorization server encountered an unexpected \
condition that prevented it from fulfilling the request."
}
oauth2_error! {
TemporarilyUnavailable,
"temporarily_unavailable" =>
"The authorization server is currently unable to handle \
the request due to a temporary overloading or maintenance \
of the server."
}
}
pub use rfc6749::*;
#[cfg(test)]
mod tests {
use serde_json::json;
use super::*;
#[test]
fn serialize_error() {
let expected = json!({"error": "invalid_grant"});
let actual = serde_json::to_value(InvalidGrant.into_response()).unwrap();
assert_eq!(expected, actual);
}
}

View File

@ -0,0 +1,25 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#![forbid(unsafe_code)]
#![deny(clippy::all)]
#![warn(clippy::pedantic)]
pub mod errors;
pub mod oidc;
pub mod pkce;
pub mod requests;
#[cfg(test)]
mod test_utils;

View File

@ -0,0 +1,75 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::collections::HashSet;
use serde::Serialize;
use serde_with::skip_serializing_none;
use url::Url;
use crate::{
pkce::CodeChallengeMethod,
requests::{ClientAuthenticationMethod, GrantType, ResponseMode},
};
// TODO: https://datatracker.ietf.org/doc/html/rfc8414#section-2
#[skip_serializing_none]
#[derive(Serialize, Clone)]
pub struct Metadata {
/// The authorization server's issuer identifier, which is a URL that uses
/// the "https" scheme and has no query or fragment components.
pub issuer: Url,
/// URL of the authorization server's authorization endpoint.
pub authorization_endpoint: Option<Url>,
/// URL of the authorization server's token endpoint.
pub token_endpoint: Option<Url>,
/// URL of the authorization server's JWK Set document.
pub jwks_uri: Option<Url>,
/// URL of the authorization server's OAuth 2.0 Dynamic Client Registration
/// endpoint.
pub registration_endpoint: Option<Url>,
/// JSON array containing a list of the OAuth 2.0 "scope" values that this
/// authorization server supports.
pub scopes_supported: Option<HashSet<String>>,
/// JSON array containing a list of the OAuth 2.0 "response_type" values
/// that this authorization server supports.
pub response_types_supported: Option<HashSet<String>>,
/// JSON array containing a list of the OAuth 2.0 "response_mode" values
/// that this authorization server supports, as specified in "OAuth 2.0
/// Multiple Response Type Encoding Practices".
pub response_modes_supported: Option<HashSet<ResponseMode>>,
/// JSON array containing a list of the OAuth 2.0 grant type values that
/// this authorization server supports.
pub grant_types_supported: Option<HashSet<GrantType>>,
/// JSON array containing a list of client authentication methods supported
/// by this token endpoint.
pub token_endpoint_auth_methods_supported: Option<HashSet<ClientAuthenticationMethod>>,
/// PKCE code challenge methods supported by this authorization server
pub code_challenge_methods_supported: Option<HashSet<CodeChallengeMethod>>,
/// URL of the authorization server's OAuth 2.0 introspection endpoint.
pub introspection_endpoint: Option<Url>,
pub userinfo_endpoint: Option<Url>,
}

View File

@ -0,0 +1,48 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use parse_display::{Display, FromStr};
use serde::{Deserialize, Serialize};
#[derive(
Debug,
Hash,
PartialEq,
Eq,
PartialOrd,
Ord,
Clone,
Copy,
Display,
FromStr,
Serialize,
Deserialize,
)]
#[cfg_attr(feature = "sqlx_type", derive(sqlx::Type))]
#[repr(i8)]
pub enum CodeChallengeMethod {
#[serde(rename = "plain")]
#[display("plain")]
Plain = 0,
#[serde(rename = "S256")]
#[display("S256")]
S256 = 1,
}
#[derive(Serialize, Deserialize)]
pub struct Request {
pub code_challenge_method: CodeChallengeMethod,
pub code_challenge: String,
}

View File

@ -0,0 +1,414 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::{collections::HashSet, hash::Hash};
use chrono::{DateTime, Duration, Utc};
use language_tags::LanguageTag;
use parse_display::{Display, FromStr};
use serde::{Deserialize, Serialize};
use serde_with::{
rust::StringWithSeparator, serde_as, skip_serializing_none, DurationSeconds, SpaceSeparator,
TimestampSeconds,
};
use url::Url;
// ref: https://www.iana.org/assignments/oauth-parameters/oauth-parameters.xhtml
#[derive(
Debug,
Hash,
PartialEq,
Eq,
PartialOrd,
Ord,
Clone,
Copy,
Display,
FromStr,
Serialize,
Deserialize,
)]
#[display(style = "snake_case")]
#[serde(rename_all = "snake_case")]
pub enum ResponseType {
Code,
IdToken,
Token,
None,
}
#[derive(
Debug,
Hash,
PartialEq,
Eq,
PartialOrd,
Ord,
Clone,
Copy,
Display,
FromStr,
Serialize,
Deserialize,
)]
#[serde(rename_all = "snake_case")]
pub enum ResponseMode {
Query,
Fragment,
FormPost,
}
#[derive(
Debug,
Hash,
PartialEq,
Eq,
PartialOrd,
Ord,
Clone,
Copy,
Display,
FromStr,
Serialize,
Deserialize,
)]
#[serde(rename_all = "snake_case")]
pub enum ClientAuthenticationMethod {
None,
ClientSecretPost,
ClientSecretBasic,
}
#[derive(
Debug,
Hash,
PartialEq,
Eq,
PartialOrd,
Ord,
Clone,
Copy,
Display,
FromStr,
Serialize,
Deserialize,
)]
#[serde(rename_all = "snake_case")]
pub enum Display {
Page,
Popup,
Touch,
Wap,
}
#[derive(
Debug,
Hash,
PartialEq,
Eq,
PartialOrd,
Ord,
Clone,
Copy,
Display,
FromStr,
Serialize,
Deserialize,
)]
#[display(style = "snake_case")]
#[serde(rename_all = "snake_case")]
pub enum Prompt {
None,
Login,
Consent,
SelectAccount,
}
#[serde_as]
#[derive(Serialize, Deserialize)]
pub struct AuthorizationRequest {
#[serde_as(as = "StringWithSeparator::<SpaceSeparator, ResponseType>")]
pub response_type: HashSet<ResponseType>,
pub client_id: String,
pub redirect_uri: Option<Url>,
#[serde_as(as = "StringWithSeparator::<SpaceSeparator, String>")]
pub scope: HashSet<String>,
pub state: Option<String>,
pub response_mode: Option<ResponseMode>,
pub nonce: Option<String>,
display: Option<Display>,
#[serde_as(as = "Option<DurationSeconds<i64>>")]
#[serde(default)]
pub max_age: Option<Duration>,
#[serde_as(as = "Option<StringWithSeparator::<SpaceSeparator, LanguageTag>>")]
#[serde(default)]
ui_locales: Option<Vec<LanguageTag>>,
id_token_hint: Option<String>,
login_hint: Option<String>,
#[serde_as(as = "Option<StringWithSeparator::<SpaceSeparator, String>>")]
#[serde(default)]
acr_values: Option<HashSet<String>>,
}
#[derive(Serialize, Deserialize, Default)]
pub struct AuthorizationResponse<R> {
pub code: Option<String>,
pub state: Option<String>,
#[serde(flatten)]
pub response: R,
}
#[derive(
Debug,
Hash,
PartialEq,
Eq,
PartialOrd,
Ord,
Clone,
Copy,
Display,
FromStr,
Serialize,
Deserialize,
)]
#[serde(rename_all = "snake_case")]
pub enum TokenType {
Bearer,
}
#[derive(Serialize, Deserialize, Debug, PartialEq)]
pub struct AuthorizationCodeGrant {
pub code: String,
#[serde(default)]
pub redirect_uri: Option<Url>,
}
#[serde_as]
#[derive(Serialize, Deserialize, Debug, PartialEq)]
pub struct RefreshTokenGrant {
pub refresh_token: String,
#[serde(default)]
#[serde_as(as = "Option<StringWithSeparator::<SpaceSeparator, String>>")]
scope: Option<HashSet<String>>,
}
#[serde_as]
#[derive(Serialize, Deserialize, Debug, PartialEq)]
pub struct ClientCredentialsGrant {
#[serde(default)]
#[serde_as(as = "Option<StringWithSeparator::<SpaceSeparator, String>>")]
scope: Option<HashSet<String>>,
}
#[derive(
Debug,
Hash,
PartialEq,
Eq,
PartialOrd,
Ord,
Clone,
Copy,
Display,
FromStr,
Serialize,
Deserialize,
)]
#[serde(rename_all = "snake_case")]
pub enum GrantType {
AuthorizationCode,
RefreshToken,
ClientCredentials,
}
#[derive(Serialize, Deserialize, Debug, PartialEq)]
#[serde(tag = "grant_type", rename_all = "snake_case")]
pub enum AccessTokenRequest {
AuthorizationCode(AuthorizationCodeGrant),
RefreshToken(RefreshTokenGrant),
ClientCredentials(ClientCredentialsGrant),
#[serde(skip_deserializing, other)]
Unsupported,
}
#[serde_as]
#[skip_serializing_none]
#[derive(Serialize, Deserialize, Debug, PartialEq)]
pub struct AccessTokenResponse {
access_token: String,
refresh_token: Option<String>,
// TODO: this should be somewhere else
id_token: Option<String>,
token_type: TokenType,
#[serde_as(as = "Option<DurationSeconds<i64>>")]
expires_in: Option<Duration>,
#[serde_as(as = "Option<StringWithSeparator::<SpaceSeparator, String>>")]
scope: Option<HashSet<String>>,
}
impl AccessTokenResponse {
#[must_use]
pub fn new(access_token: String) -> AccessTokenResponse {
AccessTokenResponse {
access_token,
refresh_token: None,
id_token: None,
token_type: TokenType::Bearer,
expires_in: None,
scope: None,
}
}
#[must_use]
pub fn with_refresh_token(mut self, refresh_token: String) -> Self {
self.refresh_token = Some(refresh_token);
self
}
#[must_use]
pub fn with_id_token(mut self, id_token: String) -> Self {
self.id_token = Some(id_token);
self
}
#[must_use]
pub fn with_scopes(mut self, scope: HashSet<String>) -> Self {
self.scope = Some(scope);
self
}
#[must_use]
pub fn with_expires_in(mut self, expires_in: Duration) -> Self {
self.expires_in = Some(expires_in);
self
}
}
#[derive(Serialize, Deserialize, Debug, PartialEq)]
#[serde(rename_all = "snake_case")]
pub enum TokenTypeHint {
AccessToken,
RefreshToken,
}
#[skip_serializing_none]
#[derive(Serialize, Deserialize, Debug, PartialEq)]
pub struct IntrospectionRequest {
pub token: String,
#[serde(default)]
pub token_type_hint: Option<TokenTypeHint>,
}
#[serde_as]
#[skip_serializing_none]
#[derive(Serialize, Deserialize, Debug, PartialEq, Default)]
pub struct IntrospectionResponse {
pub active: bool,
#[serde_as(as = "Option<StringWithSeparator::<SpaceSeparator, String>>")]
pub scope: Option<HashSet<String>>,
pub client_id: Option<String>,
pub username: Option<String>,
pub token_type: Option<TokenTypeHint>,
#[serde_as(as = "Option<TimestampSeconds>")]
pub exp: Option<DateTime<Utc>>,
#[serde_as(as = "Option<TimestampSeconds>")]
pub iat: Option<DateTime<Utc>>,
#[serde_as(as = "Option<TimestampSeconds>")]
pub nbf: Option<DateTime<Utc>>,
pub sub: Option<String>,
pub aud: Option<String>,
pub iss: Option<String>,
pub jti: Option<String>,
}
#[cfg(test)]
mod tests {
use std::collections::HashSet;
use serde_json::json;
use super::*;
use crate::test_utils::assert_serde_json;
#[test]
fn serde_refresh_token_grant() {
let expected = json!({
"grant_type": "refresh_token",
"refresh_token": "abcd",
"scope": "openid",
});
let scope = {
let mut s = HashSet::new();
// TODO: insert multiple scopes and test it. It's a bit tricky to test since
// HashSet have no guarantees regarding the ordering of items, so right
// now the output is unstable.
s.insert("openid".to_string());
Some(s)
};
let req = AccessTokenRequest::RefreshToken(RefreshTokenGrant {
refresh_token: "abcd".into(),
scope,
});
assert_serde_json(&req, expected);
}
#[test]
fn serde_authorization_code_grant() {
let expected = json!({
"grant_type": "authorization_code",
"code": "abcd",
"redirect_uri": "https://example.com/redirect",
});
let req = AccessTokenRequest::AuthorizationCode(AuthorizationCodeGrant {
code: "abcd".into(),
redirect_uri: Some("https://example.com/redirect".parse().unwrap()),
});
assert_serde_json(&req, expected);
}
}

View File

@ -0,0 +1,30 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::fmt::Debug;
use serde::{de::DeserializeOwned, Serialize};
#[track_caller]
pub(crate) fn assert_serde_json<T: Serialize + DeserializeOwned + PartialEq + Debug>(
got: &T,
expected_value: serde_json::Value,
) {
let got_value = serde_json::to_value(&got).expect("could not serialize object as JSON value");
assert_eq!(got_value, expected_value);
let expected: T =
serde_json::from_value(expected_value).expect("could not serialize object as JSON value");
assert_eq!(got, &expected);
}