diff --git a/crates/handlers/src/upstream_oauth2/template.rs b/crates/handlers/src/upstream_oauth2/template.rs index 2424adbc..56cd8743 100644 --- a/crates/handlers/src/upstream_oauth2/template.rs +++ b/crates/handlers/src/upstream_oauth2/template.rs @@ -14,7 +14,7 @@ use std::{collections::HashMap, sync::Arc}; -use base64ct::{Base64, Encoding}; +use base64ct::{Base64, Base64Unpadded, Base64Url, Base64UrlUnpadded, Encoding}; use minijinja::{Environment, Error, ErrorKind, Value}; fn split(value: &str, separator: Option<&str>) -> Vec { @@ -25,13 +25,19 @@ fn split(value: &str, separator: Option<&str>) -> Vec { } fn b64decode(value: &str) -> Result { - let bytes = Base64::decode_vec(value).map_err(|e| { - Error::new( - ErrorKind::InvalidOperation, - "Failed to decode base64 string", - ) - .with_source(e) - })?; + // We're not too concerned about the performance of this filter, so we'll just + // try all the base64 variants when decoding + let bytes = Base64::decode_vec(value) + .or_else(|_| Base64Url::decode_vec(value)) + .or_else(|_| Base64Unpadded::decode_vec(value)) + .or_else(|_| Base64UrlUnpadded::decode_vec(value)) + .map_err(|e| { + Error::new( + ErrorKind::InvalidOperation, + "Failed to decode base64 string", + ) + .with_source(e) + })?; // It is not obvious, but the cleanest way to get a Value stored as raw bytes is // to wrap it in an Arc, because Value implements From>> @@ -119,4 +125,19 @@ mod tests { .unwrap(); assert_eq!(res, "0-385-28089-0"); } + + #[test] + fn test_base64_decode() { + let env = environment(); + + let res = env + .render_str("{{ 'cGFkZGluZw==' | b64decode }}", ()) + .unwrap(); + assert_eq!(res, "padding"); + + let res = env + .render_str("{{ 'dW5wYWRkZWQ' | b64decode }}", ()) + .unwrap(); + assert_eq!(res, "unpadded"); + } }