Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions crates/gateway/src/channel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -815,6 +815,12 @@ impl ChannelService for LiveChannelService {
return Err("redirect_uri is required".into());
}

tracing::debug!(
account_id = %account_id,
redirect_uri = %redirect_uri,
"channels.oauth_start called"
);

// Merge caller-provided config (ownership_mode, policies, etc.) with
// the required OIDC fields.
let mut config = params
Expand Down
281 changes: 257 additions & 24 deletions crates/matrix/src/oidc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
//! (`Client::oauth()`), which handles PKCE, dynamic client registration,
//! token exchange, and automatic refresh.

use std::path::PathBuf;
use std::{error::Error as StdError, fmt, path::PathBuf};

use {
matrix_sdk::{
Expand All @@ -20,7 +20,7 @@ use {
moltis_common::secret_serde,
secrecy::{ExposeSecret, Secret},
serde::{Deserialize, Serialize},
tracing::{info, instrument, warn},
tracing::{debug, info, instrument, warn},
url::Url,
};

Expand Down Expand Up @@ -52,8 +52,8 @@ struct PersistedOidcSession {
refresh_token: Option<Secret<String>>,
}

impl std::fmt::Debug for PersistedOidcSession {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
impl fmt::Debug for PersistedOidcSession {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("PersistedOidcSession")
.field("client_id", &self.client_id)
.field("user_id", &self.user_id)
Expand Down Expand Up @@ -150,6 +150,85 @@ async fn load_oidc_session(account_id: &str) -> ChannelResult<Option<PersistedOi
/// MAS validates this URL and rejects loopback addresses.
const MOLTIS_CLIENT_URI: &str = "https://moltis.org/";

#[derive(Clone, Debug, Eq, PartialEq)]
struct ClientRegistrationDiagnostics {
original_redirect_uri: String,
registration_redirect_uri: String,
is_loopback: bool,
application_type: String,
issuer: String,
registration_endpoint: Option<String>,
client_metadata_json: String,
}

impl ClientRegistrationDiagnostics {
fn new(
original_redirect_uri: &Url,
registration_redirect_uri: &Url,
metadata: &ClientMetadata,
issuer: &Url,
registration_endpoint: Option<&Url>,
raw_metadata: &Raw<ClientMetadata>,
) -> Self {
Self {
original_redirect_uri: original_redirect_uri.to_string(),
registration_redirect_uri: registration_redirect_uri.to_string(),
is_loopback: is_loopback_uri(original_redirect_uri),
application_type: metadata.application_type.to_string(),
issuer: issuer.to_string(),
registration_endpoint: registration_endpoint.map(ToString::to_string),
client_metadata_json: raw_metadata.json().get().to_string(),
}
}
}

impl fmt::Display for ClientRegistrationDiagnostics {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"original_redirect_uri={}, registration_redirect_uri={}, is_loopback={}, \
application_type={}, issuer={}, registration_endpoint={}, client_metadata_json={}",
self.original_redirect_uri,
self.registration_redirect_uri,
self.is_loopback,
self.application_type,
self.issuer,
self.registration_endpoint.as_deref().unwrap_or("none"),
self.client_metadata_json
)
}
}

#[derive(Debug)]
struct ClientRegistrationFailure {
diagnostics: ClientRegistrationDiagnostics,
source: Box<dyn StdError + Send + Sync>,
}

impl ClientRegistrationFailure {
fn new(
diagnostics: ClientRegistrationDiagnostics,
source: impl StdError + Send + Sync + 'static,
) -> Self {
Self {
diagnostics,
source: Box::new(source),
}
}
}

impl fmt::Display for ClientRegistrationFailure {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}; diagnostics: {}", self.source, self.diagnostics)
}
}

impl StdError for ClientRegistrationFailure {
fn source(&self) -> Option<&(dyn StdError + 'static)> {
Some(self.source.as_ref())
}
}

