From ea669b22f531b6afe13e9bd40f15a94e0e75d107 Mon Sep 17 00:00:00 2001 From: Zain Kabani Date: Wed, 14 Jun 2023 19:35:09 -0400 Subject: [PATCH 01/18] User server parameters struct instead of server info bytesmut --- src/admin.rs | 17 +++--- src/client.rs | 10 ++-- src/pool.rs | 49 ++++++++++------- src/server.rs | 148 ++++++++++++++++++++++++++++++++++++++++++++++---- 4 files changed, 178 insertions(+), 46 deletions(-) diff --git a/src/admin.rs b/src/admin.rs index 9ae5e9d5..52ac0c96 100644 --- a/src/admin.rs +++ b/src/admin.rs @@ -1,4 +1,5 @@ use crate::pool::BanReason; +use crate::server::ServerParameters; use crate::stats::pool::PoolStats; use bytes::{Buf, BufMut, BytesMut}; use log::{error, info, trace}; @@ -17,16 +18,16 @@ use crate::pool::ClientServerMap; use crate::pool::{get_all_pools, get_pool}; use crate::stats::{get_client_stats, get_server_stats, ClientState, ServerState}; -pub fn generate_server_info_for_admin() -> BytesMut { - let mut server_info = BytesMut::new(); +pub fn generate_server_parameters_for_admin() -> ServerParameters { + let mut server_parameters = ServerParameters::new(); - server_info.put(server_parameter_message("application_name", "")); - server_info.put(server_parameter_message("client_encoding", "UTF8")); - server_info.put(server_parameter_message("server_encoding", "UTF8")); - server_info.put(server_parameter_message("server_version", VERSION)); - server_info.put(server_parameter_message("DateStyle", "ISO, MDY")); + server_parameters.set_dynamic_param("application_name".to_string(), "".to_string()); + server_parameters.set_dynamic_param("client_encoding".to_string(), "UTF8".to_string()); + server_parameters.set_dynamic_param("server_encoding".to_string(), "UTF8".to_string()); + server_parameters.set_dynamic_param("server_version".to_string(), VERSION.to_string()); + server_parameters.set_dynamic_param("DateStyle".to_string(), "ISO, MDY".to_string()); - server_info + server_parameters } /// Handle admin client. diff --git a/src/client.rs b/src/client.rs index 1ff558b5..5acdd706 100644 --- a/src/client.rs +++ b/src/client.rs @@ -11,7 +11,7 @@ use tokio::net::TcpStream; use tokio::sync::broadcast::Receiver; use tokio::sync::mpsc::Sender; -use crate::admin::{generate_server_info_for_admin, handle_admin}; +use crate::admin::{generate_server_parameters_for_admin, handle_admin}; use crate::auth_passthrough::refetch_auth_hash; use crate::config::{get_config, get_idle_client_in_transaction_timeout, Address, PoolMode}; use crate::constants::*; @@ -491,7 +491,7 @@ where }; // Authenticate admin user. - let (transaction_mode, server_info) = if admin { + let (transaction_mode, server_parameters) = if admin { let config = get_config(); // Compare server and client hashes. @@ -510,7 +510,7 @@ where return Err(error); } - (false, generate_server_info_for_admin()) + (false, generate_server_parameters_for_admin()) } // Authenticate normal user. else { @@ -643,13 +643,13 @@ where } } - (transaction_mode, pool.server_info()) + (transaction_mode, pool.server_parameters()) }; debug!("Password authentication successful"); auth_ok(&mut write).await?; - write_all(&mut write, server_info).await?; + write_all(&mut write, server_parameters.get_bytes()).await?; backend_key_data(&mut write, process_id, secret_key).await?; ready_for_query(&mut write).await?; diff --git a/src/pool.rs b/src/pool.rs index b9293521..e686c612 100644 --- a/src/pool.rs +++ b/src/pool.rs @@ -1,7 +1,6 @@ use arc_swap::ArcSwap; use async_trait::async_trait; use bb8::{ManageConnection, Pool, PooledConnection, QueueStrategy}; -use bytes::{BufMut, BytesMut}; use chrono::naive::NaiveDateTime; use log::{debug, error, info, warn}; use once_cell::sync::Lazy; @@ -25,7 +24,7 @@ use crate::errors::Error; use crate::auth_passthrough::AuthPassthrough; use crate::plugins::prewarmer; -use crate::server::Server; +use crate::server::{Server, ServerParameters}; use crate::sharding::ShardingFunction; use crate::stats::{AddressStats, ClientStats, ServerStats}; @@ -188,10 +187,10 @@ pub struct ConnectionPool { /// that should not be queried. banlist: BanList, - /// The server information (K messages) have to be passed to the + /// The server information has to be passed to the /// clients on startup. We pre-connect to all shards and replicas - /// on pool creation and save the K messages here. - server_info: Arc>, + /// on pool creation and save the startup parameters here. + original_server_parameters: ServerParameters, /// Pool configuration. pub settings: PoolSettings, @@ -258,6 +257,7 @@ impl ConnectionPool { .clone() .into_keys() .collect::>(); + let mut original_server_parameters = ServerParameters::new(); // Sort by shard number to ensure consistency. shard_ids.sort_by_key(|k| k.parse::().unwrap()); @@ -415,6 +415,20 @@ impl ConnectionPool { pool.build_unchecked(manager) }; + // Set original server parameters by getting a connection + // If we don't want to validate then a default set of parameters will be used + if config.general.validate_config { + match pool.get().await { + Ok(conn) => { + original_server_parameters = conn.server_parameters(); + } + Err(err) => { + error!("Shard {} down or misconfigured: {:?}", address, err); + return Err(Error::ServerError); + } + }; + } + pools.push(pool); servers.push(address); } @@ -437,7 +451,7 @@ impl ConnectionPool { addresses, banlist: Arc::new(RwLock::new(banlist)), config_hash: new_pool_hash_value, - server_info: Arc::new(RwLock::new(BytesMut::new())), + original_server_parameters, auth_hash: pool_auth_hash, settings: PoolSettings { pool_mode: match user.pool_mode { @@ -488,6 +502,7 @@ impl ConnectionPool { // before setting it globally. // Do this async and somewhere else, we don't have to wait here. if config.general.validate_config { + // TODO: this can't be optional since we need some startup parameters to bootstrap with let mut validate_pool = pool.clone(); tokio::task::spawn(async move { let _ = validate_pool.validate().await; @@ -512,30 +527,22 @@ impl ConnectionPool { pub async fn validate(&mut self) -> Result<(), Error> { let mut futures = Vec::new(); let validated = Arc::clone(&self.validated); + validated.store(true, Ordering::Relaxed); for shard in 0..self.shards() { for server in 0..self.servers(shard) { let databases = self.databases.clone(); let validated = Arc::clone(&validated); - let pool_server_info = Arc::clone(&self.server_info); let task = tokio::task::spawn(async move { - let connection = match databases[shard][server].get().await { - Ok(conn) => conn, + match databases[shard][server].get().await { + Ok(_) => {} Err(err) => { + validated.store(false, Ordering::Relaxed); error!("Shard {} down or misconfigured: {:?}", shard, err); return; } }; - - let proxy = connection; - let server = &*proxy; - let server_info = server.server_info(); - - let mut guard = pool_server_info.write(); - guard.clear(); - guard.put(server_info.clone()); - validated.store(true, Ordering::Relaxed); }); futures.push(task); @@ -546,7 +553,7 @@ impl ConnectionPool { // TODO: compare server information to make sure // all shards are running identical configurations. - if self.server_info.read().is_empty() { + if !self.validated() { error!("Could not validate connection pool"); return Err(Error::AllServersDown); } @@ -906,8 +913,8 @@ impl ConnectionPool { &self.addresses[shard][server] } - pub fn server_info(&self) -> BytesMut { - self.server_info.read().clone() + pub fn server_parameters(&self) -> ServerParameters { + self.original_server_parameters.clone() } fn busy_connection_count(&self, address: &Address) -> u32 { diff --git a/src/server.rs b/src/server.rs index 32dd91f8..f5b05261 100644 --- a/src/server.rs +++ b/src/server.rs @@ -6,7 +6,8 @@ use log::{debug, error, info, trace, warn}; use parking_lot::{Mutex, RwLock}; use postgres_protocol::message; use std::collections::HashMap; -use std::io::Read; +use std::io::{Cursor, Read}; +use std::mem; use std::net::IpAddr; use std::sync::Arc; use std::time::SystemTime; @@ -19,6 +20,7 @@ use crate::config::{get_config, Address, User}; use crate::constants::*; use crate::dns_cache::{AddrSet, CACHED_RESOLVER}; use crate::errors::{Error, ServerIdentifier}; +use crate::messages::BytesMutReader; use crate::messages::*; use crate::mirrors::MirroringManager; use crate::pool::ClientServerMap; @@ -145,6 +147,124 @@ impl std::fmt::Display for CleanupState { } } +#[derive(Debug, Clone)] +pub struct ServerParameters { + base_original: BytesMut, + client_encoding: String, + date_style: String, + timezone: String, + standard_conforming_strings: String, + application_name: String, +} + +impl Default for ServerParameters { + fn default() -> Self { + ServerParameters::new() + } +} + +impl ServerParameters { + pub fn new() -> Self { + ServerParameters { + base_original: BytesMut::new(), + client_encoding: "UTF8".to_string(), + date_style: "ISO".to_string(), + timezone: "UTC".to_string(), + standard_conforming_strings: "on".to_string(), + application_name: "pgcat".to_string(), + } + } + + pub fn set_dynamic_param(&mut self, key: String, value: String) { + match key.as_str() { + "client_encoding" => { + self.client_encoding = value; + } + "date_style" => { + self.date_style = value; + } + "timezone" => { + self.timezone = value; + } + "standard_conforming_strings" => { + self.standard_conforming_strings = value; + } + "application_name" => { + self.application_name = value; + } + _ => {} + } + } + + fn set_param_from_bytes(&mut self, raw_bytes: BytesMut) { + let mut message_cursor = Cursor::new(&raw_bytes); + + message_cursor.get_u8(); + message_cursor.get_i32(); + + let key = match message_cursor.read_string() { + Ok(key) => key, + Err(_) => { + return; + }, + }; + let value = message_cursor.read_string().unwrap(); + + match key.as_str() { + "client_encoding" => { + self.client_encoding = value; + } + "date_style" => { + self.date_style = value; + } + "timezone" => { + self.timezone = value; + } + "standard_conforming_strings" => { + self.standard_conforming_strings = value; + } + "application_name" => { + self.application_name = value; + } + _ => { + self.base_original.extend(raw_bytes); + } + } + } + + pub fn get_bytes(&self) -> BytesMut { + let mut bytes = self.base_original.clone(); + + self.add_parameter_message("client_encoding", &self.client_encoding, &mut bytes); + self.add_parameter_message("date_style", &self.date_style, &mut bytes); + self.add_parameter_message("timezone", &self.timezone, &mut bytes); + self.add_parameter_message( + "standard_conforming_strings", + &self.standard_conforming_strings, + &mut bytes, + ); + self.add_parameter_message("application_name", &self.application_name, &mut bytes); + + bytes + } + + fn add_parameter_message(&self, key: &str, value: &str, buffer: &mut BytesMut) { + buffer.put_u8(b'S'); + + // 4 is len of i32, the plus for the null terminator + let len = 4 + key.len() + 1 + value.len() + 1; + + buffer.put_i32(len as i32); + + buffer.put_slice(key.as_bytes()); + buffer.put_u8(0); + buffer.put_slice(value.as_bytes()); + buffer.put_u8(0); + } +} + +// pub fn compare + /// Server state. pub struct Server { /// Server host, e.g. localhost, @@ -158,7 +278,7 @@ pub struct Server { buffer: BytesMut, /// Server information the server sent us over on startup. - server_info: BytesMut, + server_parameters: ServerParameters, /// Backend id and secret key used for query cancellation. process_id: i32, @@ -341,7 +461,6 @@ impl Server { startup(&mut stream, username, database).await?; - let mut server_info = BytesMut::new(); let mut process_id: i32 = 0; let mut secret_key: i32 = 0; let server_identifier = ServerIdentifier::new(username, &database); @@ -353,6 +472,8 @@ impl Server { None => None, }; + let mut server_parameters = ServerParameters::new(); + loop { let code = match stream.read_u8().await { Ok(code) => code as char, @@ -607,9 +728,15 @@ impl Server { // ParameterStatus 'S' => { - let mut param = vec![0u8; len as usize - 4]; + let mut bytes = BytesMut::with_capacity(len as usize + 1); + bytes.put_u8(code as u8); + bytes.put_i32(len); + bytes.resize(bytes.len() + len as usize - mem::size_of::(), b'0'); + + let slice_start = mem::size_of::() + mem::size_of::(); + let slice_end = slice_start + len as usize - mem::size_of::(); - match stream.read_exact(&mut param).await { + match stream.read_exact(&mut bytes[slice_start..slice_end]).await { Ok(_) => (), Err(_) => { return Err(Error::ServerStartupError( @@ -622,9 +749,7 @@ impl Server { // Save the parameter so we can pass it to the client later. // These can be server_encoding, client_encoding, server timezone, Postgres version, // and many more interesting things we should know about the Postgres server we are talking to. - server_info.put_u8(b'S'); - server_info.put_i32(len); - server_info.put_slice(¶m[..]); + server_parameters.set_param_from_bytes(bytes); } // BackendKeyData @@ -670,7 +795,7 @@ impl Server { address: address.clone(), stream: BufStream::new(stream), buffer: BytesMut::with_capacity(8196), - server_info, + server_parameters, process_id, secret_key, in_transaction: false, @@ -946,9 +1071,8 @@ impl Server { } /// Get server startup information to forward it to the client. - /// Not used at the moment. - pub fn server_info(&self) -> BytesMut { - self.server_info.clone() + pub fn server_parameters(&self) -> ServerParameters { + self.server_parameters.clone() } /// Indicate that this server connection cannot be re-used and must be discarded. From f0efa97c33c98fc5a2a1e8a561ae856ff2a3c45b Mon Sep 17 00:00:00 2001 From: Zain Kabani Date: Thu, 15 Jun 2023 13:34:32 -0400 Subject: [PATCH 02/18] Refactor to use hashmap for all params and add server parameters to client --- src/admin.rs | 10 +-- src/client.rs | 12 +++- src/messages.rs | 16 +++++ src/server.rs | 173 ++++++++++++++++++++---------------------------- 4 files changed, 104 insertions(+), 107 deletions(-) diff --git a/src/admin.rs b/src/admin.rs index 52ac0c96..5c4c663a 100644 --- a/src/admin.rs +++ b/src/admin.rs @@ -21,11 +21,11 @@ use crate::stats::{get_client_stats, get_server_stats, ClientState, ServerState} pub fn generate_server_parameters_for_admin() -> ServerParameters { let mut server_parameters = ServerParameters::new(); - server_parameters.set_dynamic_param("application_name".to_string(), "".to_string()); - server_parameters.set_dynamic_param("client_encoding".to_string(), "UTF8".to_string()); - server_parameters.set_dynamic_param("server_encoding".to_string(), "UTF8".to_string()); - server_parameters.set_dynamic_param("server_version".to_string(), VERSION.to_string()); - server_parameters.set_dynamic_param("DateStyle".to_string(), "ISO, MDY".to_string()); + server_parameters.set_param("application_name".to_string(), "".to_string(), false); + server_parameters.set_param("client_encoding".to_string(), "UTF8".to_string(), false); + server_parameters.set_param("server_encoding".to_string(), "UTF8".to_string(), false); + server_parameters.set_param("server_version".to_string(), VERSION.to_string(), false); + server_parameters.set_param("DateStyle".to_string(), "ISO, MDY".to_string(), false); server_parameters } diff --git a/src/client.rs b/src/client.rs index 5acdd706..bd8c791e 100644 --- a/src/client.rs +++ b/src/client.rs @@ -19,7 +19,7 @@ use crate::messages::*; use crate::plugins::PluginOutput; use crate::pool::{get_pool, ClientServerMap, ConnectionPool}; use crate::query_router::{Command, QueryRouter}; -use crate::server::Server; +use crate::server::{Server, ServerParameters}; use crate::stats::{ClientStats, ServerStats}; use crate::tls::Tls; @@ -91,6 +91,9 @@ pub struct Client { /// Application name for this client (defaults to pgcat) application_name: String, + /// Server startup and session parameters that we're going to track + server_parameters: ServerParameters, + /// Used to notify clients about an impending shutdown shutdown: Receiver<()>, } @@ -491,7 +494,7 @@ where }; // Authenticate admin user. - let (transaction_mode, server_parameters) = if admin { + let (transaction_mode, mut server_parameters) = if admin { let config = get_config(); // Compare server and client hashes. @@ -646,6 +649,9 @@ where (transaction_mode, pool.server_parameters()) }; + // Update the parameters to merge what the application sent and what's originally on the server + server_parameters.set_from_hashmap(¶meters, false); + debug!("Password authentication successful"); auth_ok(&mut write).await?; @@ -680,6 +686,7 @@ where pool_name: pool_name.clone(), username: username.clone(), application_name: application_name.to_string(), + server_parameters, shutdown, connected_to_server: false, }) @@ -714,6 +721,7 @@ where pool_name: String::from("undefined"), username: String::from("undefined"), application_name: String::from("undefined"), + server_parameters: ServerParameters::new(), shutdown, connected_to_server: false, }) diff --git a/src/messages.rs b/src/messages.rs index ee4886df..71668850 100644 --- a/src/messages.rs +++ b/src/messages.rs @@ -689,3 +689,19 @@ impl BytesMutReader for Cursor<&BytesMut> { } } } + +impl BytesMutReader for BytesMut { + /// Should only be used when reading strings from the message protocol. + /// Can be used to read multiple strings from the same message which are separated by the null byte + fn read_string(&mut self) -> Result { + let null_index = self.iter().position(|&byte| byte == b'\0'); + + match null_index { + Some(index) => { + let string_bytes = self.split_to(index + 1); + Ok(String::from_utf8_lossy(&string_bytes[..string_bytes.len() - 1]).to_string()) + } + None => return Err(Error::ParseBytesError("Could not read string".to_string())), + } + } +} diff --git a/src/server.rs b/src/server.rs index f5b05261..6b030182 100644 --- a/src/server.rs +++ b/src/server.rs @@ -3,13 +3,13 @@ use bytes::{Buf, BufMut, BytesMut}; use fallible_iterator::FallibleIterator; use log::{debug, error, info, trace, warn}; +use once_cell::sync::Lazy; use parking_lot::{Mutex, RwLock}; use postgres_protocol::message; -use std::collections::HashMap; -use std::io::{Cursor, Read}; +use std::collections::{HashMap, HashSet}; use std::mem; use std::net::IpAddr; -use std::sync::Arc; +use std::sync::{Arc, Once}; use std::time::SystemTime; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, BufStream}; use tokio::net::TcpStream; @@ -147,14 +147,24 @@ impl std::fmt::Display for CleanupState { } } +static INIT: Once = Once::new(); +static TRACKED_PARAMETERS: Lazy> = Lazy::new(|| { + INIT.call_once(|| { + println!("Initializing the hashset"); + }); + + let mut set = HashSet::new(); + set.insert("client_encoding".to_string()); + set.insert("datestyle".to_string()); + set.insert("timezone".to_string()); + set.insert("standard_conforming_strings".to_string()); + set.insert("application_name".to_string()); + set +}); + #[derive(Debug, Clone)] pub struct ServerParameters { - base_original: BytesMut, - client_encoding: String, - date_style: String, - timezone: String, - standard_conforming_strings: String, - application_name: String, + parameters: HashMap, } impl Default for ServerParameters { @@ -166,84 +176,39 @@ impl Default for ServerParameters { impl ServerParameters { pub fn new() -> Self { ServerParameters { - base_original: BytesMut::new(), - client_encoding: "UTF8".to_string(), - date_style: "ISO".to_string(), - timezone: "UTC".to_string(), - standard_conforming_strings: "on".to_string(), - application_name: "pgcat".to_string(), + parameters: HashMap::new(), } } - pub fn set_dynamic_param(&mut self, key: String, value: String) { - match key.as_str() { - "client_encoding" => { - self.client_encoding = value; - } - "date_style" => { - self.date_style = value; - } - "timezone" => { - self.timezone = value; - } - "standard_conforming_strings" => { - self.standard_conforming_strings = value; - } - "application_name" => { - self.application_name = value; + // returns true if parameter was set, false if it already exists or was a non-tracked parameter + pub fn set_param(&mut self, key: String, value: String, startup: bool) -> bool { + println!("set_param: {} = {}", key, value); + + if TRACKED_PARAMETERS.contains(&key) { + self.parameters.insert(key, value); + true + } else { + if startup { + self.parameters.insert(key, value); + return false; } - _ => {} + true } } - fn set_param_from_bytes(&mut self, raw_bytes: BytesMut) { - let mut message_cursor = Cursor::new(&raw_bytes); - - message_cursor.get_u8(); - message_cursor.get_i32(); - - let key = match message_cursor.read_string() { - Ok(key) => key, - Err(_) => { - return; - }, - }; - let value = message_cursor.read_string().unwrap(); - - match key.as_str() { - "client_encoding" => { - self.client_encoding = value; - } - "date_style" => { - self.date_style = value; - } - "timezone" => { - self.timezone = value; - } - "standard_conforming_strings" => { - self.standard_conforming_strings = value; - } - "application_name" => { - self.application_name = value; - } - _ => { - self.base_original.extend(raw_bytes); - } + pub fn set_from_hashmap(&mut self, parameters: &HashMap, startup: bool) { + // iterate through each and call set_param + for (key, value) in parameters { + self.set_param(key.to_string(), value.to_string(), startup); } } pub fn get_bytes(&self) -> BytesMut { - let mut bytes = self.base_original.clone(); - - self.add_parameter_message("client_encoding", &self.client_encoding, &mut bytes); - self.add_parameter_message("date_style", &self.date_style, &mut bytes); - self.add_parameter_message("timezone", &self.timezone, &mut bytes); - self.add_parameter_message( - "standard_conforming_strings", - &self.standard_conforming_strings, - &mut bytes, - ); - self.add_parameter_message("application_name", &self.application_name, &mut bytes); + let mut bytes = BytesMut::new(); + + for (key, value) in &self.parameters { + self.add_parameter_message(key, value, &mut bytes); + } bytes } @@ -277,6 +242,9 @@ pub struct Server { /// Our server response buffer. We buffer data before we give it to the client. buffer: BytesMut, + // Original server parameters that we started with (used when we discard all) + original_server_parameters: ServerParameters, + /// Server information the server sent us over on startup. server_parameters: ServerParameters, @@ -728,15 +696,10 @@ impl Server { // ParameterStatus 'S' => { - let mut bytes = BytesMut::with_capacity(len as usize + 1); - bytes.put_u8(code as u8); - bytes.put_i32(len); - bytes.resize(bytes.len() + len as usize - mem::size_of::(), b'0'); - - let slice_start = mem::size_of::() + mem::size_of::(); - let slice_end = slice_start + len as usize - mem::size_of::(); + let mut bytes = BytesMut::with_capacity(len as usize - 4); + bytes.resize(len as usize - mem::size_of::(), b'0'); - match stream.read_exact(&mut bytes[slice_start..slice_end]).await { + match stream.read_exact(&mut bytes[..]).await { Ok(_) => (), Err(_) => { return Err(Error::ServerStartupError( @@ -746,10 +709,13 @@ impl Server { } }; + let key = bytes.read_string().unwrap(); + let value = bytes.read_string().unwrap(); + // Save the parameter so we can pass it to the client later. // These can be server_encoding, client_encoding, server timezone, Postgres version, // and many more interesting things we should know about the Postgres server we are talking to. - server_parameters.set_param_from_bytes(bytes); + let _ = server_parameters.set_param(key, value, true); } // BackendKeyData @@ -795,6 +761,7 @@ impl Server { address: address.clone(), stream: BufStream::new(stream), buffer: BytesMut::with_capacity(8196), + original_server_parameters: server_parameters.clone(), server_parameters, process_id, secret_key, @@ -951,24 +918,23 @@ impl Server { // CommandComplete 'C' => { - let mut command_tag = String::new(); - match message.reader().read_to_string(&mut command_tag) { - Ok(_) => { + match message.read_string() { + Ok(command) => { // Non-exhaustive list of commands that are likely to change session variables/resources // which can leak between clients. This is a best effort to block bad clients // from poisoning a transaction-mode pool by setting inappropriate session variables - match command_tag.as_str() { - "SET\0" => { - // We don't detect set statements in transactions - // No great way to differentiate between set and set local - // As a result, we will miss cases when set statements are used in transactions - // This will reduce amount of discard statements sent - if !self.in_transaction { - debug!("Server connection marked for clean up"); - self.cleanup_state.needs_cleanup_set = true; - } - } - "PREPARE\0" => { + match command.as_str() { + // "SET" => { + // // We don't detect set statements in transactions + // // No great way to differentiate between set and set local + // // As a result, we will miss cases when set statements are used in transactions + // // This will reduce amount of discard statements sent + // if !self.in_transaction { + // debug!("Server connection marked for clean up"); + // self.cleanup_state.needs_cleanup_set = true; + // } + // } + "PREPARE" => { debug!("Server connection marked for clean up"); self.cleanup_state.needs_cleanup_prepare = true; } @@ -982,6 +948,13 @@ impl Server { } } + 'S' => { + let key = message.read_string().unwrap(); + let value = message.read_string().unwrap(); + + self.server_parameters.set_param(key, value, false); + } + // DataRow 'D' => { // More data is available after this message, this is not the end of the reply. From aee1c1ac31098f860c783dd0403b6f852c76df7f Mon Sep 17 00:00:00 2001 From: Zain Kabani Date: Fri, 16 Jun 2023 11:56:49 -0400 Subject: [PATCH 03/18] Sync parameters on client server checkout --- src/client.rs | 5 +--- src/messages.rs | 4 +++ src/server.rs | 72 +++++++++++++++++++++++++++++++------------------ 3 files changed, 51 insertions(+), 30 deletions(-) diff --git a/src/client.rs b/src/client.rs index bd8c791e..495b1dfe 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1043,10 +1043,7 @@ where server.address() ); - // TODO: investigate other parameters and set them too. - - // Set application_name. - server.set_name(&self.application_name).await?; + server.sync_parameters(&self.server_parameters).await?; let mut initial_message = Some(message); diff --git a/src/messages.rs b/src/messages.rs index 71668850..c993d8f5 100644 --- a/src/messages.rs +++ b/src/messages.rs @@ -137,6 +137,10 @@ where bytes.put_slice(user.as_bytes()); bytes.put_u8(0); + // Application name + bytes.put(&b"application_name\0"[..]); + bytes.put_slice(&b"pgcat\0"[..]); + // Database bytes.put(&b"database\0"[..]); bytes.put_slice(database.as_bytes()); diff --git a/src/server.rs b/src/server.rs index 6b030182..4bec0d51 100644 --- a/src/server.rs +++ b/src/server.rs @@ -150,13 +150,13 @@ impl std::fmt::Display for CleanupState { static INIT: Once = Once::new(); static TRACKED_PARAMETERS: Lazy> = Lazy::new(|| { INIT.call_once(|| { - println!("Initializing the hashset"); + info!("Initializing the TRACKED_PARAMETERS hashset"); }); let mut set = HashSet::new(); set.insert("client_encoding".to_string()); - set.insert("datestyle".to_string()); - set.insert("timezone".to_string()); + set.insert("DateStyle".to_string()); + set.insert("TimeZone".to_string()); set.insert("standard_conforming_strings".to_string()); set.insert("application_name".to_string()); set @@ -181,8 +181,13 @@ impl ServerParameters { } // returns true if parameter was set, false if it already exists or was a non-tracked parameter - pub fn set_param(&mut self, key: String, value: String, startup: bool) -> bool { - println!("set_param: {} = {}", key, value); + pub fn set_param(&mut self, mut key: String, value: String, startup: bool) -> bool { + // The startup parameter will send uncapitalized keys but parameter status packets will send capitalized keys + if key == "timezone" { + key = "TimeZone".to_string(); + } else if key == "datestyle" { + key = "DateStyle".to_string(); + }; if TRACKED_PARAMETERS.contains(&key) { self.parameters.insert(key, value); @@ -203,6 +208,25 @@ impl ServerParameters { } } + // Gets the diff of the parameters + fn compare_params(&self, incoming_parameters: &ServerParameters) -> HashMap { + let mut diff = HashMap::new(); + + for (key, value) in &self.parameters { + if !TRACKED_PARAMETERS.contains(key) { + continue; + } + + if let Some(incoming_value) = incoming_parameters.parameters.get(key) { + if value != incoming_value { + diff.insert(key.to_string(), incoming_value.to_string()); + } + } + } + + diff + } + pub fn get_bytes(&self) -> BytesMut { let mut bytes = BytesMut::new(); @@ -773,7 +797,7 @@ impl Server { addr_set, connected_at: chrono::offset::Utc::now().naive_utc(), stats, - application_name: String::new(), + application_name: "pgcat".to_string(), last_activity: SystemTime::now(), mirror_manager: match address.mirrors.len() { 0 => None, @@ -786,8 +810,6 @@ impl Server { cleanup_connections, }; - server.set_name("pgcat").await?; - return Ok(server); } @@ -1048,6 +1070,22 @@ impl Server { self.server_parameters.clone() } + pub async fn sync_parameters(&mut self, parameters: &ServerParameters) -> Result<(), Error> { + let parameter_diff = self.server_parameters.compare_params(parameters); + + if parameter_diff.is_empty() { + return Ok(()); + } + + let mut query = String::from(""); + + for (key, value) in parameter_diff { + query.push_str(&format!("SET {} TO '{}';", key, value)); + } + + self.query(&query).await + } + /// Indicate that this server connection cannot be re-used and must be discarded. pub fn mark_bad(&mut self) { error!("Server {:?} marked bad", self.address); @@ -1116,24 +1154,6 @@ impl Server { Ok(()) } - /// A shorthand for `SET application_name = $1`. - pub async fn set_name(&mut self, name: &str) -> Result<(), Error> { - if self.application_name != name { - self.application_name = name.to_string(); - // We don't want `SET application_name` to mark the server connection - // as needing cleanup - let needs_cleanup_before = self.cleanup_state; - - let result = Ok(self - .query(&format!("SET application_name = '{}'", name)) - .await?); - self.cleanup_state = needs_cleanup_before; - result - } else { - Ok(()) - } - } - /// get Server stats pub fn stats(&self) -> Arc { self.stats.clone() From a92f711eed4c1848f8638061eead7d06bffd22ec Mon Sep 17 00:00:00 2001 From: Zain Kabani Date: Fri, 16 Jun 2023 12:11:37 -0400 Subject: [PATCH 04/18] minor refactor --- src/client.rs | 49 +++++++++++++++++-------------------------------- src/server.rs | 2 +- 2 files changed, 18 insertions(+), 33 deletions(-) diff --git a/src/client.rs b/src/client.rs index 495b1dfe..e101e220 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1447,38 +1447,13 @@ where pool: &ConnectionPool, client_stats: &ClientStats, ) -> Result { - if pool.settings.user.statement_timeout > 0 { - match tokio::time::timeout( - tokio::time::Duration::from_millis(pool.settings.user.statement_timeout), - server.recv(), - ) - .await - { - Ok(result) => match result { - Ok(message) => Ok(message), - Err(err) => { - pool.ban(address, BanReason::MessageReceiveFailed, Some(client_stats)); - error_response_terminal( - &mut self.write, - &format!("error receiving data from server: {:?}", err), - ) - .await?; - Err(err) - } - }, - Err(_) => { - error!( - "Statement timeout while talking to {:?} with user {}", - address, pool.settings.user.username - ); - server.mark_bad(); - pool.ban(address, BanReason::StatementTimeout, Some(client_stats)); - error_response_terminal(&mut self.write, "pool statement timeout").await?; - Err(Error::StatementTimeout) - } - } - } else { - match server.recv().await { + let statement_timeout_duration = match pool.settings.user.statement_timeout { + 0 => tokio::time::Duration::MAX, + timeout => tokio::time::Duration::from_millis(timeout), + }; + + match tokio::time::timeout(statement_timeout_duration, server.recv()).await { + Ok(result) => match result { Ok(message) => Ok(message), Err(err) => { pool.ban(address, BanReason::MessageReceiveFailed, Some(client_stats)); @@ -1489,6 +1464,16 @@ where .await?; Err(err) } + }, + Err(_) => { + error!( + "Statement timeout while talking to {:?} with user {}", + address, pool.settings.user.username + ); + server.mark_bad(); + pool.ban(address, BanReason::StatementTimeout, Some(client_stats)); + error_response_terminal(&mut self.write, "pool statement timeout").await?; + Err(Error::StatementTimeout) } } } diff --git a/src/server.rs b/src/server.rs index 4bec0d51..690e145d 100644 --- a/src/server.rs +++ b/src/server.rs @@ -208,7 +208,7 @@ impl ServerParameters { } } - // Gets the diff of the parameters + // Gets the diff of the parameters fn compare_params(&self, incoming_parameters: &ServerParameters) -> HashMap { let mut diff = HashMap::new(); From b8f0b0d87da90ed1d71f333f11bdc71bda43b5db Mon Sep 17 00:00:00 2001 From: Zain Kabani Date: Fri, 16 Jun 2023 14:09:01 -0400 Subject: [PATCH 05/18] update client side parameters when changed --- src/client.rs | 2 +- src/mirrors.rs | 2 +- src/server.rs | 13 ++++++++++--- 3 files changed, 12 insertions(+), 5 deletions(-) diff --git a/src/client.rs b/src/client.rs index e101e220..50eccd3a 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1452,7 +1452,7 @@ where timeout => tokio::time::Duration::from_millis(timeout), }; - match tokio::time::timeout(statement_timeout_duration, server.recv()).await { + match tokio::time::timeout(statement_timeout_duration, server.recv(Some(&mut self.server_parameters))).await { Ok(result) => match result { Ok(message) => Ok(message), Err(err) => { diff --git a/src/mirrors.rs b/src/mirrors.rs index 0f2b02c0..7922e6f8 100644 --- a/src/mirrors.rs +++ b/src/mirrors.rs @@ -78,7 +78,7 @@ impl MirroredClient { } // Incoming data from server (we read to clear the socket buffer and discard the data) - recv_result = server.recv() => { + recv_result = server.recv(None) => { match recv_result { Ok(message) => trace!("Received from mirror: {} {:?}", String::from_utf8_lossy(&message[..]), address.clone()), Err(err) => { diff --git a/src/server.rs b/src/server.rs index 690e145d..71d96f50 100644 --- a/src/server.rs +++ b/src/server.rs @@ -879,7 +879,10 @@ impl Server { /// Receive data from the server in response to a client request. /// This method must be called multiple times while `self.is_data_available()` is true /// in order to receive all data the server has to offer. - pub async fn recv(&mut self) -> Result { + pub async fn recv( + &mut self, + mut client_server_parameters: Option<&mut ServerParameters>, + ) -> Result { loop { let mut message = match read_message(&mut self.stream).await { Ok(message) => message, @@ -974,6 +977,10 @@ impl Server { let key = message.read_string().unwrap(); let value = message.read_string().unwrap(); + if let Some(client_server_parameters) = client_server_parameters.as_mut() { + client_server_parameters.set_param(key.clone(), value.clone(), false); + } + self.server_parameters.set_param(key, value, false); } @@ -1117,7 +1124,7 @@ impl Server { self.send(&query).await?; loop { - let _ = self.recv().await?; + let _ = self.recv(None).await?; if !self.data_available { break; @@ -1211,7 +1218,7 @@ impl Server { .await?; debug!("Connected!, sending query."); server.send(&simple_query(query)).await?; - let mut message = server.recv().await?; + let mut message = server.recv(None).await?; Ok(parse_query_message(&mut message).await?) } From 0e6997c24f83e0df27479d7468355e0c2c638f5b Mon Sep 17 00:00:00 2001 From: Zain Kabani Date: Fri, 16 Jun 2023 15:11:11 -0400 Subject: [PATCH 06/18] Move the SET statement logic from the C packet to the S packet. --- src/client.rs | 7 ++++++- src/server.rs | 33 ++++++++++++++------------------- 2 files changed, 20 insertions(+), 20 deletions(-) diff --git a/src/client.rs b/src/client.rs index 50eccd3a..073976ed 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1452,7 +1452,12 @@ where timeout => tokio::time::Duration::from_millis(timeout), }; - match tokio::time::timeout(statement_timeout_duration, server.recv(Some(&mut self.server_parameters))).await { + match tokio::time::timeout( + statement_timeout_duration, + server.recv(Some(&mut self.server_parameters)), + ) + .await + { Ok(result) => match result { Ok(message) => Ok(message), Err(err) => { diff --git a/src/server.rs b/src/server.rs index 71d96f50..d686e4b2 100644 --- a/src/server.rs +++ b/src/server.rs @@ -180,7 +180,7 @@ impl ServerParameters { } } - // returns true if parameter was set, false if it already exists or was a non-tracked parameter + /// returns true if a tracked parameter was set, false if it was a non-tracked parameter pub fn set_param(&mut self, mut key: String, value: String, startup: bool) -> bool { // The startup parameter will send uncapitalized keys but parameter status packets will send capitalized keys if key == "timezone" { @@ -195,9 +195,8 @@ impl ServerParameters { } else { if startup { self.parameters.insert(key, value); - return false; } - true + false } } @@ -266,9 +265,6 @@ pub struct Server { /// Our server response buffer. We buffer data before we give it to the client. buffer: BytesMut, - // Original server parameters that we started with (used when we discard all) - original_server_parameters: ServerParameters, - /// Server information the server sent us over on startup. server_parameters: ServerParameters, @@ -781,11 +777,10 @@ impl Server { } }; - let mut server = Server { + let server = Server { address: address.clone(), stream: BufStream::new(stream), buffer: BytesMut::with_capacity(8196), - original_server_parameters: server_parameters.clone(), server_parameters, process_id, secret_key, @@ -949,16 +944,6 @@ impl Server { // which can leak between clients. This is a best effort to block bad clients // from poisoning a transaction-mode pool by setting inappropriate session variables match command.as_str() { - // "SET" => { - // // We don't detect set statements in transactions - // // No great way to differentiate between set and set local - // // As a result, we will miss cases when set statements are used in transactions - // // This will reduce amount of discard statements sent - // if !self.in_transaction { - // debug!("Server connection marked for clean up"); - // self.cleanup_state.needs_cleanup_set = true; - // } - // } "PREPARE" => { debug!("Server connection marked for clean up"); self.cleanup_state.needs_cleanup_prepare = true; @@ -981,7 +966,17 @@ impl Server { client_server_parameters.set_param(key.clone(), value.clone(), false); } - self.server_parameters.set_param(key, value, false); + // We set a non-tracked parameter. We should reset the state + if !self.server_parameters.set_param(key, value, false) { + // We don't detect set statements in transactions + // No great way to differentiate between set and set local + // As a result, we will miss cases when set statements are used in transactions + // This will reduce amount of discard statements sent + if !self.in_transaction { + debug!("Server connection marked for clean up"); + self.cleanup_state.needs_cleanup_set = true; + } + }; } // DataRow From 94728a17606c924b1d95e8a3b202e028347b3264 Mon Sep 17 00:00:00 2001 From: Zain Kabani Date: Fri, 16 Jun 2023 20:51:30 -0400 Subject: [PATCH 07/18] trigger build From 06ab0f961c0a04d8ccb4c2275c611f6629a1a3bd Mon Sep 17 00:00:00 2001 From: Zain Kabani Date: Fri, 16 Jun 2023 23:06:10 -0400 Subject: [PATCH 08/18] revert validation changes --- src/pool.rs | 36 ++++++++++++++---------------------- src/server.rs | 13 ++++++++++++- 2 files changed, 26 insertions(+), 23 deletions(-) diff --git a/src/pool.rs b/src/pool.rs index e686c612..401344da 100644 --- a/src/pool.rs +++ b/src/pool.rs @@ -190,7 +190,7 @@ pub struct ConnectionPool { /// The server information has to be passed to the /// clients on startup. We pre-connect to all shards and replicas /// on pool creation and save the startup parameters here. - original_server_parameters: ServerParameters, + original_server_parameters: Arc>, /// Pool configuration. pub settings: PoolSettings, @@ -257,7 +257,6 @@ impl ConnectionPool { .clone() .into_keys() .collect::>(); - let mut original_server_parameters = ServerParameters::new(); // Sort by shard number to ensure consistency. shard_ids.sort_by_key(|k| k.parse::().unwrap()); @@ -415,20 +414,6 @@ impl ConnectionPool { pool.build_unchecked(manager) }; - // Set original server parameters by getting a connection - // If we don't want to validate then a default set of parameters will be used - if config.general.validate_config { - match pool.get().await { - Ok(conn) => { - original_server_parameters = conn.server_parameters(); - } - Err(err) => { - error!("Shard {} down or misconfigured: {:?}", address, err); - return Err(Error::ServerError); - } - }; - } - pools.push(pool); servers.push(address); } @@ -451,7 +436,7 @@ impl ConnectionPool { addresses, banlist: Arc::new(RwLock::new(banlist)), config_hash: new_pool_hash_value, - original_server_parameters, + original_server_parameters: Arc::new(RwLock::new(ServerParameters::default())), auth_hash: pool_auth_hash, settings: PoolSettings { pool_mode: match user.pool_mode { @@ -527,22 +512,29 @@ impl ConnectionPool { pub async fn validate(&mut self) -> Result<(), Error> { let mut futures = Vec::new(); let validated = Arc::clone(&self.validated); - validated.store(true, Ordering::Relaxed); for shard in 0..self.shards() { for server in 0..self.servers(shard) { let databases = self.databases.clone(); let validated = Arc::clone(&validated); + let pool_server_parameters = Arc::clone(&self.original_server_parameters); let task = tokio::task::spawn(async move { - match databases[shard][server].get().await { - Ok(_) => {} + let connection = match databases[shard][server].get().await { + Ok(conn) => conn, Err(err) => { - validated.store(false, Ordering::Relaxed); error!("Shard {} down or misconfigured: {:?}", shard, err); return; } }; + + let proxy = connection; + let server = &*proxy; + let server_parameters: ServerParameters = server.server_parameters(); + + let mut guard = pool_server_parameters.write(); + *guard = server_parameters; + validated.store(true, Ordering::Relaxed); }); futures.push(task); @@ -914,7 +906,7 @@ impl ConnectionPool { } pub fn server_parameters(&self) -> ServerParameters { - self.original_server_parameters.clone() + self.original_server_parameters.read().clone() } fn busy_connection_count(&self, address: &Address) -> u32 { diff --git a/src/server.rs b/src/server.rs index d686e4b2..b75c6228 100644 --- a/src/server.rs +++ b/src/server.rs @@ -169,7 +169,18 @@ pub struct ServerParameters { impl Default for ServerParameters { fn default() -> Self { - ServerParameters::new() + let mut server_parameters = ServerParameters::new(); + server_parameters.set_param("client_encoding".to_string(), "UTF8".to_string(), false); + server_parameters.set_param("DateStyle".to_string(), "ISO, MDY".to_string(), false); + server_parameters.set_param("TimeZone".to_string(), "Etc/UTC".to_string(), false); + server_parameters.set_param( + "standard_conforming_strings".to_string(), + "on".to_string(), + false, + ); + server_parameters.set_param("application_name".to_string(), "pgcat".to_string(), false); + + server_parameters } } From a4de9fca584bcfef6bbd3ca9f2b51ea59f1115ed Mon Sep 17 00:00:00 2001 From: Zain Kabani Date: Fri, 16 Jun 2023 23:10:22 -0400 Subject: [PATCH 09/18] remove comment --- src/pool.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/src/pool.rs b/src/pool.rs index 401344da..2014cb86 100644 --- a/src/pool.rs +++ b/src/pool.rs @@ -487,7 +487,6 @@ impl ConnectionPool { // before setting it globally. // Do this async and somewhere else, we don't have to wait here. if config.general.validate_config { - // TODO: this can't be optional since we need some startup parameters to bootstrap with let mut validate_pool = pool.clone(); tokio::task::spawn(async move { let _ = validate_pool.validate().await; From a2e7563b55b0b12f306d545dc037907916511e62 Mon Sep 17 00:00:00 2001 From: Zain Kabani Date: Fri, 16 Jun 2023 23:55:00 -0400 Subject: [PATCH 10/18] Try fix --- src/prometheus.rs | 4 ++-- src/server.rs | 27 +++++++++++++-------------- 2 files changed, 15 insertions(+), 16 deletions(-) diff --git a/src/prometheus.rs b/src/prometheus.rs index 7774b5a5..7e264dca 100644 --- a/src/prometheus.rs +++ b/src/prometheus.rs @@ -1,6 +1,6 @@ use hyper::service::{make_service_fn, service_fn}; use hyper::{Body, Method, Request, Response, Server, StatusCode}; -use log::{error, info, debug}; +use log::{debug, error, info}; use phf::phf_map; use std::collections::HashMap; use std::fmt; @@ -364,7 +364,7 @@ fn push_server_stats(lines: &mut Vec) { { lines.push(prometheus_metric.to_string()); } else { - warn!("Metric {} not implemented for {}", key, address.name()); + debug!("Metric {} not implemented for {}", key, address.name()); } } } diff --git a/src/server.rs b/src/server.rs index 2a7eb266..bf428efe 100644 --- a/src/server.rs +++ b/src/server.rs @@ -192,7 +192,7 @@ impl ServerParameters { } /// returns true if a tracked parameter was set, false if it was a non-tracked parameter - pub fn set_param(&mut self, mut key: String, value: String, startup: bool) -> bool { + pub fn set_param(&mut self, mut key: String, value: String, startup: bool) { // The startup parameter will send uncapitalized keys but parameter status packets will send capitalized keys if key == "timezone" { key = "TimeZone".to_string(); @@ -202,12 +202,10 @@ impl ServerParameters { if TRACKED_PARAMETERS.contains(&key) { self.parameters.insert(key, value); - true } else { if startup { self.parameters.insert(key, value); } - false } } @@ -959,6 +957,17 @@ impl Server { // which can leak between clients. This is a best effort to block bad clients // from poisoning a transaction-mode pool by setting inappropriate session variables match command.as_str() { + "SET" => { + // We don't detect set statements in transactions + // No great way to differentiate between set and set local + // As a result, we will miss cases when set statements are used in transactions + // This will reduce amount of discard statements sent + if !self.in_transaction { + debug!("Server connection marked for clean up"); + self.cleanup_state.needs_cleanup_set = true; + } + } + "PREPARE" => { debug!("Server connection marked for clean up"); self.cleanup_state.needs_cleanup_prepare = true; @@ -981,17 +990,7 @@ impl Server { client_server_parameters.set_param(key.clone(), value.clone(), false); } - // We set a non-tracked parameter. We should reset the state - if !self.server_parameters.set_param(key, value, false) { - // We don't detect set statements in transactions - // No great way to differentiate between set and set local - // As a result, we will miss cases when set statements are used in transactions - // This will reduce amount of discard statements sent - if !self.in_transaction { - debug!("Server connection marked for clean up"); - self.cleanup_state.needs_cleanup_set = true; - } - }; + self.server_parameters.set_param(key, value, false); } // DataRow From d10cf90a9fffa6369d9c726d8f514a2bde6f7b83 Mon Sep 17 00:00:00 2001 From: Zain Kabani Date: Sat, 17 Jun 2023 01:37:08 -0400 Subject: [PATCH 11/18] Reset cleanup state after sync --- src/server.rs | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/server.rs b/src/server.rs index bf428efe..6fa8cc8e 100644 --- a/src/server.rs +++ b/src/server.rs @@ -1136,7 +1136,11 @@ impl Server { query.push_str(&format!("SET {} TO '{}';", key, value)); } - self.query(&query).await + let res = self.query(&query).await; + + self.cleanup_state.reset(); + + res } /// Indicate that this server connection cannot be re-used and must be discarded. From 88ae4d057c5b2ca6108640a038571683455287c9 Mon Sep 17 00:00:00 2001 From: Zain Kabani Date: Sat, 17 Jun 2023 10:34:50 -0400 Subject: [PATCH 12/18] fix server version test --- src/admin.rs | 10 +++++----- src/server.rs | 1 + 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/src/admin.rs b/src/admin.rs index c0cd00ca..6c6ca9ac 100644 --- a/src/admin.rs +++ b/src/admin.rs @@ -21,11 +21,11 @@ use crate::stats::{get_client_stats, get_server_stats, ClientState, ServerState} pub fn generate_server_parameters_for_admin() -> ServerParameters { let mut server_parameters = ServerParameters::new(); - server_parameters.set_param("application_name".to_string(), "".to_string(), false); - server_parameters.set_param("client_encoding".to_string(), "UTF8".to_string(), false); - server_parameters.set_param("server_encoding".to_string(), "UTF8".to_string(), false); - server_parameters.set_param("server_version".to_string(), VERSION.to_string(), false); - server_parameters.set_param("DateStyle".to_string(), "ISO, MDY".to_string(), false); + server_parameters.set_param("application_name".to_string(), "".to_string(), true); + server_parameters.set_param("client_encoding".to_string(), "UTF8".to_string(), true); + server_parameters.set_param("server_encoding".to_string(), "UTF8".to_string(), true); + server_parameters.set_param("server_version".to_string(), VERSION.to_string(), true); + server_parameters.set_param("DateStyle".to_string(), "ISO, MDY".to_string(), true); server_parameters } diff --git a/src/server.rs b/src/server.rs index 6fa8cc8e..86af0b41 100644 --- a/src/server.rs +++ b/src/server.rs @@ -192,6 +192,7 @@ impl ServerParameters { } /// returns true if a tracked parameter was set, false if it was a non-tracked parameter + /// if startup is false, then then only tracked parameters will be set pub fn set_param(&mut self, mut key: String, value: String, startup: bool) { // The startup parameter will send uncapitalized keys but parameter status packets will send capitalized keys if key == "timezone" { From 6510f45a47d023847220fe2691045f81bd416a63 Mon Sep 17 00:00:00 2001 From: Zain Kabani Date: Sat, 17 Jun 2023 11:30:28 -0400 Subject: [PATCH 13/18] Track application name through client life for stats --- src/client.rs | 23 +++++++++++++---------- src/pool.rs | 2 +- src/server.rs | 26 ++++++++++++++++---------- 3 files changed, 30 insertions(+), 21 deletions(-) diff --git a/src/client.rs b/src/client.rs index de681275..c858193f 100644 --- a/src/client.rs +++ b/src/client.rs @@ -96,9 +96,6 @@ pub struct Client { /// Postgres user for this client (This comes from the user in the connection string) username: String, - /// Application name for this client (defaults to pgcat) - application_name: String, - /// Server startup and session parameters that we're going to track server_parameters: ServerParameters, @@ -696,7 +693,6 @@ where last_server_stats: None, pool_name: pool_name.clone(), username: username.clone(), - application_name: application_name.to_string(), server_parameters, shutdown, connected_to_server: false, @@ -732,7 +728,6 @@ where last_server_stats: None, pool_name: String::from("undefined"), username: String::from("undefined"), - application_name: String::from("undefined"), server_parameters: ServerParameters::new(), shutdown, connected_to_server: false, @@ -1238,7 +1233,9 @@ where if !server.in_transaction() { // Report transaction executed statistics. self.stats.transaction(); - server.stats().transaction(&self.application_name); + server + .stats() + .transaction(&self.server_parameters.get_application_name()); // Release server back to the pool if we are in transaction mode. // If we are in session mode, we keep the server until the client disconnects. @@ -1369,7 +1366,9 @@ where if !server.in_transaction() { self.stats.transaction(); - server.stats().transaction(&self.application_name); + server + .stats() + .transaction(&self.server_parameters.get_application_name()); // Release server back to the pool if we are in transaction mode. // If we are in session mode, we keep the server until the client disconnects. @@ -1418,7 +1417,9 @@ where if !server.in_transaction() { self.stats.transaction(); - server.stats().transaction(&self.application_name); + server + .stats() + .transaction(self.server_parameters.get_application_name()); // Release server back to the pool if we are in transaction mode. // If we are in session mode, we keep the server until the client disconnects. @@ -1464,7 +1465,9 @@ where Err(Error::ClientError(format!( "Invalid pool name {{ username: {}, pool_name: {}, application_name: {} }}", - self.pool_name, self.username, self.application_name + self.pool_name, + self.username, + self.server_parameters.get_application_name() ))) } } @@ -1621,7 +1624,7 @@ where client_stats.query(); server.stats().query( Instant::now().duration_since(query_start).as_millis() as u64, - &self.application_name, + &self.server_parameters.get_application_name(), ); Ok(()) diff --git a/src/pool.rs b/src/pool.rs index 2014cb86..d7c69669 100644 --- a/src/pool.rs +++ b/src/pool.rs @@ -436,7 +436,7 @@ impl ConnectionPool { addresses, banlist: Arc::new(RwLock::new(banlist)), config_hash: new_pool_hash_value, - original_server_parameters: Arc::new(RwLock::new(ServerParameters::default())), + original_server_parameters: Arc::new(RwLock::new(ServerParameters::new())), auth_hash: pool_auth_hash, settings: PoolSettings { pool_mode: match user.pool_mode { diff --git a/src/server.rs b/src/server.rs index 86af0b41..9b4687cf 100644 --- a/src/server.rs +++ b/src/server.rs @@ -169,7 +169,16 @@ pub struct ServerParameters { impl Default for ServerParameters { fn default() -> Self { - let mut server_parameters = ServerParameters::new(); + Self::new() + } +} + +impl ServerParameters { + pub fn new() -> Self { + let mut server_parameters = ServerParameters { + parameters: HashMap::new(), + }; + server_parameters.set_param("client_encoding".to_string(), "UTF8".to_string(), false); server_parameters.set_param("DateStyle".to_string(), "ISO, MDY".to_string(), false); server_parameters.set_param("TimeZone".to_string(), "Etc/UTC".to_string(), false); @@ -182,14 +191,6 @@ impl Default for ServerParameters { server_parameters } -} - -impl ServerParameters { - pub fn new() -> Self { - ServerParameters { - parameters: HashMap::new(), - } - } /// returns true if a tracked parameter was set, false if it was a non-tracked parameter /// if startup is false, then then only tracked parameters will be set @@ -236,6 +237,11 @@ impl ServerParameters { diff } + pub fn get_application_name(&self) -> &String { + // Can unwrap because we set it in the constructor + self.parameters.get("application_name").unwrap() + } + pub fn get_bytes(&self) -> BytesMut { let mut bytes = BytesMut::new(); @@ -748,7 +754,7 @@ impl Server { // Save the parameter so we can pass it to the client later. // These can be server_encoding, client_encoding, server timezone, Postgres version, // and many more interesting things we should know about the Postgres server we are talking to. - let _ = server_parameters.set_param(key, value, true); + server_parameters.set_param(key, value, true); } // BackendKeyData From f4d250d9aed95aaafeae3af0e0f452b4f96b99e4 Mon Sep 17 00:00:00 2001 From: Zain Kabani Date: Sat, 17 Jun 2023 15:16:01 -0400 Subject: [PATCH 14/18] Add tests --- tests/ruby/helpers/pgcat_process.rb | 10 ++++++++-- tests/ruby/misc_spec.rb | 24 ++++++++++++++++++++++++ 2 files changed, 32 insertions(+), 2 deletions(-) diff --git a/tests/ruby/helpers/pgcat_process.rb b/tests/ruby/helpers/pgcat_process.rb index e1dbea8b..dd3fd052 100644 --- a/tests/ruby/helpers/pgcat_process.rb +++ b/tests/ruby/helpers/pgcat_process.rb @@ -112,10 +112,16 @@ def admin_connection_string "postgresql://#{username}:#{password}@0.0.0.0:#{@port}/pgcat" end - def connection_string(pool_name, username, password = nil) + def connection_string(pool_name, username, password = nil, parameters: {}) cfg = current_config user_idx, user_obj = cfg["pools"][pool_name]["users"].detect { |k, user| user["username"] == username } - "postgresql://#{username}:#{password || user_obj["password"]}@0.0.0.0:#{@port}/#{pool_name}" + connection_string = "postgresql://#{username}:#{password || user_obj["password"]}@0.0.0.0:#{@port}/#{pool_name}" + + # Add the additional parameters to the connection string + parameter_string = parameters.map { |key, value| "#{key}=#{value}" }.join("&") + connection_string += "?#{parameter_string}" unless parameter_string.empty? + + connection_string end def example_connection_string diff --git a/tests/ruby/misc_spec.rb b/tests/ruby/misc_spec.rb index fe216e5b..628680bd 100644 --- a/tests/ruby/misc_spec.rb +++ b/tests/ruby/misc_spec.rb @@ -294,6 +294,30 @@ expect(processes.primary.count_query("DISCARD ALL")).to eq(10) end + + it "Respects tracked parameters on startup" do + conn = PG::connect(processes.pgcat.connection_string("sharded_db", "sharding_user", parameters: { "application_name" => "my_pgcat_test" })) + + expect(conn.async_exec("SHOW application_name")[0]["application_name"]).to eq("my_pgcat_test") + conn.close + end + + it "Respect tracked parameter on set statemet" do + conn = PG::connect(processes.pgcat.connection_string("sharded_db", "sharding_user")) + + conn.async_exec("SET application_name to 'my_pgcat_test'") + expect(conn.async_exec("SHOW application_name")[0]["application_name"]).to eq("my_pgcat_test") + end + + + it "Ignore untracked parameter on set statemet" do + conn = PG::connect(processes.pgcat.connection_string("sharded_db", "sharding_user")) + orignal_statement_timeout = conn.async_exec("SHOW statement_timeout")[0]["statement_timeout"] + + conn.async_exec("SET statement_timeout to 1500") + expect(conn.async_exec("SHOW statement_timeout")[0]["statement_timeout"]).to eq(orignal_statement_timeout) + end + end context "transaction mode with transactions" do From 6135196b1c84ac2e0fbb5f9e40d03a334f7036ca Mon Sep 17 00:00:00 2001 From: Zain Kabani Date: Sun, 18 Jun 2023 23:24:40 -0400 Subject: [PATCH 15/18] minor refactoring --- src/client.rs | 2 +- src/server.rs | 41 ++++++++++++++++++++--------------------- 2 files changed, 21 insertions(+), 22 deletions(-) diff --git a/src/client.rs b/src/client.rs index c858193f..30e38752 100644 --- a/src/client.rs +++ b/src/client.rs @@ -663,7 +663,7 @@ where debug!("Password authentication successful"); auth_ok(&mut write).await?; - write_all(&mut write, server_parameters.get_bytes()).await?; + write_all(&mut write, (&server_parameters).into()).await?; backend_key_data(&mut write, process_id, secret_key).await?; ready_for_query(&mut write).await?; diff --git a/src/server.rs b/src/server.rs index 9b4687cf..b9bb93cf 100644 --- a/src/server.rs +++ b/src/server.rs @@ -149,9 +149,7 @@ impl std::fmt::Display for CleanupState { static INIT: Once = Once::new(); static TRACKED_PARAMETERS: Lazy> = Lazy::new(|| { - INIT.call_once(|| { - info!("Initializing the TRACKED_PARAMETERS hashset"); - }); + INIT.call_once(|| {}); let mut set = HashSet::new(); set.insert("client_encoding".to_string()); @@ -222,14 +220,13 @@ impl ServerParameters { fn compare_params(&self, incoming_parameters: &ServerParameters) -> HashMap { let mut diff = HashMap::new(); - for (key, value) in &self.parameters { - if !TRACKED_PARAMETERS.contains(key) { - continue; - } - + // iterate through tracked parameters + for key in TRACKED_PARAMETERS.iter() { if let Some(incoming_value) = incoming_parameters.parameters.get(key) { - if value != incoming_value { - diff.insert(key.to_string(), incoming_value.to_string()); + if let Some(value) = self.parameters.get(key) { + if value != incoming_value { + diff.insert(key.to_string(), incoming_value.to_string()); + } } } } @@ -242,17 +239,7 @@ impl ServerParameters { self.parameters.get("application_name").unwrap() } - pub fn get_bytes(&self) -> BytesMut { - let mut bytes = BytesMut::new(); - - for (key, value) in &self.parameters { - self.add_parameter_message(key, value, &mut bytes); - } - - bytes - } - - fn add_parameter_message(&self, key: &str, value: &str, buffer: &mut BytesMut) { + fn add_parameter_message(key: &str, value: &str, buffer: &mut BytesMut) { buffer.put_u8(b'S'); // 4 is len of i32, the plus for the null terminator @@ -267,6 +254,18 @@ impl ServerParameters { } } +impl From<&ServerParameters> for BytesMut { + fn from(server_parameters: &ServerParameters) -> Self { + let mut bytes = BytesMut::new(); + + for (key, value) in &server_parameters.parameters { + ServerParameters::add_parameter_message(key, value, &mut bytes); + } + + bytes + } +} + // pub fn compare /// Server state. From 4a0c7ebcf057b686cceddf6e55181a97ad74d507 Mon Sep 17 00:00:00 2001 From: Zain Kabani Date: Sun, 18 Jun 2023 23:49:10 -0400 Subject: [PATCH 16/18] fmt --- src/server.rs | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/server.rs b/src/server.rs index b9bb93cf..2279e443 100644 --- a/src/server.rs +++ b/src/server.rs @@ -147,10 +147,7 @@ impl std::fmt::Display for CleanupState { } } -static INIT: Once = Once::new(); static TRACKED_PARAMETERS: Lazy> = Lazy::new(|| { - INIT.call_once(|| {}); - let mut set = HashSet::new(); set.insert("client_encoding".to_string()); set.insert("DateStyle".to_string()); From 4169b47ff134298b61cb9ada07522e405f195eed Mon Sep 17 00:00:00 2001 From: Zain Kabani Date: Wed, 9 Aug 2023 15:10:48 -0400 Subject: [PATCH 17/18] fix --- src/client.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/client.rs b/src/client.rs index 389b6e2d..c9baa001 100644 --- a/src/client.rs +++ b/src/client.rs @@ -778,7 +778,7 @@ where let mut will_prepare = false; let client_identifier = - ClientIdentifier::new(&self.application_name, &self.username, &self.pool_name); + ClientIdentifier::new(&self.server_parameters.get_application_name(), &self.username, &self.pool_name); // Our custom protocol loop. // We expect the client to either start a transaction with regular queries From d6e0e14f7ffa69048957e0a5bb4e11d4b392a805 Mon Sep 17 00:00:00 2001 From: Zain Kabani Date: Wed, 9 Aug 2023 15:19:42 -0400 Subject: [PATCH 18/18] fmt --- src/client.rs | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/client.rs b/src/client.rs index c9baa001..6cdea987 100644 --- a/src/client.rs +++ b/src/client.rs @@ -777,8 +777,11 @@ where let mut prepared_statement = None; let mut will_prepare = false; - let client_identifier = - ClientIdentifier::new(&self.server_parameters.get_application_name(), &self.username, &self.pool_name); + let client_identifier = ClientIdentifier::new( + &self.server_parameters.get_application_name(), + &self.username, + &self.pool_name, + ); // Our custom protocol loop. // We expect the client to either start a transaction with regular queries