2222//! use std::collections::HashMap;
2323//! use std::iter;
2424//!
25+ //! use readyset_util::redacted::RedactedString;
2526//! use database_utils::TlsMode;
2627//! use mysql::prelude::*;
2728//! use mysql_srv::*;
6465//! Ok(())
6566//! }
6667//!
68+ //! async fn set_auth_info(&mut self, _: &str, _: Option<RedactedString>) -> io::Result<()> {
69+ //! Ok(())
70+ //! }
71+ //!
6772//! async fn on_query(
6873//! &mut self,
6974//! query: &str,
@@ -182,6 +187,7 @@ use error::{other_error, OtherErrorKind};
182187use mysql_common:: constants:: CapabilityFlags ;
183188use readyset_adapter_types:: { DeallocateId , ParsedCommand } ;
184189use readyset_data:: DfType ;
190+ use readyset_util:: redacted:: RedactedString ;
185191use tokio:: io:: { AsyncRead , AsyncWrite } ;
186192use tokio:: net;
187193use tokio_native_tls:: TlsAcceptor ;
@@ -312,6 +318,9 @@ pub trait MySqlShim<S: AsyncRead + AsyncWrite + Unpin + Send> {
312318 /// Called when client switches user.
313319 async fn on_change_user ( & mut self , _: & str , _: & str , _: & str ) -> io:: Result < ( ) > ;
314320
321+ /// Called when client authenticates to inform which users we should use.
322+ async fn set_auth_info ( & mut self , _: & str , _: Option < RedactedString > ) -> io:: Result < ( ) > ;
323+
315324 /// Retrieve the password for the user with the given username, if any.
316325 ///
317326 /// If the user doesn't exist, return [`None`].
@@ -438,7 +447,8 @@ impl<B: MySqlShim<S> + Send, S: AsyncWrite + AsyncRead + Unpin + Send> MySqlInte
438447 tls_acceptor,
439448 tls_mode,
440449 } ;
441- if let ( true , database) = mi. init ( ) . await ? {
450+ if let ( true , username, password, database) = mi. init ( ) . await ? {
451+ mi. shim . set_auth_info ( & username, password) . await ?;
442452 if let Some ( database) = database {
443453 mi. shim . on_init ( & database, None ) . await ?;
444454 }
@@ -455,9 +465,11 @@ impl<B: MySqlShim<S> + Send, S: AsyncWrite + AsyncRead + Unpin + Send> MySqlInte
455465 /// sent and received as needed to complete authentication.
456466 ///
457467 /// If no errors are encountered, the return value contains a tuple of a boolean to indicate
458- /// whether authentication was successful, and a database name if one was specified by the
459- /// client in the handshake response.
460- async fn init ( & mut self ) -> Result < ( bool , Option < String > ) , io:: Error > {
468+ /// whether authentication was successful, the username, the plaintext password if one was
469+ /// provided, and a database name if one was specified by the client in the handshake response.
470+ async fn init (
471+ & mut self ,
472+ ) -> Result < ( bool , String , Option < RedactedString > , Option < String > ) , io:: Error > {
461473 let auth_data =
462474 generate_auth_data ( ) . map_err ( |_| other_error ( OtherErrorKind :: AuthDataErr ) ) ?;
463475 self . auth_data = auth_data;
@@ -546,7 +558,7 @@ impl<B: MySqlShim<S> + Send, S: AsyncWrite + AsyncRead + Unpin + Send> MySqlInte
546558 )
547559 . await ?;
548560 self . conn . flush ( ) . await ?;
549- return Ok ( ( false , None ) ) ;
561+ return Ok ( ( false , "" . to_string ( ) , None , None ) ) ;
550562 }
551563 }
552564
@@ -601,7 +613,7 @@ impl<B: MySqlShim<S> + Send, S: AsyncWrite + AsyncRead + Unpin + Send> MySqlInte
601613 & mut self . conn ,
602614 )
603615 . await ?;
604- return Ok ( ( false , database) ) ;
616+ return Ok ( ( false , "" . to_string ( ) , None , database) ) ;
605617 }
606618
607619 debug ! (
@@ -632,17 +644,24 @@ impl<B: MySqlShim<S> + Send, S: AsyncWrite + AsyncRead + Unpin + Send> MySqlInte
632644 password
633645 } ;
634646
635- let auth_success = !self . shim . require_authentication ( )
636- || self
637- . shim
638- . password_for_username ( & username)
639- . is_some_and ( |password| {
640- let expected = hash_password ( & password, & auth_data) ;
641- let actual = handshake_password. as_slice ( ) ;
642- trace ! ( ?expected, ?actual) ;
643- expected == actual
644- } ) ;
645-
647+ let plain_password = self . shim . password_for_username ( & username) ;
648+ let require_auth = self . shim . require_authentication ( ) ;
649+ let auth_success = !require_auth
650+ || plain_password. as_ref ( ) . is_some_and ( |password| {
651+ let expected = hash_password ( password, & auth_data) ;
652+ let actual = handshake_password. as_slice ( ) ;
653+ trace ! ( ?expected, ?actual) ;
654+ expected == actual
655+ } ) ;
656+ let plain_password = if require_auth {
657+ Some ( RedactedString :: from (
658+ plain_password
659+ . map ( |p| String :: from_utf8_lossy ( & p) . into_owned ( ) )
660+ . unwrap_or_default ( ) ,
661+ ) )
662+ } else {
663+ None
664+ } ;
646665 if auth_success {
647666 debug ! ( %username, "Successfully authenticated client" ) ;
648667 writers:: write_ok_packet ( & mut self . conn , 0 , 0 , StatusFlags :: empty ( ) ) . await ?;
@@ -656,8 +675,7 @@ impl<B: MySqlShim<S> + Send, S: AsyncWrite + AsyncRead + Unpin + Send> MySqlInte
656675 . await ?;
657676 }
658677 self . conn . flush ( ) . await ?;
659-
660- Ok ( ( auth_success, database) )
678+ Ok ( ( auth_success, username, plain_password, database) )
661679 }
662680
663681 async fn run ( mut self ) -> Result < ( ) , io:: Error > {
0 commit comments