fn is_loopback_uri(uri: &Url) -> bool {
let host = uri.host_str().unwrap_or_default();
if host == "localhost" || host == "::1" || host.ends_with(".localhost") {
Expand Down Expand Up @@ -177,30 +256,26 @@ fn normalize_loopback_redirect(redirect_uri: &Url) -> Url {
}
}

/// Build OIDC client metadata for dynamic registration.
///
/// `redirect_uri` must already be normalized (loopback https -> http) via
/// [`normalize_loopback_redirect`] before calling this.
fn build_client_metadata(redirect_uri: &Url) -> ChannelResult<ClientMetadata> {
let client_uri_url: Url = MOLTIS_CLIENT_URI
.parse()
.map_err(|error| ChannelError::external("matrix oidc parse client uri", error))?;
let client_uri = Localized::new(client_uri_url, std::iter::empty());
let is_loopback = is_loopback_uri(redirect_uri);
let registration_redirect = if is_loopback && redirect_uri.scheme() == "https" {
let mut normalized = redirect_uri.clone();
let _ = normalized.set_scheme("http");
normalized
} else {
redirect_uri.clone()
};
// MAS requires `Native` for loopback redirect URIs (RFC 8252) and `Web`
// for non-loopback URIs (e.g. behind a reverse proxy).
let app_type = if is_loopback {
let app_type = if is_loopback_uri(redirect_uri) {
ApplicationType::Native
} else {
ApplicationType::Web
};
Ok(ClientMetadata::new(
app_type,
vec![OAuthGrantType::AuthorizationCode {
redirect_uris: vec![registration_redirect],
redirect_uris: vec![redirect_uri.clone()],
}],
client_uri,
))
Expand All @@ -218,16 +293,50 @@ pub(crate) async fn start_oidc_login(
device_id: Option<&str>,
) -> ChannelResult<OidcLoginPending> {
// Verify the homeserver supports OIDC.
client
.oauth()
.server_metadata()
.await
.map_err(|error| ChannelError::external("matrix oidc server metadata discovery", error))?;
let server_metadata =
client.oauth().server_metadata().await.map_err(|error| {
ChannelError::external("matrix oidc server metadata discovery", error)
})?;

debug!(
account_id,
issuer = %server_metadata.issuer,
registration_endpoint = ?server_metadata.registration_endpoint,
"matrix OIDC server metadata discovered"
);

let registration_redirect = normalize_loopback_redirect(redirect_uri);
let metadata = build_client_metadata(redirect_uri)?;
let metadata = build_client_metadata(&registration_redirect)?;

let is_loopback = is_loopback_uri(redirect_uri);
debug!(
account_id,
original_redirect_uri = %redirect_uri,
registration_redirect_uri = %registration_redirect,
is_loopback,
application_type = ?metadata.application_type,
"matrix OIDC client registration parameters"
);

let raw_metadata: Raw<ClientMetadata> = Raw::new(&metadata)
.map_err(|error| ChannelError::external("matrix oidc serialize client metadata", error))?;
let diagnostics = ClientRegistrationDiagnostics::new(
redirect_uri,
&registration_redirect,
&metadata,
&server_metadata.issuer,
server_metadata.registration_endpoint.as_ref(),
&raw_metadata,
);

// Log the serialized metadata so operators can see exactly what is sent
// to the MAS registration endpoint.
debug!(
account_id,
client_metadata = %diagnostics.client_metadata_json,
"matrix OIDC client metadata for dynamic registration"
);

let registration_data = ClientRegistrationData::new(raw_metadata);

let device_id_owned = device_id
Expand All @@ -247,7 +356,25 @@ pub(crate) async fn start_oidc_login(
)
.build()
.await
.map_err(|error| ChannelError::external("matrix oidc authorization code build", error))?;
.map_err(|error| {
let failure = ClientRegistrationFailure::new(diagnostics.clone(), error);
warn!(
account_id,
original_redirect_uri = %diagnostics.original_redirect_uri,
registration_redirect_uri = %diagnostics.registration_redirect_uri,
is_loopback = diagnostics.is_loopback,
application_type = %diagnostics.application_type,
issuer = %diagnostics.issuer,
registration_endpoint = %diagnostics.registration_endpoint.as_deref().unwrap_or("none"),
client_metadata = %diagnostics.client_metadata_json,
error = %failure,
error_debug = ?failure,
"matrix OIDC client registration failed, \
check that redirect_uri is reachable and the homeserver's \
MAS allows dynamic client registration with this URI"
);
ChannelError::external("matrix oidc authorization code build", failure)
})?;

info!(account_id, auth_url = %url, "matrix OIDC login started");

Expand Down Expand Up @@ -398,7 +525,18 @@ fn spawn_session_persistence_task(client: &Client, account_id: &str) {

#[cfg(test)]
mod tests {
use super::*;
use {
crate::oidc::{
ClientRegistrationDiagnostics, MOLTIS_CLIENT_URI, PersistedOidcSession,
build_client_metadata, is_loopback_uri, normalize_loopback_redirect, oidc_session_path,
},
matrix_sdk::{
authentication::oauth::registration::{ApplicationType, OAuthGrantType},
ruma::serde::Raw,
},
secrecy::{ExposeSecret, Secret},
url::Url,
};

#[test]
fn oidc_session_path_returns_expected_path() {
Expand Down Expand Up @@ -437,11 +575,13 @@ mod tests {
}

#[test]
fn build_client_metadata_normalizes_loopback_and_uses_project_client_uri() {
fn build_client_metadata_uses_pre_normalized_loopback_and_project_client_uri() {
// Caller must normalize before passing to build_client_metadata.
let redirect: Url = "https://localhost:52979/auth/callback"
.parse()
.unwrap_or_else(|error| panic!("{error}"));
let metadata = build_client_metadata(&redirect).unwrap_or_else(|error| panic!("{error}"));
let normalized = normalize_loopback_redirect(&redirect);
let metadata = build_client_metadata(&normalized).unwrap_or_else(|error| panic!("{error}"));
match &metadata.grant_types[0] {
OAuthGrantType::AuthorizationCode { redirect_uris } => {
assert_eq!(
Expand Down Expand Up @@ -528,6 +668,99 @@ mod tests {
);
}

#[test]
fn registration_diagnostics_include_metadata_json_and_redirect_context() {
let original_redirect: Url = "https://localhost:52979/auth/callback"
.parse()
.unwrap_or_else(|error| panic!("{error}"));
let registration_redirect = normalize_loopback_redirect(&original_redirect);
let metadata =
build_client_metadata(&registration_redirect).unwrap_or_else(|error| panic!("{error}"));
let raw_metadata =
Raw::new(&metadata).unwrap_or_else(|error| panic!("raw metadata: {error}"));
let issuer: Url = "https://matrix.org/"
.parse()
.unwrap_or_else(|error| panic!("{error}"));
let registration_endpoint: Url =
"https://matrix.org/_matrix/client/unstable/org.matrix.msc2965/auth_issuer/register"
.parse()
.unwrap_or_else(|error| panic!("{error}"));

let diagnostics = ClientRegistrationDiagnostics::new(
&original_redirect,
&registration_redirect,
&metadata,
&issuer,
Some(&registration_endpoint),
&raw_metadata,
);

assert_eq!(
diagnostics.original_redirect_uri,
"https://localhost:52979/auth/callback"
);
assert_eq!(
diagnostics.registration_redirect_uri,
"http://localhost:52979/auth/callback"
);
assert!(diagnostics.is_loopback);
assert_eq!(diagnostics.application_type, "native");
assert_eq!(diagnostics.issuer, "https://matrix.org/");
assert_eq!(
diagnostics.registration_endpoint.as_deref(),
Some(registration_endpoint.as_str())
);
assert!(diagnostics.client_metadata_json.contains("redirect_uris"));
assert!(
diagnostics
.client_metadata_json
.contains("http://localhost:52979/auth/callback")
);

let display = diagnostics.to_string();
assert!(display.contains("original_redirect_uri=https://localhost:52979/auth/callback"));
assert!(display.contains("registration_redirect_uri=http://localhost:52979/auth/callback"));
assert!(display.contains("application_type=native"));
assert!(display.contains("registration_endpoint=https://matrix.org/"));
assert!(display.contains("client_metadata_json="));
}

#[test]
fn registration_failure_display_includes_source_and_diagnostics() {
let original_redirect: Url = "https://moltis.example.com/auth/callback"
.parse()
.unwrap_or_else(|error| panic!("{error}"));
let metadata =
build_client_metadata(&original_redirect).unwrap_or_else(|error| panic!("{error}"));
let raw_metadata =
Raw::new(&metadata).unwrap_or_else(|error| panic!("raw metadata: {error}"));
let issuer: Url = "https://matrix.org/"
.parse()
.unwrap_or_else(|error| panic!("{error}"));
let diagnostics = ClientRegistrationDiagnostics::new(
&original_redirect,
&original_redirect,
&metadata,
&issuer,
None,
&raw_metadata,
);
let failure = crate::oidc::ClientRegistrationFailure::new(
diagnostics,
std::io::Error::other("invalid_redirect_uri: invalid redirect_uri"),
);

let display = failure.to_string();
assert!(display.contains("invalid_redirect_uri: invalid redirect_uri"));
assert!(display.contains("original_redirect_uri=https://moltis.example.com/auth/callback"));
assert!(
display.contains("registration_redirect_uri=https://moltis.example.com/auth/callback")
);
assert!(display.contains("is_loopback=false"));
assert!(display.contains("application_type=web"));
assert!(display.contains("registration_endpoint=none"));
}

#[test]
fn is_loopback_uri_covers_full_127_range() {
let url_127_0_0_2: Url = "http://127.0.0.2:8080/auth/callback"
Expand Down
Loading