-
-
Notifications
You must be signed in to change notification settings - Fork 75
feat(server): #298 postgres 18 oauth support #349
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 25 commits
c77eb2d
fc1435f
0be0139
e5fbb7c
0b6db14
2bf49d4
26aee4a
76cd322
8f3c030
d7571e3
02880c8
90dbbb7
66bdefe
0b9b52e
ab5130d
f878d2a
ba3104c
055e3df
0fe9476
ac80eaa
1575df4
7881fc1
e7c5aca
a5e3c6d
641c85c
9366361
4873324
ab9fcf7
05cf994
3c05362
f2cc997
80f3b3b
d861c9f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,142 @@ | ||
| /// This example shows how to use pgwire with OAuth. | ||
| /// To connect with psql: | ||
| /// 1. Install libq-oauth: sudo apt-get install libpq-oauth | ||
| /// 2. Execute: psql "postgres://postgres@localhost:5432/db?oauth_issuer=https://auth.example.com&oauth_client_id=my-app-client-id" | ||
| use std::collections::HashMap; | ||
| use std::fs::File; | ||
| use std::io::{BufReader, Error as IOError, ErrorKind}; | ||
| use std::sync::Arc; | ||
|
|
||
| use async_trait::async_trait; | ||
|
|
||
| use pgwire::api::auth::sasl::oauth::{Oauth, OauthValidator, ValidatorModuleResult}; | ||
| use rustls_pemfile::{certs, pkcs8_private_keys}; | ||
| use rustls_pki_types::{CertificateDer, PrivateKeyDer}; | ||
| use tokio::net::TcpListener; | ||
| use tokio_rustls::rustls::ServerConfig; | ||
| use tokio_rustls::TlsAcceptor; | ||
|
|
||
| use pgwire::api::auth::sasl::SASLAuthStartupHandler; | ||
| use pgwire::api::auth::{DefaultServerParameterProvider, StartupHandler}; | ||
| use pgwire::api::PgWireServerHandlers; | ||
| use pgwire::error::PgWireResult; | ||
| use pgwire::tokio::process_socket; | ||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remember to reorder these imports |
||
|
|
||
| pub fn random_salt() -> Vec<u8> { | ||
| Vec::from(rand::random::<[u8; 10]>()) | ||
| } | ||
|
|
||
| // TODO: change to JSON web token validator, this is just for validating the current code | ||
| #[derive(Debug)] | ||
| struct SimpleTokenValidator { | ||
| valid_tokens: HashMap<String, String>, | ||
| } | ||
|
|
||
| impl SimpleTokenValidator { | ||
| pub fn new() -> Self { | ||
| let mut valid_tokens = HashMap::new(); | ||
| valid_tokens.insert( | ||
| "secret_token_123".to_string(), | ||
| "[email protected]".to_string(), | ||
| ); | ||
| valid_tokens.insert( | ||
| "admin_token_456".to_string(), | ||
| "[email protected]".to_string(), | ||
| ); | ||
|
|
||
| Self { valid_tokens } | ||
| } | ||
| } | ||
|
|
||
| #[async_trait] | ||
| impl OauthValidator for SimpleTokenValidator { | ||
| async fn validate( | ||
| &self, | ||
| token: &str, | ||
| username: &str, | ||
| issuer: &str, | ||
| required_scopes: &str, | ||
| ) -> PgWireResult<ValidatorModuleResult> { | ||
| println!("Validating token for user: {}", username); | ||
| println!("Expected issuer: {}", issuer); | ||
| println!("Required scopes: {}", required_scopes); | ||
|
|
||
| // if let Some(authenticated_user) = self.valid_tokens.get(token) { | ||
| // Ok(ValidatorModuleResult { | ||
| // authorized: true, | ||
| // authn_id: Some(authenticated_user.clone()), | ||
| // metadata: None, | ||
| // }) | ||
| // } else { | ||
| // Ok(ValidatorModuleResult { | ||
| // authorized: false, | ||
| // authn_id: None, | ||
| // metadata: None, | ||
| // }) | ||
| // } | ||
| // | ||
|
|
||
| Ok(ValidatorModuleResult { | ||
| authorized: true, | ||
| authn_id: Some(username.to_string()), | ||
| metadata: None, | ||
| }) | ||
| } | ||
| } | ||
|
|
||
| /// configure TlsAcceptor and get server cert for SCRAM channel binding | ||
| fn setup_tls() -> Result<TlsAcceptor, IOError> { | ||
| let cert = certs(&mut BufReader::new(File::open("examples/ssl/server.crt")?)) | ||
| .collect::<Result<Vec<CertificateDer>, IOError>>()?; | ||
|
|
||
| let key = pkcs8_private_keys(&mut BufReader::new(File::open("examples/ssl/server.key")?)) | ||
| .map(|key| key.map(PrivateKeyDer::from)) | ||
| .collect::<Result<Vec<PrivateKeyDer>, IOError>>()? | ||
| .remove(0); | ||
|
|
||
| let config = ServerConfig::builder() | ||
| .with_no_client_auth() | ||
| .with_single_cert(cert, key) | ||
| .map_err(|err| IOError::new(ErrorKind::InvalidInput, err))?; | ||
|
|
||
| Ok(TlsAcceptor::from(Arc::new(config))) | ||
| } | ||
|
|
||
| struct DummyProcessorFactory; | ||
|
|
||
| impl PgWireServerHandlers for DummyProcessorFactory { | ||
| fn startup_handler(&self) -> Arc<impl StartupHandler> { | ||
| let validator = SimpleTokenValidator::new(); | ||
| let oauth = Oauth::new( | ||
| "http://localhost:8080/realms/postgres-realm".to_string(), | ||
| "openid postgres".to_string(), | ||
| Arc::new(validator), | ||
| ); | ||
|
|
||
| let authenticator = | ||
| SASLAuthStartupHandler::new(Arc::new(DefaultServerParameterProvider::default())) | ||
| .with_oauth(oauth); | ||
|
|
||
| Arc::new(authenticator) | ||
| } | ||
| } | ||
|
|
||
| #[tokio::main] | ||
| pub async fn main() { | ||
| let factory = Arc::new(DummyProcessorFactory); | ||
|
|
||
| let server_addr = "127.0.0.1:5432"; | ||
| let tls_acceptor = setup_tls().unwrap(); | ||
| let listener = TcpListener::bind(server_addr).await.unwrap(); | ||
| println!("Listening to {}", server_addr); | ||
| loop { | ||
| let incoming_socket = listener.accept().await.unwrap(); | ||
| let tls_acceptor_ref = tls_acceptor.clone(); | ||
|
|
||
| let factory_ref = factory.clone(); | ||
|
|
||
| tokio::spawn(async move { | ||
| process_socket(incoming_socket.0, Some(tls_acceptor_ref), factory_ref).await | ||
| }); | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -13,6 +13,7 @@ use crate::messages::{PgWireBackendMessage, PgWireFrontendMessage}; | |
|
|
||
| use super::{ServerParameterProvider, StartupHandler}; | ||
|
|
||
| pub mod oauth; | ||
| pub mod scram; | ||
|
|
||
| #[derive(Debug)] | ||
|
|
@@ -22,6 +23,10 @@ pub enum SASLState { | |
| ScramClientFirstReceived, | ||
| // cached password, channel_binding and partial auth-message | ||
| ScramServerFirstSent(Password, String, String), | ||
| // oauth authentication method selected | ||
| OauthStateInit, | ||
| // failure during authentication | ||
| OauthStateError, | ||
| // finished | ||
| Finished, | ||
| } | ||
|
|
@@ -33,6 +38,10 @@ impl SASLState { | |
| SASLState::ScramClientFirstReceived | SASLState::ScramServerFirstSent(_, _, _) | ||
| ) | ||
| } | ||
|
|
||
| fn is_oauth(&self) -> bool { | ||
| matches!(self, SASLState::OauthStateInit | SASLState::OauthStateError) | ||
| } | ||
| } | ||
|
|
||
| #[derive(Debug)] | ||
|
|
@@ -42,6 +51,8 @@ pub struct SASLAuthStartupHandler<P> { | |
| state: Mutex<SASLState>, | ||
| /// scram configuration | ||
| scram: Option<scram::ScramAuth>, | ||
| /// oauth configuration | ||
| oauth: Option<oauth::Oauth>, | ||
| } | ||
|
|
||
| #[async_trait] | ||
|
|
@@ -70,39 +81,58 @@ impl<P: ServerParameterProvider> StartupHandler for SASLAuthStartupHandler<P> { | |
| } | ||
| PgWireFrontendMessage::PasswordMessageFamily(mut msg) => { | ||
| let mut state = self.state.lock().await; | ||
| if let SASLState::Initial = *state { | ||
|
|
||
| msg = if let SASLState::Initial = *state { | ||
| let sasl_initial_response = msg.into_sasl_initial_response()?; | ||
| let selected_mechanism = sasl_initial_response.auth_method.as_str(); | ||
|
|
||
| if [Self::SCRAM_SHA_256, Self::SCRAM_SHA_256_PLUS].contains(&selected_mechanism) | ||
| *state = if [Self::SCRAM_SHA_256, Self::SCRAM_SHA_256_PLUS] | ||
| .contains(&selected_mechanism) | ||
| { | ||
| *state = SASLState::ScramClientFirstReceived; | ||
| SASLState::ScramClientFirstReceived | ||
| } else if Self::OAUTHBEARER == selected_mechanism { | ||
| SASLState::OauthStateInit | ||
| } else { | ||
| return Err(PgWireError::UnsupportedSASLAuthMethod( | ||
| selected_mechanism.to_string(), | ||
| )); | ||
| } | ||
| }; | ||
|
|
||
| msg = PasswordMessageFamily::SASLInitialResponse(sasl_initial_response); | ||
| PasswordMessageFamily::SASLInitialResponse(sasl_initial_response) | ||
| } else { | ||
| let sasl_response = msg.into_sasl_response()?; | ||
| msg = PasswordMessageFamily::SASLResponse(sasl_response); | ||
| } | ||
|
|
||
| // SCRAM authentication | ||
| if state.is_scram() { | ||
| if let Some(scram) = &self.scram { | ||
| let (resp, new_state) = | ||
| scram.process_scram_message(client, msg, &state).await?; | ||
| PasswordMessageFamily::SASLResponse(sasl_response) | ||
| }; | ||
|
|
||
| let (res, new_state) = if state.is_scram() { | ||
| let scram = self.scram.as_ref().ok_or_else(|| { | ||
| PgWireError::UnsupportedSASLAuthMethod("SCRAM".to_string()) | ||
| })?; | ||
| scram.process_scram_message(client, msg, &state).await? | ||
| } else if state.is_oauth() { | ||
| let oauth = self.oauth.as_ref().ok_or_else(|| { | ||
| PgWireError::UnsupportedSASLAuthMethod("OAUTHBEARER".to_string()) | ||
| })?; | ||
| let r = oauth.process_oauth_message(client, msg, &state).await?; | ||
| r | ||
Lilit0x marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| } else { | ||
| return Err(PgWireError::InvalidSASLState); | ||
| }; | ||
|
|
||
| // we need to skip sending Authentication::Ok for Oauth after successful | ||
| // validation, but we mustn't also prevent other messages from getting sent. | ||
| match (state.is_oauth(), &res, &new_state) { | ||
| (true, Authentication::Ok, SASLState::Finished) => { | ||
| // we skip sending Authentication::Ok for OAuth because finish_authentication will send it | ||
| } | ||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why we need different workflow for oauth and scram here?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Per the docs:
Here, Oauth differs from SCRAM in the sense that we don't need to send SASLFinal after successful oauth validation, but we still need to return
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Got it. By design, I was not going to return I think we can change |
||
| _ => { | ||
| client | ||
| .send(PgWireBackendMessage::Authentication(resp)) | ||
| .send(PgWireBackendMessage::Authentication(res)) | ||
| .await?; | ||
| *state = new_state; | ||
| } else { | ||
| // scram is not configured | ||
| return Err(PgWireError::UnsupportedSASLAuthMethod("SCRAM".to_string())); | ||
| } | ||
| } | ||
| }; | ||
|
|
||
| *state = new_state; | ||
|
|
||
| if matches!(*state, SASLState::Finished) { | ||
| super::finish_authentication(client, self.parameter_provider.as_ref()).await?; | ||
|
|
@@ -121,6 +151,7 @@ impl<P> SASLAuthStartupHandler<P> { | |
| parameter_provider, | ||
| state: Mutex::new(SASLState::Initial), | ||
| scram: None, | ||
| oauth: None, | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -129,8 +160,14 @@ impl<P> SASLAuthStartupHandler<P> { | |
| self | ||
| } | ||
|
|
||
| pub fn with_oauth(mut self, oauth: oauth::Oauth) -> Self { | ||
| self.oauth = Some(oauth); | ||
| self | ||
| } | ||
|
|
||
| const SCRAM_SHA_256: &str = "SCRAM-SHA-256"; | ||
| const SCRAM_SHA_256_PLUS: &str = "SCRAM-SHA-256-PLUS"; | ||
| const OAUTHBEARER: &str = "OAUTHBEARER"; | ||
|
|
||
| fn supported_mechanisms(&self) -> Vec<String> { | ||
| let mut mechanisms = vec![]; | ||
|
|
@@ -143,6 +180,10 @@ impl<P> SASLAuthStartupHandler<P> { | |
| } | ||
| } | ||
|
|
||
| if self.oauth.is_some() { | ||
| mechanisms.push(Self::OAUTHBEARER.to_owned()); | ||
| } | ||
|
|
||
| mechanisms | ||
| } | ||
| } | ||
Uh oh!
There was an error while loading. Please reload this page.