diff --git a/.circleci/config.yml b/.circleci/config.yml index c7f5c9fa..c8344911 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -9,7 +9,7 @@ jobs: # Specify the execution environment. You can specify an image from Dockerhub or use one of our Convenience Images from CircleCI's Developer Hub. # See: https://circleci.com/docs/2.0/configuration-reference/#docker-machine-macos-windows-executor docker: - - image: ghcr.io/levkk/pgcat-ci:1.67 + - image: ghcr.io/postgresml/pgcat-ci:latest environment: RUST_LOG: info LLVM_PROFILE_FILE: /tmp/pgcat-%m-%p.profraw diff --git a/pgcat.toml b/pgcat.toml index 654f5e89..772a1365 100644 --- a/pgcat.toml +++ b/pgcat.toml @@ -177,6 +177,12 @@ primary_reads_enabled = true # shard_id_regex = '/\* shard_id: (\d+) \*/' # regex_search_limit = 1000 # only look at the first 1000 characters of SQL statements +# Defines the behavior when no shard is selected in a sharded system. +# `random`: picks a shard at random +# `random_healthy`: picks a shard at random favoring shards with the least number of recent errors +# `shard_<number>`: e.g. shard_0, shard_4, etc. picks a specific shard, everytime +# no_shard_specified_behavior = "shard_0" + # So what if you wanted to implement a different hashing function, # or you've already built one and you want this pooler to use it? # Current options: diff --git a/src/client.rs b/src/client.rs index 8edecea1..4b281121 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1009,23 +1009,27 @@ where // SET SHARD TO Some((Command::SetShard, _)) => { - // Selected shard is not configured. - if query_router.shard() >= pool.shards() { - // Set the shard back to what it was. - query_router.set_shard(current_shard); - - error_response( - &mut self.write, - &format!( - "shard {} is more than configured {}, staying on shard {} (shard numbers start at 0)", - query_router.shard(), - pool.shards(), - current_shard, - ), - ) - .await?; - } else { - custom_protocol_response_ok(&mut self.write, "SET SHARD").await?; + match query_router.shard() { + None => (), + Some(selected_shard) => { + if selected_shard >= pool.shards() { + // Bad shard number, send error message to client. + query_router.set_shard(current_shard); + + error_response( + &mut self.write, + &format!( + "shard {} is not configured {}, staying on shard {:?} (shard numbers start at 0)", + selected_shard, + pool.shards(), + current_shard, + ), + ) + .await?; + } else { + custom_protocol_response_ok(&mut self.write, "SET SHARD").await?; + } + } } continue; } @@ -1093,8 +1097,11 @@ where self.buffer.clear(); } - error_response(&mut self.write, "could not get connection from the pool") - .await?; + error_response( + &mut self.write, + format!("could not get connection from the pool - {}", err).as_str(), + ) + .await?; error!( "Could not get connection from pool: \ @@ -1234,7 +1241,7 @@ where {{ \ pool_name: {}, \ username: {}, \ - shard: {}, \ + shard: {:?}, \ role: \"{:?}\" \ }}", self.pool_name, diff --git a/src/config.rs b/src/config.rs index dc915d57..0404abc9 100644 --- a/src/config.rs +++ b/src/config.rs @@ -3,11 +3,14 @@ use arc_swap::ArcSwap; use log::{error, info}; use once_cell::sync::Lazy; use regex::Regex; +use serde::{Deserializer, Serializer}; use serde_derive::{Deserialize, Serialize}; + use std::collections::hash_map::DefaultHasher; use std::collections::{BTreeMap, HashMap, HashSet}; use std::hash::{Hash, Hasher}; use std::path::Path; +use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::Arc; use tokio::fs::File; use tokio::io::AsyncReadExt; @@ -101,6 +104,9 @@ pub struct Address { /// Address stats pub stats: Arc<AddressStats>, + + /// Number of errors encountered since last successful checkout + pub error_count: Arc<AtomicU64>, } impl Default for Address { @@ -118,6 +124,7 @@ impl Default for Address { pool_name: String::from("pool_name"), mirrors: Vec::new(), stats: Arc::new(AddressStats::default()), + error_count: Arc::new(AtomicU64::new(0)), } } } @@ -182,6 +189,18 @@ impl Address { ), } } + + pub fn error_count(&self) -> u64 { + self.error_count.load(Ordering::Relaxed) + } + + pub fn increment_error_count(&self) { + self.error_count.fetch_add(1, Ordering::Relaxed); + } + + pub fn reset_error_count(&self) { + self.error_count.store(0, Ordering::Relaxed); + } } /// PostgreSQL user. @@ -540,6 +559,9 @@ pub struct Pool { pub shard_id_regex: Option<String>, pub regex_search_limit: Option<usize>, + #[serde(default = "Pool::default_default_shard")] + pub default_shard: DefaultShard, + pub auth_query: Option<String>, pub auth_query_user: Option<String>, pub auth_query_password: Option<String>, @@ -575,6 +597,10 @@ impl Pool { PoolMode::Transaction } + pub fn default_default_shard() -> DefaultShard { + DefaultShard::default() + } + pub fn default_load_balancing_mode() -> LoadBalancingMode { LoadBalancingMode::Random } @@ -666,6 +692,16 @@ impl Pool { None => None, }; + match self.default_shard { + DefaultShard::Shard(shard_number) => { + if shard_number >= self.shards.len() { + error!("Invalid shard {:?}", shard_number); + return Err(Error::BadConfig); + } + } + _ => (), + } + for (_, user) in &self.users { user.validate()?; } @@ -693,6 +729,7 @@ impl Default for Pool { sharding_key_regex: None, shard_id_regex: None, regex_search_limit: Some(1000), + default_shard: Self::default_default_shard(), auth_query: None, auth_query_user: None, auth_query_password: None, @@ -711,6 +748,50 @@ pub struct ServerConfig { pub role: Role, } +// No Shard Specified handling. +#[derive(Debug, PartialEq, Clone, Eq, Hash, Copy)] +pub enum DefaultShard { + Shard(usize), + Random, + RandomHealthy, +} +impl Default for DefaultShard { + fn default() -> Self { + DefaultShard::Shard(0) + } +} +impl serde::Serialize for DefaultShard { + fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> { + match self { + DefaultShard::Shard(shard) => { + serializer.serialize_str(&format!("shard_{}", &shard.to_string())) + } + DefaultShard::Random => serializer.serialize_str("random"), + DefaultShard::RandomHealthy => serializer.serialize_str("random_healthy"), + } + } +} +impl<'de> serde::Deserialize<'de> for DefaultShard { + fn deserialize<D>(deserializer: D) -> Result<Self, D::Error> + where + D: Deserializer<'de>, + { + let s = String::deserialize(deserializer)?; + if s.starts_with("shard_") { + let shard = s[6..].parse::<usize>().map_err(serde::de::Error::custom)?; + return Ok(DefaultShard::Shard(shard)); + } + + match s.as_str() { + "random" => Ok(DefaultShard::Random), + "random_healthy" => Ok(DefaultShard::RandomHealthy), + _ => Err(serde::de::Error::custom( + "invalid value for no_shard_specified_behavior", + )), + } + } +} + #[derive(Clone, PartialEq, Serialize, Deserialize, Debug, Hash, Eq)] pub struct MirrorServerConfig { pub host: String, diff --git a/src/errors.rs b/src/errors.rs index 014a1340..a6aebc50 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -28,6 +28,7 @@ pub enum Error { UnsupportedStatement, QueryRouterParserError(String), QueryRouterError(String), + InvalidShardId(usize), } #[derive(Clone, PartialEq, Debug)] diff --git a/src/pool.rs b/src/pool.rs index 7e110ce2..18123407 100644 --- a/src/pool.rs +++ b/src/pool.rs @@ -10,6 +10,7 @@ use rand::thread_rng; use regex::Regex; use std::collections::HashMap; use std::fmt::{Display, Formatter}; +use std::sync::atomic::AtomicU64; use std::sync::{ atomic::{AtomicBool, Ordering}, Arc, @@ -18,7 +19,7 @@ use std::time::Instant; use tokio::sync::Notify; use crate::config::{ - get_config, Address, General, LoadBalancingMode, Plugins, PoolMode, Role, User, + get_config, Address, DefaultShard, General, LoadBalancingMode, Plugins, PoolMode, Role, User, }; use crate::errors::Error; @@ -140,6 +141,9 @@ pub struct PoolSettings { // Regex for searching for the shard id in SQL statements pub shard_id_regex: Option<Regex>, + // What to do when no shard is selected in a sharded system + pub default_shard: DefaultShard, + // Limit how much of each query is searched for a potential shard regex match pub regex_search_limit: usize, @@ -173,6 +177,7 @@ impl Default for PoolSettings { sharding_key_regex: None, shard_id_regex: None, regex_search_limit: 1000, + default_shard: DefaultShard::Shard(0), auth_query: None, auth_query_user: None, auth_query_password: None, @@ -299,6 +304,7 @@ impl ConnectionPool { pool_name: pool_name.clone(), mirrors: vec![], stats: Arc::new(AddressStats::default()), + error_count: Arc::new(AtomicU64::new(0)), }); address_id += 1; } @@ -317,6 +323,7 @@ impl ConnectionPool { pool_name: pool_name.clone(), mirrors: mirror_addresses, stats: Arc::new(AddressStats::default()), + error_count: Arc::new(AtomicU64::new(0)), }; address_id += 1; @@ -482,6 +489,7 @@ impl ConnectionPool { .clone() .map(|regex| Regex::new(regex.as_str()).unwrap()), regex_search_limit: pool_config.regex_search_limit.unwrap_or(1000), + default_shard: pool_config.default_shard.clone(), auth_query: pool_config.auth_query.clone(), auth_query_user: pool_config.auth_query_user.clone(), auth_query_password: pool_config.auth_query_password.clone(), @@ -603,19 +611,51 @@ impl ConnectionPool { /// Get a connection from the pool. pub async fn get( &self, - shard: usize, // shard number + shard: Option<usize>, // shard number role: Option<Role>, // primary or replica client_stats: &ClientStats, // client id ) -> Result<(PooledConnection<'_, ServerPool>, Address), Error> { - let mut candidates: Vec<&Address> = self.addresses[shard] + let effective_shard_id = if self.shards() == 1 { + // The base, unsharded case + Some(0) + } else { + if !self.valid_shard_id(shard) { + // None is valid shard ID so it is safe to unwrap here + return Err(Error::InvalidShardId(shard.unwrap())); + } + shard + }; + + let mut candidates = self + .addresses .iter() + .flatten() .filter(|address| address.role == role) - .collect(); + .collect::<Vec<&Address>>(); - // We shuffle even if least_outstanding_queries is used to avoid imbalance - // in cases where all candidates have more or less the same number of outstanding - // queries + // We start with a shuffled list of addresses even if we end up resorting + // this is meant to avoid hitting instance 0 everytime if the sorting metric + // ends up being the same for all instances candidates.shuffle(&mut thread_rng()); + + match effective_shard_id { + Some(shard_id) => candidates.retain(|address| address.shard == shard_id), + None => match self.settings.default_shard { + DefaultShard::Shard(shard_id) => { + candidates.retain(|address| address.shard == shard_id) + } + DefaultShard::Random => (), + DefaultShard::RandomHealthy => { + candidates.sort_by(|a, b| { + b.error_count + .load(Ordering::Relaxed) + .partial_cmp(&a.error_count.load(Ordering::Relaxed)) + .unwrap() + }); + } + }, + }; + if self.settings.load_balancing_mode == LoadBalancingMode::LeastOutstandingConnections { candidates.sort_by(|a, b| { self.busy_connection_count(b) @@ -651,7 +691,10 @@ impl ConnectionPool { .get() .await { - Ok(conn) => conn, + Ok(conn) => { + address.reset_error_count(); + conn + } Err(err) => { error!( "Connection checkout error for instance {:?}, error: {:?}", @@ -766,6 +809,18 @@ impl ConnectionPool { /// traffic for any new transactions. Existing transactions on that replica /// will finish successfully or error out to the clients. pub fn ban(&self, address: &Address, reason: BanReason, client_info: Option<&ClientStats>) { + // Count the number of errors since the last successful checkout + // This is used to determine if the shard is down + match reason { + BanReason::FailedHealthCheck + | BanReason::FailedCheckout + | BanReason::MessageSendFailed + | BanReason::MessageReceiveFailed => { + address.increment_error_count(); + } + _ => (), + }; + // Primary can never be banned if address.role == Role::Primary { return; @@ -920,6 +975,7 @@ impl ConnectionPool { self.original_server_parameters.read().clone() } + /// Get the number of checked out connection for an address fn busy_connection_count(&self, address: &Address) -> u32 { let state = self.pool_state(address.shard, address.address_index); let idle = state.idle_connections; @@ -933,6 +989,13 @@ impl ConnectionPool { debug!("{:?} has {:?} busy connections", address, busy); return busy; } + + fn valid_shard_id(&self, shard: Option<usize>) -> bool { + match shard { + None => true, + Some(shard) => shard < self.shards(), + } + } } /// Wrapper for the bb8 connection pool. diff --git a/src/query_router.rs b/src/query_router.rs index efca499f..9d7a106a 100644 --- a/src/query_router.rs +++ b/src/query_router.rs @@ -143,13 +143,14 @@ impl QueryRouter { let code = message_cursor.get_u8() as char; let len = message_cursor.get_i32() as usize; + let comment_shard_routing_enabled = self.pool_settings.shard_id_regex.is_some() + || self.pool_settings.sharding_key_regex.is_some(); + // Check for any sharding regex matches in any queries - match code as char { - // For Parse and Query messages peek to see if they specify a shard_id as a comment early in the statement - 'P' | 'Q' => { - if self.pool_settings.shard_id_regex.is_some() - || self.pool_settings.sharding_key_regex.is_some() - { + if comment_shard_routing_enabled { + match code as char { + // For Parse and Query messages peek to see if they specify a shard_id as a comment early in the statement + 'P' | 'Q' => { // Check only the first block of bytes configured by the pool settings let seg = cmp::min(len - 5, self.pool_settings.regex_search_limit); @@ -166,7 +167,7 @@ impl QueryRouter { }); if let Some(shard_id) = shard_id { debug!("Setting shard to {:?}", shard_id); - self.set_shard(shard_id); + self.set_shard(Some(shard_id)); // Skip other command processing since a sharding command was found return None; } @@ -188,8 +189,8 @@ impl QueryRouter { } } } + _ => {} } - _ => {} } // Only simple protocol supported for commands processed below @@ -248,7 +249,9 @@ impl QueryRouter { } } - Command::ShowShard => self.shard().to_string(), + Command::ShowShard => self + .shard() + .map_or_else(|| "unset".to_string(), |x| x.to_string()), Command::ShowServerRole => match self.active_role { Some(Role::Primary) => Role::Primary.to_string(), Some(Role::Replica) => Role::Replica.to_string(), @@ -581,7 +584,7 @@ impl QueryRouter { // TODO: Support multi-shard queries some day. if shards.len() == 1 { debug!("Found one sharding key"); - self.set_shard(*shards.first().unwrap()); + self.set_shard(Some(*shards.first().unwrap())); true } else { debug!("Found no sharding keys"); @@ -865,7 +868,7 @@ impl QueryRouter { self.pool_settings.sharding_function, ); let shard = sharder.shard(sharding_key); - self.set_shard(shard); + self.set_shard(Some(shard)); self.active_shard } @@ -875,12 +878,12 @@ impl QueryRouter { } /// Get desired shard we should be talking to. - pub fn shard(&self) -> usize { - self.active_shard.unwrap_or(0) + pub fn shard(&self) -> Option<usize> { + self.active_shard } - pub fn set_shard(&mut self, shard: usize) { - self.active_shard = Some(shard); + pub fn set_shard(&mut self, shard: Option<usize>) { + self.active_shard = shard; } /// Should we attempt to parse queries? @@ -1090,7 +1093,7 @@ mod test { qr.try_execute_command(&query), Some((Command::SetShardingKey, String::from("0"))) ); - assert_eq!(qr.shard(), 0); + assert_eq!(qr.shard().unwrap(), 0); // SetShard let query = simple_query("SET SHARD TO '1'"); @@ -1098,7 +1101,7 @@ mod test { qr.try_execute_command(&query), Some((Command::SetShard, String::from("1"))) ); - assert_eq!(qr.shard(), 1); + assert_eq!(qr.shard().unwrap(), 1); // ShowShard let query = simple_query("SHOW SHARD"); @@ -1204,6 +1207,7 @@ mod test { ban_time: PoolSettings::default().ban_time, sharding_key_regex: None, shard_id_regex: None, + default_shard: crate::config::DefaultShard::Shard(0), regex_search_limit: 1000, auth_query: None, auth_query_password: None, @@ -1281,6 +1285,7 @@ mod test { ban_time: PoolSettings::default().ban_time, sharding_key_regex: Some(Regex::new(r"/\* sharding_key: (\d+) \*/").unwrap()), shard_id_regex: Some(Regex::new(r"/\* shard_id: (\d+) \*/").unwrap()), + default_shard: crate::config::DefaultShard::Shard(0), regex_search_limit: 1000, auth_query: None, auth_query_password: None, @@ -1331,7 +1336,7 @@ mod test { .unwrap(), ) .is_ok()); - assert_eq!(qr.shard(), 2); + assert_eq!(qr.shard().unwrap(), 2); assert!(qr .infer( @@ -1341,7 +1346,7 @@ mod test { .unwrap() ) .is_ok()); - assert_eq!(qr.shard(), 0); + assert_eq!(qr.shard().unwrap(), 0); assert!(qr .infer( @@ -1354,7 +1359,7 @@ mod test { .unwrap() ) .is_ok()); - assert_eq!(qr.shard(), 2); + assert_eq!(qr.shard().unwrap(), 2); // Shard did not move because we couldn't determine the sharding key since it could be ambiguous // in the query. @@ -1366,7 +1371,7 @@ mod test { .unwrap() ) .is_ok()); - assert_eq!(qr.shard(), 2); + assert_eq!(qr.shard().unwrap(), 2); assert!(qr .infer( @@ -1376,7 +1381,7 @@ mod test { .unwrap() ) .is_ok()); - assert_eq!(qr.shard(), 0); + assert_eq!(qr.shard().unwrap(), 0); assert!(qr .infer( @@ -1386,7 +1391,7 @@ mod test { .unwrap() ) .is_ok()); - assert_eq!(qr.shard(), 2); + assert_eq!(qr.shard().unwrap(), 2); // Super unique sharding key qr.pool_settings.automatic_sharding_key = Some("*.unique_enough_column_name".to_string()); @@ -1398,7 +1403,7 @@ mod test { .unwrap() ) .is_ok()); - assert_eq!(qr.shard(), 0); + assert_eq!(qr.shard().unwrap(), 0); assert!(qr .infer( @@ -1406,7 +1411,7 @@ mod test { .unwrap() ) .is_ok()); - assert_eq!(qr.shard(), 0); + assert_eq!(qr.shard().unwrap(), 0); } #[test] @@ -1434,7 +1439,7 @@ mod test { assert_eq!(qr.placeholders.len(), 1); assert!(qr.infer_shard_from_bind(&bind)); - assert_eq!(qr.shard(), 2); + assert_eq!(qr.shard().unwrap(), 2); assert!(qr.placeholders.is_empty()); } diff --git a/tests/docker/Dockerfile b/tests/docker/Dockerfile index 99fd694d..261adb05 100644 --- a/tests/docker/Dockerfile +++ b/tests/docker/Dockerfile @@ -1,5 +1,7 @@ FROM rust:bullseye +COPY --from=sclevine/yj /bin/yj /bin/yj +RUN /bin/yj -h RUN apt-get update && apt-get install llvm-11 psmisc postgresql-contrib postgresql-client ruby ruby-dev libpq-dev python3 python3-pip lcov curl sudo iproute2 -y RUN cargo install cargo-binutils rustfilt RUN rustup component add llvm-tools-preview diff --git a/tests/ruby/auth_query_spec.rb b/tests/ruby/auth_query_spec.rb index 1ac62164..c1ee744a 100644 --- a/tests/ruby/auth_query_spec.rb +++ b/tests/ruby/auth_query_spec.rb @@ -185,7 +185,7 @@ }, } } - } + } context 'and with cleartext passwords set' do it 'it uses local passwords' do diff --git a/tests/ruby/helpers/auth_query_helper.rb b/tests/ruby/helpers/auth_query_helper.rb index 60e85713..43d7c785 100644 --- a/tests/ruby/helpers/auth_query_helper.rb +++ b/tests/ruby/helpers/auth_query_helper.rb @@ -33,18 +33,18 @@ def self.single_shard_auth_query( "0" => { "database" => "shard0", "servers" => [ - ["localhost", primary.port.to_s, "primary"], - ["localhost", replica.port.to_s, "replica"], + ["localhost", primary.port.to_i, "primary"], + ["localhost", replica.port.to_i, "replica"], ] }, }, "users" => { "0" => user.merge(config_user) } } } - pgcat_cfg["general"]["port"] = pgcat.port + pgcat_cfg["general"]["port"] = pgcat.port.to_i pgcat.update_config(pgcat_cfg) pgcat.start - + pgcat.wait_until_ready( pgcat.connection_string( "sharded_db", @@ -92,13 +92,13 @@ def self.two_pools_auth_query( "0" => { "database" => database, "servers" => [ - ["localhost", primary.port.to_s, "primary"], - ["localhost", replica.port.to_s, "replica"], + ["localhost", primary.port.to_i, "primary"], + ["localhost", replica.port.to_i, "replica"], ] }, }, "users" => { "0" => user.merge(config_user) } - } + } end # Main proxy configs pgcat_cfg["pools"] = { @@ -109,7 +109,7 @@ def self.two_pools_auth_query( pgcat_cfg["general"]["port"] = pgcat.port pgcat.update_config(pgcat_cfg.deep_merge(extra_conf)) pgcat.start - + pgcat.wait_until_ready(pgcat.connection_string("sharded_db0", pg_user['username'], pg_user['password'])) OpenStruct.new.tap do |struct| diff --git a/tests/ruby/helpers/pg_instance.rb b/tests/ruby/helpers/pg_instance.rb index a3828248..53617c24 100644 --- a/tests/ruby/helpers/pg_instance.rb +++ b/tests/ruby/helpers/pg_instance.rb @@ -7,10 +7,24 @@ class PgInstance attr_reader :password attr_reader :database_name + def self.mass_takedown(databases) + raise StandardError "block missing" unless block_given? + + databases.each do |database| + database.toxiproxy.toxic(:limit_data, bytes: 1).toxics.each(&:save) + end + sleep 0.1 + yield + ensure + databases.each do |database| + database.toxiproxy.toxics.each(&:destroy) + end + end + def initialize(port, username, password, database_name) - @original_port = port + @original_port = port.to_i @toxiproxy_port = 10000 + port.to_i - @port = @toxiproxy_port + @port = @toxiproxy_port.to_i @username = username @password = password @@ -48,9 +62,9 @@ def toxiproxy def take_down if block_given? - Toxiproxy[@toxiproxy_name].toxic(:limit_data, bytes: 5).apply { yield } + Toxiproxy[@toxiproxy_name].toxic(:limit_data, bytes: 1).apply { yield } else - Toxiproxy[@toxiproxy_name].toxic(:limit_data, bytes: 5).toxics.each(&:save) + Toxiproxy[@toxiproxy_name].toxic(:limit_data, bytes: 1).toxics.each(&:save) end end @@ -89,6 +103,6 @@ def count_query(query) end def count_select_1_plus_2 - with_connection { |c| c.async_exec("SELECT SUM(calls) FROM pg_stat_statements WHERE query = 'SELECT $1 + $2'")[0]["sum"].to_i } + with_connection { |c| c.async_exec("SELECT SUM(calls) FROM pg_stat_statements WHERE query LIKE '%SELECT $1 + $2%'")[0]["sum"].to_i } end end diff --git a/tests/ruby/helpers/pgcat_helper.rb b/tests/ruby/helpers/pgcat_helper.rb index 9b764d87..9b95dbfa 100644 --- a/tests/ruby/helpers/pgcat_helper.rb +++ b/tests/ruby/helpers/pgcat_helper.rb @@ -38,9 +38,9 @@ def self.three_shard_setup(pool_name, pool_size, pool_mode="transaction", lb_mod "automatic_sharding_key" => "data.id", "sharding_function" => "pg_bigint_hash", "shards" => { - "0" => { "database" => "shard0", "servers" => [["localhost", primary0.port.to_s, "primary"]] }, - "1" => { "database" => "shard1", "servers" => [["localhost", primary1.port.to_s, "primary"]] }, - "2" => { "database" => "shard2", "servers" => [["localhost", primary2.port.to_s, "primary"]] }, + "0" => { "database" => "shard0", "servers" => [["localhost", primary0.port.to_i, "primary"]] }, + "1" => { "database" => "shard1", "servers" => [["localhost", primary1.port.to_i, "primary"]] }, + "2" => { "database" => "shard2", "servers" => [["localhost", primary2.port.to_i, "primary"]] }, }, "users" => { "0" => user }, "plugins" => { @@ -100,7 +100,7 @@ def self.single_instance_setup(pool_name, pool_size, pool_mode="transaction", lb "0" => { "database" => "shard0", "servers" => [ - ["localhost", primary.port.to_s, "primary"] + ["localhost", primary.port.to_i, "primary"] ] }, }, @@ -146,10 +146,10 @@ def self.single_shard_setup(pool_name, pool_size, pool_mode="transaction", lb_mo "0" => { "database" => "shard0", "servers" => [ - ["localhost", primary.port.to_s, "primary"], - ["localhost", replica0.port.to_s, "replica"], - ["localhost", replica1.port.to_s, "replica"], - ["localhost", replica2.port.to_s, "replica"] + ["localhost", primary.port.to_i, "primary"], + ["localhost", replica0.port.to_i, "replica"], + ["localhost", replica1.port.to_i, "replica"], + ["localhost", replica2.port.to_i, "replica"] ] }, }, diff --git a/tests/ruby/helpers/pgcat_process.rb b/tests/ruby/helpers/pgcat_process.rb index dd3fd052..9328ff60 100644 --- a/tests/ruby/helpers/pgcat_process.rb +++ b/tests/ruby/helpers/pgcat_process.rb @@ -1,8 +1,10 @@ require 'pg' -require 'toml' +require 'json' +require 'tempfile' require 'fileutils' require 'securerandom' +class ConfigReloadFailed < StandardError; end class PgcatProcess attr_reader :port attr_reader :pid @@ -18,7 +20,7 @@ def self.finalize(pid, log_filename, config_filename) end def initialize(log_level) - @env = {"RUST_LOG" => log_level} + @env = {} @port = rand(20000..32760) @log_level = log_level @log_filename = "/tmp/pgcat_log_#{SecureRandom.urlsafe_base64}.log" @@ -30,7 +32,7 @@ def initialize(log_level) '../../target/debug/pgcat' end - @command = "#{command_path} #{@config_filename}" + @command = "#{command_path} #{@config_filename} --log-level #{@log_level}" FileUtils.cp("../../pgcat.toml", @config_filename) cfg = current_config @@ -46,22 +48,34 @@ def logs def update_config(config_hash) @original_config = current_config - output_to_write = TOML::Generator.new(config_hash).body - output_to_write = output_to_write.gsub(/,\s*["|'](\d+)["|']\s*,/, ',\1,') - output_to_write = output_to_write.gsub(/,\s*["|'](\d+)["|']\s*\]/, ',\1]') - File.write(@config_filename, output_to_write) + Tempfile.create('json_out', '/tmp') do |f| + f.write(config_hash.to_json) + f.flush + `cat #{f.path} | yj -jt > #{@config_filename}` + end end def current_config - loadable_string = File.read(@config_filename) - loadable_string = loadable_string.gsub(/,\s*(\d+)\s*,/, ', "\1",') - loadable_string = loadable_string.gsub(/,\s*(\d+)\s*\]/, ', "\1"]') - TOML.load(loadable_string) + JSON.parse(`cat #{@config_filename} | yj -tj`) + end + + def raw_config_file + File.read(@config_filename) end def reload_config - `kill -s HUP #{@pid}` - sleep 0.5 + conn = PG.connect(admin_connection_string) + + conn.async_exec("RELOAD") + rescue PG::ConnectionBad => e + errors = logs.split("Reloading config").last + errors = errors.gsub(/\e\[([;\d]+)?m/, '') # Remove color codes + errors = errors. + split("\n").select{|line| line.include?("ERROR") }. + map { |line| line.split("pgcat::config: ").last } + raise ConfigReloadFailed, errors.join("\n") + ensure + conn&.close end def start @@ -116,11 +130,11 @@ def connection_string(pool_name, username, password = nil, parameters: {}) cfg = current_config user_idx, user_obj = cfg["pools"][pool_name]["users"].detect { |k, user| user["username"] == username } connection_string = "postgresql://#{username}:#{password || user_obj["password"]}@0.0.0.0:#{@port}/#{pool_name}" - + # Add the additional parameters to the connection string parameter_string = parameters.map { |key, value| "#{key}=#{value}" }.join("&") connection_string += "?#{parameter_string}" unless parameter_string.empty? - + connection_string end diff --git a/tests/ruby/mirrors_spec.rb b/tests/ruby/mirrors_spec.rb index 898d0d71..b6a4514c 100644 --- a/tests/ruby/mirrors_spec.rb +++ b/tests/ruby/mirrors_spec.rb @@ -11,9 +11,9 @@ before do new_configs = processes.pgcat.current_config new_configs["pools"]["sharded_db"]["shards"]["0"]["mirrors"] = [ - [mirror_host, mirror_pg.port.to_s, "0"], - [mirror_host, mirror_pg.port.to_s, "0"], - [mirror_host, mirror_pg.port.to_s, "0"], + [mirror_host, mirror_pg.port.to_i, 0], + [mirror_host, mirror_pg.port.to_i, 0], + [mirror_host, mirror_pg.port.to_i, 0], ] processes.pgcat.update_config(new_configs) processes.pgcat.reload_config @@ -31,7 +31,8 @@ runs.times { conn.async_exec("SELECT 1 + 2") } sleep 0.5 expect(processes.all_databases.first.count_select_1_plus_2).to eq(runs) - expect(mirror_pg.count_select_1_plus_2).to eq(runs * 3) + # Allow some slack in mirroring successes + expect(mirror_pg.count_select_1_plus_2).to be > ((runs - 5) * 3) end context "when main server connection is closed" do @@ -42,9 +43,9 @@ new_configs = processes.pgcat.current_config new_configs["pools"]["sharded_db"]["idle_timeout"] = 5000 + i new_configs["pools"]["sharded_db"]["shards"]["0"]["mirrors"] = [ - [mirror_host, mirror_pg.port.to_s, "0"], - [mirror_host, mirror_pg.port.to_s, "0"], - [mirror_host, mirror_pg.port.to_s, "0"], + [mirror_host, mirror_pg.port.to_i, 0], + [mirror_host, mirror_pg.port.to_i, 0], + [mirror_host, mirror_pg.port.to_i, 0], ] processes.pgcat.update_config(new_configs) processes.pgcat.reload_config diff --git a/tests/ruby/misc_spec.rb b/tests/ruby/misc_spec.rb index 1d4ade4c..aa17e8ec 100644 --- a/tests/ruby/misc_spec.rb +++ b/tests/ruby/misc_spec.rb @@ -252,7 +252,7 @@ end expect(processes.primary.count_query("RESET ROLE")).to eq(10) - end + end end context "transaction mode" do @@ -317,7 +317,7 @@ conn.async_exec("SET statement_timeout to 1500") expect(conn.async_exec("SHOW statement_timeout")[0]["statement_timeout"]).to eq(orignal_statement_timeout) end - + end context "transaction mode with transactions" do @@ -354,7 +354,6 @@ conn.async_exec("SET statement_timeout TO 1000") conn.close - puts processes.pgcat.logs expect(processes.primary.count_query("RESET ALL")).to eq(0) end @@ -365,7 +364,6 @@ conn.close - puts processes.pgcat.logs expect(processes.primary.count_query("RESET ALL")).to eq(0) end end @@ -376,10 +374,9 @@ before do current_configs = processes.pgcat.current_config correct_idle_client_transaction_timeout = current_configs["general"]["idle_client_in_transaction_timeout"] - puts(current_configs["general"]["idle_client_in_transaction_timeout"]) - + current_configs["general"]["idle_client_in_transaction_timeout"] = 0 - + processes.pgcat.update_config(current_configs) # with timeout 0 processes.pgcat.reload_config end @@ -397,9 +394,9 @@ context "idle transaction timeout set to 500ms" do before do current_configs = processes.pgcat.current_config - correct_idle_client_transaction_timeout = current_configs["general"]["idle_client_in_transaction_timeout"] + correct_idle_client_transaction_timeout = current_configs["general"]["idle_client_in_transaction_timeout"] current_configs["general"]["idle_client_in_transaction_timeout"] = 500 - + processes.pgcat.update_config(current_configs) # with timeout 500 processes.pgcat.reload_config end @@ -418,7 +415,7 @@ conn.async_exec("BEGIN") conn.async_exec("SELECT 1") sleep(1) # above 500ms - expect{ conn.async_exec("COMMIT") }.to raise_error(PG::SystemError, /idle transaction timeout/) + expect{ conn.async_exec("COMMIT") }.to raise_error(PG::SystemError, /idle transaction timeout/) conn.async_exec("SELECT 1") # should be able to send another query conn.close end diff --git a/tests/ruby/sharding_spec.rb b/tests/ruby/sharding_spec.rb index 123c10dc..746627d1 100644 --- a/tests/ruby/sharding_spec.rb +++ b/tests/ruby/sharding_spec.rb @@ -7,11 +7,11 @@ before do conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user")) - # Setup the sharding data 3.times do |i| conn.exec("SET SHARD TO '#{i}'") - conn.exec("DELETE FROM data WHERE id > 0") + + conn.exec("DELETE FROM data WHERE id > 0") rescue nil end 18.times do |i| @@ -19,10 +19,11 @@ conn.exec("SET SHARDING KEY TO '#{i}'") conn.exec("INSERT INTO data (id, value) VALUES (#{i}, 'value_#{i}')") end + + conn.close end after do - processes.all_databases.map(&:reset) processes.pgcat.shutdown end @@ -48,4 +49,148 @@ end end end + + describe "no_shard_specified_behavior config" do + context "when default shard number is invalid" do + it "prevents config reload" do + admin_conn = PG::connect(processes.pgcat.admin_connection_string) + + current_configs = processes.pgcat.current_config + current_configs["pools"]["sharded_db"]["default_shard"] = "shard_99" + + processes.pgcat.update_config(current_configs) + + expect { processes.pgcat.reload_config }.to raise_error(ConfigReloadFailed, /Invalid shard 99/) + end + end + end + + describe "comment-based routing" do + context "when no configs are set" do + it "routes queries with a shard_id comment to the default shard" do + conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user")) + 10.times { conn.async_exec("/* shard_id: 2 */ SELECT 1 + 2") } + + expect(processes.all_databases.map(&:count_select_1_plus_2)).to eq([10, 0, 0]) + end + + it "does not honor no_shard_specified_behavior directives" do + end + end + + [ + ["shard_id_regex", "/\\* the_shard_id: (\\d+) \\*/", "/* the_shard_id: 1 */"], + ["sharding_key_regex", "/\\* the_sharding_key: (\\d+) \\*/", "/* the_sharding_key: 3 */"], + ].each do |config_name, config_value, comment_to_use| + context "when #{config_name} config is set" do + let(:no_shard_specified_behavior) { nil } + + before do + admin_conn = PG::connect(processes.pgcat.admin_connection_string) + + current_configs = processes.pgcat.current_config + current_configs["pools"]["sharded_db"][config_name] = config_value + if no_shard_specified_behavior + current_configs["pools"]["sharded_db"]["default_shard"] = no_shard_specified_behavior + else + current_configs["pools"]["sharded_db"].delete("default_shard") + end + + processes.pgcat.update_config(current_configs) + processes.pgcat.reload_config + end + + it "routes queries with a shard_id comment to the correct shard" do + conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user")) + 25.times { conn.async_exec("#{comment_to_use} SELECT 1 + 2") } + + expect(processes.all_databases.map(&:count_select_1_plus_2)).to eq([0, 25, 0]) + end + + context "when no_shard_specified_behavior config is set to random" do + let(:no_shard_specified_behavior) { "random" } + + context "with no shard comment" do + it "sends queries to random shard" do + conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user")) + 25.times { conn.async_exec("SELECT 1 + 2") } + + expect(processes.all_databases.map(&:count_select_1_plus_2).all?(&:positive?)).to be true + end + end + + context "with a shard comment" do + it "honors the comment" do + conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user")) + 25.times { conn.async_exec("#{comment_to_use} SELECT 1 + 2") } + + expect(processes.all_databases.map(&:count_select_1_plus_2)).to eq([0, 25, 0]) + end + end + end + + context "when no_shard_specified_behavior config is set to random_healthy" do + let(:no_shard_specified_behavior) { "random_healthy" } + + context "with no shard comment" do + it "sends queries to random healthy shard" do + + good_databases = [processes.all_databases[0], processes.all_databases[2]] + bad_database = processes.all_databases[1] + conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user")) + 250.times { conn.async_exec("SELECT 99") } + bad_database.take_down do + 250.times do + conn.async_exec("SELECT 99") + rescue PG::ConnectionBad => e + conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user")) + end + end + + # Routes traffic away from bad shard + 25.times { conn.async_exec("SELECT 1 + 2") } + expect(good_databases.map(&:count_select_1_plus_2).all?(&:positive?)).to be true + expect(bad_database.count_select_1_plus_2).to eq(0) + + # Routes traffic to the bad shard if the shard_id is specified + 25.times { conn.async_exec("#{comment_to_use} SELECT 1 + 2") } + bad_database = processes.all_databases[1] + expect(bad_database.count_select_1_plus_2).to eq(25) + end + end + + context "with a shard comment" do + it "honors the comment" do + conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user")) + 25.times { conn.async_exec("#{comment_to_use} SELECT 1 + 2") } + + expect(processes.all_databases.map(&:count_select_1_plus_2)).to eq([0, 25, 0]) + end + end + end + + context "when no_shard_specified_behavior config is set to shard_x" do + let(:no_shard_specified_behavior) { "shard_2" } + + context "with no shard comment" do + it "sends queries to the specified shard" do + conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user")) + 25.times { conn.async_exec("SELECT 1 + 2") } + + expect(processes.all_databases.map(&:count_select_1_plus_2)).to eq([0, 0, 25]) + end + end + + context "with a shard comment" do + it "honors the comment" do + conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user")) + 25.times { conn.async_exec("#{comment_to_use} SELECT 1 + 2") } + + expect(processes.all_databases.map(&:count_select_1_plus_2)).to eq([0, 25, 0]) + end + end + end + end + end + end end