1
0
mirror of https://github.com/matrix-org/matrix-authentication-service.git synced 2025-07-29 22:01:14 +03:00

Create mas-oidc-client crate

Methods to interact as an RP with an OIDC OP.
This commit is contained in:
Kévin Commaille
2022-11-07 11:15:22 +01:00
committed by Quentin Gliech
parent c590e8df92
commit 90d0e12b7f
35 changed files with 6200 additions and 40 deletions

355
Cargo.lock generated
View File

@ -68,7 +68,7 @@ version = "0.7.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fcb51a0695d8f838b1ee009b3fbf66bda078cd64590202a864a8f3e8c4315c47"
dependencies = [
"getrandom",
"getrandom 0.2.8",
"once_cell",
"version_check",
]
@ -135,12 +135,33 @@ version = "0.9.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "71938f30533e4d95a6d17aa530939da3842c2ab6f4f84b9dae68447e4129f74a"
[[package]]
name = "assert-json-diff"
version = "2.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "47e4f2b81832e72834d7518d8487a0396a28cc408186a2e8854c0f98011faf12"
dependencies = [
"serde",
"serde_json",
]
[[package]]
name = "assert_matches"
version = "1.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9b34d609dfbaf33d6889b2b7106d3ca345eacad44200913df5ba02bfd31d2ba9"
[[package]]
name = "async-channel"
version = "1.7.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e14485364214912d3b19cc3435dde4df66065127f05fa0d75c712f36f12c2f28"
dependencies = [
"concurrent-queue",
"event-listener",
"futures-core",
]
[[package]]
name = "async-compression"
version = "0.3.15"
@ -851,6 +872,12 @@ dependencies = [
"either",
]
[[package]]
name = "cache-padded"
version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c1db59621ec70f09c5e9b597b220c7a2b43611f4710dc03ceb8748637775692c"
[[package]]
name = "camino"
version = "1.1.1"
@ -1004,6 +1031,15 @@ dependencies = [
"unicode-width",
]
[[package]]
name = "concurrent-queue"
version = "1.2.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "af4780a44ab5696ea9e28294517f1fffb421a83a25af521333c838635509db9c"
dependencies = [
"cache-padded",
]
[[package]]
name = "console"
version = "0.15.2"
@ -1042,7 +1078,7 @@ dependencies = [
"base64",
"hkdf",
"percent-encoding",
"rand",
"rand 0.8.5",
"sha2 0.10.6",
"subtle",
"time 0.3.17",
@ -1241,7 +1277,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ef2b4b23cddf68b89b8f8069890e8c270d54e2d5fe1b143820234805e4cb17ef"
dependencies = [
"generic-array",
"rand_core",
"rand_core 0.6.4",
"subtle",
"zeroize",
]
@ -1253,7 +1289,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1bfb12502f3fc46cca1bb51ac28df9d618d813cdc3d2f25b9fe775a34af26bb3"
dependencies = [
"generic-array",
"rand_core",
"rand_core 0.6.4",
"typenum",
]
@ -1386,6 +1422,25 @@ version = "2.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3ee2393c4a91429dffb4bedf19f4d6abf27d8a732c8ce4980305d782e5426d57"
[[package]]
name = "deadpool"
version = "0.9.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "421fe0f90f2ab22016f32a9881be5134fdd71c65298917084b0c7477cbc3856e"
dependencies = [
"async-trait",
"deadpool-runtime",
"num_cpus",
"retain_mut",
"tokio",
]
[[package]]
name = "deadpool-runtime"
version = "0.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "eaa37046cc0f6c3cc6090fbdbf73ef0b8ef4cfcc37f6befc0020f63e8cf121e1"
[[package]]
name = "der"
version = "0.6.0"
@ -1465,6 +1520,12 @@ dependencies = [
"winapi",
]
[[package]]
name = "discard"
version = "1.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "212d0f5754cb6769937f4501cc0e67f4f4483c8d2c3e1e922ee9edbe4ab4c7c0"
[[package]]
name = "dotenv"
version = "0.15.0"
@ -1520,7 +1581,7 @@ dependencies = [
"hkdf",
"pem-rfc7468",
"pkcs8",
"rand_core",
"rand_core 0.6.4",
"sec1",
"subtle",
"zeroize",
@ -1627,7 +1688,7 @@ version = "0.12.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d013fc25338cc558c5c2cfbad646908fb23591e2404481826742b651c9af7160"
dependencies = [
"rand_core",
"rand_core 0.6.4",
"subtle",
]
@ -1753,6 +1814,21 @@ version = "0.3.25"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "00f5fb52a06bdcadeb54e8d3671f8888a39697dcb0b81b23b55174030427f4eb"
[[package]]
name = "futures-lite"
version = "1.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7694489acd39452c77daa48516b894c153f192c3578d5a839b62c58099fcbf48"
dependencies = [
"fastrand",
"futures-core",
"futures-io",
"memchr",
"parking",
"pin-project-lite",
"waker-fn",
]
[[package]]
name = "futures-macro"
version = "0.3.25"
@ -1764,6 +1840,21 @@ dependencies = [
"syn",
]
[[package]]
name = "futures-signals"
version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a3acc659ba666cff13fdf65242d16428f2f11935b688f82e4024ad39667a5132"
dependencies = [
"discard",
"futures-channel",
"futures-core",
"futures-util",
"log",
"pin-project",
"serde",
]
[[package]]
name = "futures-sink"
version = "0.3.25"
@ -1776,6 +1867,12 @@ version = "0.3.25"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2ffb393ac5d9a6eaa9d3fdf37ae2776656b706e200c8e16b1bdb227f5198e6ea"
[[package]]
name = "futures-timer"
version = "3.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e64b03909df88034c26dc1547e8970b91f98bdb65165d6a4e9110d94263dbb2c"
[[package]]
name = "futures-util"
version = "0.3.25"
@ -1814,6 +1911,17 @@ dependencies = [
"version_check",
]
[[package]]
name = "getrandom"
version = "0.1.16"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8fc3cb4d91f53b50155bdcfd23f6a4c39ae1969c2ae85982b135750cccaf5fce"
dependencies = [
"cfg-if",
"libc",
"wasi 0.9.0+wasi-snapshot-preview1",
]
[[package]]
name = "getrandom"
version = "0.2.8"
@ -1877,7 +1985,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5dfbfb3a6cfbd390d5c9564ab283a0349b9b9fcd46a706c1eb10e0db70bfbac7"
dependencies = [
"ff",
"rand_core",
"rand_core 0.6.4",
"subtle",
]
@ -2034,6 +2142,27 @@ version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0bfe8eed0a9285ef776bb792479ea3834e8b94e13d615c2f66d03dd50a435a29"
[[package]]
name = "http-types"
version = "2.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6e9b187a72d63adbfba487f48095306ac823049cb504ee195541e91c7775f5ad"
dependencies = [
"anyhow",
"async-channel",
"base64",
"futures-lite",
"http",
"infer",
"pin-project-lite",
"rand 0.7.3",
"serde",
"serde_json",
"serde_qs",
"serde_urlencoded",
"url",
]
[[package]]
name = "httparse"
version = "1.8.0"
@ -2091,6 +2220,7 @@ dependencies = [
"http",
"hyper",
"rustls",
"rustls-native-certs",
"tokio",
"tokio-rustls",
]
@ -2193,6 +2323,12 @@ version = "1.0.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "adab1eaa3408fb7f0c777a73e7465fd5656136fc93b670eb6df3c88c2c1344e3"
[[package]]
name = "infer"
version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "64e9829a50b42bb782c1df523f78d332fe371b10c661e78b7a3c34b0198e9fac"
[[package]]
name = "inlinable_string"
version = "0.1.15"
@ -2478,7 +2614,7 @@ dependencies = [
"mas-storage",
"mas-templates",
"mime",
"rand",
"rand 0.8.5",
"serde",
"serde_json",
"serde_urlencoded",
@ -2527,8 +2663,8 @@ dependencies = [
"opentelemetry-semantic-conventions",
"opentelemetry-zipkin",
"prometheus",
"rand",
"rand_chacha",
"rand 0.8.5",
"rand_chacha 0.3.1",
"rustls",
"serde_json",
"serde_yaml",
@ -2559,8 +2695,8 @@ dependencies = [
"mas-jose",
"mas-keystore",
"pem-rfc7468",
"rand",
"rand_chacha",
"rand 0.8.5",
"rand_chacha 0.3.1",
"rustls-pemfile",
"schemars",
"serde",
@ -2583,7 +2719,7 @@ dependencies = [
"mas-iana",
"mas-jose",
"oauth2-types",
"rand",
"rand 0.8.5",
"serde",
"thiserror",
"url",
@ -2656,8 +2792,8 @@ dependencies = [
"mas-templates",
"mime",
"oauth2-types",
"rand",
"rand_chacha",
"rand 0.8.5",
"rand_chacha 0.3.1",
"serde",
"serde_json",
"serde_urlencoded",
@ -2749,8 +2885,8 @@ dependencies = [
"mas-iana",
"p256",
"p384",
"rand",
"rand_chacha",
"rand 0.8.5",
"rand_chacha 0.3.1",
"rsa",
"schemars",
"sec1",
@ -2787,8 +2923,8 @@ dependencies = [
"pem-rfc7468",
"pkcs1",
"pkcs8",
"rand",
"rand_chacha",
"rand 0.8.5",
"rand_chacha 0.3.1",
"rsa",
"sec1",
"spki",
@ -2817,6 +2953,47 @@ dependencies = [
"tracing-subscriber",
]
[[package]]
name = "mas-oidc-client"
version = "0.1.0"
dependencies = [
"assert_matches",
"base64ct",
"bitflags",
"bytes 1.3.0",
"chrono",
"form_urlencoded",
"futures 0.3.25",
"futures-signals",
"futures-util",
"headers",
"http",
"http-body",
"hyper",
"hyper-rustls",
"mas-http",
"mas-iana",
"mas-jose",
"mas-keystore",
"mime",
"oauth2-types",
"once_cell",
"rand 0.8.5",
"rand_chacha 0.3.1",
"rustls",
"serde",
"serde_json",
"serde_urlencoded",
"serde_with",
"thiserror",
"tokio",
"tower",
"tower-http",
"tracing",
"url",
"wiremock",
]
[[package]]
name = "mas-policy"
version = "0.1.0"
@ -2872,8 +3049,8 @@ dependencies = [
"mas-jose",
"oauth2-types",
"password-hash",
"rand",
"rand_chacha",
"rand 0.8.5",
"rand_chacha 0.3.1",
"serde",
"serde_json",
"sqlx",
@ -3075,7 +3252,7 @@ dependencies = [
"num-integer",
"num-iter",
"num-traits",
"rand",
"rand 0.8.5",
"smallvec",
"zeroize",
]
@ -3175,7 +3352,7 @@ dependencies = [
"json-patch",
"md-5",
"parse-size",
"rand",
"rand 0.8.5",
"semver",
"serde",
"serde_json",
@ -3348,7 +3525,7 @@ dependencies = [
"once_cell",
"opentelemetry_api",
"percent-encoding",
"rand",
"rand 0.8.5",
"thiserror",
"tokio",
"tokio-stream",
@ -3397,6 +3574,12 @@ dependencies = [
"sha2 0.10.6",
]
[[package]]
name = "parking"
version = "2.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "427c3892f9e783d91cc128285287e70a59e206ca452770ece88a76f7a3eddd72"
[[package]]
name = "parking_lot"
version = "0.11.2"
@ -3493,7 +3676,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7676374caaee8a325c9e7a2ae557f216c5563a171d6997b0ef8a65af35147700"
dependencies = [
"base64ct",
"rand_core",
"rand_core 0.6.4",
"subtle",
]
@ -3630,7 +3813,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b1181c94580fa345f50f19d738aaa39c0ed30a600d95cb2d3e23f94266f14fbf"
dependencies = [
"phf_shared",
"rand",
"rand 0.8.5",
]
[[package]]
@ -3711,7 +3894,7 @@ checksum = "9eca2c590a5f85da82668fa685c09ce2888b9430e83299debf1f34b65fd4a4ba"
dependencies = [
"der",
"pkcs5",
"rand_core",
"rand_core 0.6.4",
"spki",
]
@ -3911,6 +4094,19 @@ version = "0.4.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3fee2dce59f7a43418e3382c766554c614e06a552d53a8f07ef499ea4b332c0f"
[[package]]
name = "rand"
version = "0.7.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6a6b1679d49b24bbfe0c803429aa1874472f50d9b363131f0e89fc356b544d03"
dependencies = [
"getrandom 0.1.16",
"libc",
"rand_chacha 0.2.2",
"rand_core 0.5.1",
"rand_hc",
]
[[package]]
name = "rand"
version = "0.8.5"
@ -3918,8 +4114,18 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404"
dependencies = [
"libc",
"rand_chacha",
"rand_core",
"rand_chacha 0.3.1",
"rand_core 0.6.4",
]
[[package]]
name = "rand_chacha"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f4c8ed856279c9737206bf725bf36935d8666ead7aa69b52be55af369d193402"
dependencies = [
"ppv-lite86",
"rand_core 0.5.1",
]
[[package]]
@ -3929,7 +4135,16 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88"
dependencies = [
"ppv-lite86",
"rand_core",
"rand_core 0.6.4",
]
[[package]]
name = "rand_core"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "90bde5296fc891b0cef12a6d03ddccc162ce7b2aff54160af9338f8d40df6d19"
dependencies = [
"getrandom 0.1.16",
]
[[package]]
@ -3938,7 +4153,16 @@ version = "0.6.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c"
dependencies = [
"getrandom",
"getrandom 0.2.8",
]
[[package]]
name = "rand_hc"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ca3129af7b92a17112d59ad498c6f81eaf463253766b90396d39ea7a39d6613c"
dependencies = [
"rand_core 0.5.1",
]
[[package]]
@ -3956,7 +4180,7 @@ version = "0.4.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b033d837a7cf162d7993aded9304e30a83213c648b6e389db233191f891e5c2b"
dependencies = [
"getrandom",
"getrandom 0.2.8",
"redox_syscall",
"thiserror",
]
@ -4047,6 +4271,12 @@ dependencies = [
"winreg",
]
[[package]]
name = "retain_mut"
version = "0.1.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4389f1d5789befaf6029ebd9f7dac4af7f7e3d61b69d4f30e2ac02b57e7712b0"
[[package]]
name = "rfc6979"
version = "0.3.1"
@ -4087,7 +4317,7 @@ dependencies = [
"num-traits",
"pkcs1",
"pkcs8",
"rand_core",
"rand_core 0.6.4",
"signature",
"smallvec",
"subtle",
@ -4357,6 +4587,17 @@ dependencies = [
"serde",
]
[[package]]
name = "serde_qs"
version = "0.8.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c7715380eec75f029a4ef7de39a9200e0a63823176b759d055b613f5a87df6a6"
dependencies = [
"percent-encoding",
"serde",
"thiserror",
]
[[package]]
name = "serde_urlencoded"
version = "0.7.1"
@ -4481,7 +4722,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "74233d3b3b2f6d4b006dc19dee745e73e2a6bfb6f93607cd3b02bd5b00797d7c"
dependencies = [
"digest 0.10.6",
"rand_core",
"rand_core 0.6.4",
]
[[package]]
@ -4621,7 +4862,7 @@ dependencies = [
"once_cell",
"paste",
"percent-encoding",
"rand",
"rand 0.8.5",
"rustls",
"rustls-pemfile",
"serde",
@ -4781,7 +5022,7 @@ dependencies = [
"percent-encoding",
"pest",
"pest_derive",
"rand",
"rand 0.8.5",
"regex",
"serde",
"serde_json",
@ -5096,7 +5337,7 @@ dependencies = [
"indexmap",
"pin-project",
"pin-project-lite",
"rand",
"rand 0.8.5",
"slab",
"tokio",
"tokio-util 0.7.4",
@ -5271,7 +5512,7 @@ dependencies = [
"http",
"httparse",
"log",
"rand",
"rand 0.8.5",
"sha-1",
"thiserror",
"url",
@ -5307,7 +5548,7 @@ version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "13a3aaa69b04e5b66cc27309710a569ea23593612387d67daaf102e73aa974fd"
dependencies = [
"rand",
"rand 0.8.5",
"serde",
"uuid",
]
@ -5483,6 +5724,12 @@ version = "0.9.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f"
[[package]]
name = "waker-fn"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9d5b2c62b4012a3e1eca5a7e077d13b3bf498c4073e33ccd58626607748ceeca"
[[package]]
name = "walkdir"
version = "2.3.2"
@ -5504,6 +5751,12 @@ dependencies = [
"try-lock",
]
[[package]]
name = "wasi"
version = "0.9.0+wasi-snapshot-preview1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cccddf32554fecc6acb585f82a32a72e28b48f8c4c1883ddfeeeaa96f7d8e519"
[[package]]
name = "wasi"
version = "0.10.0+wasi-snapshot-preview1"
@ -5750,7 +6003,7 @@ dependencies = [
"mach",
"memoffset",
"paste",
"rand",
"rand 0.8.5",
"rustix",
"thiserror",
"wasmtime-asm-macros",
@ -5981,6 +6234,28 @@ dependencies = [
"winapi",
]
[[package]]
name = "wiremock"
version = "0.5.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "249dc68542861d17eae4b4e5e8fb381c2f9e8f255a84f6771d5fdf8b6c03ce3c"
dependencies = [
"assert-json-diff",
"async-trait",
"base64",
"deadpool",
"futures 0.3.25",
"futures-timer",
"http-types",
"hyper",
"log",
"once_cell",
"regex",
"serde",
"serde_json",
"tokio",
]
[[package]]
name = "xmlparser"
version = "0.13.5"

View File

@ -0,0 +1,71 @@
[package]
name = "mas-oidc-client"
version = "0.1.0"
authors = ["Quentin Gliech <quenting@element.io>"]
edition = "2021"
license = "Apache-2.0"
[features]
default = ["hyper", "keystore"]
hyper = [
"dep:http-body",
"dep:hyper",
"dep:hyper-rustls",
"dep:rustls",
"dep:tower-http",
"tower/limit",
]
keystore = ["dep:mas-keystore"]
[dependencies]
base64ct = { version = "1.5.3", features = ["std"] }
bytes = "1.3.0"
chrono = "0.4.23"
form_urlencoded = "1.1.0"
futures = "0.3.25"
futures-signals = "0.3.31"
futures-util = "0.3.25"
headers = "0.3.8"
http = "0.2.8"
once_cell = "1.16.0"
mime = "0.3.16"
rand = "0.8.5"
serde = { version = "1.0.147", features = ["derive"] }
serde_json = "1.0.88"
serde_urlencoded = "0.7.1"
serde_with = "2.1.0"
thiserror = "1.0.37"
tokio = { version = "1.22.0", features = ["rt", "macros", "rt-multi-thread"] }
tower = { version = "0.4.13", features = ["full"] }
tracing = "0.1.37"
url = { version = "2.3.1", features = ["serde"] }
mas-http = { path = "../http" }
mas-iana = { path = "../iana" }
mas-jose = { path = "../jose" }
mas-keystore = { path = "../keystore", optional = true }
oauth2-types = { path = "../oauth2-types" }
# Default http service
http-body = { version = "0.4.5", optional = true }
rustls = {version = "0.20.7", optional = true }
[dependencies.hyper-rustls]
version = "0.23.1"
features = ["http1", "http2", "rustls-native-certs"]
default-features = false
optional = true
[dependencies.hyper]
version = "0.14.23"
features = ["client", "http1", "http2", "stream", "runtime" ]
optional = true
[dependencies.tower-http]
version = "0.3.4"
features = ["follow-redirect", "decompression-full", "set-header", "timeout"]
optional = true
[dev-dependencies]
assert_matches = "1.5.0"
bitflags = "1.3.2"
mas-keystore = { path = "../keystore" }
rand_chacha = "0.3.1"
wiremock = "0.5.15"

View File

@ -0,0 +1,714 @@
// Copyright 2022 Kévin Commaille.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! The error types used in this crate.
use std::{str::Utf8Error, sync::Arc};
use headers::authorization::InvalidBearerToken;
use http::{header::ToStrError, StatusCode};
use mas_http::{catch_http_codes, form_urlencoded_request, json_request, json_response};
use mas_jose::{
claims::ClaimError,
jwa::InvalidAlgorithm,
jwt::{JwtDecodeError, JwtSignatureError, NoKeyWorked},
};
use mas_keystore::WrongAlgorithmError;
use oauth2_types::{
errors::ClientErrorCode, oidc::ProviderMetadataVerificationError, pkce::CodeChallengeError,
};
use serde::{Deserialize, Serialize};
use thiserror::Error;
pub use tower::BoxError;
/// All possible errors when using this crate.
#[derive(Debug, Error)]
#[error(transparent)]
pub enum Error {
/// An error occurred fetching provider metadata.
Discovery(#[from] DiscoveryError),
/// An error occurred fetching the provider JWKS.
Jwks(#[from] JwksError),
/// An error occurred during client registration.
Registration(#[from] RegistrationError),
/// An error occurred building the authorization URL.
Authorization(#[from] AuthorizationError),
/// An error occurred exchanging an authorization code for an access token.
TokenAuthorizationCode(#[from] TokenAuthorizationCodeError),
/// An error occurred requesting an access token with client credentials.
TokenClientCredentials(#[from] TokenRequestError),
/// An error occurred refreshing an access token.
TokenRefresh(#[from] TokenRefreshError),
/// An error occurred revoking a token.
TokenRevoke(#[from] TokenRevokeError),
/// An error occurred requesting user info.
UserInfo(#[from] UserInfoError),
/// An error occurred introspecting a token.
Introspection(#[from] IntrospectionError),
}
/// All possible errors when fetching provider metadata.
#[derive(Debug, Error)]
pub enum DiscoveryError {
/// An error occurred building the request's URL.
#[error(transparent)]
IntoUrl(#[from] url::ParseError),
/// An error occurred building the request.
#[error(transparent)]
IntoHttp(#[from] http::Error),
/// The server returned an HTTP error status code.
#[error(transparent)]
Http(#[from] HttpError),
/// An error occurred deserializing the response.
#[error(transparent)]
FromJson(#[from] serde_json::Error),
/// An error occurred validating the metadata.
#[error(transparent)]
Validation(#[from] ProviderMetadataVerificationError),
/// An error occurred sending the request.
#[error(transparent)]
Service(BoxError),
}
impl<S> From<json_response::Error<S>> for DiscoveryError
where
S: Into<DiscoveryError>,
{
fn from(err: json_response::Error<S>) -> Self {
match err {
json_response::Error::Deserialize { inner } => inner.into(),
json_response::Error::Service { inner } => inner.into(),
}
}
}
impl<S> From<catch_http_codes::Error<S, Option<ErrorBody>>> for DiscoveryError
where
S: Into<BoxError>,
{
fn from(err: catch_http_codes::Error<S, Option<ErrorBody>>) -> Self {
match err {
catch_http_codes::Error::HttpError { status_code, inner } => {
Self::Http(HttpError::new(status_code, inner))
}
catch_http_codes::Error::Service { inner } => Self::Service(inner.into()),
}
}
}
/// All possible errors when registering the client.
#[derive(Debug, Error)]
pub enum RegistrationError {
/// An error occurred building the request.
#[error(transparent)]
IntoHttp(#[from] http::Error),
/// An error occurred serializing the request or deserializing the response.
#[error(transparent)]
Json(#[from] serde_json::Error),
/// The server returned an HTTP error status code.
#[error(transparent)]
Http(#[from] HttpError),
/// No client secret was received although one was expected because of the
/// authentication method.
#[error("missing client secret in response")]
MissingClientSecret,
/// An error occurred sending the request.
#[error(transparent)]
Service(BoxError),
}
impl<S> From<json_request::Error<S>> for RegistrationError
where
S: Into<RegistrationError>,
{
fn from(err: json_request::Error<S>) -> Self {
match err {
json_request::Error::Serialize { inner } => inner.into(),
json_request::Error::Service { inner } => inner.into(),
}
}
}
impl<S> From<json_response::Error<S>> for RegistrationError
where
S: Into<RegistrationError>,
{
fn from(err: json_response::Error<S>) -> Self {
match err {
json_response::Error::Deserialize { inner } => inner.into(),
json_response::Error::Service { inner } => inner.into(),
}
}
}
impl<S> From<catch_http_codes::Error<S, Option<ErrorBody>>> for RegistrationError
where
S: Into<BoxError>,
{
fn from(err: catch_http_codes::Error<S, Option<ErrorBody>>) -> Self {
match err {
catch_http_codes::Error::HttpError { status_code, inner } => {
HttpError::new(status_code, inner).into()
}
catch_http_codes::Error::Service { inner } => Self::Service(inner.into()),
}
}
}
/// All possible errors when making a pushed authorization request.
#[derive(Debug, Error)]
pub enum PushedAuthorizationError {
/// An error occurred serializing the request.
#[error(transparent)]
UrlEncoded(#[from] serde_urlencoded::ser::Error),
/// An error occurred building the request.
#[error(transparent)]
IntoHttp(#[from] http::Error),
/// An error occurred adding the client credentials to the request.
#[error(transparent)]
Credentials(#[from] CredentialsError),
/// The server returned an HTTP error status code.
#[error(transparent)]
Http(#[from] HttpError),
/// An error occurred deserializing the response.
#[error(transparent)]
Json(#[from] serde_json::Error),
/// An error occurred sending the request.
#[error(transparent)]
Service(BoxError),
}
impl<S> From<form_urlencoded_request::Error<S>> for PushedAuthorizationError
where
S: Into<PushedAuthorizationError>,
{
fn from(err: form_urlencoded_request::Error<S>) -> Self {
match err {
form_urlencoded_request::Error::Serialize { inner } => inner.into(),
form_urlencoded_request::Error::Service { inner } => inner.into(),
}
}
}
impl<S> From<json_response::Error<S>> for PushedAuthorizationError
where
S: Into<PushedAuthorizationError>,
{
fn from(err: json_response::Error<S>) -> Self {
match err {
json_response::Error::Deserialize { inner } => inner.into(),
json_response::Error::Service { inner } => inner.into(),
}
}
}
impl<S> From<catch_http_codes::Error<S, Option<ErrorBody>>> for PushedAuthorizationError
where
S: Into<BoxError>,
{
fn from(err: catch_http_codes::Error<S, Option<ErrorBody>>) -> Self {
match err {
catch_http_codes::Error::HttpError { status_code, inner } => {
HttpError::new(status_code, inner).into()
}
catch_http_codes::Error::Service { inner } => Self::Service(inner.into()),
}
}
}
/// All possible errors when authorizing the client.
#[derive(Debug, Error)]
pub enum AuthorizationError {
/// An error occurred constructing the PKCE code challenge.
#[error(transparent)]
Pkce(#[from] CodeChallengeError),
/// An error occurred serializing the request.
#[error(transparent)]
UrlEncoded(#[from] serde_urlencoded::ser::Error),
/// An error occurred making the PAR request.
#[error(transparent)]
PushedAuthorization(#[from] PushedAuthorizationError),
}
/// All possible errors when requesting an access token.
#[derive(Debug, Error)]
pub enum TokenRequestError {
/// An error occurred building the request.
#[error(transparent)]
IntoHttp(#[from] http::Error),
/// An error occurred adding the client credentials to the request.
#[error(transparent)]
Credentials(#[from] CredentialsError),
/// An error occurred serializing the request.
#[error(transparent)]
UrlEncoded(#[from] serde_urlencoded::ser::Error),
/// The server returned an HTTP error status code.
#[error(transparent)]
Http(#[from] HttpError),
/// An error occurred deserializing the response.
#[error(transparent)]
Json(#[from] serde_json::Error),
/// An error occurred sending the request.
#[error(transparent)]
Service(BoxError),
}
impl<S> From<form_urlencoded_request::Error<S>> for TokenRequestError
where
S: Into<TokenRequestError>,
{
fn from(err: form_urlencoded_request::Error<S>) -> Self {
match err {
form_urlencoded_request::Error::Serialize { inner } => inner.into(),
form_urlencoded_request::Error::Service { inner } => inner.into(),
}
}
}
impl<S> From<json_response::Error<S>> for TokenRequestError
where
S: Into<TokenRequestError>,
{
fn from(err: json_response::Error<S>) -> Self {
match err {
json_response::Error::Deserialize { inner } => inner.into(),
json_response::Error::Service { inner } => inner.into(),
}
}
}
impl<S> From<catch_http_codes::Error<S, Option<ErrorBody>>> for TokenRequestError
where
S: Into<BoxError>,
{
fn from(err: catch_http_codes::Error<S, Option<ErrorBody>>) -> Self {
match err {
catch_http_codes::Error::HttpError { status_code, inner } => {
HttpError::new(status_code, inner).into()
}
catch_http_codes::Error::Service { inner } => Self::Service(inner.into()),
}
}
}
/// All possible errors when exchanging a code for an access token.
#[derive(Debug, Error)]
pub enum TokenAuthorizationCodeError {
/// The nonce doesn't match the one that was sent.
#[error("wrong nonce")]
WrongNonce,
/// An error occurred requesting the access token.
#[error(transparent)]
Token(#[from] TokenRequestError),
/// An error occurred validating the ID Token.
#[error(transparent)]
IdToken(#[from] IdTokenError),
}
/// All possible errors when refreshing an access token.
#[derive(Debug, Error)]
pub enum TokenRefreshError {
/// An error occurred requesting the access token.
#[error(transparent)]
Token(#[from] TokenRequestError),
/// An error occurred validating the ID Token.
#[error(transparent)]
IdToken(#[from] IdTokenError),
}
/// All possible errors when revoking a token.
#[derive(Debug, Error)]
pub enum TokenRevokeError {
/// An error occurred building the request.
#[error(transparent)]
IntoHttp(#[from] http::Error),
/// An error occurred adding the client credentials to the request.
#[error(transparent)]
Credentials(#[from] CredentialsError),
/// An error occurred serializing the request.
#[error(transparent)]
UrlEncoded(#[from] serde_urlencoded::ser::Error),
/// An error occurred deserializing the error response.
#[error(transparent)]
Json(#[from] serde_json::Error),
/// The server returned an HTTP error status code.
#[error(transparent)]
Http(#[from] HttpError),
/// An error occurred sending the request.
#[error(transparent)]
Service(BoxError),
}
impl<S> From<form_urlencoded_request::Error<S>> for TokenRevokeError
where
S: Into<TokenRevokeError>,
{
fn from(err: form_urlencoded_request::Error<S>) -> Self {
match err {
form_urlencoded_request::Error::Serialize { inner } => inner.into(),
form_urlencoded_request::Error::Service { inner } => inner.into(),
}
}
}
impl<S> From<catch_http_codes::Error<S, Option<ErrorBody>>> for TokenRevokeError
where
S: Into<BoxError>,
{
fn from(err: catch_http_codes::Error<S, Option<ErrorBody>>) -> Self {
match err {
catch_http_codes::Error::HttpError { status_code, inner } => {
HttpError::new(status_code, inner).into()
}
catch_http_codes::Error::Service { inner } => Self::Service(inner.into()),
}
}
}
/// All possible errors when requesting user info.
#[derive(Debug, Error)]
pub enum UserInfoError {
/// An error occurred getting the provider metadata.
#[error(transparent)]
Discovery(#[from] Arc<DiscoveryError>),
/// The provider doesn't support requesting user info.
#[error("missing UserInfo support")]
MissingUserInfoSupport,
/// No token is available to get info from.
#[error("missing token")]
MissingToken,
/// No client metadata is available.
#[error("missing client metadata")]
MissingClientMetadata,
/// The access token is invalid.
#[error(transparent)]
Token(#[from] InvalidBearerToken),
/// An error occurred building the request.
#[error(transparent)]
IntoHttp(#[from] http::Error),
/// The content-type header is missing from the response.
#[error("missing response content-type")]
MissingResponseContentType,
/// The content-type header could not be decoded.
#[error("could not decoded response content-type: {0}")]
DecodeResponseContentType(#[from] ToStrError),
/// The content-type is not the one that was expected.
#[error("invalid response content-type {got:?}, expected {expected:?}")]
InvalidResponseContentType {
/// The expected content-type.
expected: String,
/// The returned content-type.
got: String,
},
/// An error occurred reading the response.
#[error(transparent)]
FromUtf8(#[from] Utf8Error),
/// An error occurred deserializing the JSON or error response.
#[error(transparent)]
Json(#[from] serde_json::Error),
/// An error occurred verifying the Id Token.
#[error(transparent)]
IdToken(#[from] IdTokenError),
/// The server returned an HTTP error status code.
#[error(transparent)]
Http(#[from] HttpError),
/// An error occurred sending the request.
#[error(transparent)]
Service(BoxError),
}
impl<S> From<catch_http_codes::Error<S, Option<ErrorBody>>> for UserInfoError
where
S: Into<BoxError>,
{
fn from(err: catch_http_codes::Error<S, Option<ErrorBody>>) -> Self {
match err {
catch_http_codes::Error::HttpError { status_code, inner } => {
HttpError::new(status_code, inner).into()
}
catch_http_codes::Error::Service { inner } => Self::Service(inner.into()),
}
}
}
/// All possible errors when introspecting a token.
#[derive(Debug, Error)]
pub enum IntrospectionError {
/// An error occurred building the request.
#[error(transparent)]
IntoHttp(#[from] http::Error),
/// An error occurred adding the client credentials to the request.
#[error(transparent)]
Credentials(#[from] CredentialsError),
/// The access token is invalid.
#[error(transparent)]
Token(#[from] InvalidBearerToken),
/// An error occurred serializing the request.
#[error(transparent)]
UrlEncoded(#[from] serde_urlencoded::ser::Error),
/// An error occurred deserializing the JSON or error response.
#[error(transparent)]
Json(#[from] serde_json::Error),
/// The server returned an HTTP error status code.
#[error(transparent)]
Http(#[from] HttpError),
/// An error occurred sending the request.
#[error(transparent)]
Service(BoxError),
}
impl<S> From<form_urlencoded_request::Error<S>> for IntrospectionError
where
S: Into<IntrospectionError>,
{
fn from(err: form_urlencoded_request::Error<S>) -> Self {
match err {
form_urlencoded_request::Error::Serialize { inner } => inner.into(),
form_urlencoded_request::Error::Service { inner } => inner.into(),
}
}
}
impl<S> From<json_response::Error<S>> for IntrospectionError
where
S: Into<IntrospectionError>,
{
fn from(err: json_response::Error<S>) -> Self {
match err {
json_response::Error::Deserialize { inner } => inner.into(),
json_response::Error::Service { inner } => inner.into(),
}
}
}
impl<S> From<catch_http_codes::Error<S, Option<ErrorBody>>> for IntrospectionError
where
S: Into<BoxError>,
{
fn from(err: catch_http_codes::Error<S, Option<ErrorBody>>) -> Self {
match err {
catch_http_codes::Error::HttpError { status_code, inner } => {
HttpError::new(status_code, inner).into()
}
catch_http_codes::Error::Service { inner } => Self::Service(inner.into()),
}
}
}
/// All possible errors when requesting a JWKS.
#[derive(Debug, Error)]
pub enum JwksError {
/// An error occurred building the request.
#[error(transparent)]
IntoHttp(#[from] http::Error),
/// An error occurred deserializing the response.
#[error(transparent)]
Json(#[from] serde_json::Error),
/// An error occurred sending the request.
#[error(transparent)]
Service(BoxError),
}
impl<S> From<json_response::Error<S>> for JwksError
where
S: Into<BoxError>,
{
fn from(err: json_response::Error<S>) -> Self {
match err {
json_response::Error::Service { inner } => Self::Service(inner.into()),
json_response::Error::Deserialize { inner } => Self::Json(inner),
}
}
}
/// All possible errors when verifying a JWT.
#[derive(Debug, Error)]
pub enum JwtVerificationError {
/// An error occured decoding the JWT.
#[error(transparent)]
JwtDecode(#[from] JwtDecodeError),
/// No key worked for verifying the JWT's signature.
#[error(transparent)]
JwtSignature(#[from] NoKeyWorked),
/// An error occurred extracting a claim.
#[error(transparent)]
Claim(#[from] ClaimError),
/// The issuer is not the one that sent the JWT.
#[error("wrong issuer claim")]
WrongIssuer,
/// The audience of the JWT is not this client.
#[error("wrong aud claim")]
WrongAudience,
/// The algorithm used for signing the JWT is not the one that was
/// requested.
#[error("wrong signature alg")]
WrongSignatureAlg,
}
/// All possible errors when verifying an ID token.
#[derive(Debug, Error)]
pub enum IdTokenError {
/// No ID Token was found in the response although one was expected.
#[error("ID token is missing")]
MissingIdToken,
/// The ID Token from the latest Authorization was not provided although
/// this request expects to be verified against one.
#[error("Authorization ID token is missing")]
MissingAuthIdToken,
/// An error occurred validating the ID Token's signature and basic claims.
#[error(transparent)]
Jwt(#[from] JwtVerificationError),
/// An error occurred extracting a claim.
#[error(transparent)]
Claim(#[from] ClaimError),
/// The subject identifier returned by the issuer is not the same as the one
/// we got before.
#[error("wrong subject identifier")]
WrongSubjectIdentifier,
/// The authentication time returned by the issuer is not the same as the
/// one we got before.
#[error("wrong authentication time")]
WrongAuthTime,
}
/// An error that can be returned by an OpenID Provider.
#[derive(Debug, Clone, Error)]
#[error("{status}: {body:?}")]
pub struct HttpError {
/// The status code of the error.
pub status: StatusCode,
/// The body of the error, if any.
pub body: Option<ErrorBody>,
}
impl HttpError {
/// Creates a new `HttpError` with the given status code and optional body.
#[must_use]
pub fn new(status: StatusCode, body: Option<ErrorBody>) -> Self {
Self { status, body }
}
}
/// The body of an error that can be returned by an OpenID Provider.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ErrorBody {
/// The error code.
pub error: ClientErrorCode,
/// Additional text description of the error for debugging.
pub error_description: Option<String>,
}
/// All errors that can occur when adding client credentials to the request.
#[derive(Debug, Error)]
pub enum CredentialsError {
/// Trying to use an unsupported authentication method.
#[error("unsupported authentication method")]
UnsupportedMethod,
/// When authenticationg with `private_key_jwt`, no private key was found
/// for the given algorithm.
#[error("no private key was found for the given algorithm")]
NoPrivateKeyFound,
/// The signing algorithm is invalid for this authentication method.
#[error("invalid algorithm: {0}")]
InvalidSigningAlgorithm(#[from] InvalidAlgorithm),
/// An error occurred when building the claims of the JWT.
#[error(transparent)]
JwtClaims(#[from] ClaimError),
/// The key found cannot be used with the algorithm.
#[error(transparent)]
JwtWrongAlgorithm(#[from] WrongAlgorithmError),
/// An error occurred when signing the JWT.
#[error(transparent)]
JwtSignature(#[from] JwtSignatureError),
/// An error occurred with a custom signing method.
#[error(transparent)]
Custom(BoxError),
}

View File

@ -0,0 +1,88 @@
// Copyright 2022 Kévin Commaille.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::task::Poll;
use bytes::Bytes;
use futures_util::future::BoxFuture;
use http::{Request, Response};
use http_body::{Body, Full};
use hyper::body::to_bytes;
use thiserror::Error;
use tower::{BoxError, Layer, Service};
#[derive(Debug, Error)]
#[error(transparent)]
pub enum BodyError<E> {
Decompression(BoxError),
Service(E),
}
#[derive(Clone)]
pub struct BodyService<S> {
inner: S,
}
impl<S> BodyService<S> {
pub const fn new(inner: S) -> Self {
Self { inner }
}
}
impl<S, E, ResBody> Service<Request<Bytes>> for BodyService<S>
where
S: Service<Request<Full<Bytes>>, Response = Response<ResBody>, Error = E>,
ResBody: Body<Data = Bytes, Error = BoxError> + Send,
S::Future: Send + 'static,
{
type Error = BodyError<E>;
type Response = Response<Bytes>;
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx).map_err(BodyError::Service)
}
fn call(&mut self, request: Request<Bytes>) -> Self::Future {
let (parts, body) = request.into_parts();
let body = Full::new(body);
let request = Request::from_parts(parts, body);
let fut = self.inner.call(request);
let fut = async {
let response = fut.await.map_err(BodyError::Service)?;
let (parts, body) = response.into_parts();
let body = to_bytes(body).await.map_err(BodyError::Decompression)?;
let response = Response::from_parts(parts, body);
Ok(response)
};
Box::pin(fut)
}
}
#[derive(Default, Clone, Copy)]
pub struct BodyLayer(());
impl<S> Layer<S> for BodyLayer {
type Service = BodyService<S>;
fn layer(&self, inner: S) -> Self::Service {
BodyService::new(inner)
}
}

View File

@ -0,0 +1,75 @@
// Copyright 2022 Kévin Commaille.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! A [`HttpService`] that uses [hyper] as a backend.
//!
//! [hyper]: https://crates.io/crates/hyper
use std::time::Duration;
use http::{header::USER_AGENT, HeaderValue};
use hyper::client::{connect::dns::GaiResolver, HttpConnector};
use hyper_rustls::{ConfigBuilderExt, HttpsConnectorBuilder};
use tower::{limit::ConcurrencyLimitLayer, BoxError, ServiceBuilder};
use tower_http::{
decompression::DecompressionLayer, follow_redirect::FollowRedirectLayer,
set_header::SetRequestHeaderLayer, timeout::TimeoutLayer,
};
mod body_layer;
use self::body_layer::BodyLayer;
use super::HttpService;
static MAS_USER_AGENT: HeaderValue = HeaderValue::from_static("mas-oidc-client/0.0.1");
/// Constructs a [`HttpService`] using [hyper] as a backend.
///
/// [hyper]: https://crates.io/crates/hyper
#[must_use]
pub fn hyper_service() -> HttpService {
let resolver = ServiceBuilder::new().service(GaiResolver::new());
let mut http = HttpConnector::new_with_resolver(resolver);
http.enforce_http(false);
let tls_config = rustls::ClientConfig::builder()
.with_safe_defaults()
.with_native_roots()
.with_no_client_auth();
let https = HttpsConnectorBuilder::new()
.with_tls_config(tls_config)
.https_or_http()
.enable_http1()
.enable_http2()
.wrap_connector(http);
let client = hyper::Client::builder().build(https);
let client = ServiceBuilder::new()
.map_err(BoxError::from)
.layer(BodyLayer::default())
.layer(DecompressionLayer::new())
.layer(SetRequestHeaderLayer::overriding(
USER_AGENT,
MAS_USER_AGENT.clone(),
))
.layer(ConcurrencyLimitLayer::new(10))
.layer(FollowRedirectLayer::new())
.layer(TimeoutLayer::new(Duration::from_secs(10)))
.service(client);
HttpService::new(client)
}

View File

@ -0,0 +1,109 @@
// Copyright 2022 Kévin Commaille.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Traits to implement to provide a custom HTTP service for `Client`.
use std::{
fmt,
task::{Context, Poll},
};
use bytes::Bytes;
use futures::future::BoxFuture;
use tower::{BoxError, Service, ServiceExt};
#[cfg(feature = "hyper")]
pub mod hyper;
/// Type for the underlying HTTP service.
///
/// Allows implementors to use different libraries that provide a [`Service`]
/// that implements [`Clone`] + [`Send`] + [`Sync`].
pub type HttpService = BoxCloneSyncService<http::Request<Bytes>, http::Response<Bytes>, BoxError>;
impl fmt::Debug for HttpService {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt.debug_struct("HttpService").finish()
}
}
/// A [`Clone`] + [`Send`] + [`Sync`] boxed [`Service`].
///
/// [`BoxCloneSyncService`] turns a service into a trait object, allowing the
/// response future type to be dynamic, and allowing the service to be cloned.
#[allow(clippy::type_complexity)]
pub struct BoxCloneSyncService<T, U, E>(
Box<
dyn CloneSyncService<T, Response = U, Error = E, Future = BoxFuture<'static, Result<U, E>>>,
>,
);
impl<T, U, E> BoxCloneSyncService<T, U, E> {
/// Create a new `BoxCloneSyncService`.
pub fn new<S>(inner: S) -> Self
where
S: Service<T, Response = U, Error = E> + Clone + Send + Sync + 'static,
S::Future: Send + 'static,
{
let inner = inner.map_future(|f| Box::pin(f) as _);
Self(Box::new(inner))
}
}
impl<T, U, E> Service<T> for BoxCloneSyncService<T, U, E> {
type Response = U;
type Error = E;
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
#[inline]
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.0.poll_ready(cx)
}
#[inline]
fn call(&mut self, request: T) -> Self::Future {
self.0.call(request)
}
}
impl<T, U, E> Clone for BoxCloneSyncService<T, U, E> {
fn clone(&self) -> Self {
Self(self.0.clone_sync_box())
}
}
trait CloneSyncService<R>: Service<R> + Send + Sync {
fn clone_sync_box(
&self,
) -> Box<
dyn CloneSyncService<
R,
Response = Self::Response,
Error = Self::Error,
Future = Self::Future,
>,
>;
}
impl<R, T> CloneSyncService<R> for T
where
T: Service<R> + Send + Sync + Clone + 'static,
{
fn clone_sync_box(
&self,
) -> Box<dyn CloneSyncService<R, Response = T::Response, Error = T::Error, Future = T::Future>>
{
Box::new(self.clone())
}
}

View File

@ -0,0 +1,88 @@
// Copyright 2022 Kévin Commaille.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! An [OpenID Connect] client library for the [Matrix] specification.
//!
//! This is part of the [Matrix Authentication Service] project.
//!
//! # Scope
//!
//! The scope of this crate is to support OIDC features required by the
//! Matrix specification according to [MSC3861] and its sub-proposals.
//!
//! As such, it is compatible with the OpenID Connect 1.0 specification, but
//! also enforces Matrix-specific requirements or adds compatibility with new
//! [OAuth 2.0] features.
//!
//! # OpenID Connect and OAuth 2.0 Features
//!
//! - Grant Types:
//! - [Authorization Code](https://openid.net/specs/openid-connect-core-1_0.html#CodeFlowAuth)
//! - [Client Credentials](https://www.rfc-editor.org/rfc/rfc6749#section-4.4)
//! - [Device Code](https://www.rfc-editor.org/rfc/rfc8628) (TBD)
//! - [User Info](https://openid.net/specs/openid-connect-core-1_0.html#UserInfo)
//! - Token:
//! - [Refresh Token](https://openid.net/specs/openid-connect-core-1_0.html#RefreshTokens)
//! - [Introspection](https://www.rfc-editor.org/rfc/rfc7662)
//! - [Revocation](https://www.rfc-editor.org/rfc/rfc7009)
//! - [Dynamic Client Registration](https://openid.net/specs/openid-connect-registration-1_0.html)
//! - [PKCE](https://www.rfc-editor.org/rfc/rfc7636)
//! - [Pushed Authorization Requests](https://www.rfc-editor.org/rfc/rfc9126)
//!
//! # Matrix features
//!
//! - Client registration
//! - Login
//! - Matrix API Scopes
//! - Logout
//!
//! [OpenID Connect]: https://openid.net/connect/
//! [Matrix]: https://matrix.org/
//! [Matrix Authentication Service]: https://github.com/matrix-org/matrix-authentication-service
//! [MSC3861]: https://github.com/matrix-org/matrix-spec-proposals/pull/3861
//! [OAuth 2.0]: https://oauth.net/2/
#![forbid(unsafe_code)]
#![deny(
clippy::all,
clippy::str_to_string,
rustdoc::broken_intra_doc_links,
missing_docs
)]
#![warn(clippy::pedantic)]
#![allow(clippy::module_name_repetitions, clippy::implicit_hasher)]
pub mod error;
pub mod http_service;
pub mod requests;
pub mod types;
mod utils;
use std::fmt;
#[doc(inline)]
pub use mas_jose as jose;
// Wrapper around `String` that cannot be used in a meaningful way outside of
// this crate. Used for string enums that only allow certain characters because
// their variant can't be private.
#[doc(hidden)]
#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct PrivString(String);
impl fmt::Debug for PrivString {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.0.fmt(f)
}
}

View File

@ -0,0 +1,460 @@
// Copyright 2022 Kévin Commaille.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Requests for the [Authorization Code flow].
//!
//! [Authorization Code flow]: https://openid.net/specs/openid-connect-core-1_0.html#CodeFlowAuth
use base64ct::{Base64UrlUnpadded, Encoding};
use chrono::{DateTime, Utc};
use http::header::CONTENT_TYPE;
use mas_http::{CatchHttpCodesLayer, FormUrlencodedRequestLayer, JsonResponseLayer};
use mas_iana::oauth::{OAuthAuthorizationEndpointResponseType, PkceCodeChallengeMethod};
use mas_jose::claims::{self, TokenHash};
use oauth2_types::{
pkce,
prelude::CodeChallengeMethodExt,
requests::{
AccessTokenRequest, AccessTokenResponse, AuthorizationCodeGrant, AuthorizationRequest,
Prompt, PushedAuthorizationResponse,
},
scope::Scope,
};
use rand::{
distributions::{Alphanumeric, DistString},
Rng,
};
use serde::Serialize;
use serde_with::skip_serializing_none;
use tower::{Layer, Service, ServiceExt};
use url::Url;
use super::jose::JwtVerificationData;
use crate::{
error::{
AuthorizationError, IdTokenError, PushedAuthorizationError, TokenAuthorizationCodeError,
},
http_service::HttpService,
requests::{jose::verify_id_token, token::request_access_token},
types::{
client_credentials::ClientCredentials,
scope::{ScopeExt, ScopeToken},
IdToken,
},
utils::{http_all_error_status_codes, http_error_mapper},
};
/// The data necessary to build an authorization request.
#[derive(Debug, Clone, Copy)]
pub struct AuthorizationRequestData<'a> {
/// The ID obtained when registering the client.
pub client_id: &'a str,
/// The PKCE methods supported by the issuer, from its metadata.
pub code_challenge_methods_supported: Option<&'a [PkceCodeChallengeMethod]>,
/// The scope to authorize.
///
/// If the OpenID Connect scope token (`openid`) is not included, it will be
/// added.
pub scope: &'a Scope,
/// The URI to redirect the end-user to after the authorization.
///
/// It must be one of the redirect URIs provided during registration.
pub redirect_uri: &'a Url,
/// Optional hints for the action to be performed.
pub prompt: Option<&'a [Prompt]>,
}
/// The data necessary to validate a response from the Token endpoint in the
/// Authorization Code flow.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct AuthorizationValidationData {
/// A unique identifier for the request.
pub state: String,
/// A string to mitigate replay attacks.
pub nonce: String,
/// The URI where the end-user will be redirected after authorization.
pub redirect_uri: Url,
/// A string to correlate the authorization request to the token request.
pub code_challenge_verifier: Option<String>,
}
#[skip_serializing_none]
#[derive(Clone, Serialize)]
struct FullAuthorizationRequest {
#[serde(flatten)]
inner: AuthorizationRequest,
#[serde(flatten)]
pkce: Option<pkce::AuthorizationRequest>,
}
/// Build the authorization request.
fn build_authorization_request(
authorization_data: AuthorizationRequestData<'_>,
rng: &mut impl Rng,
) -> Result<(FullAuthorizationRequest, AuthorizationValidationData), AuthorizationError> {
let AuthorizationRequestData {
client_id,
code_challenge_methods_supported,
scope,
redirect_uri,
prompt,
} = authorization_data;
let mut scope = scope.clone();
// Generate a random CSRF "state" token and a nonce.
let state = Alphanumeric.sample_string(rng, 16);
let nonce = Alphanumeric.sample_string(rng, 16);
// Use PKCE, whenever possible.
let (pkce, code_challenge_verifier) = if code_challenge_methods_supported
.iter()
.any(|methods| methods.contains(&PkceCodeChallengeMethod::S256))
{
let mut verifier = [0u8; 32];
rng.fill(&mut verifier);
let method = PkceCodeChallengeMethod::S256;
let verifier = Base64UrlUnpadded::encode_string(&verifier);
let code_challenge = method.compute_challenge(&verifier)?.into();
let pkce = pkce::AuthorizationRequest {
code_challenge_method: method,
code_challenge,
};
(Some(pkce), Some(verifier))
} else {
(None, None)
};
scope.insert_token(ScopeToken::Openid);
let auth_request = FullAuthorizationRequest {
inner: AuthorizationRequest {
response_type: OAuthAuthorizationEndpointResponseType::Code.into(),
client_id: client_id.to_owned(),
redirect_uri: Some(redirect_uri.clone()),
scope,
state: Some(state.clone()),
response_mode: None,
nonce: Some(nonce.clone()),
display: None,
prompt: prompt.map(ToOwned::to_owned),
max_age: None,
ui_locales: None,
id_token_hint: None,
login_hint: None,
acr_values: None,
request: None,
request_uri: None,
registration: None,
},
pkce,
};
let auth_data = AuthorizationValidationData {
state,
nonce,
redirect_uri: redirect_uri.clone(),
code_challenge_verifier,
};
Ok((auth_request, auth_data))
}
/// Build the URL for authenticating at the Authorization endpoint.
///
/// # Arguments
///
/// * `authorization_endpoint` - The URL of the issuer's authorization endpoint.
///
/// * `authorization_data` - The data necessary to build the authorization
/// request.
///
/// * `rng` - A random number generator.
///
/// # Returns
///
/// A URL to be opened in a web browser where the end-user will be able to
/// authorize the given scope, and the [`AuthorizationValidationData`] to
/// validate this request.
///
/// The redirect URI will receive parameters in its query:
///
/// * A successful response will receive a `code` and a `state`.
///
/// * If the authorization fails, it should receive an `error` parameter with a
/// [`ClientErrorCode`] and optionally an `error_description`.
///
/// # Errors
///
/// Returns an error if preparing the URL fails.
///
/// [`VerifiedClientMetadata`]: oauth2_types::registration::VerifiedClientMetadata
/// [`ClientErrorCode`]: oauth2_types::errors::ClientErrorCode
#[allow(clippy::too_many_lines)]
pub fn build_authorization_url(
authorization_endpoint: Url,
authorization_data: AuthorizationRequestData<'_>,
rng: &mut impl Rng,
) -> Result<(Url, AuthorizationValidationData), AuthorizationError> {
tracing::debug!(
scope = ?authorization_data.scope,
"Authorizing..."
);
let (authorization_request, validation_data) =
build_authorization_request(authorization_data, rng)?;
let authorization_query = serde_urlencoded::to_string(authorization_request)?;
let mut authorization_url = authorization_endpoint;
// Add our parameters to the query, because the URL might already have one.
let mut full_query = authorization_url
.query()
.map(ToOwned::to_owned)
.unwrap_or_default();
if !full_query.is_empty() {
full_query.push('&');
}
full_query.push_str(&authorization_query);
authorization_url.set_query(Some(&full_query));
Ok((authorization_url, validation_data))
}
/// Make a [Pushed Authorization Request] and build the URL for authenticating
/// at the Authorization endpoint.
///
/// # Arguments
///
/// * `http_service` - The service to use for making HTTP requests.
///
/// * `client_credentials` - The credentials obtained when registering the
/// client.
///
/// * `par_endpoint` - The URL of the issuer's Pushed Authorization Request
/// endpoint.
///
/// * `authorization_endpoint` - The URL of the issuer's Authorization endpoint.
///
/// * `authorization_data` - The data necessary to build the authorization
/// request.
///
/// * `now` - The current time.
///
/// * `rng` - A random number generator.
///
/// # Returns
///
/// A URL to be opened in a web browser where the end-user will be able to
/// authorize the given scope, and the [`AuthorizationValidationData`] to
/// validate this request.
///
/// The redirect URI will receive parameters in its query:
///
/// * A successful response will receive a `code` and a `state`.
///
/// * If the authorization fails, it should receive an `error` parameter with a
/// [`ClientErrorCode`] and optionally an `error_description`.
///
/// # Errors
///
/// Returns an error if the request fails, the response is invalid or building
/// the URL fails.
///
/// [Pushed Authorization Request]: https://oauth.net/2/pushed-authorization-requests/
/// [`ClientErrorCode`]: oauth2_types::errors::ClientErrorCode
#[allow(clippy::too_many_lines)]
#[tracing::instrument(skip_all, fields(par_endpoint))]
pub async fn build_par_authorization_url(
http_service: &HttpService,
client_credentials: ClientCredentials,
par_endpoint: &Url,
authorization_endpoint: Url,
authorization_data: AuthorizationRequestData<'_>,
now: DateTime<Utc>,
rng: &mut impl Rng,
) -> Result<(Url, AuthorizationValidationData), AuthorizationError> {
tracing::debug!(
scope = ?authorization_data.scope,
"Authorizing with a PAR..."
);
let client_id = client_credentials.client_id().to_owned();
let (authorization_request, validation_data) =
build_authorization_request(authorization_data, rng)?;
let par_request = http::Request::post(par_endpoint.as_str())
.header(CONTENT_TYPE, mime::APPLICATION_WWW_FORM_URLENCODED.as_ref())
.body(authorization_request)
.map_err(PushedAuthorizationError::from)?;
let par_request = client_credentials
.apply_to_request(par_request, now, rng)
.map_err(PushedAuthorizationError::from)?;
let service = (
FormUrlencodedRequestLayer::default(),
JsonResponseLayer::<PushedAuthorizationResponse>::default(),
CatchHttpCodesLayer::new(http_all_error_status_codes(), http_error_mapper),
)
.layer(http_service.clone());
let par_response = service
.ready_oneshot()
.await
.map_err(PushedAuthorizationError::from)?
.call(par_request)
.await
.map_err(PushedAuthorizationError::from)?
.into_body();
let authorization_query = serde_urlencoded::to_string([
("request_uri", par_response.request_uri.as_str()),
("client_id", &client_id),
])?;
let mut authorization_url = authorization_endpoint;
// Add our parameters to the query, because the URL might already have one.
let mut full_query = authorization_url
.query()
.map(ToOwned::to_owned)
.unwrap_or_default();
if !full_query.is_empty() {
full_query.push('&');
}
full_query.push_str(&authorization_query);
authorization_url.set_query(Some(&full_query));
Ok((authorization_url, validation_data))
}
/// Exchange an authorization code for an access token.
///
/// This should be used as the first step for logging in, and to request a
/// token with a new scope.
///
/// # Arguments
///
/// * `http_service` - The service to use for making HTTP requests.
///
/// * `client_credentials` - The credentials obtained when registering the
/// client.
///
/// * `token_endpoint` - The URL of the issuer's Token endpoint.
///
/// * `code` - The authorization code returned at the Authorization endpoint.
///
/// * `validation_data` - The validation data that was returned when building
/// the Authorization URL, for the state returned at the Authorization
/// endpoint.
///
/// * `id_token_verification_data` - The data required to verify the ID Token in
/// the response.
///
/// The signing algorithm corresponds to the `id_token_signed_response_alg`
/// field in the client metadata.
///
/// If it is not provided, the ID Token won't be verified. Note that in the
/// OpenID Connect specification, this verification is required.
///
/// * `now` - The current time.
///
/// * `rng` - A random number generator.
///
/// # Errors
///
/// Returns an error if the request fails, the response is invalid or the
/// verification of the ID Token fails.
#[allow(clippy::too_many_arguments)]
#[tracing::instrument(skip_all, fields(token_endpoint))]
pub async fn access_token_with_authorization_code(
http_service: &HttpService,
client_credentials: ClientCredentials,
token_endpoint: &Url,
code: String,
validation_data: AuthorizationValidationData,
id_token_verification_data: Option<JwtVerificationData<'_>>,
now: DateTime<Utc>,
rng: &mut impl Rng,
) -> Result<(AccessTokenResponse, Option<IdToken<'static>>), TokenAuthorizationCodeError> {
tracing::debug!("Exchanging authorization code for access token...");
let token_response = request_access_token(
http_service,
client_credentials,
token_endpoint,
AccessTokenRequest::AuthorizationCode(AuthorizationCodeGrant {
code: code.clone(),
redirect_uri: Some(validation_data.redirect_uri),
code_verifier: validation_data.code_challenge_verifier,
}),
now,
rng,
)
.await?;
let id_token = if let Some(verification_data) = id_token_verification_data {
let signing_alg = verification_data.signing_algorithm;
let id_token = token_response
.id_token
.as_deref()
.ok_or(IdTokenError::MissingIdToken)?;
let id_token = verify_id_token(id_token, verification_data, None, now)?;
let mut claims = id_token.payload().clone();
// Access token hash must match.
claims::AT_HASH
.extract_optional_with_options(
&mut claims,
TokenHash::new(signing_alg, &token_response.access_token),
)
.map_err(IdTokenError::from)?;
// Code hash must match.
claims::C_HASH
.extract_optional_with_options(&mut claims, TokenHash::new(signing_alg, &code))
.map_err(IdTokenError::from)?;
// Nonce must match.
let token_nonce = claims::NONCE
.extract_required(&mut claims)
.map_err(IdTokenError::from)?;
if token_nonce != validation_data.nonce {
return Err(TokenAuthorizationCodeError::WrongNonce);
}
Some(id_token.into_owned())
} else {
None
};
Ok((token_response, id_token))
}

View File

@ -0,0 +1,75 @@
// Copyright 2022 Kévin Commaille.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Requests for the [Client Credentials flow].
//!
//! [Client Credentials flow]: https://www.rfc-editor.org/rfc/rfc6749#section-4.4
use chrono::{DateTime, Utc};
use oauth2_types::{
requests::{AccessTokenRequest, AccessTokenResponse, ClientCredentialsGrant},
scope::Scope,
};
use rand::Rng;
use url::Url;
use crate::{
error::TokenRequestError, http_service::HttpService, requests::token::request_access_token,
types::client_credentials::ClientCredentials,
};
/// Exchange an authorization code for an access token.
///
/// This should be used as the first step for logging in, and to request a
/// token with a new scope.
///
/// # Arguments
///
/// * `http_service` - The service to use for making HTTP requests.
///
/// * `client_credentials` - The credentials obtained when registering the
/// client.
///
/// * `token_endpoint` - The URL of the issuer's Token endpoint.
///
/// * `scope` - The scope to authorize.
///
/// * `now` - The current time.
///
/// * `rng` - A random number generator.
///
/// # Errors
///
/// Returns an error if the request fails or the response is invalid.
#[tracing::instrument(skip_all, fields(token_endpoint))]
pub async fn access_token_with_client_credentials(
http_service: &HttpService,
client_credentials: ClientCredentials,
token_endpoint: &Url,
scope: Option<Scope>,
now: DateTime<Utc>,
rng: &mut impl Rng,
) -> Result<AccessTokenResponse, TokenRequestError> {
tracing::debug!("Requesting access token with client credentials...");
request_access_token(
http_service,
client_credentials,
token_endpoint,
AccessTokenRequest::ClientCredentials(ClientCredentialsGrant { scope }),
now,
rng,
)
.await
}

View File

@ -0,0 +1,109 @@
// Copyright 2022 Kévin Commaille.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Requests for OpenID Connect Provider [Discovery].
//!
//! [Discovery]: https://openid.net/specs/openid-connect-discovery-1_0.html
use bytes::Bytes;
use mas_http::{CatchHttpCodesLayer, JsonResponseLayer};
use oauth2_types::oidc::{ProviderMetadata, VerifiedProviderMetadata};
use tower::{Layer, Service, ServiceExt};
use url::Url;
use crate::{
error::DiscoveryError,
http_service::HttpService,
utils::{http_all_error_status_codes, http_error_mapper},
};
/// Fetch the provider metadata.
async fn discover_inner(
http_service: &HttpService,
issuer: &Url,
) -> Result<ProviderMetadata, DiscoveryError> {
tracing::debug!("Fetching provider metadata...");
let mut config_url = issuer.clone();
// If the path doesn't end with a slash, the last segment is removed when
// using `join`.
if !config_url.path().ends_with('/') {
let mut path = config_url.path().to_owned();
path.push('/');
config_url.set_path(&path);
}
let config_url = config_url.join(".well-known/openid-configuration")?;
let config_req = http::Request::get(config_url.as_str()).body(Bytes::new())?;
let service = (
JsonResponseLayer::<ProviderMetadata>::default(),
CatchHttpCodesLayer::new(http_all_error_status_codes(), http_error_mapper),
)
.layer(http_service.clone());
let response = service.ready_oneshot().await?.call(config_req).await?;
tracing::debug!(?response);
Ok(response.into_body())
}
/// Fetch the provider metadata and validate it.
///
/// # Errors
///
/// Returns an error if the request fails or if the data is invalid.
#[tracing::instrument(skip_all, fields(issuer))]
pub async fn discover(
http_service: &HttpService,
issuer: &Url,
) -> Result<VerifiedProviderMetadata, DiscoveryError> {
let provider_metadata = discover_inner(http_service, issuer).await?;
Ok(provider_metadata.validate(issuer)?)
}
/// Fetch the [provider metadata] and make basic checks.
///
/// Contrary to [`discover()`], this uses
/// [`ProviderMetadata::insecure_verify_metadata()`] to check the received
/// metadata instead of validating it according to the specification.
///
/// # Arguments
///
/// * `http_service` - The service to use for making HTTP requests.
///
/// * `issuer` - The URL of the OpenID Connect Provider to fetch metadata for.
///
/// # Errors
///
/// Returns an error if the request fails or if the data is invalid.
///
/// # Warning
///
/// It is not recommended to use this method in production as it doesn't
/// ensure that the issuer implements the proper security practices.
///
/// [provider metadata]: https://openid.net/specs/openid-connect-discovery-1_0.html
#[tracing::instrument(skip_all, fields(issuer))]
pub async fn insecure_discover(
http_service: &HttpService,
issuer: &Url,
) -> Result<VerifiedProviderMetadata, DiscoveryError> {
let provider_metadata = discover_inner(http_service, issuer).await?;
Ok(provider_metadata.insecure_verify_metadata()?)
}

View File

@ -0,0 +1,153 @@
// Copyright 2022 Kévin Commaille.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Requests for [Token Introspection].
//!
//! [Token Introspection]: https://www.rfc-editor.org/rfc/rfc7662
use chrono::{DateTime, Utc};
use headers::{Authorization, HeaderMapExt};
use http::Request;
use mas_http::{CatchHttpCodesLayer, FormUrlencodedRequestLayer, JsonResponseLayer};
use mas_iana::oauth::OAuthTokenTypeHint;
use oauth2_types::requests::{IntrospectionRequest, IntrospectionResponse};
use rand::Rng;
use serde::Serialize;
use tower::{Layer, Service, ServiceExt};
use url::Url;
use crate::{
error::IntrospectionError,
http_service::HttpService,
types::client_credentials::{ClientCredentials, RequestWithClientCredentials},
utils::{http_all_error_status_codes, http_error_mapper},
};
/// The method used to authenticate at the introspection endpoint.
pub enum IntrospectionAuthentication<'a> {
/// Using client authentication.
Credentials(ClientCredentials),
/// Using a bearer token.
BearerToken(&'a str),
}
impl<'a> IntrospectionAuthentication<'a> {
/// Constructs an `IntrospectionAuthentication` from the given client
/// credentials.
#[must_use]
pub fn with_client_credentials(credentials: ClientCredentials) -> Self {
Self::Credentials(credentials)
}
/// Constructs an `IntrospectionAuthentication` from the given bearer token.
#[must_use]
pub fn with_bearer_token(token: &'a str) -> Self {
Self::BearerToken(token)
}
fn apply_to_request<T: Serialize>(
self,
request: Request<T>,
now: DateTime<Utc>,
rng: &mut impl Rng,
) -> Result<Request<RequestWithClientCredentials<T>>, IntrospectionError> {
let res = match self {
IntrospectionAuthentication::Credentials(client_credentials) => {
client_credentials.apply_to_request(request, now, rng)?
}
IntrospectionAuthentication::BearerToken(access_token) => {
let (mut parts, body) = request.into_parts();
parts
.headers
.typed_insert(Authorization::bearer(access_token)?);
let body = RequestWithClientCredentials {
body,
credentials: None,
};
http::Request::from_parts(parts, body)
}
};
Ok(res)
}
}
impl<'a> From<ClientCredentials> for IntrospectionAuthentication<'a> {
fn from(credentials: ClientCredentials) -> Self {
Self::with_client_credentials(credentials)
}
}
/// Obtain information about a token.
///
/// # Arguments
///
/// * `http_service` - The service to use for making HTTP requests.
///
/// * `authentication` - The method used to authenticate the request.
///
/// * `revocation_endpoint` - The URL of the issuer's Revocation endpoint.
///
/// * `token` - The token to introspect.
///
/// * `token_type_hint` - Hint about the type of the token.
///
/// * `now` - The current time.
///
/// * `rng` - A random number generator.
///
/// # Errors
///
/// Returns an error if the request fails or the response is invalid.
#[tracing::instrument(skip_all, fields(introspection_endpoint))]
pub async fn introspect_token(
http_service: &HttpService,
authentication: IntrospectionAuthentication<'_>,
introspection_endpoint: &Url,
token: String,
token_type_hint: Option<OAuthTokenTypeHint>,
now: DateTime<Utc>,
rng: &mut impl Rng,
) -> Result<IntrospectionResponse, IntrospectionError> {
tracing::debug!("Introspecting token…");
let introspection_request = IntrospectionRequest {
token,
token_type_hint,
};
let introspection_request =
http::Request::post(introspection_endpoint.as_str()).body(introspection_request)?;
let introspection_request = authentication.apply_to_request(introspection_request, now, rng)?;
let service = (
FormUrlencodedRequestLayer::default(),
JsonResponseLayer::<IntrospectionResponse>::default(),
CatchHttpCodesLayer::new(http_all_error_status_codes(), http_error_mapper),
)
.layer(http_service.clone());
let introspection_response = service
.ready_oneshot()
.await?
.call(introspection_request)
.await?
.into_body();
Ok(introspection_response)
}

View File

@ -0,0 +1,223 @@
// Copyright 2022 Kévin Commaille.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Requests and method related to JSON Object Signing and Encryption.
use std::collections::HashMap;
use bytes::Bytes;
use chrono::{DateTime, Utc};
use mas_http::JsonResponseLayer;
use mas_iana::jose::JsonWebSignatureAlg;
use mas_jose::{
claims::{self, TimeOptions},
jwk::PublicJsonWebKeySet,
jwt::Jwt,
};
use serde_json::Value;
use tower::{Layer, Service, ServiceExt};
use url::Url;
use crate::{
error::{IdTokenError, JwksError, JwtVerificationError},
http_service::HttpService,
types::IdToken,
};
/// Fetch a JWKS at the given URL.
///
/// # Arguments
///
/// * `http_service` - The service to use for making HTTP requests.
///
/// * `jwks_uri` - The URL where the JWKS can be retrieved.
///
/// # Errors
///
/// Returns an error if the request fails or if the data is invalid.
#[tracing::instrument(skip_all, fields(jwks_uri))]
pub async fn fetch_jwks(
http_service: &HttpService,
jwks_uri: &Url,
) -> Result<PublicJsonWebKeySet, JwksError> {
tracing::debug!("Fetching JWKS...");
let jwks_request = http::Request::get(jwks_uri.as_str()).body(Bytes::new())?;
let service = JsonResponseLayer::<PublicJsonWebKeySet>::default().layer(http_service.clone());
let response = service.ready_oneshot().await?.call(jwks_request).await?;
Ok(response.into_body())
}
/// The data required to verify a JWT.
#[derive(Clone, Copy)]
pub struct JwtVerificationData<'a> {
/// The URL of the issuer that generated the ID Token.
pub issuer: &'a Url,
/// The issuer's JWKS.
pub jwks: &'a PublicJsonWebKeySet,
/// The ID obtained when registering the client.
pub client_id: &'a String,
/// The JWA that should have been used to sign the JWT, as set during
/// client registration.
pub signing_algorithm: &'a JsonWebSignatureAlg,
}
/// Decode and verify a signed JWT.
///
/// The following checks are performed:
///
/// * The signature is verified with the given JWKS.
///
/// * The `iss` claim must be present and match the issuer.
///
/// * The `aud` claim must be present and match the client ID.
///
/// * The `alg` in the header must match the signing algorithm.
///
/// # Arguments
///
/// * `jwt` - The serialized JWT to decode and verify.
///
/// * `jwks` - The JWKS that should contain the public key to verify the JWT's
/// signature.
///
/// * `issuer` - The issuer of the JWT.
///
/// * `audience` - The audience that the JWT is intended for.
///
/// * `signing_algorithm` - The JWA that should have been used to sign the JWT.
///
/// # Errors
///
/// Returns an error if the data is invalid or verification fails.
pub fn verify_signed_jwt<'a>(
jwt: &'a str,
verification_data: JwtVerificationData<'_>,
) -> Result<Jwt<'a, HashMap<String, Value>>, JwtVerificationError> {
tracing::debug!("Validating JWT...");
let JwtVerificationData {
issuer,
jwks,
client_id,
signing_algorithm,
} = verification_data;
let jwt: Jwt<HashMap<String, Value>> = jwt.try_into()?;
jwt.verify_with_jwks(jwks)?;
let (header, mut claims) = jwt.clone().into_parts();
// Must have the proper issuer.
let iss = claims::ISS.extract_required(&mut claims)?;
if iss != issuer.as_str() {
return Err(JwtVerificationError::WrongIssuer);
}
// Must have the proper audience.
let aud = claims::AUD.extract_required(&mut claims)?;
if !aud.contains(client_id) {
return Err(JwtVerificationError::WrongAudience);
}
// Must use the proper algorithm.
if header.alg() != signing_algorithm {
return Err(JwtVerificationError::WrongSignatureAlg);
}
Ok(jwt)
}
/// Decode and verify an ID Token.
///
/// Besides the checks of [`verify_signed_jwt()`], the following checks are
/// performed:
///
/// * The `exp` claim must be present and the token must not have expired.
///
/// * The `iat` claim must be present must be in the past.
///
/// * The `sub` claim must be present.
///
/// If an authorization ID token is provided, these extra checks are performed:
///
/// * The `sub` claims must match.
///
/// * The `auth_time` claims must match.
///
/// # Arguments
///
/// * `id_token` - The serialized ID Token to decode and verify.
///
/// * `verification_data` - The data necessary to verify the ID Token.
///
/// * `auth_id_token` - If the ID Token is not verified during an authorization
/// request, the ID token that was returned from the latest authorization
/// request.
///
/// # Errors
///
/// Returns an error if the data is invalid or verification fails.
pub fn verify_id_token<'a>(
id_token: &'a str,
verification_data: JwtVerificationData<'_>,
auth_id_token: Option<&IdToken<'_>>,
now: DateTime<Utc>,
) -> Result<IdToken<'a>, IdTokenError> {
let id_token = verify_signed_jwt(id_token, verification_data)?;
let mut claims = id_token.payload().clone();
let time_options = TimeOptions::new(now);
// Must not have expired.
claims::EXP.extract_required_with_options(&mut claims, &time_options)?;
// `iat` claim must be present.
claims::IAT.extract_required_with_options(&mut claims, time_options)?;
// Subject identifier must be present.
let sub = claims::SUB.extract_required(&mut claims)?;
// No more checks if there is no previous ID token.
let auth_id_token = match auth_id_token {
Some(id_token) => id_token,
None => return Ok(id_token),
};
let mut auth_claims = auth_id_token.payload().clone();
// Subject identifier must always be the same.
let auth_sub = claims::SUB.extract_required(&mut auth_claims)?;
if sub != auth_sub {
return Err(IdTokenError::WrongSubjectIdentifier);
}
// If the authentication time is present, it must be unchanged.
if let Some(auth_time) = claims::AUTH_TIME.extract_optional(&mut claims)? {
let prev_auth_time = claims::AUTH_TIME.extract_required(&mut auth_claims)?;
if prev_auth_time != auth_time {
return Err(IdTokenError::WrongAuthTime);
}
}
Ok(id_token)
}

View File

@ -0,0 +1,26 @@
// Copyright 2022 Kévin Commaille.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Methods to interact with OpenID Connect and OAuth2.0 endpoints.
pub mod authorization_code;
pub mod client_credentials;
pub mod discovery;
pub mod introspection;
pub mod jose;
pub mod refresh_token;
pub mod registration;
pub mod revocation;
pub mod token;
pub mod userinfo;

View File

@ -0,0 +1,128 @@
// Copyright 2022 Kévin Commaille.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Requests for using [Refresh Tokens].
//!
//! [Refresh Tokens]: https://openid.net/specs/openid-connect-core-1_0.html#RefreshTokens
use chrono::{DateTime, Utc};
use mas_jose::claims::{self, TokenHash};
use oauth2_types::{
requests::{AccessTokenRequest, AccessTokenResponse, RefreshTokenGrant},
scope::Scope,
};
use rand::Rng;
use url::Url;
use super::jose::JwtVerificationData;
use crate::{
error::{IdTokenError, TokenRefreshError},
http_service::HttpService,
requests::{jose::verify_id_token, token::request_access_token},
types::{client_credentials::ClientCredentials, IdToken},
};
/// Exchange an authorization code for an access token.
///
/// This should be used as the first step for logging in, and to request a
/// token with a new scope.
///
/// # Arguments
///
/// * `http_service` - The service to use for making HTTP requests.
///
/// * `client_credentials` - The credentials obtained when registering the
/// client.
///
/// * `token_endpoint` - The URL of the issuer's Token endpoint.
///
/// * `refresh_token` - The token used to refresh the access token returned at
/// the Token endpoint.
///
/// * `scope` - The scope of the access token. The requested scope must not
/// include any scope not originally granted to the access token, and if
/// omitted is treated as equal to the scope originally granted by the issuer.
///
/// * `id_token_verification_data` - The data required to verify the ID Token in
/// the response.
///
/// The signing algorithm corresponds to the `id_token_signed_response_alg`
/// field in the client metadata.
///
/// If it is not provided, the ID Token won't be verified.
///
/// * `auth_id_token` - If an ID Token is expected in the response, the ID token
/// that was returned from the latest authorization request.
///
/// * `now` - The current time.
///
/// * `rng` - A random number generator.
///
/// # Errors
///
/// Returns an error if the request fails, the response is invalid or the
/// verification of the ID Token fails.
#[allow(clippy::too_many_arguments)]
#[tracing::instrument(skip_all, fields(token_endpoint))]
pub async fn refresh_access_token(
http_service: &HttpService,
client_credentials: ClientCredentials,
token_endpoint: &Url,
refresh_token: String,
scope: Option<Scope>,
id_token_verification_data: Option<JwtVerificationData<'_>>,
auth_id_token: Option<&IdToken<'_>>,
now: DateTime<Utc>,
rng: &mut impl Rng,
) -> Result<(AccessTokenResponse, Option<IdToken<'static>>), TokenRefreshError> {
tracing::debug!("Refreshing access token…");
let token_response = request_access_token(
http_service,
client_credentials,
token_endpoint,
AccessTokenRequest::RefreshToken(RefreshTokenGrant {
refresh_token,
scope,
}),
now,
rng,
)
.await?;
let id_token = if let Some((verification_data, id_token)) =
id_token_verification_data.zip(token_response.id_token.as_ref())
{
let auth_id_token = auth_id_token.ok_or(IdTokenError::MissingAuthIdToken)?;
let signing_alg = verification_data.signing_algorithm;
let id_token = verify_id_token(id_token, verification_data, Some(auth_id_token), now)?;
let mut claims = id_token.payload().clone();
// Access token hash must match.
claims::AT_HASH
.extract_optional_with_options(
&mut claims,
TokenHash::new(signing_alg, &token_response.access_token),
)
.map_err(IdTokenError::from)?;
Some(id_token.into_owned())
} else {
None
};
Ok((token_response, id_token))
}

View File

@ -0,0 +1,82 @@
// Copyright 2022 Kévin Commaille.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Requests for [Dynamic Registration].
//!
//! [Dynamic Registration]: https://openid.net/specs/openid-connect-registration-1_0.html
use mas_http::{CatchHttpCodesLayer, JsonRequestLayer, JsonResponseLayer};
use mas_iana::oauth::OAuthClientAuthenticationMethod;
use oauth2_types::registration::{ClientRegistrationResponse, VerifiedClientMetadata};
use tower::{Layer, Service, ServiceExt};
use url::Url;
use crate::{
error::RegistrationError,
http_service::HttpService,
utils::{http_all_error_status_codes, http_error_mapper},
};
/// Register a client with an OpenID Provider.
///
/// # Arguments
///
/// * `http_service` - The service to use for making HTTP requests.
///
/// * `registration_endpoint` - The URL of the issuer's Registration endpoint.
///
/// * `client_metadata` - The metadata to register with the issuer.
///
/// # Errors
///
/// Returns an error if the request fails or the response is invalid.
#[tracing::instrument(skip_all, fields(registration_endpoint))]
pub async fn register_client(
http_service: &HttpService,
registration_endpoint: &Url,
client_metadata: VerifiedClientMetadata,
) -> Result<ClientRegistrationResponse, RegistrationError> {
tracing::debug!("Registering client...");
let registration_req =
http::Request::post(registration_endpoint.as_str()).body(client_metadata.clone())?;
let service = (
JsonRequestLayer::default(),
JsonResponseLayer::<ClientRegistrationResponse>::default(),
CatchHttpCodesLayer::new(http_all_error_status_codes(), http_error_mapper),
)
.layer(http_service.clone());
let response = service
.ready_oneshot()
.await?
.call(registration_req)
.await?
.into_body();
match client_metadata.token_endpoint_auth_method() {
OAuthClientAuthenticationMethod::ClientSecretPost
| OAuthClientAuthenticationMethod::ClientSecretBasic
| OAuthClientAuthenticationMethod::ClientSecretJwt => {
response
.client_secret
.as_ref()
.ok_or(RegistrationError::MissingClientSecret)?;
}
_ => {}
}
Ok(response)
}

View File

@ -0,0 +1,90 @@
// Copyright 2022 Kévin Commaille.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Requests for [Token Revocation].
//!
//! [Token Revocation]: https://www.rfc-editor.org/rfc/rfc7009.html
use chrono::{DateTime, Utc};
use mas_http::{CatchHttpCodesLayer, FormUrlencodedRequestLayer};
use mas_iana::oauth::OAuthTokenTypeHint;
use oauth2_types::requests::IntrospectionRequest;
use rand::Rng;
use tower::{Layer, Service, ServiceExt};
use url::Url;
use crate::{
error::TokenRevokeError,
http_service::HttpService,
types::client_credentials::ClientCredentials,
utils::{http_all_error_status_codes, http_error_mapper},
};
/// Revoke a token.
///
/// # Arguments
///
/// * `http_service` - The service to use for making HTTP requests.
///
/// * `client_credentials` - The credentials obtained when registering the
/// client.
///
/// * `revocation_endpoint` - The URL of the issuer's Revocation endpoint.
///
/// * `token` - The token to revoke.
///
/// * `token_type_hint` - Hint about the type of the token.
///
/// * `now` - The current time.
///
/// * `rng` - A random number generator.
///
/// # Errors
///
/// Returns an error if the request fails or the response is invalid.
#[tracing::instrument(skip_all, fields(revocation_endpoint))]
pub async fn revoke_token(
http_service: &HttpService,
client_credentials: ClientCredentials,
revocation_endpoint: &Url,
token: String,
token_type_hint: Option<OAuthTokenTypeHint>,
now: DateTime<Utc>,
rng: &mut impl Rng,
) -> Result<(), TokenRevokeError> {
tracing::debug!("Revoking token…");
let request = IntrospectionRequest {
token,
token_type_hint,
};
let revocation_request = http::Request::post(revocation_endpoint.as_str()).body(request)?;
let revocation_request = client_credentials.apply_to_request(revocation_request, now, rng)?;
let service = (
FormUrlencodedRequestLayer::default(),
CatchHttpCodesLayer::new(http_all_error_status_codes(), http_error_mapper),
)
.layer(http_service.clone());
service
.ready_oneshot()
.await?
.call(revocation_request)
.await?;
Ok(())
}

View File

@ -0,0 +1,78 @@
// Copyright 2022 Kévin Commaille.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Requests for the Token endpoint.
use chrono::{DateTime, Utc};
use mas_http::{CatchHttpCodesLayer, FormUrlencodedRequestLayer, JsonResponseLayer};
use oauth2_types::requests::{AccessTokenRequest, AccessTokenResponse};
use rand::Rng;
use tower::{Layer, Service, ServiceExt};
use url::Url;
use crate::{
error::TokenRequestError,
http_service::HttpService,
types::client_credentials::ClientCredentials,
utils::{http_all_error_status_codes, http_error_mapper},
};
/// Request an access token.
///
/// # Arguments
///
/// * `http_service` - The service to use for making HTTP requests.
///
/// * `client_credentials` - The credentials obtained when registering the
/// client.
///
/// * `token_endpoint` - The URL of the issuer's Token endpoint.
///
/// * `request` - The request to make at the Token endpoint.
///
/// * `now` - The current time.
///
/// * `rng` - A random number generator.
///
/// # Errors
///
/// Returns an error if the request fails or the response is invalid.
#[tracing::instrument(skip_all, fields(token_endpoint, request))]
pub async fn request_access_token(
http_service: &HttpService,
client_credentials: ClientCredentials,
token_endpoint: &Url,
request: AccessTokenRequest,
now: DateTime<Utc>,
rng: &mut impl Rng,
) -> Result<AccessTokenResponse, TokenRequestError> {
tracing::debug!(?request, "Requesting access token...");
let token_request = http::Request::post(token_endpoint.as_str()).body(request)?;
let token_request = client_credentials.apply_to_request(token_request, now, rng)?;
let service = (
FormUrlencodedRequestLayer::default(),
JsonResponseLayer::<AccessTokenResponse>::default(),
CatchHttpCodesLayer::new(http_all_error_status_codes(), http_error_mapper),
)
.layer(http_service.clone());
let res = service.ready_oneshot().await?.call(token_request).await?;
let token_response = res.into_body();
Ok(token_response)
}

View File

@ -0,0 +1,139 @@
// Copyright 2022 Kévin Commaille.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Requests for obtaining [Claims] about an end-user.
//!
//! [Claims]: https://openid.net/specs/openid-connect-core-1_0.html#Claims
use std::collections::HashMap;
use bytes::Bytes;
use headers::{Authorization, HeaderMapExt, HeaderValue};
use http::header::{ACCEPT, CONTENT_TYPE};
use mas_http::CatchHttpCodesLayer;
use mas_jose::claims;
use serde_json::Value;
use tower::{Layer, Service, ServiceExt};
use url::Url;
use super::jose::JwtVerificationData;
use crate::{
error::{IdTokenError, UserInfoError},
http_service::HttpService,
requests::jose::verify_signed_jwt,
types::IdToken,
utils::{http_all_error_status_codes, http_error_mapper},
};
/// Obtain information about an authenticated end-user.
///
/// Returns a map of claims with their value, that should be extracted with
/// one of the [`Claim`] methods.
///
/// # Arguments
///
/// * `http_service` - The service to use for making HTTP requests.
///
/// * `userinfo_endpoint` - The URL of the issuer's User Info endpoint.
///
/// * `access_token` - The access token of the end-user.
///
/// * `jwt_verification_data` - The data required to verify the response if a
/// signed response was requested during client registration.
///
/// The signing algorithm corresponds to the `userinfo_signed_response_alg`
/// field in the client metadata.
///
/// * `auth_id_token` - The ID token that was returned from the latest
/// authorization request.
///
/// # Errors
///
/// Returns an error if the request fails, the response is invalid or the
/// validation of the signed response fails.
///
/// [`Claim`]: mas_jose::claims::Claim
#[tracing::instrument(skip_all, fields(userinfo_endpoint))]
pub async fn fetch_userinfo(
http_service: &HttpService,
userinfo_endpoint: &Url,
access_token: &str,
jwt_verification_data: Option<JwtVerificationData<'_>>,
auth_id_token: &IdToken<'_>,
) -> Result<HashMap<String, Value>, UserInfoError> {
tracing::debug!("Obtaining user info…");
let mut userinfo_request = http::Request::get(userinfo_endpoint.as_str());
let expected_content_type = if jwt_verification_data.is_some() {
"application/jwt"
} else {
mime::APPLICATION_JSON.as_ref()
};
if let Some(headers) = userinfo_request.headers_mut() {
headers.typed_insert(Authorization::bearer(access_token)?);
headers.insert(ACCEPT, HeaderValue::from_static(expected_content_type));
}
let userinfo_request = userinfo_request.body(Bytes::new())?;
let service = CatchHttpCodesLayer::new(http_all_error_status_codes(), http_error_mapper)
.layer(http_service.clone());
let userinfo_response = service
.ready_oneshot()
.await?
.call(userinfo_request)
.await?;
let content_type = userinfo_response
.headers()
.get(CONTENT_TYPE)
.ok_or(UserInfoError::MissingResponseContentType)?
.to_str()?;
if content_type != expected_content_type {
return Err(UserInfoError::InvalidResponseContentType {
expected: expected_content_type.to_owned(),
got: content_type.to_owned(),
});
}
let response_body = std::str::from_utf8(userinfo_response.body())?;
let mut claims = if let Some(verification_data) = jwt_verification_data {
verify_signed_jwt(response_body, verification_data)
.map_err(IdTokenError::from)?
.into_parts()
.1
} else {
serde_json::from_str(response_body)?
};
let mut auth_claims = auth_id_token.payload().clone();
// Subject identifier must always be the same.
let sub = claims::SUB
.extract_required(&mut claims)
.map_err(IdTokenError::from)?;
let auth_sub = claims::SUB
.extract_required(&mut auth_claims)
.map_err(IdTokenError::from)?;
if sub != auth_sub {
return Err(IdTokenError::WrongSubjectIdentifier.into());
}
Ok(claims)
}

View File

@ -0,0 +1,669 @@
// Copyright 2022 Kévin Commaille.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Types and methods for client credentials.
use std::{collections::HashMap, fmt};
use base64ct::{Base64UrlUnpadded, Encoding};
use chrono::{DateTime, Duration, Utc};
use headers::{Authorization, HeaderMapExt};
use http::Request;
use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod};
use mas_jose::{
claims::{self, ClaimError},
jwa::SymmetricKey,
jwt::{JsonWebSignatureHeader, Jwt},
};
#[cfg(feature = "keystore")]
use mas_keystore::Keystore;
use rand::Rng;
use serde::Serialize;
use serde_json::Value;
use serde_with::skip_serializing_none;
use tower::BoxError;
use url::Url;
use crate::error::CredentialsError;
/// The supported authentication methods of this library.
///
/// During client registration, make sure that you only use one of the values
/// defined here.
pub const CLIENT_SUPPORTED_AUTH_METHODS: &[OAuthClientAuthenticationMethod] = &[
OAuthClientAuthenticationMethod::None,
OAuthClientAuthenticationMethod::ClientSecretBasic,
OAuthClientAuthenticationMethod::ClientSecretPost,
OAuthClientAuthenticationMethod::ClientSecretJwt,
OAuthClientAuthenticationMethod::PrivateKeyJwt,
];
/// A function that takes a map of claims and a signing algorithm and returns a
/// signed JWT.
pub type JwtSigningFn =
dyn Fn(HashMap<String, Value>, JsonWebSignatureAlg) -> Result<String, BoxError> + Send + Sync;
/// The method used to sign JWTs with a private key.
pub enum JwtSigningMethod {
/// Sign the JWTs with this library, by providing the signing keys.
#[cfg(feature = "keystore")]
Keystore(Keystore),
/// Sign the JWTs in a callback.
Custom(Box<JwtSigningFn>),
}
impl JwtSigningMethod {
/// Creates a new [`JwtSigningMethod`] from a [`Keystore`].
#[cfg(feature = "keystore")]
#[must_use]
pub fn with_keystore(keystore: Keystore) -> Self {
Self::Keystore(keystore)
}
/// Creates a new [`JwtSigningMethod`] from a [`JwtSigningFn`].
#[must_use]
pub fn with_custom_signing_method<F>(signing_fn: F) -> Self
where
F: Fn(HashMap<String, Value>, JsonWebSignatureAlg) -> Result<String, BoxError>
+ Send
+ Sync
+ 'static,
{
Self::Custom(Box::new(signing_fn))
}
/// Get the [`Keystore`] from this [`JwtSigningMethod`].
#[cfg(feature = "keystore")]
#[must_use]
pub fn keystore(&self) -> Option<&Keystore> {
match self {
JwtSigningMethod::Keystore(k) => Some(k),
JwtSigningMethod::Custom(_) => None,
}
}
/// Get the [`JwtSigningFn`] from this [`JwtSigningMethod`].
#[must_use]
pub fn jwt_custom(&self) -> Option<&JwtSigningFn> {
match self {
JwtSigningMethod::Custom(s) => Some(s),
JwtSigningMethod::Keystore(_) => None,
}
}
}
/// The credentials obtained during registration, to authenticate a client on
/// endpoints that require it.
pub enum ClientCredentials {
/// No client authentication is used.
///
/// This is used if the client is public.
None {
/// The unique ID for the client.
client_id: String,
},
/// The client authentication is sent via the Authorization HTTP header.
ClientSecretBasic {
/// The unique ID for the client.
client_id: String,
/// The secret of the client.
client_secret: String,
},
/// The client authentication is sent with the body of the request.
ClientSecretPost {
/// The unique ID for the client.
client_id: String,
/// The secret of the client.
client_secret: String,
},
/// The client authentication uses a JWT signed with a key derived from the
/// client secret.
ClientSecretJwt {
/// The unique ID for the client.
client_id: String,
/// The secret of the client.
client_secret: String,
/// The algorithm used to sign the JWT.
signing_algorithm: JsonWebSignatureAlg,
/// The URL of the issuer's Token endpoint.
token_endpoint: Url,
},
/// The client authentication uses a JWT signed with a private key.
PrivateKeyJwt {
/// The unique ID for the client.
client_id: String,
/// The method used to sign the JWT.
jwt_signing_method: JwtSigningMethod,
/// The algorithm used to sign the JWT.
signing_algorithm: JsonWebSignatureAlg,
/// The URL of the issuer's Token endpoint.
token_endpoint: Url,
},
}
impl ClientCredentials {
/// Get the client ID of these `ClientCredentials`.
#[must_use]
pub fn client_id(&self) -> &str {
match self {
ClientCredentials::None { client_id }
| ClientCredentials::ClientSecretBasic { client_id, .. }
| ClientCredentials::ClientSecretPost { client_id, .. }
| ClientCredentials::ClientSecretJwt { client_id, .. }
| ClientCredentials::PrivateKeyJwt { client_id, .. } => client_id,
}
}
/// Apply these `ClientCredentials` to the given request.
pub(crate) fn apply_to_request<T: Serialize>(
self,
request: Request<T>,
now: DateTime<Utc>,
rng: &mut impl Rng,
) -> Result<Request<RequestWithClientCredentials<T>>, CredentialsError> {
let credentials = RequestClientCredentials::try_from_credentials(self, now, rng)?;
let (parts, body) = request.into_parts();
let mut body = RequestWithClientCredentials {
body,
credentials: None,
};
let request = match credentials {
RequestClientCredentials::Body(credentials) => {
body.credentials = Some(credentials);
Request::from_parts(parts, body)
}
RequestClientCredentials::Header(credentials) => {
let HeaderClientCredentials {
client_id,
client_secret,
} = credentials;
let mut request = Request::from_parts(parts, body);
// Encode the values with `application/x-www-form-urlencoded`.
let client_id =
form_urlencoded::byte_serialize(client_id.as_bytes()).collect::<String>();
let client_secret =
form_urlencoded::byte_serialize(client_secret.as_bytes()).collect::<String>();
let auth = Authorization::basic(&client_id, &client_secret);
request.headers_mut().typed_insert(auth);
request
}
};
Ok(request)
}
}
impl fmt::Debug for ClientCredentials {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::None { client_id } => f
.debug_struct("None")
.field("client_id", client_id)
.finish(),
Self::ClientSecretBasic { client_id, .. } => f
.debug_struct("ClientSecretBasic")
.field("client_id", client_id)
.finish_non_exhaustive(),
Self::ClientSecretPost { client_id, .. } => f
.debug_struct("ClientSecretPost")
.field("client_id", client_id)
.finish_non_exhaustive(),
Self::ClientSecretJwt {
client_id,
signing_algorithm,
token_endpoint,
..
} => f
.debug_struct("ClientSecretJwt")
.field("client_id", client_id)
.field("signing_algorithm", signing_algorithm)
.field("token_endpoint", token_endpoint)
.finish_non_exhaustive(),
Self::PrivateKeyJwt {
client_id,
signing_algorithm,
token_endpoint,
..
} => f
.debug_struct("PrivateKeyJwt")
.field("client_id", client_id)
.field("signing_algorithm", signing_algorithm)
.field("token_endpoint", token_endpoint)
.finish_non_exhaustive(),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)]
#[serde(rename = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer")]
pub(crate) struct JwtBearerClientAssertionType;
enum RequestClientCredentials {
Body(BodyClientCredentials),
Header(HeaderClientCredentials),
}
impl RequestClientCredentials {
fn try_from_credentials(
credentials: ClientCredentials,
now: DateTime<Utc>,
rng: &mut impl Rng,
) -> Result<Self, CredentialsError> {
let res = match credentials {
ClientCredentials::None { client_id } => Self::Body(BodyClientCredentials {
client_id,
client_secret: None,
client_assertion: None,
client_assertion_type: None,
}),
ClientCredentials::ClientSecretBasic {
client_id,
client_secret,
} => Self::Header(HeaderClientCredentials {
client_id,
client_secret,
}),
ClientCredentials::ClientSecretPost {
client_id,
client_secret,
} => Self::Body(BodyClientCredentials {
client_id,
client_secret: Some(client_secret),
client_assertion: None,
client_assertion_type: None,
}),
ClientCredentials::ClientSecretJwt {
client_id,
client_secret,
signing_algorithm,
token_endpoint,
} => {
let claims =
prepare_claims(client_id.clone(), token_endpoint.to_string(), now, rng)?;
let key = SymmetricKey::new_for_alg(client_secret.into(), &signing_algorithm)?;
let header = JsonWebSignatureHeader::new(signing_algorithm);
let jwt = Jwt::sign(header, claims, &key)?;
Self::Body(BodyClientCredentials {
client_id,
client_secret: None,
client_assertion: Some(jwt.to_string()),
client_assertion_type: Some(JwtBearerClientAssertionType),
})
}
ClientCredentials::PrivateKeyJwt {
client_id,
jwt_signing_method,
signing_algorithm,
token_endpoint,
} => {
let claims =
prepare_claims(client_id.clone(), token_endpoint.to_string(), now, rng)?;
let client_assertion = match jwt_signing_method {
#[cfg(feature = "keystore")]
JwtSigningMethod::Keystore(keystore) => {
let key = keystore
.signing_key_for_algorithm(&signing_algorithm)
.ok_or(CredentialsError::NoPrivateKeyFound)?;
let signer = key.params().signing_key_for_alg(&signing_algorithm)?;
let header = JsonWebSignatureHeader::new(signing_algorithm);
Jwt::sign(header, claims, &signer)?.to_string()
}
JwtSigningMethod::Custom(jwt_signing_fn) => {
jwt_signing_fn(claims, signing_algorithm)
.map_err(CredentialsError::Custom)?
}
};
Self::Body(BodyClientCredentials {
client_id,
client_secret: None,
client_assertion: Some(client_assertion),
client_assertion_type: Some(JwtBearerClientAssertionType),
})
}
};
Ok(res)
}
}
#[skip_serializing_none]
#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
pub(crate) struct BodyClientCredentials {
client_id: String,
client_secret: Option<String>,
client_assertion: Option<String>,
client_assertion_type: Option<JwtBearerClientAssertionType>,
}
#[derive(Debug, Clone)]
struct HeaderClientCredentials {
client_id: String,
client_secret: String,
}
fn prepare_claims(
iss: String,
aud: String,
now: DateTime<Utc>,
rng: &mut impl Rng,
) -> Result<HashMap<String, Value>, ClaimError> {
let mut claims = HashMap::new();
claims::ISS.insert(&mut claims, iss.clone())?;
claims::SUB.insert(&mut claims, iss)?;
claims::AUD.insert(&mut claims, aud)?;
claims::IAT.insert(&mut claims, now)?;
claims::EXP.insert(&mut claims, now + Duration::minutes(5))?;
let mut jti = [0u8; 16];
rng.fill(&mut jti);
let jti = Base64UrlUnpadded::encode_string(&jti);
claims::JTI.insert(&mut claims, jti)?;
Ok(claims)
}
/// A request with client credentials added to it.
#[derive(Clone, Serialize)]
#[skip_serializing_none]
pub struct RequestWithClientCredentials<T: Serialize> {
#[serde(flatten)]
pub(crate) body: T,
#[serde(flatten)]
pub(crate) credentials: Option<BodyClientCredentials>,
}
#[cfg(test)]
mod test {
use assert_matches::assert_matches;
use headers::authorization::Basic;
#[cfg(feature = "keystore")]
use mas_keystore::{JsonWebKey, JsonWebKeySet, Keystore, PrivateKey};
use rand::SeedableRng;
use rand_chacha::ChaCha8Rng;
use super::*;
const CLIENT_ID: &str = "abcd$++";
const CLIENT_SECRET: &str = "xyz!;?";
const REQUEST_BODY: &str = "some_body";
#[derive(Serialize)]
struct Body {
body: &'static str,
}
fn now() -> DateTime<Utc> {
#[allow(clippy::disallowed_methods)]
Utc::now()
}
#[test]
fn serialize_credentials() {
assert_eq!(
serde_urlencoded::to_string(BodyClientCredentials {
client_id: CLIENT_ID.to_owned(),
client_secret: None,
client_assertion: None,
client_assertion_type: None,
})
.unwrap(),
"client_id=abcd%24%2B%2B"
);
assert_eq!(
serde_urlencoded::to_string(BodyClientCredentials {
client_id: CLIENT_ID.to_owned(),
client_secret: Some(CLIENT_SECRET.to_owned()),
client_assertion: None,
client_assertion_type: None,
})
.unwrap(),
"client_id=abcd%24%2B%2B&client_secret=xyz%21%3B%3F"
);
assert_eq!(
serde_urlencoded::to_string(BodyClientCredentials {
client_id: CLIENT_ID.to_owned(),
client_secret: None,
client_assertion: Some(CLIENT_SECRET.to_owned()),
client_assertion_type: Some(JwtBearerClientAssertionType)
})
.unwrap(),
"client_id=abcd%24%2B%2B&client_assertion=xyz%21%3B%3F&client_assertion_type=urn%3Aietf%3Aparams%3Aoauth%3Aclient-assertion-type%3Ajwt-bearer"
);
}
#[test]
fn serialize_request_with_credentials() {
let req = RequestWithClientCredentials {
body: Body { body: REQUEST_BODY },
credentials: None,
};
assert_eq!(serde_urlencoded::to_string(req).unwrap(), "body=some_body");
let req = RequestWithClientCredentials {
body: Body { body: REQUEST_BODY },
credentials: Some(BodyClientCredentials {
client_id: CLIENT_ID.to_owned(),
client_secret: None,
client_assertion: None,
client_assertion_type: None,
}),
};
assert_eq!(
serde_urlencoded::to_string(req).unwrap(),
"body=some_body&client_id=abcd%24%2B%2B"
);
let req = RequestWithClientCredentials {
body: Body { body: REQUEST_BODY },
credentials: Some(BodyClientCredentials {
client_id: CLIENT_ID.to_owned(),
client_secret: Some(CLIENT_SECRET.to_owned()),
client_assertion: None,
client_assertion_type: None,
}),
};
assert_eq!(
serde_urlencoded::to_string(req).unwrap(),
"body=some_body&client_id=abcd%24%2B%2B&client_secret=xyz%21%3B%3F"
);
let req = RequestWithClientCredentials {
body: Body { body: REQUEST_BODY },
credentials: Some(BodyClientCredentials {
client_id: CLIENT_ID.to_owned(),
client_secret: None,
client_assertion: Some(CLIENT_SECRET.to_owned()),
client_assertion_type: Some(JwtBearerClientAssertionType),
}),
};
assert_eq!(
serde_urlencoded::to_string(req).unwrap(),
"body=some_body&client_id=abcd%24%2B%2B&client_assertion=xyz%21%3B%3F&client_assertion_type=urn%3Aietf%3Aparams%3Aoauth%3Aclient-assertion-type%3Ajwt-bearer"
);
}
#[tokio::test]
async fn build_request_none() {
let credentials = ClientCredentials::None {
client_id: CLIENT_ID.to_owned(),
};
let request = Request::new(Body { body: REQUEST_BODY });
let now = now();
let mut rng = ChaCha8Rng::seed_from_u64(42);
let request = credentials
.apply_to_request(request, now, &mut rng)
.unwrap();
assert_eq!(request.headers().typed_get::<Authorization<Basic>>(), None);
let body = request.into_body();
assert_eq!(body.body.body, REQUEST_BODY);
let credentials = body.credentials.unwrap();
assert_eq!(credentials.client_id, CLIENT_ID);
assert_eq!(credentials.client_secret, None);
assert_eq!(credentials.client_assertion, None);
assert_eq!(credentials.client_assertion_type, None);
}
#[tokio::test]
async fn build_request_client_secret_basic() {
let credentials = ClientCredentials::ClientSecretBasic {
client_id: CLIENT_ID.to_owned(),
client_secret: CLIENT_SECRET.to_owned(),
};
let now = now();
let mut rng = ChaCha8Rng::seed_from_u64(42);
let request = Request::new(Body { body: REQUEST_BODY });
let request = credentials
.apply_to_request(request, now, &mut rng)
.unwrap();
let auth = assert_matches!(
request.headers().typed_get::<Authorization<Basic>>(),
Some(auth) => auth
);
assert_eq!(
form_urlencoded::parse(auth.username().as_bytes())
.next()
.unwrap()
.0,
CLIENT_ID
);
assert_eq!(
form_urlencoded::parse(auth.password().as_bytes())
.next()
.unwrap()
.0,
CLIENT_SECRET
);
let body = request.into_body();
assert_eq!(body.body.body, REQUEST_BODY);
assert_eq!(body.credentials, None);
}
#[tokio::test]
async fn build_request_client_secret_post() {
let credentials = ClientCredentials::ClientSecretPost {
client_id: CLIENT_ID.to_owned(),
client_secret: CLIENT_SECRET.to_owned(),
};
let now = now();
let mut rng = ChaCha8Rng::seed_from_u64(42);
let request = Request::new(Body { body: REQUEST_BODY });
let request = credentials
.apply_to_request(request, now, &mut rng)
.unwrap();
assert_eq!(request.headers().typed_get::<Authorization<Basic>>(), None);
let body = request.into_body();
assert_eq!(body.body.body, REQUEST_BODY);
let credentials = body.credentials.unwrap();
assert_eq!(credentials.client_id, CLIENT_ID);
assert_eq!(credentials.client_secret.unwrap(), CLIENT_SECRET);
assert_eq!(credentials.client_assertion, None);
assert_eq!(credentials.client_assertion_type, None);
}
#[tokio::test]
async fn build_request_client_secret_jwt() {
let credentials = ClientCredentials::ClientSecretJwt {
client_id: CLIENT_ID.to_owned(),
client_secret: CLIENT_SECRET.to_owned(),
signing_algorithm: JsonWebSignatureAlg::Hs256,
token_endpoint: Url::parse("http://localhost").unwrap(),
};
let now = now();
let mut rng = ChaCha8Rng::seed_from_u64(42);
let request = Request::new(Body { body: REQUEST_BODY });
let request = credentials
.apply_to_request(request, now, &mut rng)
.unwrap();
assert_eq!(request.headers().typed_get::<Authorization<Basic>>(), None);
let body = request.into_body();
assert_eq!(body.body.body, REQUEST_BODY);
let credentials = body.credentials.unwrap();
assert_eq!(credentials.client_id, CLIENT_ID);
assert_eq!(credentials.client_secret, None);
credentials.client_assertion.unwrap();
credentials.client_assertion_type.unwrap();
}
#[tokio::test]
#[cfg(feature = "keystore")]
async fn build_request_private_key_jwt() {
let rng = rand_chacha::ChaCha8Rng::seed_from_u64(42);
let key = PrivateKey::generate_rsa(rng).unwrap();
let keystore = Keystore::new(JsonWebKeySet::<PrivateKey>::new(vec![JsonWebKey::new(key)]));
let jwt_signing_method = JwtSigningMethod::with_keystore(keystore);
let now = now();
let mut rng = ChaCha8Rng::seed_from_u64(42);
let credentials = ClientCredentials::PrivateKeyJwt {
client_id: CLIENT_ID.to_owned(),
jwt_signing_method,
signing_algorithm: JsonWebSignatureAlg::Rs256,
token_endpoint: Url::parse("http://localhost").unwrap(),
};
let request = Request::new(Body { body: REQUEST_BODY });
let request = credentials
.apply_to_request(request, now, &mut rng)
.unwrap();
assert_eq!(request.headers().typed_get::<Authorization<Basic>>(), None);
let body = request.into_body();
assert_eq!(body.body.body, REQUEST_BODY);
let credentials = body.credentials.unwrap();
assert_eq!(credentials.client_id, CLIENT_ID);
assert_eq!(credentials.client_secret, None);
credentials.client_assertion.unwrap();
credentials.client_assertion_type.unwrap();
}
}

View File

@ -0,0 +1,31 @@
// Copyright 2022 Kévin Commaille.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! OAuth 2.0 and OpenID Connect types.
pub mod client_credentials;
pub mod scope;
use std::collections::HashMap;
#[doc(inline)]
pub use mas_iana as iana;
use mas_jose::jwt::Jwt;
pub use oauth2_types::*;
use serde_json::Value;
/// An OpenID Connect [ID Token].
///
/// [ID Token]: https://openid.net/specs/openid-connect-core-1_0.html#IDToken
pub type IdToken<'a> = Jwt<'a, HashMap<String, Value>>;

View File

@ -0,0 +1,226 @@
// Copyright 2022 Kévin Commaille.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Helpers types to use scopes.
use std::{fmt, str::FromStr};
use oauth2_types::scope::{InvalidScope, Scope, ScopeToken as StrScopeToken};
use crate::PrivString;
/// Tokens to define the scope of an access token or to request specific claims.
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub enum ScopeToken {
/// `openid`
///
/// Required for OpenID Connect requests.
Openid,
/// `profile`
///
/// Requests access to the end-user's profile.
Profile,
/// `email`
///
/// Requests access to the end-user's email address.
Email,
/// `address`
///
/// Requests access to the end-user's address.
Address,
/// `phone`
///
/// Requests access to the end-user's phone number.
Phone,
/// `offline_access`
///
/// Requests that an OAuth 2.0 refresh token be issued that can be used to
/// obtain an access token that grants access to the end-user's UserInfo
/// Endpoint even when the end-user is not present (not logged in).
OfflineAccess,
/// `urn:matrix:org.matrix.msc2967.client:api:*`
///
/// Requests access to the Matrix Client API.
MatrixApi,
/// `urn:matrix:org.matrix.msc2967.client:device:{device_id}`
///
/// Requests access to the Matrix device with the given `device_id`.
///
/// To access the device ID, use [`ScopeToken::matrix_device_id`].
MatrixDevice(PrivString),
/// Another scope token.
///
/// To access it's value use this type's `Display` implementation.
Custom(PrivString),
}
impl ScopeToken {
/// Creates a Matrix device scope token with the given device ID.
///
/// # Errors
///
/// Returns an error if the device ID string is not compatible with the
/// scope syntax.
pub fn try_with_matrix_device(device_id: String) -> Result<Self, InvalidScope> {
// Check that the device ID is compatible with the scope format.
StrScopeToken::from_str(&device_id)?;
Ok(Self::MatrixDevice(PrivString(device_id)))
}
/// Get the device ID of this scope token, if it is a
/// [`ScopeToken::MatrixDevice`].
#[must_use]
pub fn matrix_device_id(&self) -> Option<&str> {
match &self {
Self::MatrixDevice(id) => Some(&id.0),
_ => None,
}
}
}
impl fmt::Display for ScopeToken {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
ScopeToken::Openid => write!(f, "openid"),
ScopeToken::Profile => write!(f, "profile"),
ScopeToken::Email => write!(f, "email"),
ScopeToken::Address => write!(f, "address"),
ScopeToken::Phone => write!(f, "phone"),
ScopeToken::OfflineAccess => write!(f, "offline_access"),
ScopeToken::MatrixApi => write!(f, "urn:matrix:org.matrix.msc2967.client:api:*"),
ScopeToken::MatrixDevice(s) => {
write!(f, "urn:matrix:org.matrix.msc2967.client:device:{}", s.0)
}
ScopeToken::Custom(s) => f.write_str(&s.0),
}
}
}
impl From<StrScopeToken> for ScopeToken {
fn from(t: StrScopeToken) -> Self {
match &*t {
"openid" => Self::Openid,
"profile" => Self::Profile,
"email" => Self::Email,
"address" => Self::Address,
"phone" => Self::Phone,
"offline_access" => Self::OfflineAccess,
"urn:matrix:org.matrix.msc2967.client:api:*" => Self::MatrixApi,
s => {
if let Some(device_id) =
s.strip_prefix("urn:matrix:org.matrix.msc2967.client:device:")
{
Self::MatrixDevice(PrivString(device_id.to_owned()))
} else {
Self::Custom(PrivString(s.to_owned()))
}
}
}
}
}
impl From<ScopeToken> for StrScopeToken {
fn from(t: ScopeToken) -> Self {
let s = t.to_string();
match StrScopeToken::from_str(&s) {
Ok(t) => t,
Err(_) => unreachable!(),
}
}
}
impl FromStr for ScopeToken {
type Err = InvalidScope;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let t = StrScopeToken::from_str(s)?;
Ok(t.into())
}
}
/// Helpers for [`Scope`] to work with [`ScopeToken`].
pub trait ScopeExt {
/// Insert the given `ScopeToken` into this `Scope`.
fn insert_token(&mut self, token: ScopeToken) -> bool;
/// Whether this `Scope` contains the given `ScopeToken`.
fn contains_token(&self, token: &ScopeToken) -> bool;
}
impl ScopeExt for Scope {
fn insert_token(&mut self, token: ScopeToken) -> bool {
self.insert(token.into())
}
fn contains_token(&self, token: &ScopeToken) -> bool {
self.contains(&token.to_string())
}
}
impl FromIterator<ScopeToken> for Scope {
fn from_iter<T: IntoIterator<Item = ScopeToken>>(iter: T) -> Self {
iter.into_iter().map(Into::<StrScopeToken>::into).collect()
}
}
#[cfg(test)]
mod tests {
use assert_matches::assert_matches;
use super::*;
#[test]
fn parse_scope_token() {
assert_eq!(ScopeToken::from_str("openid"), Ok(ScopeToken::Openid));
let scope =
ScopeToken::from_str("urn:matrix:org.matrix.msc2967.client:device:ABCDEFGHIJKL")
.unwrap();
assert_matches!(scope, ScopeToken::MatrixDevice(_));
assert_eq!(scope.matrix_device_id(), Some("ABCDEFGHIJKL"));
assert_eq!(ScopeToken::from_str("invalid\\scope"), Err(InvalidScope));
}
#[test]
fn parse_scope() {
let scope = Scope::from_str("openid profile address").unwrap();
assert_eq!(scope.len(), 3);
assert!(scope.contains_token(&ScopeToken::Openid));
assert!(scope.contains_token(&ScopeToken::Profile));
assert!(scope.contains_token(&ScopeToken::Address));
assert!(!scope.contains_token(&ScopeToken::OfflineAccess));
}
#[test]
fn display_scope() {
let mut scope: Scope = [ScopeToken::Profile].into_iter().collect();
assert_eq!(scope.to_string(), "profile");
scope.insert_token(ScopeToken::MatrixApi);
assert_eq!(
scope.to_string(),
"profile urn:matrix:org.matrix.msc2967.client:api:*"
);
}
}

View File

@ -0,0 +1,41 @@
// Copyright 2022 Kévin Commaille.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::ops::RangeBounds;
use bytes::Buf;
use http::{Response, StatusCode};
use crate::error::ErrorBody;
pub fn http_error_mapper<T>(response: Response<T>) -> Option<ErrorBody>
where
T: Buf,
{
let body = response.into_body();
serde_json::from_reader(body.reader()).ok()
}
pub fn http_all_error_status_codes() -> impl RangeBounds<StatusCode> {
let client_errors_start_code = match StatusCode::from_u16(400) {
Ok(code) => code,
Err(_) => unreachable!(),
};
let server_errors_end_code = match StatusCode::from_u16(599) {
Ok(code) => code,
Err(_) => unreachable!(),
};
client_errors_start_code..=server_errors_end_code
}

View File

@ -0,0 +1,173 @@
// Copyright 2022 Kévin Commaille.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::collections::HashMap;
use chrono::{DateTime, Duration, Utc};
use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod};
use mas_jose::{
claims::{self, hash_token},
constraints::Constrainable,
jwk::PublicJsonWebKeySet,
jwt::{JsonWebSignatureHeader, Jwt},
};
use mas_keystore::{JsonWebKey, JsonWebKeySet, Keystore, PrivateKey};
use mas_oidc_client::{
http_service::{hyper::hyper_service, HttpService},
types::{
client_credentials::{ClientCredentials, JwtSigningFn, JwtSigningMethod},
IdToken,
},
};
use rand::{
distributions::{Alphanumeric, DistString},
SeedableRng,
};
use url::Url;
use wiremock::MockServer;
mod requests;
mod types;
const REDIRECT_URI: &str = "http://localhost/";
const CLIENT_ID: &str = "client!+ID";
const CLIENT_SECRET: &str = "SECRET?%Gclient";
const REQUEST_URI: &str = "REQUESTur1";
const AUTHORIZATION_CODE: &str = "authC0D3";
const CODE_VERIFIER: &str = "cODEv3R1f1ER";
const NONCE: &str = "No0o0o0once";
const ACCESS_TOKEN: &str = "AccessToken1";
const REFRESH_TOKEN: &str = "RefreshToken1";
const SUBJECT_IDENTIFIER: &str = "SubjectID";
const ID_TOKEN_SIGNING_ALG: JsonWebSignatureAlg = JsonWebSignatureAlg::Rs256;
fn now() -> DateTime<Utc> {
#[allow(clippy::disallowed_methods)]
Utc::now()
}
async fn init_test() -> (HttpService, MockServer, Url) {
let http_service = hyper_service();
let mock_server = MockServer::start().await;
let issuer = Url::parse(&mock_server.uri()).expect("Couldn't parse URL");
(http_service, mock_server, issuer)
}
/// Generate a keystore with a single key for the given algorithm.
fn keystore(alg: &JsonWebSignatureAlg) -> Keystore {
let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(42);
let private_key = match alg {
JsonWebSignatureAlg::Rs256
| JsonWebSignatureAlg::Rs384
| JsonWebSignatureAlg::Rs512
| JsonWebSignatureAlg::Ps256
| JsonWebSignatureAlg::Ps384
| JsonWebSignatureAlg::Ps512 => PrivateKey::generate_rsa(&mut rng).unwrap(),
JsonWebSignatureAlg::Es256 => PrivateKey::generate_ec_p256(&mut rng),
JsonWebSignatureAlg::Es384 => PrivateKey::generate_ec_p384(&mut rng),
_ => unimplemented!(),
};
let jwk = JsonWebKey::new(private_key).with_kid(Alphanumeric.sample_string(&mut rng, 10));
Keystore::new(JsonWebKeySet::new(vec![jwk]))
}
/// Generate an ID token.
fn id_token(issuer: &Url) -> (IdToken, PublicJsonWebKeySet) {
let signing_alg = ID_TOKEN_SIGNING_ALG;
let keystore = keystore(&signing_alg);
let mut claims = HashMap::new();
let now = now();
claims::ISS.insert(&mut claims, issuer.to_string()).unwrap();
claims::SUB
.insert(&mut claims, SUBJECT_IDENTIFIER.to_owned())
.unwrap();
claims::AUD
.insert(&mut claims, CLIENT_ID.to_owned())
.unwrap();
claims::NONCE.insert(&mut claims, NONCE.to_owned()).unwrap();
claims::IAT.insert(&mut claims, now).unwrap();
claims::EXP
.insert(&mut claims, now + Duration::hours(1))
.unwrap();
claims::AT_HASH
.insert(&mut claims, hash_token(&signing_alg, ACCESS_TOKEN).unwrap())
.unwrap();
claims::C_HASH
.insert(
&mut claims,
hash_token(&signing_alg, AUTHORIZATION_CODE).unwrap(),
)
.unwrap();
let key = keystore.signing_key_for_algorithm(&signing_alg).unwrap();
let signer = key.params().signing_key_for_alg(&signing_alg).unwrap();
let header = JsonWebSignatureHeader::new(signing_alg).with_kid(key.kid().unwrap());
let id_token = Jwt::sign(header, claims, &signer).unwrap();
(id_token, keystore.public_jwks())
}
/// Generate client credentials for the given authentication method.
fn client_credentials(
auth_method: OAuthClientAuthenticationMethod,
issuer: &Url,
custom_signing: Option<Box<JwtSigningFn>>,
) -> ClientCredentials {
match auth_method {
OAuthClientAuthenticationMethod::None => ClientCredentials::None {
client_id: CLIENT_ID.to_owned(),
},
OAuthClientAuthenticationMethod::ClientSecretPost => ClientCredentials::ClientSecretPost {
client_id: CLIENT_ID.to_owned(),
client_secret: CLIENT_SECRET.to_owned(),
},
OAuthClientAuthenticationMethod::ClientSecretBasic => {
ClientCredentials::ClientSecretBasic {
client_id: CLIENT_ID.to_owned(),
client_secret: CLIENT_SECRET.to_owned(),
}
}
OAuthClientAuthenticationMethod::ClientSecretJwt => ClientCredentials::ClientSecretJwt {
client_id: CLIENT_ID.to_owned(),
client_secret: CLIENT_SECRET.to_owned(),
signing_algorithm: JsonWebSignatureAlg::Hs256,
token_endpoint: issuer.join("token").unwrap(),
},
OAuthClientAuthenticationMethod::PrivateKeyJwt => {
let signing_algorithm = JsonWebSignatureAlg::Es256;
let jwt_signing_method = if let Some(signing_fn) = custom_signing {
JwtSigningMethod::with_custom_signing_method(signing_fn)
} else {
JwtSigningMethod::with_keystore(keystore(&signing_algorithm))
};
ClientCredentials::PrivateKeyJwt {
client_id: CLIENT_ID.to_owned(),
jwt_signing_method,
signing_algorithm,
token_endpoint: issuer.join("token").unwrap(),
}
}
_ => unimplemented!(),
}
}

View File

@ -0,0 +1,421 @@
// Copyright 2022 Kévin Commaille.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::{
collections::HashMap,
sync::{Arc, Mutex},
};
use assert_matches::assert_matches;
use chrono::Duration;
use mas_iana::oauth::{
OAuthAccessTokenType, OAuthClientAuthenticationMethod, PkceCodeChallengeMethod,
};
use mas_jose::jwk::PublicJsonWebKeySet;
use mas_oidc_client::{
error::{
AuthorizationError, IdTokenError, PushedAuthorizationError, TokenAuthorizationCodeError,
},
requests::{
authorization_code::{
access_token_with_authorization_code, build_authorization_url,
build_par_authorization_url, AuthorizationRequestData, AuthorizationValidationData,
},
jose::JwtVerificationData,
},
types::scope::{ScopeExt, ScopeToken},
};
use oauth2_types::requests::{AccessTokenResponse, PushedAuthorizationResponse};
use rand::SeedableRng;
use tokio::sync::oneshot;
use url::Url;
use wiremock::{
matchers::{method, path},
Mock, Request, ResponseTemplate,
};
use crate::{
client_credentials, id_token, init_test, now, ACCESS_TOKEN, AUTHORIZATION_CODE, CLIENT_ID,
CODE_VERIFIER, ID_TOKEN_SIGNING_ALG, NONCE, REDIRECT_URI, REQUEST_URI,
};
#[test]
fn pass_authorization_url() {
let issuer = Url::parse("http://localhost/").unwrap();
let authorization_endpoint = issuer.join("authorize").unwrap();
let redirect_uri = Url::parse(REDIRECT_URI).unwrap();
let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(42);
let (url, validation_data) = build_authorization_url(
authorization_endpoint,
AuthorizationRequestData {
client_id: CLIENT_ID,
code_challenge_methods_supported: Some(&[PkceCodeChallengeMethod::S256]),
scope: &[ScopeToken::Openid].into_iter().collect(),
redirect_uri: &redirect_uri,
prompt: None,
},
&mut rng,
)
.unwrap();
assert_eq!(validation_data.state, "OrJ8xbWovSpJUTKz");
assert_eq!(
validation_data.code_challenge_verifier.unwrap(),
"TSgZ_hr3TJPjhq4aDp34K_8ksjLwaa1xDcPiRGBcjhM"
);
assert_eq!(url.path(), "/authorize");
let query_pairs = url.query_pairs().collect::<HashMap<_, _>>();
assert_eq!(query_pairs.get("scope").unwrap(), "openid");
assert_eq!(query_pairs.get("response_type").unwrap(), "code");
assert_eq!(query_pairs.get("client_id").unwrap(), CLIENT_ID);
assert_eq!(query_pairs.get("redirect_uri").unwrap(), REDIRECT_URI);
assert_eq!(*query_pairs.get("state").unwrap(), validation_data.state);
assert_eq!(query_pairs.get("nonce").unwrap(), "ox0PigY5l9xl5uTL");
let code_challenge = query_pairs.get("code_challenge").unwrap();
assert!(code_challenge.len() >= 43);
assert_eq!(query_pairs.get("code_challenge_method").unwrap(), "S256");
}
#[tokio::test]
async fn pass_pushed_authorization_request() {
let (http_service, mock_server, issuer) = init_test().await;
let client_credentials =
client_credentials(OAuthClientAuthenticationMethod::None, &issuer, None);
let authorization_endpoint = issuer.join("authorize").unwrap();
let par_endpoint = issuer.join("par").unwrap();
let redirect_uri = Url::parse(REDIRECT_URI).unwrap();
let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(42);
let (sender, receiver) = oneshot::channel();
let sender_mutex = Arc::new(Mutex::new(Some(sender)));
Mock::given(method("POST"))
.and(path("/par"))
.and(move |req: &Request| {
let body = form_urlencoded::parse(&req.body)
.into_owned()
.collect::<HashMap<_, _>>();
if let Some(sender) = sender_mutex.lock().unwrap().take() {
sender.send(body).unwrap();
true
} else {
false
}
})
.respond_with(
ResponseTemplate::new(200).set_body_json(PushedAuthorizationResponse {
request_uri: REQUEST_URI.to_owned(),
expires_in: Duration::seconds(30),
}),
)
.mount(&mock_server)
.await;
let (url, validation_data) = build_par_authorization_url(
&http_service,
client_credentials,
&par_endpoint,
authorization_endpoint,
AuthorizationRequestData {
client_id: CLIENT_ID,
code_challenge_methods_supported: Some(&[PkceCodeChallengeMethod::S256]),
scope: &[ScopeToken::Openid].into_iter().collect(),
redirect_uri: &redirect_uri,
prompt: None,
},
now(),
&mut rng,
)
.await
.unwrap();
assert_eq!(validation_data.state, "OrJ8xbWovSpJUTKz");
assert_eq!(
validation_data.code_challenge_verifier.unwrap(),
"TSgZ_hr3TJPjhq4aDp34K_8ksjLwaa1xDcPiRGBcjhM"
);
let request_pairs = receiver.await.unwrap();
assert_eq!(url.path(), "/authorize");
let query_pairs = url.query_pairs().collect::<HashMap<_, _>>();
assert_eq!(query_pairs.get("request_uri").unwrap(), REQUEST_URI,);
assert_eq!(query_pairs.get("client_id").unwrap(), CLIENT_ID);
assert_eq!(request_pairs.get("scope").unwrap(), "openid");
assert_eq!(request_pairs.get("response_type").unwrap(), "code");
assert_eq!(request_pairs.get("client_id").unwrap(), CLIENT_ID);
assert_eq!(request_pairs.get("redirect_uri").unwrap(), REDIRECT_URI);
assert_eq!(*request_pairs.get("state").unwrap(), validation_data.state);
assert_eq!(request_pairs.get("nonce").unwrap(), "ox0PigY5l9xl5uTL");
let code_challenge = request_pairs.get("code_challenge").unwrap();
assert!(code_challenge.len() >= 43);
assert_eq!(request_pairs.get("code_challenge_method").unwrap(), "S256");
}
#[tokio::test]
async fn fail_pushed_authorization_request_404() {
let (http_service, _, issuer) = init_test().await;
let client_credentials =
client_credentials(OAuthClientAuthenticationMethod::None, &issuer, None);
let authorization_endpoint = issuer.join("authorize").unwrap();
let par_endpoint = issuer.join("par").unwrap();
let redirect_uri = Url::parse(REDIRECT_URI).unwrap();
let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(42);
let error = build_par_authorization_url(
&http_service,
client_credentials,
&par_endpoint,
authorization_endpoint,
AuthorizationRequestData {
client_id: CLIENT_ID,
code_challenge_methods_supported: Some(&[PkceCodeChallengeMethod::S256]),
scope: &[ScopeToken::Openid].into_iter().collect(),
redirect_uri: &redirect_uri,
prompt: None,
},
now(),
&mut rng,
)
.await
.unwrap_err();
assert_matches!(
error,
AuthorizationError::PushedAuthorization(PushedAuthorizationError::Http(_))
)
}
/// Check if the given request to the token endpoint is valid.
fn is_valid_token_endpoint_request(req: &Request) -> bool {
let body = form_urlencoded::parse(&req.body).collect::<HashMap<_, _>>();
if body.get("client_id").filter(|s| *s == CLIENT_ID).is_none() {
println!("Missing or wrong client ID");
return false;
}
if body
.get("grant_type")
.filter(|s| *s == "authorization_code")
.is_none()
{
println!("Missing or wrong grant type");
return false;
}
if body
.get("code")
.filter(|s| *s == AUTHORIZATION_CODE)
.is_none()
{
println!("Missing or wrong authorization code");
return false;
}
if body
.get("redirect_uri")
.filter(|s| *s == REDIRECT_URI)
.is_none()
{
println!("Missing or wrong redirect URI");
return false;
}
if body
.get("code_verifier")
.filter(|s| *s == CODE_VERIFIER)
.is_none()
{
println!("Missing or wrong code verifier");
return false;
}
true
}
#[tokio::test]
async fn pass_access_token_with_authorization_code() {
let (http_service, mock_server, issuer) = init_test().await;
let client_credentials =
client_credentials(OAuthClientAuthenticationMethod::None, &issuer, None);
let token_endpoint = issuer.join("token").unwrap();
let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(42);
let redirect_uri = Url::parse(REDIRECT_URI).unwrap();
let validation_data = AuthorizationValidationData {
state: "some_state".to_owned(),
nonce: NONCE.to_owned(),
redirect_uri,
code_challenge_verifier: Some(CODE_VERIFIER.to_owned()),
};
let (id_token, jwks) = id_token(&issuer);
let id_token_verification_data = JwtVerificationData {
issuer: &issuer,
jwks: &jwks,
client_id: &CLIENT_ID.to_owned(),
signing_algorithm: &ID_TOKEN_SIGNING_ALG,
};
Mock::given(method("POST"))
.and(path("/token"))
.and(is_valid_token_endpoint_request)
.respond_with(
ResponseTemplate::new(200).set_body_json(AccessTokenResponse {
access_token: ACCESS_TOKEN.to_owned(),
refresh_token: None,
id_token: Some(id_token.to_string()),
token_type: OAuthAccessTokenType::Bearer,
expires_in: None,
scope: Some([ScopeToken::Openid].into_iter().collect()),
}),
)
.mount(&mock_server)
.await;
let (response, response_id_token) = access_token_with_authorization_code(
&http_service,
client_credentials,
&token_endpoint,
AUTHORIZATION_CODE.to_owned(),
validation_data,
Some(id_token_verification_data),
now(),
&mut rng,
)
.await
.unwrap();
assert_eq!(response.access_token, ACCESS_TOKEN);
assert_eq!(response.refresh_token, None);
assert!(response.scope.unwrap().contains_token(&ScopeToken::Openid));
assert_eq!(response_id_token.unwrap().as_str(), id_token.as_str());
}
#[tokio::test]
async fn fail_access_token_with_authorization_code_wrong_nonce() {
let (http_service, mock_server, issuer) = init_test().await;
let client_credentials =
client_credentials(OAuthClientAuthenticationMethod::None, &issuer, None);
let token_endpoint = issuer.join("token").unwrap();
let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(42);
let redirect_uri = Url::parse(REDIRECT_URI).unwrap();
let validation_data = AuthorizationValidationData {
state: "some_state".to_owned(),
nonce: "wrong_nonce".to_owned(),
redirect_uri,
code_challenge_verifier: Some(CODE_VERIFIER.to_owned()),
};
let (id_token, jwks) = id_token(&issuer);
let id_token_verification_data = JwtVerificationData {
issuer: &issuer,
jwks: &jwks,
client_id: &CLIENT_ID.to_owned(),
signing_algorithm: &ID_TOKEN_SIGNING_ALG,
};
Mock::given(method("POST"))
.and(path("/token"))
.and(is_valid_token_endpoint_request)
.respond_with(
ResponseTemplate::new(200).set_body_json(AccessTokenResponse {
access_token: ACCESS_TOKEN.to_owned(),
refresh_token: None,
id_token: Some(id_token.into_string()),
token_type: OAuthAccessTokenType::Bearer,
expires_in: None,
scope: Some([ScopeToken::Openid].into_iter().collect()),
}),
)
.mount(&mock_server)
.await;
let error = access_token_with_authorization_code(
&http_service,
client_credentials,
&token_endpoint,
AUTHORIZATION_CODE.to_owned(),
validation_data,
Some(id_token_verification_data),
now(),
&mut rng,
)
.await
.unwrap_err();
assert_matches!(error, TokenAuthorizationCodeError::WrongNonce);
}
#[tokio::test]
async fn fail_access_token_with_authorization_code_no_id_token() {
let (http_service, mock_server, issuer) = init_test().await;
let client_credentials =
client_credentials(OAuthClientAuthenticationMethod::None, &issuer, None);
let token_endpoint = issuer.join("token").unwrap();
let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(42);
let redirect_uri = Url::parse(REDIRECT_URI).unwrap();
let nonce = "some_nonce".to_owned();
let validation_data = AuthorizationValidationData {
state: "some_state".to_owned(),
nonce: nonce.clone(),
redirect_uri,
code_challenge_verifier: Some(CODE_VERIFIER.to_owned()),
};
let id_token_verification_data = JwtVerificationData {
issuer: &issuer,
jwks: &PublicJsonWebKeySet::default(),
client_id: &CLIENT_ID.to_owned(),
signing_algorithm: &ID_TOKEN_SIGNING_ALG,
};
Mock::given(method("POST"))
.and(path("/token"))
.and(is_valid_token_endpoint_request)
.respond_with(
ResponseTemplate::new(200).set_body_json(AccessTokenResponse {
access_token: ACCESS_TOKEN.to_owned(),
refresh_token: None,
id_token: None,
token_type: OAuthAccessTokenType::Bearer,
expires_in: None,
scope: Some([ScopeToken::Openid].into_iter().collect()),
}),
)
.mount(&mock_server)
.await;
let error = access_token_with_authorization_code(
&http_service,
client_credentials,
&token_endpoint,
AUTHORIZATION_CODE.to_owned(),
validation_data,
Some(id_token_verification_data),
now(),
&mut rng,
)
.await
.unwrap_err();
assert_matches!(
error,
TokenAuthorizationCodeError::IdToken(IdTokenError::MissingIdToken)
);
}

View File

@ -0,0 +1,110 @@
// Copyright 2022 Kévin Commaille.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::collections::HashMap;
use mas_iana::oauth::{OAuthAccessTokenType, OAuthClientAuthenticationMethod};
use mas_oidc_client::{
requests::client_credentials::access_token_with_client_credentials,
types::scope::{ScopeExt, ScopeToken},
};
use oauth2_types::{requests::AccessTokenResponse, scope::Scope};
use rand::SeedableRng;
use wiremock::{
matchers::{method, path},
Mock, Request, ResponseTemplate,
};
use crate::{client_credentials, init_test, now, ACCESS_TOKEN, CLIENT_ID, CLIENT_SECRET};
#[tokio::test]
async fn pass_access_token_with_client_credentials() {
let (http_service, mock_server, issuer) = init_test().await;
let client_credentials = client_credentials(
OAuthClientAuthenticationMethod::ClientSecretPost,
&issuer,
None,
);
let token_endpoint = issuer.join("token").unwrap();
let scope = [ScopeToken::Profile].into_iter().collect::<Scope>();
let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(42);
Mock::given(method("POST"))
.and(path("/token"))
.and(|req: &Request| {
let query_pairs = form_urlencoded::parse(&req.body).collect::<HashMap<_, _>>();
if query_pairs
.get("grant_type")
.filter(|s| *s == "client_credentials")
.is_none()
{
println!("Wrong or missing grant type");
return false;
}
if query_pairs
.get("scope")
.filter(|s| *s == "profile")
.is_none()
{
println!("Wrong or missing scope");
return false;
}
if query_pairs
.get("client_id")
.filter(|s| *s == CLIENT_ID)
.is_none()
{
println!("Wrong or missing client ID");
return false;
}
if query_pairs
.get("client_secret")
.filter(|s| *s == CLIENT_SECRET)
.is_none()
{
println!("Wrong or missing client secret");
return false;
}
true
})
.respond_with(
ResponseTemplate::new(200).set_body_json(AccessTokenResponse {
access_token: ACCESS_TOKEN.to_owned(),
refresh_token: None,
id_token: None,
token_type: OAuthAccessTokenType::Bearer,
expires_in: None,
scope: Some(scope.clone()),
}),
)
.mount(&mock_server)
.await;
let response = access_token_with_client_credentials(
&http_service,
client_credentials,
&token_endpoint,
Some(scope),
now(),
&mut rng,
)
.await
.unwrap();
assert_eq!(response.access_token, ACCESS_TOKEN);
assert_eq!(response.refresh_token, None);
assert!(response.scope.unwrap().contains_token(&ScopeToken::Profile));
}

View File

@ -0,0 +1,97 @@
// Copyright 2022 Kévin Commaille.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use assert_matches::assert_matches;
use mas_iana::oauth::{OAuthAuthorizationEndpointResponseType, PkceCodeChallengeMethod};
use mas_jose::jwa::SUPPORTED_SIGNING_ALGORITHMS;
use mas_oidc_client::{
error::DiscoveryError,
requests::discovery::{discover, insecure_discover},
};
use oauth2_types::oidc::{ProviderMetadata, SubjectType};
use url::Url;
use wiremock::{
matchers::{method, path},
Mock, ResponseTemplate,
};
use crate::init_test;
fn provider_metadata(issuer: &Url) -> ProviderMetadata {
ProviderMetadata {
issuer: Some(issuer.clone()),
authorization_endpoint: issuer.join("authorize").ok(),
token_endpoint: issuer.join("token").ok(),
jwks_uri: issuer.join("jwks").ok(),
response_types_supported: Some(vec![OAuthAuthorizationEndpointResponseType::Code.into()]),
subject_types_supported: Some(vec![SubjectType::Pairwise, SubjectType::Public]),
id_token_signing_alg_values_supported: Some(SUPPORTED_SIGNING_ALGORITHMS.into()),
code_challenge_methods_supported: Some(vec![PkceCodeChallengeMethod::S256]),
..Default::default()
}
}
#[tokio::test]
async fn pass_discover() {
let (http_service, mock_server, issuer) = init_test().await;
Mock::given(method("GET"))
.and(path("/.well-known/openid-configuration"))
.respond_with(ResponseTemplate::new(200).set_body_json(provider_metadata(&issuer)))
.mount(&mock_server)
.await;
let provider_metadata = insecure_discover(&http_service, &issuer).await.unwrap();
assert_eq!(*provider_metadata.issuer(), issuer);
}
#[tokio::test]
async fn fail_discover_404() {
let (http_service, _mock_server, issuer) = init_test().await;
let error = discover(&http_service, &issuer).await.unwrap_err();
assert_matches!(error, DiscoveryError::Http(_));
}
#[tokio::test]
async fn fail_discover_not_json() {
let (http_service, mock_server, issuer) = init_test().await;
Mock::given(method("GET"))
.and(path("/.well-known/openid-configuration"))
.respond_with(ResponseTemplate::new(200))
.mount(&mock_server)
.await;
let error = discover(&http_service, &issuer).await.unwrap_err();
assert_matches!(error, DiscoveryError::FromJson(_));
}
#[tokio::test]
async fn fail_discover_invalid_metadata() {
let (http_service, mock_server, issuer) = init_test().await;
Mock::given(method("GET"))
.and(path("/.well-known/openid-configuration"))
.respond_with(ResponseTemplate::new(200).set_body_json(ProviderMetadata::default()))
.mount(&mock_server)
.await;
let error = discover(&http_service, &issuer).await.unwrap_err();
assert_matches!(error, DiscoveryError::Validation(_));
}

View File

@ -0,0 +1,108 @@
// Copyright 2022 Kévin Commaille.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::collections::HashMap;
use mas_iana::oauth::{OAuthClientAuthenticationMethod, OAuthTokenTypeHint};
use mas_oidc_client::{
requests::introspection::introspect_token,
types::scope::{ScopeExt, ScopeToken},
};
use oauth2_types::requests::IntrospectionResponse;
use rand::SeedableRng;
use wiremock::{
matchers::{method, path},
Mock, Request, ResponseTemplate,
};
use crate::{client_credentials, init_test, now, ACCESS_TOKEN, CLIENT_ID, SUBJECT_IDENTIFIER};
#[tokio::test]
async fn pass_introspect_token() {
let (http_service, mock_server, issuer) = init_test().await;
let client_credentials =
client_credentials(OAuthClientAuthenticationMethod::None, &issuer, None);
let introspection_endpoint = issuer.join("introspect").unwrap();
let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(42);
Mock::given(method("POST"))
.and(path("/introspect"))
.and(|req: &Request| {
let query_pairs = form_urlencoded::parse(&req.body).collect::<HashMap<_, _>>();
if query_pairs
.get("token")
.filter(|s| *s == ACCESS_TOKEN)
.is_none()
{
println!("Wrong or missing token");
return false;
}
if query_pairs
.get("token_type_hint")
.filter(|s| *s == "access_token")
.is_none()
{
println!("Wrong or missing token type hint");
return false;
}
if query_pairs
.get("client_id")
.filter(|s| *s == CLIENT_ID)
.is_none()
{
println!("Wrong or missing client ID");
return false;
}
true
})
.respond_with(
ResponseTemplate::new(200).set_body_json(IntrospectionResponse {
active: true,
scope: Some([ScopeToken::Profile].into_iter().collect()),
client_id: Some(CLIENT_ID.to_owned()),
username: None,
token_type: Some(OAuthTokenTypeHint::AccessToken),
exp: None,
iat: None,
nbf: None,
sub: Some(SUBJECT_IDENTIFIER.to_owned()),
aud: Some(CLIENT_ID.to_owned()),
iss: Some(issuer.to_string()),
jti: None,
}),
)
.mount(&mock_server)
.await;
let response = introspect_token(
&http_service,
client_credentials.into(),
&introspection_endpoint,
ACCESS_TOKEN.to_owned(),
Some(OAuthTokenTypeHint::AccessToken),
now(),
&mut rng,
)
.await
.unwrap();
assert!(response.active);
assert_eq!(response.aud.unwrap(), CLIENT_ID);
assert!(response.scope.unwrap().contains_token(&ScopeToken::Profile));
assert_eq!(response.client_id.unwrap(), CLIENT_ID);
assert_eq!(response.iss.unwrap(), issuer.as_str());
assert_eq!(response.sub.unwrap(), SUBJECT_IDENTIFIER);
}

View File

@ -0,0 +1,242 @@
// Copyright 2022 Kévin Commaille.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::collections::HashMap;
use assert_matches::assert_matches;
use chrono::{DateTime, Duration, Utc};
use mas_iana::jose::JsonWebSignatureAlg;
use mas_jose::{
claims,
constraints::Constrainable,
jwk::PublicJsonWebKeySet,
jwt::{JsonWebSignatureHeader, Jwt},
};
use mas_oidc_client::{
error::{IdTokenError, JwtVerificationError},
requests::jose::{verify_id_token, JwtVerificationData},
types::IdToken,
};
use url::Url;
use crate::{keystore, now, CLIENT_ID, ID_TOKEN_SIGNING_ALG, SUBJECT_IDENTIFIER};
#[derive(Clone, Copy, PartialEq, Eq)]
enum IdTokenFlag {
WrongExpiration,
WrongSubject,
}
/// Generate an ID token with the given settings.
fn id_token(
issuer: &Url,
flag: Option<IdTokenFlag>,
auth_time: Option<DateTime<Utc>>,
) -> (IdToken, PublicJsonWebKeySet) {
let signing_alg = ID_TOKEN_SIGNING_ALG;
let keystore = keystore(&signing_alg);
let mut claims = HashMap::new();
let now = now();
claims::ISS.insert(&mut claims, issuer.to_string()).unwrap();
claims::AUD
.insert(&mut claims, CLIENT_ID.to_owned())
.unwrap();
if flag == Some(IdTokenFlag::WrongSubject) {
claims::SUB
.insert(&mut claims, "wrong_subject".to_owned())
.unwrap();
} else {
claims::SUB
.insert(&mut claims, SUBJECT_IDENTIFIER.to_owned())
.unwrap();
}
claims::IAT.insert(&mut claims, now).unwrap();
if flag == Some(IdTokenFlag::WrongExpiration) {
claims::EXP
.insert(&mut claims, now - Duration::hours(1))
.unwrap();
} else {
claims::EXP
.insert(&mut claims, now + Duration::hours(1))
.unwrap();
}
if let Some(auth_time) = auth_time {
claims::AUTH_TIME.insert(&mut claims, auth_time).unwrap();
}
let key = keystore.signing_key_for_algorithm(&signing_alg).unwrap();
let signer = key.params().signing_key_for_alg(&signing_alg).unwrap();
let header = JsonWebSignatureHeader::new(signing_alg).with_kid(key.kid().unwrap());
let id_token = Jwt::sign(header, claims, &signer).unwrap();
(id_token, keystore.public_jwks())
}
#[tokio::test]
async fn pass_verify_id_token() {
let issuer = Url::parse("http://localhost/").unwrap();
let now = now();
let (auth_id_token, _) = id_token(&issuer, None, Some(now));
let (id_token, jwks) = id_token(&issuer, None, Some(now));
let verification_data = JwtVerificationData {
issuer: &issuer,
jwks: &jwks,
client_id: &CLIENT_ID.to_owned(),
signing_algorithm: &ID_TOKEN_SIGNING_ALG,
};
verify_id_token(
id_token.as_str(),
verification_data,
Some(&auth_id_token),
now,
)
.unwrap();
}
#[tokio::test]
async fn fail_verify_id_token_wrong_issuer() {
let issuer = Url::parse("http://localhost/").unwrap();
let wrong_issuer = Url::parse("http://distanthost/").unwrap();
let (id_token, jwks) = id_token(&issuer, None, None);
let now = now();
let verification_data = JwtVerificationData {
issuer: &wrong_issuer,
jwks: &jwks,
client_id: &CLIENT_ID.to_owned(),
signing_algorithm: &ID_TOKEN_SIGNING_ALG,
};
let error = verify_id_token(id_token.as_str(), verification_data, None, now).unwrap_err();
assert_matches!(error, IdTokenError::Jwt(JwtVerificationError::WrongIssuer));
}
#[tokio::test]
async fn fail_verify_id_token_wrong_audience() {
let issuer = Url::parse("http://localhost/").unwrap();
let (id_token, jwks) = id_token(&issuer, None, None);
let now = now();
let verification_data = JwtVerificationData {
issuer: &issuer,
jwks: &jwks,
client_id: &"wrong_client_id".to_owned(),
signing_algorithm: &ID_TOKEN_SIGNING_ALG,
};
let error = verify_id_token(id_token.as_str(), verification_data, None, now).unwrap_err();
assert_matches!(
error,
IdTokenError::Jwt(JwtVerificationError::WrongAudience)
);
}
#[tokio::test]
async fn fail_verify_id_token_wrong_signing_algorithm() {
let issuer = Url::parse("http://localhost/").unwrap();
let (id_token, jwks) = id_token(&issuer, None, None);
let now = now();
let verification_data = JwtVerificationData {
issuer: &issuer,
jwks: &jwks,
client_id: &CLIENT_ID.to_owned(),
signing_algorithm: &JsonWebSignatureAlg::Unknown("wrong_algorithm".to_owned()),
};
let error = verify_id_token(id_token.as_str(), verification_data, None, now).unwrap_err();
assert_matches!(
error,
IdTokenError::Jwt(JwtVerificationError::WrongSignatureAlg)
);
}
#[tokio::test]
async fn fail_verify_id_token_wrong_expiration() {
let issuer = Url::parse("http://localhost/").unwrap();
let (id_token, jwks) = id_token(&issuer, Some(IdTokenFlag::WrongExpiration), None);
let now = now();
let verification_data = JwtVerificationData {
issuer: &issuer,
jwks: &jwks,
client_id: &CLIENT_ID.to_owned(),
signing_algorithm: &ID_TOKEN_SIGNING_ALG,
};
let error = verify_id_token(id_token.as_str(), verification_data, None, now).unwrap_err();
assert_matches!(error, IdTokenError::Claim(_));
}
#[tokio::test]
async fn fail_verify_id_token_wrong_subject() {
let issuer = Url::parse("http://localhost/").unwrap();
let now = now();
let (auth_id_token, _) = id_token(&issuer, None, Some(now));
let (id_token, jwks) = id_token(&issuer, Some(IdTokenFlag::WrongSubject), None);
let verification_data = JwtVerificationData {
issuer: &issuer,
jwks: &jwks,
client_id: &CLIENT_ID.to_owned(),
signing_algorithm: &ID_TOKEN_SIGNING_ALG,
};
let error = verify_id_token(
id_token.as_str(),
verification_data,
Some(&auth_id_token),
now,
)
.unwrap_err();
assert_matches!(error, IdTokenError::WrongSubjectIdentifier);
}
#[tokio::test]
async fn fail_verify_id_token_wrong_auth_time() {
let issuer = Url::parse("http://localhost/").unwrap();
let now = now();
let (auth_id_token, _) = id_token(&issuer, None, Some(now));
let (id_token, jwks) = id_token(&issuer, None, Some(now + Duration::hours(1)));
let verification_data = JwtVerificationData {
issuer: &issuer,
jwks: &jwks,
client_id: &CLIENT_ID.to_owned(),
signing_algorithm: &ID_TOKEN_SIGNING_ALG,
};
let error = verify_id_token(
id_token.as_str(),
verification_data,
Some(&auth_id_token),
now,
)
.unwrap_err();
assert_matches!(error, IdTokenError::WrongAuthTime)
}

View File

@ -0,0 +1,23 @@
// Copyright 2022 Kévin Commaille.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
mod authorization_code;
mod client_credentials;
mod discovery;
mod introspection;
mod jose;
mod refresh_token;
mod registration;
mod revocation;
mod userinfo;

View File

@ -0,0 +1,99 @@
// Copyright 2022 Kévin Commaille.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::collections::HashMap;
use assert_matches::assert_matches;
use mas_iana::oauth::{OAuthAccessTokenType, OAuthClientAuthenticationMethod};
use mas_oidc_client::requests::refresh_token::refresh_access_token;
use oauth2_types::requests::AccessTokenResponse;
use rand::SeedableRng;
use wiremock::{
matchers::{method, path},
Mock, Request, ResponseTemplate,
};
use crate::{client_credentials, init_test, now, ACCESS_TOKEN, CLIENT_ID, REFRESH_TOKEN};
#[tokio::test]
async fn pass_refresh_access_token() {
let (http_service, mock_server, issuer) = init_test().await;
let client_credentials =
client_credentials(OAuthClientAuthenticationMethod::None, &issuer, None);
let token_endpoint = issuer.join("token").unwrap();
let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(42);
Mock::given(method("POST"))
.and(path("/token"))
.and(|req: &Request| {
let query_pairs = form_urlencoded::parse(&req.body).collect::<HashMap<_, _>>();
if query_pairs
.get("grant_type")
.filter(|s| *s == "refresh_token")
.is_none()
{
println!("Wrong or missing grant type");
return false;
}
if query_pairs
.get("refresh_token")
.filter(|s| *s == REFRESH_TOKEN)
.is_none()
{
println!("Wrong or missing refresh token");
return false;
}
if query_pairs
.get("client_id")
.filter(|s| *s == CLIENT_ID)
.is_none()
{
println!("Wrong or missing client ID");
return false;
}
true
})
.respond_with(
ResponseTemplate::new(200).set_body_json(AccessTokenResponse {
access_token: ACCESS_TOKEN.to_owned(),
refresh_token: None,
id_token: None,
token_type: OAuthAccessTokenType::Bearer,
expires_in: None,
scope: None,
}),
)
.mount(&mock_server)
.await;
let (response, response_id_token) = refresh_access_token(
&http_service,
client_credentials,
&token_endpoint,
REFRESH_TOKEN.to_owned(),
None,
None,
None,
now(),
&mut rng,
)
.await
.unwrap();
assert_eq!(response.access_token, ACCESS_TOKEN);
assert_eq!(response.refresh_token, None);
assert_matches!(response_id_token, None);
}

View File

@ -0,0 +1,259 @@
// Copyright 2022 Kévin Commaille.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use assert_matches::assert_matches;
use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod};
use mas_jose::jwk::PublicJsonWebKeySet;
use mas_oidc_client::{error::RegistrationError, requests::registration::register_client};
use oauth2_types::{
oidc::ApplicationType,
registration::{ClientMetadata, ClientRegistrationResponse, VerifiedClientMetadata},
};
use serde_json::json;
use url::Url;
use wiremock::{
matchers::{body_partial_json, method, path},
Mock, Request, ResponseTemplate,
};
use crate::{init_test, CLIENT_ID, CLIENT_SECRET, REDIRECT_URI};
/// Generate valid client metadata for the given authentication method.
fn client_metadata(auth_method: OAuthClientAuthenticationMethod) -> VerifiedClientMetadata {
let (signing_alg, jwks) = match &auth_method {
OAuthClientAuthenticationMethod::ClientSecretJwt => {
(Some(JsonWebSignatureAlg::Hs256), None)
}
OAuthClientAuthenticationMethod::PrivateKeyJwt => (
Some(JsonWebSignatureAlg::Es256),
Some(PublicJsonWebKeySet::default()),
),
_ => (None, None),
};
ClientMetadata {
redirect_uris: Some(vec![Url::parse(REDIRECT_URI).expect("Couldn't parse URL")]),
application_type: Some(ApplicationType::Native),
token_endpoint_auth_method: Some(auth_method),
token_endpoint_auth_signing_alg: signing_alg,
jwks,
..Default::default()
}
.validate()
.unwrap()
}
#[tokio::test]
async fn pass_register_client_none() {
let (http_service, mock_server, issuer) = init_test().await;
let client_metadata = client_metadata(OAuthClientAuthenticationMethod::None);
let registration_endpoint = issuer.join("register").unwrap();
Mock::given(method("POST"))
.and(path("/register"))
.and(body_partial_json(json!({
"redirect_uris": [REDIRECT_URI],
"token_endpoint_auth_method": "none",
})))
.respond_with(
ResponseTemplate::new(200).set_body_json(ClientRegistrationResponse {
client_id: CLIENT_ID.to_owned(),
client_secret: None,
client_id_issued_at: None,
client_secret_expires_at: None,
}),
)
.mount(&mock_server)
.await;
let response = register_client(&http_service, &registration_endpoint, client_metadata)
.await
.unwrap();
assert_eq!(response.client_id, CLIENT_ID);
assert_eq!(response.client_secret, None);
}
#[tokio::test]
async fn pass_register_client_client_secret_basic() {
let (http_service, mock_server, issuer) = init_test().await;
let client_metadata = client_metadata(OAuthClientAuthenticationMethod::ClientSecretBasic);
let registration_endpoint = issuer.join("register").unwrap();
Mock::given(method("POST"))
.and(path("/register"))
.and(body_partial_json(json!({
"redirect_uris": [REDIRECT_URI],
"token_endpoint_auth_method": "client_secret_basic",
})))
.respond_with(
ResponseTemplate::new(200).set_body_json(ClientRegistrationResponse {
client_id: CLIENT_ID.to_owned(),
client_secret: Some(CLIENT_SECRET.to_owned()),
client_id_issued_at: None,
client_secret_expires_at: None,
}),
)
.mount(&mock_server)
.await;
let response = register_client(&http_service, &registration_endpoint, client_metadata)
.await
.unwrap();
assert_eq!(response.client_id, CLIENT_ID);
assert_eq!(response.client_secret.unwrap(), CLIENT_SECRET);
}
#[tokio::test]
async fn pass_register_client_client_secret_post() {
let (http_service, mock_server, issuer) = init_test().await;
let client_metadata = client_metadata(OAuthClientAuthenticationMethod::ClientSecretPost);
let registration_endpoint = issuer.join("register").unwrap();
Mock::given(method("POST"))
.and(path("/register"))
.and(body_partial_json(json!({
"redirect_uris": [REDIRECT_URI],
"token_endpoint_auth_method": "client_secret_post",
})))
.respond_with(
ResponseTemplate::new(200).set_body_json(ClientRegistrationResponse {
client_id: CLIENT_ID.to_owned(),
client_secret: Some(CLIENT_SECRET.to_owned()),
client_id_issued_at: None,
client_secret_expires_at: None,
}),
)
.mount(&mock_server)
.await;
let response = register_client(&http_service, &registration_endpoint, client_metadata)
.await
.unwrap();
assert_eq!(response.client_id, CLIENT_ID);
assert_eq!(response.client_secret.unwrap(), CLIENT_SECRET);
}
#[tokio::test]
async fn pass_register_client_client_secret_jwt() {
let (http_service, mock_server, issuer) = init_test().await;
let client_metadata = client_metadata(OAuthClientAuthenticationMethod::ClientSecretJwt);
let registration_endpoint = issuer.join("register").unwrap();
Mock::given(method("POST"))
.and(path("/register"))
.and(body_partial_json(json!({
"redirect_uris": [REDIRECT_URI],
"token_endpoint_auth_method": "client_secret_jwt",
"token_endpoint_auth_signing_alg": "HS256",
})))
.respond_with(
ResponseTemplate::new(200).set_body_json(ClientRegistrationResponse {
client_id: CLIENT_ID.to_owned(),
client_secret: Some(CLIENT_SECRET.to_owned()),
client_id_issued_at: None,
client_secret_expires_at: None,
}),
)
.mount(&mock_server)
.await;
let response = register_client(&http_service, &registration_endpoint, client_metadata)
.await
.unwrap();
assert_eq!(response.client_id, CLIENT_ID);
assert_eq!(response.client_secret.unwrap(), CLIENT_SECRET);
}
#[tokio::test]
async fn pass_register_client_private_key_jwt() {
let (http_service, mock_server, issuer) = init_test().await;
let client_metadata = client_metadata(OAuthClientAuthenticationMethod::PrivateKeyJwt);
let registration_endpoint = issuer.join("register").unwrap();
Mock::given(method("POST"))
.and(path("/register"))
.and(|req: &Request| {
let metadata = match req.body_json::<ClientMetadata>() {
Ok(body) => body,
Err(_) => return false,
};
*metadata.token_endpoint_auth_method() == OAuthClientAuthenticationMethod::PrivateKeyJwt
&& metadata.token_endpoint_auth_signing_alg == Some(JsonWebSignatureAlg::Es256)
&& metadata.jwks.is_some()
})
.respond_with(
ResponseTemplate::new(200).set_body_json(ClientRegistrationResponse {
client_id: CLIENT_ID.to_owned(),
client_secret: None,
client_id_issued_at: None,
client_secret_expires_at: None,
}),
)
.mount(&mock_server)
.await;
let response = register_client(&http_service, &registration_endpoint, client_metadata)
.await
.unwrap();
assert_eq!(response.client_id, CLIENT_ID);
assert_eq!(response.client_secret, None);
}
#[tokio::test]
async fn fail_register_client_404() {
let (http_service, _, issuer) = init_test().await;
let client_metadata = client_metadata(OAuthClientAuthenticationMethod::None);
let registration_endpoint = issuer.join("register").unwrap();
let error = register_client(&http_service, &registration_endpoint, client_metadata)
.await
.unwrap_err();
assert_matches!(error, RegistrationError::Http(_));
}
#[tokio::test]
async fn fail_register_client_missing_secret() {
let (http_service, mock_server, issuer) = init_test().await;
let client_metadata = client_metadata(OAuthClientAuthenticationMethod::ClientSecretBasic);
let registration_endpoint = issuer.join("register").unwrap();
Mock::given(method("POST"))
.and(path("/register"))
.and(body_partial_json(json!({
"token_endpoint_auth_method": "client_secret_basic",
})))
.respond_with(
ResponseTemplate::new(200).set_body_json(ClientRegistrationResponse {
client_id: CLIENT_ID.to_owned(),
client_secret: None,
client_id_issued_at: None,
client_secret_expires_at: None,
}),
)
.mount(&mock_server)
.await;
let error = register_client(&http_service, &registration_endpoint, client_metadata)
.await
.unwrap_err();
assert_matches!(error, RegistrationError::MissingClientSecret);
}

View File

@ -0,0 +1,82 @@
// Copyright 2022 Kévin Commaille.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::collections::HashMap;
use mas_iana::oauth::{OAuthClientAuthenticationMethod, OAuthTokenTypeHint};
use mas_oidc_client::requests::revocation::revoke_token;
use rand::SeedableRng;
use wiremock::{
matchers::{method, path},
Mock, Request, ResponseTemplate,
};
use crate::{client_credentials, init_test, ACCESS_TOKEN, CLIENT_ID};
#[tokio::test]
async fn pass_revoke_token() {
let (http_service, mock_server, issuer) = init_test().await;
let client_credentials =
client_credentials(OAuthClientAuthenticationMethod::None, &issuer, None);
let revocation_endpoint = issuer.join("revoke").unwrap();
let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(42);
Mock::given(method("POST"))
.and(path("/revoke"))
.and(|req: &Request| {
let query_pairs = form_urlencoded::parse(&req.body).collect::<HashMap<_, _>>();
if query_pairs
.get("token")
.filter(|s| *s == ACCESS_TOKEN)
.is_none()
{
println!("Wrong or missing refresh token");
return false;
}
if query_pairs
.get("token_type_hint")
.filter(|s| *s == "access_token")
.is_none()
{
println!("Wrong or missing token type hint");
return false;
}
if query_pairs
.get("client_id")
.filter(|s| *s == CLIENT_ID)
.is_none()
{
println!("Wrong or missing client ID");
return false;
}
true
})
.respond_with(ResponseTemplate::new(200))
.mount(&mock_server)
.await;
revoke_token(
&http_service,
client_credentials,
&revocation_endpoint,
ACCESS_TOKEN.to_owned(),
Some(OAuthTokenTypeHint::AccessToken),
crate::now(),
&mut rng,
)
.await
.unwrap();
}

View File

@ -0,0 +1,93 @@
// Copyright 2022 Kévin Commaille.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use assert_matches::assert_matches;
use mas_oidc_client::{
error::{IdTokenError, UserInfoError},
requests::userinfo::fetch_userinfo,
};
use serde_json::json;
use wiremock::{
matchers::{header, method, path},
Mock, ResponseTemplate,
};
use crate::{id_token, init_test, ACCESS_TOKEN, SUBJECT_IDENTIFIER};
#[tokio::test]
async fn pass_fetch_userinfo() {
let (http_service, mock_server, issuer) = init_test().await;
let userinfo_endpoint = issuer.join("userinfo").unwrap();
let (auth_id_token, _) = id_token(&issuer);
Mock::given(method("GET"))
.and(path("/userinfo"))
.and(header(
"authorization",
format!("Bearer {ACCESS_TOKEN}").as_str(),
))
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
"sub": SUBJECT_IDENTIFIER,
"email": "janedoe@example.com",
})))
.mount(&mock_server)
.await;
let claims = fetch_userinfo(
&http_service,
&userinfo_endpoint,
ACCESS_TOKEN,
None,
&auth_id_token,
)
.await
.unwrap();
assert_eq!(claims.get("email").unwrap(), "janedoe@example.com");
}
#[tokio::test]
async fn fail_wrong_subject_identifier() {
let (http_service, mock_server, issuer) = init_test().await;
let userinfo_endpoint = issuer.join("userinfo").unwrap();
let (auth_id_token, _) = id_token(&issuer);
Mock::given(method("GET"))
.and(path("/userinfo"))
.and(header(
"authorization",
format!("Bearer {ACCESS_TOKEN}").as_str(),
))
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
"sub": "wrong_subject_identifier",
"email": "janedoe@example.com",
})))
.mount(&mock_server)
.await;
let error = fetch_userinfo(
&http_service,
&userinfo_endpoint,
ACCESS_TOKEN,
None,
&auth_id_token,
)
.await
.unwrap_err();
assert_matches!(
error,
UserInfoError::IdToken(IdTokenError::WrongSubjectIdentifier)
);
}

View File

@ -0,0 +1,488 @@
// Copyright 2022 Kévin Commaille.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::collections::HashMap;
use assert_matches::assert_matches;
use base64ct::Encoding;
use mas_iana::oauth::{OAuthAccessTokenType, OAuthClientAuthenticationMethod};
use mas_jose::{
claims::{self, TimeOptions},
jwt::Jwt,
};
use mas_oidc_client::{
error::{CredentialsError, TokenRequestError},
requests::client_credentials::access_token_with_client_credentials,
types::client_credentials::ClientCredentials,
};
use oauth2_types::requests::AccessTokenResponse;
use rand::SeedableRng;
use serde_json::Value;
use tower::BoxError;
use wiremock::{
matchers::{header, method, path},
Mock, Request, ResponseTemplate,
};
use crate::{client_credentials, init_test, now, ACCESS_TOKEN, CLIENT_ID, CLIENT_SECRET};
#[tokio::test]
async fn pass_none() {
let (http_service, mock_server, issuer) = init_test().await;
let client_credentials =
client_credentials(OAuthClientAuthenticationMethod::None, &issuer, None);
let token_endpoint = issuer.join("token").unwrap();
let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(42);
Mock::given(method("POST"))
.and(path("/token"))
.and(|req: &Request| {
let query_pairs = form_urlencoded::parse(&req.body).collect::<HashMap<_, _>>();
if query_pairs
.get("client_id")
.filter(|s| *s == CLIENT_ID)
.is_none()
{
println!("Wrong or missing client ID");
return false;
}
true
})
.respond_with(
ResponseTemplate::new(200).set_body_json(AccessTokenResponse {
access_token: ACCESS_TOKEN.to_owned(),
refresh_token: None,
id_token: None,
token_type: OAuthAccessTokenType::Bearer,
expires_in: None,
scope: None,
}),
)
.mount(&mock_server)
.await;
access_token_with_client_credentials(
&http_service,
client_credentials,
&token_endpoint,
None,
now(),
&mut rng,
)
.await
.unwrap();
}
#[tokio::test]
async fn pass_client_secret_basic() {
let (http_service, mock_server, issuer) = init_test().await;
let client_credentials = client_credentials(
OAuthClientAuthenticationMethod::ClientSecretBasic,
&issuer,
None,
);
let token_endpoint = issuer.join("token").unwrap();
let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(42);
let username = form_urlencoded::byte_serialize(CLIENT_ID.as_bytes()).collect::<String>();
let password = form_urlencoded::byte_serialize(CLIENT_SECRET.as_bytes()).collect::<String>();
let enc_user_pass =
base64ct::Base64::encode_string(format!("{username}:{password}").as_bytes());
let authorization_header = format!("Basic {enc_user_pass}");
Mock::given(method("POST"))
.and(path("/token"))
.and(header("authorization", authorization_header.as_str()))
.respond_with(
ResponseTemplate::new(200).set_body_json(AccessTokenResponse {
access_token: ACCESS_TOKEN.to_owned(),
refresh_token: None,
id_token: None,
token_type: OAuthAccessTokenType::Bearer,
expires_in: None,
scope: None,
}),
)
.mount(&mock_server)
.await;
access_token_with_client_credentials(
&http_service,
client_credentials,
&token_endpoint,
None,
now(),
&mut rng,
)
.await
.unwrap();
}
#[tokio::test]
async fn pass_client_secret_post() {
let (http_service, mock_server, issuer) = init_test().await;
let client_credentials = client_credentials(
OAuthClientAuthenticationMethod::ClientSecretPost,
&issuer,
None,
);
let token_endpoint = issuer.join("token").unwrap();
let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(42);
Mock::given(method("POST"))
.and(path("/token"))
.and(|req: &Request| {
let query_pairs = form_urlencoded::parse(&req.body).collect::<HashMap<_, _>>();
if query_pairs
.get("client_id")
.filter(|s| *s == CLIENT_ID)
.is_none()
{
println!("Wrong or missing client ID");
return false;
}
if query_pairs
.get("client_secret")
.filter(|s| *s == CLIENT_SECRET)
.is_none()
{
println!("Wrong or missing client secret");
return false;
}
true
})
.respond_with(
ResponseTemplate::new(200).set_body_json(AccessTokenResponse {
access_token: ACCESS_TOKEN.to_owned(),
refresh_token: None,
id_token: None,
token_type: OAuthAccessTokenType::Bearer,
expires_in: None,
scope: None,
}),
)
.mount(&mock_server)
.await;
access_token_with_client_credentials(
&http_service,
client_credentials,
&token_endpoint,
None,
now(),
&mut rng,
)
.await
.unwrap();
}
#[tokio::test]
async fn pass_client_secret_jwt() {
let (http_service, mock_server, issuer) = init_test().await;
let client_credentials = client_credentials(
OAuthClientAuthenticationMethod::ClientSecretJwt,
&issuer,
None,
);
let token_endpoint = issuer.join("token").unwrap();
let endpoint = token_endpoint.to_string();
let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(42);
Mock::given(method("POST"))
.and(path("/token"))
.and(move |req: &Request| {
let query_pairs = form_urlencoded::parse(&req.body).collect::<HashMap<_, _>>();
if query_pairs
.get("client_id")
.filter(|s| *s == CLIENT_ID)
.is_none()
{
println!("Wrong or missing client ID");
return false;
}
if query_pairs
.get("client_assertion_type")
.filter(|s| *s == "urn:ietf:params:oauth:client-assertion-type:jwt-bearer")
.is_none()
{
println!("Wrong or missing client assertion type");
return false;
}
let jwt = match query_pairs.get("client_assertion") {
Some(jwt) => jwt,
None => {
println!("Missing client assertion");
return false;
}
};
let jwt = Jwt::<HashMap<String, Value>>::try_from(jwt.as_ref()).unwrap();
if jwt
.verify_with_shared_secret(CLIENT_SECRET.as_bytes().to_owned())
.is_err()
{
println!("Client assertion signature verification failed");
return false;
}
let mut claims = jwt.into_parts().1;
if let Err(error) = verify_client_jwt(&mut claims, &endpoint) {
println!("Client assertion claims verification failed: {error}");
return false;
}
true
})
.respond_with(
ResponseTemplate::new(200).set_body_json(AccessTokenResponse {
access_token: ACCESS_TOKEN.to_owned(),
refresh_token: None,
id_token: None,
token_type: OAuthAccessTokenType::Bearer,
expires_in: None,
scope: None,
}),
)
.mount(&mock_server)
.await;
access_token_with_client_credentials(
&http_service,
client_credentials,
&token_endpoint,
None,
now(),
&mut rng,
)
.await
.unwrap();
}
#[tokio::test]
async fn pass_private_key_jwt_with_keystore() {
let (http_service, mock_server, issuer) = init_test().await;
let client_credentials = client_credentials(
OAuthClientAuthenticationMethod::PrivateKeyJwt,
&issuer,
None,
);
let token_endpoint = issuer.join("token").unwrap();
let endpoint = token_endpoint.to_string();
let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(42);
let client_jwks = if let ClientCredentials::PrivateKeyJwt {
jwt_signing_method, ..
} = &client_credentials
{
let keystore = jwt_signing_method.keystore().unwrap();
keystore.public_jwks()
} else {
panic!("should be PrivateKeyJwt")
};
Mock::given(method("POST"))
.and(path("/token"))
.and(move |req: &Request| {
let query_pairs = form_urlencoded::parse(&req.body).collect::<HashMap<_, _>>();
if query_pairs
.get("client_id")
.filter(|s| *s == CLIENT_ID)
.is_none()
{
println!("Wrong or missing client ID");
return false;
}
if query_pairs
.get("client_assertion_type")
.filter(|s| *s == "urn:ietf:params:oauth:client-assertion-type:jwt-bearer")
.is_none()
{
println!("Wrong or missing client assertion type");
return false;
}
let jwt = match query_pairs.get("client_assertion") {
Some(jwt) => jwt,
None => {
println!("Missing client assertion");
return false;
}
};
let jwt = Jwt::<HashMap<String, Value>>::try_from(jwt.as_ref()).unwrap();
if jwt.verify_with_jwks(&client_jwks).is_err() {
println!("Client assertion signature verification failed");
return false;
}
let mut claims = jwt.into_parts().1;
if let Err(error) = verify_client_jwt(&mut claims, &endpoint) {
println!("Client assertion claims verification failed: {error}");
return false;
}
true
})
.respond_with(
ResponseTemplate::new(200).set_body_json(AccessTokenResponse {
access_token: ACCESS_TOKEN.to_owned(),
refresh_token: None,
id_token: None,
token_type: OAuthAccessTokenType::Bearer,
expires_in: None,
scope: None,
}),
)
.mount(&mock_server)
.await;
access_token_with_client_credentials(
&http_service,
client_credentials,
&token_endpoint,
None,
now(),
&mut rng,
)
.await
.unwrap();
}
#[tokio::test]
async fn pass_private_key_jwt_with_custom_signing() {
let (http_service, mock_server, issuer) = init_test().await;
let client_credentials = client_credentials(
OAuthClientAuthenticationMethod::PrivateKeyJwt,
&issuer,
Some(Box::new(|_claims, _alg| Ok("fake.signed.jwt".to_owned()))),
);
let token_endpoint = issuer.join("token").unwrap();
let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(42);
Mock::given(method("POST"))
.and(path("/token"))
.and(move |req: &Request| {
let query_pairs = form_urlencoded::parse(&req.body).collect::<HashMap<_, _>>();
if query_pairs
.get("client_id")
.filter(|s| *s == CLIENT_ID)
.is_none()
{
println!("Wrong or missing client ID");
return false;
}
if query_pairs
.get("client_assertion_type")
.filter(|s| *s == "urn:ietf:params:oauth:client-assertion-type:jwt-bearer")
.is_none()
{
println!("Wrong or missing client assertion type");
return false;
}
if query_pairs
.get("client_assertion")
.filter(|s| *s == "fake.signed.jwt")
.is_none()
{
println!("Wrong or missing client assertion");
return false;
}
true
})
.respond_with(
ResponseTemplate::new(200).set_body_json(AccessTokenResponse {
access_token: ACCESS_TOKEN.to_owned(),
refresh_token: None,
id_token: None,
token_type: OAuthAccessTokenType::Bearer,
expires_in: None,
scope: None,
}),
)
.mount(&mock_server)
.await;
access_token_with_client_credentials(
&http_service,
client_credentials,
&token_endpoint,
None,
now(),
&mut rng,
)
.await
.unwrap();
}
#[tokio::test]
async fn fail_private_key_jwt_with_custom_signing() {
let (http_service, _, issuer) = init_test().await;
let client_credentials = client_credentials(
OAuthClientAuthenticationMethod::PrivateKeyJwt,
&issuer,
Some(Box::new(|_claims, _alg| Err("Something went wrong".into()))),
);
let token_endpoint = issuer.join("token").unwrap();
let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(42);
let error = access_token_with_client_credentials(
&http_service,
client_credentials,
&token_endpoint,
None,
now(),
&mut rng,
)
.await
.unwrap_err();
assert_matches!(
error,
TokenRequestError::Credentials(CredentialsError::Custom(_))
);
}
fn verify_client_jwt(
claims: &mut HashMap<String, Value>,
token_endpoint: &String,
) -> Result<(), BoxError> {
let iss = claims::ISS.extract_required(claims)?;
if iss != CLIENT_ID {
return Err("Wrong iss".into());
}
let sub = claims::SUB.extract_required(claims)?;
if sub != CLIENT_ID {
return Err("Wrong sub".into());
}
let aud = claims::AUD.extract_required(claims)?;
if !aud.contains(token_endpoint) {
return Err("Wrong aud".into());
}
claims::EXP.extract_required_with_options(claims, TimeOptions::new(now()))?;
Ok(())
}

View File

@ -0,0 +1,15 @@
// Copyright 2022 Kévin Commaille.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
mod client_credentials;