From 75836da2e27b960abf4c5da0a1d88bedac524bdc Mon Sep 17 00:00:00 2001
From: Mostafa Abdelraouf <mostafa.abdelraouf@instacart.com>
Date: Mon, 4 Sep 2023 08:27:57 -0500
Subject: [PATCH 1/8] Allow configuring routing decision when no shard is
 selected

---
 .circleci/config.yml                    |   2 +-
 pgcat.toml                              |   9 +-
 src/client.rs                           |  12 +-
 src/config.rs                           |  84 +++++++++++++
 src/errors.rs                           |   1 +
 src/pool.rs                             |  78 ++++++++++--
 src/query_router.rs                     |  57 +++++----
 tests/docker/Dockerfile                 |   2 +
 tests/ruby/auth_query_spec.rb           |   2 +-
 tests/ruby/helpers/auth_query_helper.rb |  16 +--
 tests/ruby/helpers/pg_instance.rb       |  24 +++-
 tests/ruby/helpers/pgcat_helper.rb      |  16 +--
 tests/ruby/helpers/pgcat_process.rb     |  44 ++++---
 tests/ruby/mirrors_spec.rb              |  19 +--
 tests/ruby/misc_spec.rb                 |  17 ++-
 tests/ruby/sharding_spec.rb             | 151 +++++++++++++++++++++++-
 16 files changed, 434 insertions(+), 100 deletions(-)

diff --git a/.circleci/config.yml b/.circleci/config.yml
index c7f5c9fa..9fe3c256 100644
--- a/.circleci/config.yml
+++ b/.circleci/config.yml
@@ -9,7 +9,7 @@ jobs:
     # Specify the execution environment. You can specify an image from Dockerhub or use one of our Convenience Images from CircleCI's Developer Hub.
     # See: https://circleci.com/docs/2.0/configuration-reference/#docker-machine-macos-windows-executor
     docker:
-      - image: ghcr.io/levkk/pgcat-ci:1.67
+      - image: ghcr.io/levkk/pgcat-ci:latest
         environment:
           RUST_LOG: info
           LLVM_PROFILE_FILE: /tmp/pgcat-%m-%p.profraw
diff --git a/pgcat.toml b/pgcat.toml
index 654f5e89..39689cd0 100644
--- a/pgcat.toml
+++ b/pgcat.toml
@@ -171,12 +171,17 @@ query_parser_read_write_splitting = true
 # queries. The primary can always be explicitly selected with our custom protocol.
 primary_reads_enabled = true
 
-# Allow sharding commands to be passed as statement comments instead of
-# separate commands. If these are unset this functionality is disabled.
 # sharding_key_regex = '/\* sharding_key: (\d+) \*/'
 # shard_id_regex = '/\* shard_id: (\d+) \*/'
 # regex_search_limit = 1000 # only look at the first 1000 characters of SQL statements
 
+# Defines the behavior when no shard_id or sharding_key are specified for a query against
+# a sharded system with either sharding_key_regex or shard_id_regex specified.
+# `random`: picks a shard at random
+# `random_healthy`: picks a shard at random favoring shards with the least number of recent errors
+# `shard_<number>`: e.g. shard_0, shard_4, etc. picks a specific shard, everytime
+# no_shard_specified_behavior = "random"
+
 # So what if you wanted to implement a different hashing function,
 # or you've already built one and you want this pooler to use it?
 # Current options:
diff --git a/src/client.rs b/src/client.rs
index 8edecea1..56fddbfe 100644
--- a/src/client.rs
+++ b/src/client.rs
@@ -1010,14 +1010,14 @@ where
                 // SET SHARD TO
                 Some((Command::SetShard, _)) => {
                     // Selected shard is not configured.
-                    if query_router.shard() >= pool.shards() {
+                    if query_router.shard().unwrap_or(0) >= pool.shards() {
                         // Set the shard back to what it was.
                         query_router.set_shard(current_shard);
 
                         error_response(
                             &mut self.write,
                             &format!(
-                                "shard {} is more than configured {}, staying on shard {} (shard numbers start at 0)",
+                                "shard {:?} is more than configured {}, staying on shard {:?} (shard numbers start at 0)",
                                 query_router.shard(),
                                 pool.shards(),
                                 current_shard,
@@ -1093,8 +1093,10 @@ where
                         self.buffer.clear();
                     }
 
-                    error_response(&mut self.write, "could not get connection from the pool")
-                        .await?;
+                    error_response(
+                        &mut self.write,
+                        format!("could not get connection from the pool - {}", err).as_str()
+                ).await?;
 
                     error!(
                         "Could not get connection from pool: \
@@ -1234,7 +1236,7 @@ where
                                     {{ \
                                         pool_name: {}, \
                                         username: {}, \
-                                        shard: {}, \
+                                        shard: {:?}, \
                                         role: \"{:?}\" \
                                     }}",
                                     self.pool_name,
diff --git a/src/config.rs b/src/config.rs
index dc915d57..8ac6187b 100644
--- a/src/config.rs
+++ b/src/config.rs
@@ -3,11 +3,14 @@ use arc_swap::ArcSwap;
 use log::{error, info};
 use once_cell::sync::Lazy;
 use regex::Regex;
+use serde::{Deserializer, Serializer};
 use serde_derive::{Deserialize, Serialize};
+
 use std::collections::hash_map::DefaultHasher;
 use std::collections::{BTreeMap, HashMap, HashSet};
 use std::hash::{Hash, Hasher};
 use std::path::Path;
+use std::sync::atomic::{AtomicU64, Ordering};
 use std::sync::Arc;
 use tokio::fs::File;
 use tokio::io::AsyncReadExt;
@@ -101,6 +104,9 @@ pub struct Address {
 
     /// Address stats
     pub stats: Arc<AddressStats>,
+
+    /// Number of errors encountered since last successful checkout
+    pub error_count: Arc<AtomicU64>,
 }
 
 impl Default for Address {
@@ -118,6 +124,7 @@ impl Default for Address {
             pool_name: String::from("pool_name"),
             mirrors: Vec::new(),
             stats: Arc::new(AddressStats::default()),
+            error_count: Arc::new(AtomicU64::new(0)),
         }
     }
 }
@@ -182,6 +189,18 @@ impl Address {
             ),
         }
     }
+
+    pub fn error_count(&self) -> u64 {
+        self.error_count.load(Ordering::Relaxed)
+    }
+
+    pub fn increment_error_count(&self) {
+        self.error_count.fetch_add(1, Ordering::Relaxed);
+    }
+
+    pub fn reset_error_count(&self) {
+        self.error_count.store(0, Ordering::Relaxed);
+    }
 }
 
 /// PostgreSQL user.
@@ -540,6 +559,9 @@ pub struct Pool {
     pub shard_id_regex: Option<String>,
     pub regex_search_limit: Option<usize>,
 
+    #[serde(default = "Pool::default_no_shard_specified_behavior")]
+    pub no_shard_specified_behavior: NoShardSpecifiedHandling,
+
     pub auth_query: Option<String>,
     pub auth_query_user: Option<String>,
     pub auth_query_password: Option<String>,
@@ -575,6 +597,10 @@ impl Pool {
         PoolMode::Transaction
     }
 
+    pub fn default_no_shard_specified_behavior()-> NoShardSpecifiedHandling {
+        NoShardSpecifiedHandling::default()
+    }
+
     pub fn default_load_balancing_mode() -> LoadBalancingMode {
         LoadBalancingMode::Random
     }
@@ -666,6 +692,19 @@ impl Pool {
             None => None,
         };
 
+        match self.no_shard_specified_behavior  {
+            NoShardSpecifiedHandling::Shard(shard_number) => {
+                if shard_number >= self.shards.len() {
+                    error!(
+                        "Invalid shard {:?}",
+                        shard_number
+                    );
+                    return Err(Error::BadConfig);
+                }
+            }
+            _ => (),
+        }
+
         for (_, user) in &self.users {
             user.validate()?;
         }
@@ -693,6 +732,7 @@ impl Default for Pool {
             sharding_key_regex: None,
             shard_id_regex: None,
             regex_search_limit: Some(1000),
+            no_shard_specified_behavior: Self::default_no_shard_specified_behavior(),
             auth_query: None,
             auth_query_user: None,
             auth_query_password: None,
@@ -711,6 +751,50 @@ pub struct ServerConfig {
     pub role: Role,
 }
 
+// No Shard Specified handling.
+#[derive(Debug, PartialEq, Clone, Eq, Hash, Copy)]
+pub enum NoShardSpecifiedHandling {
+    Shard(usize),
+    Random,
+    RandomHealthy,
+}
+impl Default for NoShardSpecifiedHandling {
+    fn default() -> Self {
+        NoShardSpecifiedHandling::Shard(0)
+    }
+}
+impl serde::Serialize for NoShardSpecifiedHandling {
+    fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
+        match self {
+            NoShardSpecifiedHandling::Shard(shard) => {
+                serializer.serialize_str(&format!("shard_{}", &shard.to_string()))
+            }
+            NoShardSpecifiedHandling::Random => serializer.serialize_str("random"),
+            NoShardSpecifiedHandling::RandomHealthy => serializer.serialize_str("random_healthy"),
+        }
+    }
+}
+impl<'de> serde::Deserialize<'de> for NoShardSpecifiedHandling {
+    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
+    where
+        D: Deserializer<'de>,
+    {
+        let s = String::deserialize(deserializer)?;
+        if s.starts_with("shard_") {
+            let shard = s[6..].parse::<usize>().map_err(serde::de::Error::custom)?;
+            return Ok(NoShardSpecifiedHandling::Shard(shard));
+        }
+
+        match s.as_str() {
+            "random" => Ok(NoShardSpecifiedHandling::Random),
+            "random_healthy" => Ok(NoShardSpecifiedHandling::RandomHealthy),
+            _ => Err(serde::de::Error::custom(
+                "invalid value for no_shard_specified_behavior",
+            )),
+        }
+    }
+}
+
 #[derive(Clone, PartialEq, Serialize, Deserialize, Debug, Hash, Eq)]
 pub struct MirrorServerConfig {
     pub host: String,
diff --git a/src/errors.rs b/src/errors.rs
index 014a1340..a6aebc50 100644
--- a/src/errors.rs
+++ b/src/errors.rs
@@ -28,6 +28,7 @@ pub enum Error {
     UnsupportedStatement,
     QueryRouterParserError(String),
     QueryRouterError(String),
+    InvalidShardId(usize),
 }
 
 #[derive(Clone, PartialEq, Debug)]
diff --git a/src/pool.rs b/src/pool.rs
index 7e110ce2..80b2d1e6 100644
--- a/src/pool.rs
+++ b/src/pool.rs
@@ -10,6 +10,7 @@ use rand::thread_rng;
 use regex::Regex;
 use std::collections::HashMap;
 use std::fmt::{Display, Formatter};
+use std::sync::atomic::AtomicU64;
 use std::sync::{
     atomic::{AtomicBool, Ordering},
     Arc,
@@ -18,7 +19,8 @@ use std::time::Instant;
 use tokio::sync::Notify;
 
 use crate::config::{
-    get_config, Address, General, LoadBalancingMode, Plugins, PoolMode, Role, User,
+    get_config, Address, General, LoadBalancingMode, NoShardSpecifiedHandling, Plugins, PoolMode,
+    Role, User,
 };
 use crate::errors::Error;
 
@@ -140,6 +142,9 @@ pub struct PoolSettings {
     // Regex for searching for the shard id in SQL statements
     pub shard_id_regex: Option<Regex>,
 
+    // What to do when no shard is selected in a sharded system
+    pub no_shard_specified_behavior: NoShardSpecifiedHandling,
+
     // Limit how much of each query is searched for a potential shard regex match
     pub regex_search_limit: usize,
 
@@ -173,6 +178,7 @@ impl Default for PoolSettings {
             sharding_key_regex: None,
             shard_id_regex: None,
             regex_search_limit: 1000,
+            no_shard_specified_behavior: NoShardSpecifiedHandling::Shard(0),
             auth_query: None,
             auth_query_user: None,
             auth_query_password: None,
@@ -299,6 +305,7 @@ impl ConnectionPool {
                                     pool_name: pool_name.clone(),
                                     mirrors: vec![],
                                     stats: Arc::new(AddressStats::default()),
+                                    error_count: Arc::new(AtomicU64::new(0)),
                                 });
                                 address_id += 1;
                             }
@@ -317,6 +324,7 @@ impl ConnectionPool {
                             pool_name: pool_name.clone(),
                             mirrors: mirror_addresses,
                             stats: Arc::new(AddressStats::default()),
+                            error_count: Arc::new(AtomicU64::new(0)),
                         };
 
                         address_id += 1;
@@ -482,6 +490,9 @@ impl ConnectionPool {
                             .clone()
                             .map(|regex| Regex::new(regex.as_str()).unwrap()),
                         regex_search_limit: pool_config.regex_search_limit.unwrap_or(1000),
+                        no_shard_specified_behavior: pool_config
+                            .no_shard_specified_behavior
+                            .clone(),
                         auth_query: pool_config.auth_query.clone(),
                         auth_query_user: pool_config.auth_query_user.clone(),
                         auth_query_password: pool_config.auth_query_password.clone(),
@@ -603,19 +614,56 @@ impl ConnectionPool {
     /// Get a connection from the pool.
     pub async fn get(
         &self,
-        shard: usize,               // shard number
+        shard: Option<usize>,       // shard number
         role: Option<Role>,         // primary or replica
         client_stats: &ClientStats, // client id
     ) -> Result<(PooledConnection<'_, ServerPool>, Address), Error> {
-        let mut candidates: Vec<&Address> = self.addresses[shard]
-            .iter()
-            .filter(|address| address.role == role)
-            .collect();
+
+        let mut effective_shard_id:Option<usize> = shard;
+
+        // The base, unsharded case
+        if self.shards() == 1 {
+            effective_shard_id = Some(0);
+        }
+
+        let mut sort_by_error_count = false;
+        let mut candidates: Vec<_> = match effective_shard_id {
+            Some(shard_id) => self.addresses[shard_id].iter().collect(),
+            None => {
+                match self.settings.no_shard_specified_behavior {
+                    NoShardSpecifiedHandling::Random => self.addresses.iter().flatten().collect(),
+                    NoShardSpecifiedHandling::RandomHealthy => {
+                        sort_by_error_count = true;
+                        self.addresses.iter().flatten().collect()
+                    }
+                    NoShardSpecifiedHandling::Shard(shard) => {
+                        if shard >= self.shards() {
+                            return Err(Error::InvalidShardId(shard));
+                        } else {
+                            self.addresses[shard].iter().collect()
+                        }
+                    }
+                }
+            }
+        };
+        candidates.retain(|address| address.role == role);
 
         // We shuffle even if least_outstanding_queries is used to avoid imbalance
         // in cases where all candidates have more or less the same number of outstanding
         // queries
         candidates.shuffle(&mut thread_rng());
+
+        // The branch should only be hit if no shard is specified and we are using
+        // random healthy routing mode
+        if sort_by_error_count {
+            candidates.sort_by(|a, b| {
+                b.error_count
+                    .load(Ordering::Relaxed)
+                    .partial_cmp(&a.error_count.load(Ordering::Relaxed))
+                    .unwrap()
+            });
+        }
+
         if self.settings.load_balancing_mode == LoadBalancingMode::LeastOutstandingConnections {
             candidates.sort_by(|a, b| {
                 self.busy_connection_count(b)
@@ -651,7 +699,10 @@ impl ConnectionPool {
                 .get()
                 .await
             {
-                Ok(conn) => conn,
+                Ok(conn) => {
+                    address.reset_error_count();
+                    conn
+                }
                 Err(err) => {
                     error!(
                         "Connection checkout error for instance {:?}, error: {:?}",
@@ -766,6 +817,18 @@ impl ConnectionPool {
     /// traffic for any new transactions. Existing transactions on that replica
     /// will finish successfully or error out to the clients.
     pub fn ban(&self, address: &Address, reason: BanReason, client_info: Option<&ClientStats>) {
+        // Count the number of errors since the last successful checkout
+        // This is used to determine if the shard is down
+        match reason {
+            BanReason::FailedHealthCheck
+            | BanReason::FailedCheckout
+            | BanReason::MessageSendFailed
+            | BanReason::MessageReceiveFailed => {
+                address.increment_error_count();
+            }
+            _ => (),
+        };
+
         // Primary can never be banned
         if address.role == Role::Primary {
             return;
@@ -920,6 +983,7 @@ impl ConnectionPool {
         self.original_server_parameters.read().clone()
     }
 
+    /// Get the number of checked out connection for an address
     fn busy_connection_count(&self, address: &Address) -> u32 {
         let state = self.pool_state(address.shard, address.address_index);
         let idle = state.idle_connections;
diff --git a/src/query_router.rs b/src/query_router.rs
index efca499f..f2f20b93 100644
--- a/src/query_router.rs
+++ b/src/query_router.rs
@@ -143,13 +143,14 @@ impl QueryRouter {
         let code = message_cursor.get_u8() as char;
         let len = message_cursor.get_i32() as usize;
 
+        let comment_shard_routing_enabled = self.pool_settings.shard_id_regex.is_some()
+            || self.pool_settings.sharding_key_regex.is_some();
+
         // Check for any sharding regex matches in any queries
-        match code as char {
-            // For Parse and Query messages peek to see if they specify a shard_id as a comment early in the statement
-            'P' | 'Q' => {
-                if self.pool_settings.shard_id_regex.is_some()
-                    || self.pool_settings.sharding_key_regex.is_some()
-                {
+        if comment_shard_routing_enabled {
+            match code as char {
+                // For Parse and Query messages peek to see if they specify a shard_id as a comment early in the statement
+                'P' | 'Q' => {
                     // Check only the first block of bytes configured by the pool settings
                     let seg = cmp::min(len - 5, self.pool_settings.regex_search_limit);
 
@@ -166,7 +167,7 @@ impl QueryRouter {
                         });
                         if let Some(shard_id) = shard_id {
                             debug!("Setting shard to {:?}", shard_id);
-                            self.set_shard(shard_id);
+                            self.set_shard(Some(shard_id));
                             // Skip other command processing since a sharding command was found
                             return None;
                         }
@@ -188,8 +189,8 @@ impl QueryRouter {
                         }
                     }
                 }
+                _ => {}
             }
-            _ => {}
         }
 
         // Only simple protocol supported for commands processed below
@@ -248,7 +249,9 @@ impl QueryRouter {
                 }
             }
 
-            Command::ShowShard => self.shard().to_string(),
+            Command::ShowShard => self
+                .shard()
+                .map_or_else(|| "unset".to_string(), |x| x.to_string()),
             Command::ShowServerRole => match self.active_role {
                 Some(Role::Primary) => Role::Primary.to_string(),
                 Some(Role::Replica) => Role::Replica.to_string(),
@@ -581,7 +584,7 @@ impl QueryRouter {
         // TODO: Support multi-shard queries some day.
         if shards.len() == 1 {
             debug!("Found one sharding key");
-            self.set_shard(*shards.first().unwrap());
+            self.set_shard(Some(*shards.first().unwrap()));
             true
         } else {
             debug!("Found no sharding keys");
@@ -865,7 +868,7 @@ impl QueryRouter {
             self.pool_settings.sharding_function,
         );
         let shard = sharder.shard(sharding_key);
-        self.set_shard(shard);
+        self.set_shard(Some(shard));
         self.active_shard
     }
 
@@ -875,12 +878,12 @@ impl QueryRouter {
     }
 
     /// Get desired shard we should be talking to.
-    pub fn shard(&self) -> usize {
-        self.active_shard.unwrap_or(0)
+    pub fn shard(&self) -> Option<usize> {
+        self.active_shard
     }
 
-    pub fn set_shard(&mut self, shard: usize) {
-        self.active_shard = Some(shard);
+    pub fn set_shard(&mut self, shard: Option<usize>) {
+        self.active_shard = shard;
     }
 
     /// Should we attempt to parse queries?
@@ -1090,7 +1093,7 @@ mod test {
             qr.try_execute_command(&query),
             Some((Command::SetShardingKey, String::from("0")))
         );
-        assert_eq!(qr.shard(), 0);
+        assert_eq!(qr.shard().unwrap(), 0);
 
         // SetShard
         let query = simple_query("SET SHARD TO '1'");
@@ -1098,7 +1101,7 @@ mod test {
             qr.try_execute_command(&query),
             Some((Command::SetShard, String::from("1")))
         );
-        assert_eq!(qr.shard(), 1);
+        assert_eq!(qr.shard().unwrap(), 1);
 
         // ShowShard
         let query = simple_query("SHOW SHARD");
@@ -1204,6 +1207,7 @@ mod test {
             ban_time: PoolSettings::default().ban_time,
             sharding_key_regex: None,
             shard_id_regex: None,
+            no_shard_specified_behavior: crate::config::NoShardSpecifiedHandling::Shard(0),
             regex_search_limit: 1000,
             auth_query: None,
             auth_query_password: None,
@@ -1281,6 +1285,7 @@ mod test {
             ban_time: PoolSettings::default().ban_time,
             sharding_key_regex: Some(Regex::new(r"/\* sharding_key: (\d+) \*/").unwrap()),
             shard_id_regex: Some(Regex::new(r"/\* shard_id: (\d+) \*/").unwrap()),
+            no_shard_specified_behavior: crate::config::NoShardSpecifiedHandling::Shard(0),
             regex_search_limit: 1000,
             auth_query: None,
             auth_query_password: None,
@@ -1331,7 +1336,7 @@ mod test {
                     .unwrap(),
             )
             .is_ok());
-        assert_eq!(qr.shard(), 2);
+        assert_eq!(qr.shard().unwrap(), 2);
 
         assert!(qr
             .infer(
@@ -1341,7 +1346,7 @@ mod test {
                 .unwrap()
             )
             .is_ok());
-        assert_eq!(qr.shard(), 0);
+        assert_eq!(qr.shard().unwrap(), 0);
 
         assert!(qr
             .infer(
@@ -1354,7 +1359,7 @@ mod test {
                 .unwrap()
             )
             .is_ok());
-        assert_eq!(qr.shard(), 2);
+        assert_eq!(qr.shard().unwrap(), 2);
 
         // Shard did not move because we couldn't determine the sharding key since it could be ambiguous
         // in the query.
@@ -1366,7 +1371,7 @@ mod test {
                 .unwrap()
             )
             .is_ok());
-        assert_eq!(qr.shard(), 2);
+        assert_eq!(qr.shard().unwrap(), 2);
 
         assert!(qr
             .infer(
@@ -1376,7 +1381,7 @@ mod test {
                 .unwrap()
             )
             .is_ok());
-        assert_eq!(qr.shard(), 0);
+        assert_eq!(qr.shard().unwrap(), 0);
 
         assert!(qr
             .infer(
@@ -1386,7 +1391,7 @@ mod test {
                 .unwrap()
             )
             .is_ok());
-        assert_eq!(qr.shard(), 2);
+        assert_eq!(qr.shard().unwrap(), 2);
 
         // Super unique sharding key
         qr.pool_settings.automatic_sharding_key = Some("*.unique_enough_column_name".to_string());
@@ -1398,7 +1403,7 @@ mod test {
                 .unwrap()
             )
             .is_ok());
-        assert_eq!(qr.shard(), 0);
+        assert_eq!(qr.shard().unwrap(), 0);
 
         assert!(qr
             .infer(
@@ -1406,7 +1411,7 @@ mod test {
                     .unwrap()
             )
             .is_ok());
-        assert_eq!(qr.shard(), 0);
+        assert_eq!(qr.shard().unwrap(), 0);
     }
 
     #[test]
@@ -1434,7 +1439,7 @@ mod test {
         assert_eq!(qr.placeholders.len(), 1);
 
         assert!(qr.infer_shard_from_bind(&bind));
-        assert_eq!(qr.shard(), 2);
+        assert_eq!(qr.shard().unwrap(), 2);
         assert!(qr.placeholders.is_empty());
     }
 
diff --git a/tests/docker/Dockerfile b/tests/docker/Dockerfile
index 99fd694d..261adb05 100644
--- a/tests/docker/Dockerfile
+++ b/tests/docker/Dockerfile
@@ -1,5 +1,7 @@
 FROM rust:bullseye
 
+COPY --from=sclevine/yj /bin/yj /bin/yj
+RUN /bin/yj -h
 RUN apt-get update && apt-get install llvm-11 psmisc postgresql-contrib postgresql-client ruby ruby-dev libpq-dev python3 python3-pip lcov curl sudo iproute2 -y
 RUN cargo install cargo-binutils rustfilt
 RUN rustup component add llvm-tools-preview
diff --git a/tests/ruby/auth_query_spec.rb b/tests/ruby/auth_query_spec.rb
index 1ac62164..c1ee744a 100644
--- a/tests/ruby/auth_query_spec.rb
+++ b/tests/ruby/auth_query_spec.rb
@@ -185,7 +185,7 @@
               },
             }
           }
-        } 
+        }
 
         context 'and with cleartext passwords set' do
           it 'it uses local passwords' do
diff --git a/tests/ruby/helpers/auth_query_helper.rb b/tests/ruby/helpers/auth_query_helper.rb
index 60e85713..43d7c785 100644
--- a/tests/ruby/helpers/auth_query_helper.rb
+++ b/tests/ruby/helpers/auth_query_helper.rb
@@ -33,18 +33,18 @@ def self.single_shard_auth_query(
             "0" => {
               "database" => "shard0",
               "servers" => [
-                ["localhost", primary.port.to_s, "primary"],
-                ["localhost", replica.port.to_s, "replica"],
+                ["localhost", primary.port.to_i, "primary"],
+                ["localhost", replica.port.to_i, "replica"],
               ]
             },
           },
           "users" => { "0" => user.merge(config_user) }
         }
       }
-      pgcat_cfg["general"]["port"] = pgcat.port
+      pgcat_cfg["general"]["port"] = pgcat.port.to_i
       pgcat.update_config(pgcat_cfg)
       pgcat.start
-      
+
       pgcat.wait_until_ready(
         pgcat.connection_string(
           "sharded_db",
@@ -92,13 +92,13 @@ def self.two_pools_auth_query(
             "0" => {
               "database" => database,
               "servers" => [
-                ["localhost", primary.port.to_s, "primary"],
-                ["localhost", replica.port.to_s, "replica"],
+                ["localhost", primary.port.to_i, "primary"],
+                ["localhost", replica.port.to_i, "replica"],
               ]
             },
           },
           "users" => { "0" => user.merge(config_user) }
-        }                                  
+        }
       end
       # Main proxy configs
       pgcat_cfg["pools"] = {
@@ -109,7 +109,7 @@ def self.two_pools_auth_query(
       pgcat_cfg["general"]["port"] = pgcat.port
       pgcat.update_config(pgcat_cfg.deep_merge(extra_conf))
       pgcat.start
-      
+
       pgcat.wait_until_ready(pgcat.connection_string("sharded_db0", pg_user['username'], pg_user['password']))
 
       OpenStruct.new.tap do |struct|
diff --git a/tests/ruby/helpers/pg_instance.rb b/tests/ruby/helpers/pg_instance.rb
index a3828248..53617c24 100644
--- a/tests/ruby/helpers/pg_instance.rb
+++ b/tests/ruby/helpers/pg_instance.rb
@@ -7,10 +7,24 @@ class PgInstance
   attr_reader :password
   attr_reader :database_name
 
+  def self.mass_takedown(databases)
+    raise StandardError "block missing" unless block_given?
+
+    databases.each do |database|
+      database.toxiproxy.toxic(:limit_data, bytes: 1).toxics.each(&:save)
+    end
+    sleep 0.1
+    yield
+  ensure
+    databases.each do |database|
+      database.toxiproxy.toxics.each(&:destroy)
+    end
+  end
+
   def initialize(port, username, password, database_name)
-    @original_port = port
+    @original_port = port.to_i
     @toxiproxy_port = 10000 + port.to_i
-    @port = @toxiproxy_port
+    @port = @toxiproxy_port.to_i
 
     @username = username
     @password = password
@@ -48,9 +62,9 @@ def toxiproxy
 
   def take_down
     if block_given?
-      Toxiproxy[@toxiproxy_name].toxic(:limit_data, bytes: 5).apply { yield }
+      Toxiproxy[@toxiproxy_name].toxic(:limit_data, bytes: 1).apply { yield }
     else
-      Toxiproxy[@toxiproxy_name].toxic(:limit_data, bytes: 5).toxics.each(&:save)
+      Toxiproxy[@toxiproxy_name].toxic(:limit_data, bytes: 1).toxics.each(&:save)
     end
   end
 
@@ -89,6 +103,6 @@ def count_query(query)
   end
 
   def count_select_1_plus_2
-    with_connection { |c| c.async_exec("SELECT SUM(calls) FROM pg_stat_statements WHERE query = 'SELECT $1 + $2'")[0]["sum"].to_i }
+    with_connection { |c| c.async_exec("SELECT SUM(calls) FROM pg_stat_statements WHERE query LIKE '%SELECT $1 + $2%'")[0]["sum"].to_i }
   end
 end
diff --git a/tests/ruby/helpers/pgcat_helper.rb b/tests/ruby/helpers/pgcat_helper.rb
index 9b764d87..9b95dbfa 100644
--- a/tests/ruby/helpers/pgcat_helper.rb
+++ b/tests/ruby/helpers/pgcat_helper.rb
@@ -38,9 +38,9 @@ def self.three_shard_setup(pool_name, pool_size, pool_mode="transaction", lb_mod
           "automatic_sharding_key" => "data.id",
           "sharding_function" => "pg_bigint_hash",
           "shards" => {
-            "0" => { "database" => "shard0", "servers" => [["localhost", primary0.port.to_s, "primary"]] },
-            "1" => { "database" => "shard1", "servers" => [["localhost", primary1.port.to_s, "primary"]] },
-            "2" => { "database" => "shard2", "servers" => [["localhost", primary2.port.to_s, "primary"]] },
+            "0" => { "database" => "shard0", "servers" => [["localhost", primary0.port.to_i, "primary"]] },
+            "1" => { "database" => "shard1", "servers" => [["localhost", primary1.port.to_i, "primary"]] },
+            "2" => { "database" => "shard2", "servers" => [["localhost", primary2.port.to_i, "primary"]] },
           },
           "users" => { "0" => user },
           "plugins" => {
@@ -100,7 +100,7 @@ def self.single_instance_setup(pool_name, pool_size, pool_mode="transaction", lb
             "0" => {
               "database" => "shard0",
               "servers" => [
-                ["localhost", primary.port.to_s, "primary"]
+                ["localhost", primary.port.to_i, "primary"]
               ]
             },
           },
@@ -146,10 +146,10 @@ def self.single_shard_setup(pool_name, pool_size, pool_mode="transaction", lb_mo
           "0" => {
             "database" => "shard0",
             "servers" => [
-              ["localhost", primary.port.to_s, "primary"],
-              ["localhost", replica0.port.to_s, "replica"],
-              ["localhost", replica1.port.to_s, "replica"],
-              ["localhost", replica2.port.to_s, "replica"]
+              ["localhost", primary.port.to_i, "primary"],
+              ["localhost", replica0.port.to_i, "replica"],
+              ["localhost", replica1.port.to_i, "replica"],
+              ["localhost", replica2.port.to_i, "replica"]
             ]
           },
         },
diff --git a/tests/ruby/helpers/pgcat_process.rb b/tests/ruby/helpers/pgcat_process.rb
index dd3fd052..9328ff60 100644
--- a/tests/ruby/helpers/pgcat_process.rb
+++ b/tests/ruby/helpers/pgcat_process.rb
@@ -1,8 +1,10 @@
 require 'pg'
-require 'toml'
+require 'json'
+require 'tempfile'
 require 'fileutils'
 require 'securerandom'
 
+class ConfigReloadFailed < StandardError; end
 class PgcatProcess
   attr_reader :port
   attr_reader :pid
@@ -18,7 +20,7 @@ def self.finalize(pid, log_filename, config_filename)
   end
 
   def initialize(log_level)
-    @env = {"RUST_LOG" => log_level}
+    @env = {}
     @port = rand(20000..32760)
     @log_level = log_level
     @log_filename = "/tmp/pgcat_log_#{SecureRandom.urlsafe_base64}.log"
@@ -30,7 +32,7 @@ def initialize(log_level)
                      '../../target/debug/pgcat'
                    end
 
-    @command = "#{command_path} #{@config_filename}"
+    @command = "#{command_path} #{@config_filename} --log-level #{@log_level}"
 
     FileUtils.cp("../../pgcat.toml", @config_filename)
     cfg = current_config
@@ -46,22 +48,34 @@ def logs
 
   def update_config(config_hash)
     @original_config = current_config
-    output_to_write = TOML::Generator.new(config_hash).body
-    output_to_write = output_to_write.gsub(/,\s*["|'](\d+)["|']\s*,/, ',\1,')
-    output_to_write = output_to_write.gsub(/,\s*["|'](\d+)["|']\s*\]/, ',\1]')
-    File.write(@config_filename, output_to_write)
+    Tempfile.create('json_out', '/tmp') do |f|
+      f.write(config_hash.to_json)
+      f.flush
+      `cat #{f.path} | yj -jt > #{@config_filename}`
+    end
   end
 
   def current_config
-    loadable_string = File.read(@config_filename)
-    loadable_string = loadable_string.gsub(/,\s*(\d+)\s*,/,  ', "\1",')
-    loadable_string = loadable_string.gsub(/,\s*(\d+)\s*\]/, ', "\1"]')
-    TOML.load(loadable_string)
+    JSON.parse(`cat #{@config_filename} | yj -tj`)
+  end
+
+  def raw_config_file
+    File.read(@config_filename)
   end
 
   def reload_config
-    `kill -s HUP #{@pid}`
-    sleep 0.5
+    conn = PG.connect(admin_connection_string)
+
+    conn.async_exec("RELOAD")
+  rescue PG::ConnectionBad => e
+    errors = logs.split("Reloading config").last
+    errors = errors.gsub(/\e\[([;\d]+)?m/, '') # Remove color codes
+    errors = errors.
+                split("\n").select{|line| line.include?("ERROR") }.
+                map { |line| line.split("pgcat::config: ").last }
+    raise ConfigReloadFailed, errors.join("\n")
+  ensure
+    conn&.close
   end
 
   def start
@@ -116,11 +130,11 @@ def connection_string(pool_name, username, password = nil, parameters: {})
     cfg = current_config
     user_idx, user_obj = cfg["pools"][pool_name]["users"].detect { |k, user| user["username"] == username }
     connection_string = "postgresql://#{username}:#{password || user_obj["password"]}@0.0.0.0:#{@port}/#{pool_name}"
-  
+
     # Add the additional parameters to the connection string
     parameter_string = parameters.map { |key, value| "#{key}=#{value}" }.join("&")
     connection_string += "?#{parameter_string}" unless parameter_string.empty?
-  
+
     connection_string
   end
 
diff --git a/tests/ruby/mirrors_spec.rb b/tests/ruby/mirrors_spec.rb
index 898d0d71..91123a16 100644
--- a/tests/ruby/mirrors_spec.rb
+++ b/tests/ruby/mirrors_spec.rb
@@ -11,9 +11,9 @@
   before do
     new_configs = processes.pgcat.current_config
     new_configs["pools"]["sharded_db"]["shards"]["0"]["mirrors"] = [
-      [mirror_host, mirror_pg.port.to_s, "0"],
-      [mirror_host, mirror_pg.port.to_s, "0"],
-      [mirror_host, mirror_pg.port.to_s, "0"],
+      [mirror_host, mirror_pg.port.to_i, 0],
+      [mirror_host, mirror_pg.port.to_i, 0],
+      [mirror_host, mirror_pg.port.to_i, 0],
     ]
     processes.pgcat.update_config(new_configs)
     processes.pgcat.reload_config
@@ -25,13 +25,14 @@
     processes.pgcat.shutdown
   end
 
-  xit "can mirror a query" do
+  it "can mirror a query" do
     conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
     runs = 15
     runs.times { conn.async_exec("SELECT 1 + 2") }
     sleep 0.5
     expect(processes.all_databases.first.count_select_1_plus_2).to eq(runs)
-    expect(mirror_pg.count_select_1_plus_2).to eq(runs * 3)
+    # Allow some slack in mirroring successes
+    expect(mirror_pg.count_select_1_plus_2).to be > ((runs - 5) * 3)
   end
 
   context "when main server connection is closed" do
@@ -42,9 +43,9 @@
         new_configs = processes.pgcat.current_config
         new_configs["pools"]["sharded_db"]["idle_timeout"] = 5000 + i
         new_configs["pools"]["sharded_db"]["shards"]["0"]["mirrors"] = [
-          [mirror_host, mirror_pg.port.to_s, "0"],
-          [mirror_host, mirror_pg.port.to_s, "0"],
-          [mirror_host, mirror_pg.port.to_s, "0"],
+          [mirror_host, mirror_pg.port.to_i, 0],
+          [mirror_host, mirror_pg.port.to_i, 0],
+          [mirror_host, mirror_pg.port.to_i, 0],
         ]
         processes.pgcat.update_config(new_configs)
         processes.pgcat.reload_config
@@ -57,7 +58,7 @@
     end
   end
 
-  xcontext "when mirror server goes down temporarily" do
+  context "when mirror server goes down temporarily" do
     it "continues to transmit queries after recovery" do
       conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
       mirror_pg.take_down do
diff --git a/tests/ruby/misc_spec.rb b/tests/ruby/misc_spec.rb
index 1d4ade4c..aa17e8ec 100644
--- a/tests/ruby/misc_spec.rb
+++ b/tests/ruby/misc_spec.rb
@@ -252,7 +252,7 @@
         end
 
         expect(processes.primary.count_query("RESET ROLE")).to eq(10)
-      end 
+      end
     end
 
     context "transaction mode" do
@@ -317,7 +317,7 @@
         conn.async_exec("SET statement_timeout to 1500")
         expect(conn.async_exec("SHOW statement_timeout")[0]["statement_timeout"]).to eq(orignal_statement_timeout)
       end
- 
+
     end
 
     context "transaction mode with transactions" do
@@ -354,7 +354,6 @@
         conn.async_exec("SET statement_timeout TO 1000")
         conn.close
 
-        puts processes.pgcat.logs
         expect(processes.primary.count_query("RESET ALL")).to eq(0)
       end
 
@@ -365,7 +364,6 @@
 
         conn.close
 
-        puts processes.pgcat.logs
         expect(processes.primary.count_query("RESET ALL")).to eq(0)
       end
     end
@@ -376,10 +374,9 @@
       before do
         current_configs = processes.pgcat.current_config
         correct_idle_client_transaction_timeout = current_configs["general"]["idle_client_in_transaction_timeout"]
-        puts(current_configs["general"]["idle_client_in_transaction_timeout"])
-  
+
         current_configs["general"]["idle_client_in_transaction_timeout"] = 0
-  
+
         processes.pgcat.update_config(current_configs) # with timeout 0
         processes.pgcat.reload_config
       end
@@ -397,9 +394,9 @@
     context "idle transaction timeout set to 500ms" do
       before do
         current_configs = processes.pgcat.current_config
-        correct_idle_client_transaction_timeout = current_configs["general"]["idle_client_in_transaction_timeout"]  
+        correct_idle_client_transaction_timeout = current_configs["general"]["idle_client_in_transaction_timeout"]
         current_configs["general"]["idle_client_in_transaction_timeout"] = 500
-  
+
         processes.pgcat.update_config(current_configs) # with timeout 500
         processes.pgcat.reload_config
       end
@@ -418,7 +415,7 @@
         conn.async_exec("BEGIN")
         conn.async_exec("SELECT 1")
         sleep(1) # above 500ms
-        expect{ conn.async_exec("COMMIT") }.to raise_error(PG::SystemError, /idle transaction timeout/) 
+        expect{ conn.async_exec("COMMIT") }.to raise_error(PG::SystemError, /idle transaction timeout/)
         conn.async_exec("SELECT 1") # should be able to send another query
         conn.close
       end
diff --git a/tests/ruby/sharding_spec.rb b/tests/ruby/sharding_spec.rb
index 123c10dc..47bc2a12 100644
--- a/tests/ruby/sharding_spec.rb
+++ b/tests/ruby/sharding_spec.rb
@@ -7,11 +7,11 @@
 
   before do
     conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
-
     # Setup the sharding data
     3.times do |i|
       conn.exec("SET SHARD TO '#{i}'")
-      conn.exec("DELETE FROM data WHERE id > 0")
+
+      conn.exec("DELETE FROM data WHERE id > 0") rescue nil
     end
 
     18.times do |i|
@@ -19,10 +19,11 @@
       conn.exec("SET SHARDING KEY TO '#{i}'")
       conn.exec("INSERT INTO data (id, value) VALUES (#{i}, 'value_#{i}')")
     end
+
+    conn.close
   end
 
   after do
-
     processes.all_databases.map(&:reset)
     processes.pgcat.shutdown
   end
@@ -48,4 +49,148 @@
       end
     end
   end
+
+  describe "no_shard_specified_behavior config" do
+    context "when default shard number is invalid" do
+      it "prevents config reload" do
+        admin_conn = PG::connect(processes.pgcat.admin_connection_string)
+
+        current_configs = processes.pgcat.current_config
+        current_configs["pools"]["sharded_db"]["no_shard_specified_behavior"] = "shard_99"
+
+        processes.pgcat.update_config(current_configs)
+
+        expect { processes.pgcat.reload_config }.to raise_error(ConfigReloadFailed, /Invalid shard 99/)
+      end
+    end
+  end
+
+  describe "comment-based routing" do
+    context "when no configs are set" do
+      it "routes queries with a shard_id comment to the default shard" do
+        conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
+        10.times { conn.async_exec("/* shard_id: 2 */ SELECT 1 + 2") }
+
+        expect(processes.all_databases.map(&:count_select_1_plus_2)).to eq([10, 0, 0])
+      end
+
+      it "does not honor no_shard_specified_behavior directives" do
+      end
+    end
+
+    [
+      ["shard_id_regex", "/\\* the_shard_id: (\\d+) \\*/", "/* the_shard_id: 1 */"],
+      ["sharding_key_regex", "/\\* the_sharding_key: (\\d+) \\*/", "/* the_sharding_key: 3 */"],
+    ].each do |config_name, config_value, comment_to_use|
+      context "when #{config_name} config is set" do
+        let(:no_shard_specified_behavior) { nil }
+
+        before do
+          admin_conn = PG::connect(processes.pgcat.admin_connection_string)
+
+          current_configs = processes.pgcat.current_config
+          current_configs["pools"]["sharded_db"][config_name] = config_value
+          if no_shard_specified_behavior
+            current_configs["pools"]["sharded_db"]["no_shard_specified_behavior"] = no_shard_specified_behavior
+          else
+            current_configs["pools"]["sharded_db"].delete("no_shard_specified_behavior")
+          end
+
+          processes.pgcat.update_config(current_configs)
+          processes.pgcat.reload_config
+        end
+
+        it "routes queries with a shard_id comment to the correct shard" do
+          conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
+          25.times { conn.async_exec("#{comment_to_use} SELECT 1 + 2") }
+
+          expect(processes.all_databases.map(&:count_select_1_plus_2)).to eq([0, 25, 0])
+        end
+
+        context "when no_shard_specified_behavior config is set to random" do
+          let(:no_shard_specified_behavior) { "random" }
+
+          context "with no shard comment" do
+            it "sends queries to random shard" do
+              conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
+              25.times { conn.async_exec("SELECT 1 + 2") }
+
+              expect(processes.all_databases.map(&:count_select_1_plus_2).all?(&:positive?)).to be true
+            end
+          end
+
+          context "with a shard comment" do
+            it "honors the comment" do
+              conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
+              25.times { conn.async_exec("#{comment_to_use} SELECT 1 + 2") }
+
+              expect(processes.all_databases.map(&:count_select_1_plus_2)).to eq([0, 25, 0])
+            end
+          end
+        end
+
+        context "when no_shard_specified_behavior config is set to random_healthy" do
+          let(:no_shard_specified_behavior) { "random_healthy" }
+
+          context "with no shard comment" do
+            it "sends queries to random healthy shard" do
+
+              good_databases = [processes.all_databases[0], processes.all_databases[2]]
+              bad_database = processes.all_databases[1]
+              conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
+              250.times { conn.async_exec("SELECT 99") }
+              bad_database.take_down do
+                250.times do
+                  conn.async_exec("SELECT 99")
+                rescue PG::ConnectionBad => e
+                  conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
+                end
+              end
+
+              # Routes traffic away from bad shard
+              25.times { conn.async_exec("SELECT 1 + 2") }
+              expect(good_databases.map(&:count_select_1_plus_2).all?(&:positive?)).to be true
+              expect(bad_database.count_select_1_plus_2).to eq(0)
+
+              # Routes traffic to the bad shard if the shard_id is specified
+              25.times { conn.async_exec("#{comment_to_use} SELECT 1 + 2") }
+              bad_database = processes.all_databases[1]
+              expect(bad_database.count_select_1_plus_2).to eq(25)
+            end
+          end
+
+          context "with a shard comment" do
+            it "honors the comment" do
+              conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
+              25.times { conn.async_exec("#{comment_to_use} SELECT 1 + 2") }
+
+              expect(processes.all_databases.map(&:count_select_1_plus_2)).to eq([0, 25, 0])
+            end
+          end
+        end
+
+        context "when no_shard_specified_behavior config is set to shard_x" do
+          let(:no_shard_specified_behavior) { "shard_2" }
+
+          context "with no shard comment" do
+            it "sends queries to the specified shard" do
+              conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
+              25.times { conn.async_exec("SELECT 1 + 2") }
+
+              expect(processes.all_databases.map(&:count_select_1_plus_2)).to eq([0, 0, 25])
+            end
+          end
+
+          context "with a shard comment" do
+            it "honors the comment" do
+              conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
+              25.times { conn.async_exec("#{comment_to_use} SELECT 1 + 2") }
+
+              expect(processes.all_databases.map(&:count_select_1_plus_2)).to eq([0, 25, 0])
+            end
+          end
+        end
+      end
+    end
+  end
 end

From 425a216d37a55b57b2d539c194aad14e23f6686c Mon Sep 17 00:00:00 2001
From: Mostafa Abdelraouf <mostafa.abdelraouf@instacart.com>
Date: Mon, 4 Sep 2023 08:29:05 -0500
Subject: [PATCH 2/8] fmt + comment

---
 pgcat.toml    |  7 ++++---
 src/client.rs |  5 +++--
 src/config.rs |  9 +++------
 src/pool.rs   | 29 +++++++++++++----------------
 4 files changed, 23 insertions(+), 27 deletions(-)

diff --git a/pgcat.toml b/pgcat.toml
index 39689cd0..772a1365 100644
--- a/pgcat.toml
+++ b/pgcat.toml
@@ -171,16 +171,17 @@ query_parser_read_write_splitting = true
 # queries. The primary can always be explicitly selected with our custom protocol.
 primary_reads_enabled = true
 
+# Allow sharding commands to be passed as statement comments instead of
+# separate commands. If these are unset this functionality is disabled.
 # sharding_key_regex = '/\* sharding_key: (\d+) \*/'
 # shard_id_regex = '/\* shard_id: (\d+) \*/'
 # regex_search_limit = 1000 # only look at the first 1000 characters of SQL statements
 
-# Defines the behavior when no shard_id or sharding_key are specified for a query against
-# a sharded system with either sharding_key_regex or shard_id_regex specified.
+# Defines the behavior when no shard is selected in a sharded system.
 # `random`: picks a shard at random
 # `random_healthy`: picks a shard at random favoring shards with the least number of recent errors
 # `shard_<number>`: e.g. shard_0, shard_4, etc. picks a specific shard, everytime
-# no_shard_specified_behavior = "random"
+# no_shard_specified_behavior = "shard_0"
 
 # So what if you wanted to implement a different hashing function,
 # or you've already built one and you want this pooler to use it?
diff --git a/src/client.rs b/src/client.rs
index 56fddbfe..ba035169 100644
--- a/src/client.rs
+++ b/src/client.rs
@@ -1095,8 +1095,9 @@ where
 
                     error_response(
                         &mut self.write,
-                        format!("could not get connection from the pool - {}", err).as_str()
-                ).await?;
+                        format!("could not get connection from the pool - {}", err).as_str(),
+                    )
+                    .await?;
 
                     error!(
                         "Could not get connection from pool: \
diff --git a/src/config.rs b/src/config.rs
index 8ac6187b..5814a9fa 100644
--- a/src/config.rs
+++ b/src/config.rs
@@ -597,7 +597,7 @@ impl Pool {
         PoolMode::Transaction
     }
 
-    pub fn default_no_shard_specified_behavior()-> NoShardSpecifiedHandling {
+    pub fn default_no_shard_specified_behavior() -> NoShardSpecifiedHandling {
         NoShardSpecifiedHandling::default()
     }
 
@@ -692,13 +692,10 @@ impl Pool {
             None => None,
         };
 
-        match self.no_shard_specified_behavior  {
+        match self.no_shard_specified_behavior {
             NoShardSpecifiedHandling::Shard(shard_number) => {
                 if shard_number >= self.shards.len() {
-                    error!(
-                        "Invalid shard {:?}",
-                        shard_number
-                    );
+                    error!("Invalid shard {:?}", shard_number);
                     return Err(Error::BadConfig);
                 }
             }
diff --git a/src/pool.rs b/src/pool.rs
index 80b2d1e6..9e3a193a 100644
--- a/src/pool.rs
+++ b/src/pool.rs
@@ -618,8 +618,7 @@ impl ConnectionPool {
         role: Option<Role>,         // primary or replica
         client_stats: &ClientStats, // client id
     ) -> Result<(PooledConnection<'_, ServerPool>, Address), Error> {
-
-        let mut effective_shard_id:Option<usize> = shard;
+        let mut effective_shard_id: Option<usize> = shard;
 
         // The base, unsharded case
         if self.shards() == 1 {
@@ -629,22 +628,20 @@ impl ConnectionPool {
         let mut sort_by_error_count = false;
         let mut candidates: Vec<_> = match effective_shard_id {
             Some(shard_id) => self.addresses[shard_id].iter().collect(),
-            None => {
-                match self.settings.no_shard_specified_behavior {
-                    NoShardSpecifiedHandling::Random => self.addresses.iter().flatten().collect(),
-                    NoShardSpecifiedHandling::RandomHealthy => {
-                        sort_by_error_count = true;
-                        self.addresses.iter().flatten().collect()
-                    }
-                    NoShardSpecifiedHandling::Shard(shard) => {
-                        if shard >= self.shards() {
-                            return Err(Error::InvalidShardId(shard));
-                        } else {
-                            self.addresses[shard].iter().collect()
-                        }
+            None => match self.settings.no_shard_specified_behavior {
+                NoShardSpecifiedHandling::Random => self.addresses.iter().flatten().collect(),
+                NoShardSpecifiedHandling::RandomHealthy => {
+                    sort_by_error_count = true;
+                    self.addresses.iter().flatten().collect()
+                }
+                NoShardSpecifiedHandling::Shard(shard) => {
+                    if shard >= self.shards() {
+                        return Err(Error::InvalidShardId(shard));
+                    } else {
+                        self.addresses[shard].iter().collect()
                     }
                 }
-            }
+            },
         };
         candidates.retain(|address| address.role == role);
 

From 7f69133498cebca4cebfb2bd4996ad7d5618460f Mon Sep 17 00:00:00 2001
From: Mostafa Abdelraouf <mostafa.abdelraouf@instacart.com>
Date: Mon, 4 Sep 2023 09:41:18 -0500
Subject: [PATCH 3/8] use a tagged image

---
 .circleci/config.yml | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/.circleci/config.yml b/.circleci/config.yml
index 9fe3c256..898ac027 100644
--- a/.circleci/config.yml
+++ b/.circleci/config.yml
@@ -9,7 +9,7 @@ jobs:
     # Specify the execution environment. You can specify an image from Dockerhub or use one of our Convenience Images from CircleCI's Developer Hub.
     # See: https://circleci.com/docs/2.0/configuration-reference/#docker-machine-macos-windows-executor
     docker:
-      - image: ghcr.io/levkk/pgcat-ci:latest
+      - image: ghcr.io/levkk/pgcat-ci:4f49ec147c040e246b284081d8c0eca076a32f8a
         environment:
           RUST_LOG: info
           LLVM_PROFILE_FILE: /tmp/pgcat-%m-%p.profraw

From b48c08e6805e930e0baabbea070e977394025b29 Mon Sep 17 00:00:00 2001
From: Mostafa Abdelraouf <mostafa.abdelraouf@instacart.com>
Date: Mon, 4 Sep 2023 09:53:25 -0500
Subject: [PATCH 4/8] use pgml ci image

---
 .circleci/config.yml | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/.circleci/config.yml b/.circleci/config.yml
index 898ac027..c8344911 100644
--- a/.circleci/config.yml
+++ b/.circleci/config.yml
@@ -9,7 +9,7 @@ jobs:
     # Specify the execution environment. You can specify an image from Dockerhub or use one of our Convenience Images from CircleCI's Developer Hub.
     # See: https://circleci.com/docs/2.0/configuration-reference/#docker-machine-macos-windows-executor
     docker:
-      - image: ghcr.io/levkk/pgcat-ci:4f49ec147c040e246b284081d8c0eca076a32f8a
+      - image: ghcr.io/postgresml/pgcat-ci:latest
         environment:
           RUST_LOG: info
           LLVM_PROFILE_FILE: /tmp/pgcat-%m-%p.profraw

From 3ab8424f4d5c29c9a05b2f8fed4f90d499f41b3b Mon Sep 17 00:00:00 2001
From: Mostafa Abdelraouf <mostafa.abdelraouf@instacart.com>
Date: Mon, 4 Sep 2023 11:05:36 -0500
Subject: [PATCH 5/8] redisable

---
 tests/ruby/mirrors_spec.rb | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/tests/ruby/mirrors_spec.rb b/tests/ruby/mirrors_spec.rb
index 91123a16..b6a4514c 100644
--- a/tests/ruby/mirrors_spec.rb
+++ b/tests/ruby/mirrors_spec.rb
@@ -25,7 +25,7 @@
     processes.pgcat.shutdown
   end
 
-  it "can mirror a query" do
+  xit "can mirror a query" do
     conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
     runs = 15
     runs.times { conn.async_exec("SELECT 1 + 2") }
@@ -58,7 +58,7 @@
     end
   end
 
-  context "when mirror server goes down temporarily" do
+  xcontext "when mirror server goes down temporarily" do
     it "continues to transmit queries after recovery" do
       conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
       mirror_pg.take_down do

From f6ae4990fd8d9afcb65990e357e9cbd0771ad063 Mon Sep 17 00:00:00 2001
From: Mostafa Abdelraouf <mostafa.abdelraouf@instacart.com>
Date: Mon, 4 Sep 2023 11:14:37 -0500
Subject: [PATCH 6/8] add config.md

---
 CONFIG.md | 188 ++++++++++++++++++++++++------------------------------
 1 file changed, 85 insertions(+), 103 deletions(-)

diff --git a/CONFIG.md b/CONFIG.md
index fc118cb4..a0899697 100644
--- a/CONFIG.md
+++ b/CONFIG.md
@@ -1,4 +1,4 @@
-# PgCat Configurations
+# PgCat Configurations 
 ## `general` Section
 
 ### host
@@ -57,38 +57,6 @@ default: 86400000 # 24 hours
 
 Max connection lifetime before it's closed, even if actively used.
 
-### server_round_robin
-```
-path: general.server_round_robin
-default: false
-```
-
-Whether to use round robin for server selection or not.
-
-### server_tls
-```
-path: general.server_tls
-default: false
-```
-
-Whether to use TLS for server connections or not.
-
-### verify_server_certificate
-```
-path: general.verify_server_certificate
-default: false
-```
-
-Whether to verify server certificate or not.
-
-### verify_config
-```
-path: general.verify_config
-default: true
-```
-
-Whether to verify config or not.
-
 ### idle_client_in_transaction_timeout
 ```
 path: general.idle_client_in_transaction_timeout
@@ -148,10 +116,10 @@ If we should log client disconnections
 ### autoreload
 ```
 path: general.autoreload
-default: 15000 # milliseconds
+default: 15000
 ```
 
-When set, PgCat automatically reloads its configurations at the specified interval (in milliseconds) if it detects changes in the configuration file. The default interval is 15000 milliseconds or 15 seconds.
+When set to true, PgCat reloads configs if it detects a change in the config file.
 
 ### worker_threads
 ```
@@ -183,19 +151,29 @@ path: general.tcp_keepalives_interval
 default: 5
 ```
 
-### tcp_user_timeout
+Number of seconds between keepalive packets.
+
+### prepared_statements
+```
+path: general.prepared_statements
+default: true
+```
+
+Handle prepared statements.
+
+### prepared_statements_cache_size
 ```
-path: general.tcp_user_timeout
-default: 10000
+path: general.prepared_statements_cache_size
+default: 500
 ```
-A linux-only parameters that defines the amount of time in milliseconds that transmitted data may remain unacknowledged or buffered data may remain untransmitted (due to zero window size) before TCP will forcibly disconnect
 
+Prepared statements server cache size.
 
 ### tls_certificate
 ```
 path: general.tls_certificate
 default: <UNSET>
-example: "server.cert"
+example: ".circleci/server.cert"
 ```
 
 Path to TLS Certificate file to use for TLS connections
@@ -204,11 +182,27 @@ Path to TLS Certificate file to use for TLS connections
 ```
 path: general.tls_private_key
 default: <UNSET>
-example: "server.key"
+example: ".circleci/server.key"
 ```
 
 Path to TLS private key file to use for TLS connections
 
+### server_tls
+```
+path: general.server_tls
+default: false
+```
+
+Enable/disable server TLS
+
+### verify_server_certificate
+```
+path: general.verify_server_certificate
+default: false
+```
+
+Verify server certificate is completely authentic.
+
 ### admin_username
 ```
 path: general.admin_username
@@ -226,70 +220,15 @@ default: "admin_pass"
 
 Password to access the virtual administrative database
 
-### auth_query
-```
-path: general.auth_query
-default: <UNSET>
-example: "SELECT $1"
-```
+## `plugins` Section
 
-Query to be sent to servers to obtain the hash used for md5 authentication. The connection will be
-established using the database configured in the pool. This parameter is inherited by every pool
-and can be redefined in pool configuration.
+## `plugins.prewarmer` Section
 
-### auth_query_user
-```
-path: general.auth_query_user
-default: <UNSET>
-example: "sharding_user"
-```
+## `plugins.query_logger` Section
 
-User to be used for connecting to servers to obtain the hash used for md5 authentication by sending the query
-specified in `auth_query_user`. The connection will be established using the database configured in the pool.
-This parameter is inherited by every pool and can be redefined in pool configuration.
+## `plugins.table_access` Section
 
-### auth_query_password
-```
-path: general.auth_query_password
-default: <UNSET>
-example: "sharding_user"
-```
-
-Password to be used for connecting to servers to obtain the hash used for md5 authentication by sending the query
-specified in `auth_query_user`. The connection will be established using the database configured in the pool.
-This parameter is inherited by every pool and can be redefined in pool configuration.
-
-### prepared_statements
-```
-path: general.prepared_statements
-default: false
-```
-
-Whether to use prepared statements or not.
-
-### prepared_statements_cache_size
-```
-path: general.prepared_statements_cache_size
-default: 500
-```
-
-Size of the prepared statements cache.
-
-### dns_cache_enabled
-```
-path: general.dns_cache_enabled
-default: false
-```
-When enabled, ip resolutions for server connections specified using hostnames will be cached
-and checked for changes every `dns_max_ttl` seconds. If a change in the host resolution is found
-old ip connections are closed (gracefully) and new connections will start using new ip.
-
-### dns_max_ttl
-```
-path: general.dns_max_ttl
-default: 30
-```
-Specifies how often (in seconds) cached ip addresses for servers are rechecked (see `dns_cache_enabled`).
+## `plugins.intercept` Section
 
 ## `pools.<pool_name>` Section
 
@@ -311,7 +250,7 @@ default: "random"
 
 Load balancing mode
 `random` selects the server at random
-`loc` selects the server with the least outstanding busy connections
+`loc` selects the server with the least outstanding busy conncetions
 
 ### default_role
 ```
@@ -335,6 +274,15 @@ every incoming query to determine if it's a read or a write.
 If it's a read query, we'll direct it to a replica. Otherwise, if it's a write,
 we'll direct it to the primary.
 
+### query_parser_read_write_splitting
+```
+path: pools.<pool_name>.query_parser_read_write_splitting
+default: true
+```
+
+If the query parser is enabled and this setting is enabled, we'll attempt to
+infer the role from the query itself.
+
 ### primary_reads_enabled
 ```
 path: pools.<pool_name>.primary_reads_enabled
@@ -355,6 +303,18 @@ example: '/\* sharding_key: (\d+) \*/'
 Allow sharding commands to be passed as statement comments instead of
 separate commands. If these are unset this functionality is disabled.
 
+### no_shard_specified_behavior
+```
+path: pools.<pool_name>.no_shard_specified_behavior
+default: <UNSET>
+example: "shard_0"
+```
+
+Defines the behavior when no shard is selected in a sharded system.
+`random`: picks a shard at random
+`random_healthy`: picks a shard at random favoring shards with the least number of recent errors
+`shard_<number>`: e.g. shard_0, shard_4, etc. picks a specific shard, everytime
+
 ### sharding_function
 ```
 path: pools.<pool_name>.sharding_function
@@ -371,7 +331,7 @@ Current options:
 ```
 path: pools.<pool_name>.auth_query
 default: <UNSET>
-example: "SELECT $1"
+example: "SELECT usename, passwd FROM pg_shadow WHERE usename='$1'"
 ```
 
 Query to be sent to servers to obtain the hash used for md5 authentication. The connection will be
@@ -425,6 +385,28 @@ default: 3000
 
 Connect timeout can be overwritten in the pool
 
+### dns_cache_enabled
+```
+path: pools.<pool_name>.dns_cache_enabled
+default: <UNSET>
+example: false
+```
+
+When enabled, ip resolutions for server connections specified using hostnames will be cached
+and checked for changes every `dns_max_ttl` seconds. If a change in the host resolution is found
+old ip connections are closed (gracefully) and new connections will start using new ip.
+
+### dns_max_ttl
+```
+path: pools.<pool_name>.dns_max_ttl
+default: <UNSET>
+example: 30
+```
+
+Specifies how often (in seconds) cached ip addresses for servers are rechecked (see `dns_cache_enabled`).
+
+## `pool.<pool_name>.plugins` Section
+
 ## `pools.<pool_name>.users.<user_index>` Section
 
 ### username

From 39674bc20cc7c793d052d11775130cabc29c57df Mon Sep 17 00:00:00 2001
From: Mostafa Abdelraouf <mostafa.abdelraouf@instacart.com>
Date: Mon, 4 Sep 2023 11:21:42 -0500
Subject: [PATCH 7/8] revert

---
 CONFIG.md | 188 ++++++++++++++++++++++++++++++------------------------
 1 file changed, 103 insertions(+), 85 deletions(-)

diff --git a/CONFIG.md b/CONFIG.md
index a0899697..fc118cb4 100644
--- a/CONFIG.md
+++ b/CONFIG.md
@@ -1,4 +1,4 @@
-# PgCat Configurations 
+# PgCat Configurations
 ## `general` Section
 
 ### host
@@ -57,6 +57,38 @@ default: 86400000 # 24 hours
 
 Max connection lifetime before it's closed, even if actively used.
 
+### server_round_robin
+```
+path: general.server_round_robin
+default: false
+```
+
+Whether to use round robin for server selection or not.
+
+### server_tls
+```
+path: general.server_tls
+default: false
+```
+
+Whether to use TLS for server connections or not.
+
+### verify_server_certificate
+```
+path: general.verify_server_certificate
+default: false
+```
+
+Whether to verify server certificate or not.
+
+### verify_config
+```
+path: general.verify_config
+default: true
+```
+
+Whether to verify config or not.
+
 ### idle_client_in_transaction_timeout
 ```
 path: general.idle_client_in_transaction_timeout
@@ -116,10 +148,10 @@ If we should log client disconnections
 ### autoreload
 ```
 path: general.autoreload
-default: 15000
+default: 15000 # milliseconds
 ```
 
-When set to true, PgCat reloads configs if it detects a change in the config file.
+When set, PgCat automatically reloads its configurations at the specified interval (in milliseconds) if it detects changes in the configuration file. The default interval is 15000 milliseconds or 15 seconds.
 
 ### worker_threads
 ```
@@ -151,29 +183,19 @@ path: general.tcp_keepalives_interval
 default: 5
 ```
 
-Number of seconds between keepalive packets.
-
-### prepared_statements
-```
-path: general.prepared_statements
-default: true
-```
-
-Handle prepared statements.
-
-### prepared_statements_cache_size
+### tcp_user_timeout
 ```
-path: general.prepared_statements_cache_size
-default: 500
+path: general.tcp_user_timeout
+default: 10000
 ```
+A linux-only parameters that defines the amount of time in milliseconds that transmitted data may remain unacknowledged or buffered data may remain untransmitted (due to zero window size) before TCP will forcibly disconnect
 
-Prepared statements server cache size.
 
 ### tls_certificate
 ```
 path: general.tls_certificate
 default: <UNSET>
-example: ".circleci/server.cert"
+example: "server.cert"
 ```
 
 Path to TLS Certificate file to use for TLS connections
@@ -182,27 +204,11 @@ Path to TLS Certificate file to use for TLS connections
 ```
 path: general.tls_private_key
 default: <UNSET>
-example: ".circleci/server.key"
+example: "server.key"
 ```
 
 Path to TLS private key file to use for TLS connections
 
-### server_tls
-```
-path: general.server_tls
-default: false
-```
-
-Enable/disable server TLS
-
-### verify_server_certificate
-```
-path: general.verify_server_certificate
-default: false
-```
-
-Verify server certificate is completely authentic.
-
 ### admin_username
 ```
 path: general.admin_username
@@ -220,15 +226,70 @@ default: "admin_pass"
 
 Password to access the virtual administrative database
 
-## `plugins` Section
+### auth_query
+```
+path: general.auth_query
+default: <UNSET>
+example: "SELECT $1"
+```
 
-## `plugins.prewarmer` Section
+Query to be sent to servers to obtain the hash used for md5 authentication. The connection will be
+established using the database configured in the pool. This parameter is inherited by every pool
+and can be redefined in pool configuration.
 
-## `plugins.query_logger` Section
+### auth_query_user
+```
+path: general.auth_query_user
+default: <UNSET>
+example: "sharding_user"
+```
 
-## `plugins.table_access` Section
+User to be used for connecting to servers to obtain the hash used for md5 authentication by sending the query
+specified in `auth_query_user`. The connection will be established using the database configured in the pool.
+This parameter is inherited by every pool and can be redefined in pool configuration.
 
-## `plugins.intercept` Section
+### auth_query_password
+```
+path: general.auth_query_password
+default: <UNSET>
+example: "sharding_user"
+```
+
+Password to be used for connecting to servers to obtain the hash used for md5 authentication by sending the query
+specified in `auth_query_user`. The connection will be established using the database configured in the pool.
+This parameter is inherited by every pool and can be redefined in pool configuration.
+
+### prepared_statements
+```
+path: general.prepared_statements
+default: false
+```
+
+Whether to use prepared statements or not.
+
+### prepared_statements_cache_size
+```
+path: general.prepared_statements_cache_size
+default: 500
+```
+
+Size of the prepared statements cache.
+
+### dns_cache_enabled
+```
+path: general.dns_cache_enabled
+default: false
+```
+When enabled, ip resolutions for server connections specified using hostnames will be cached
+and checked for changes every `dns_max_ttl` seconds. If a change in the host resolution is found
+old ip connections are closed (gracefully) and new connections will start using new ip.
+
+### dns_max_ttl
+```
+path: general.dns_max_ttl
+default: 30
+```
+Specifies how often (in seconds) cached ip addresses for servers are rechecked (see `dns_cache_enabled`).
 
 ## `pools.<pool_name>` Section
 
@@ -250,7 +311,7 @@ default: "random"
 
 Load balancing mode
 `random` selects the server at random
-`loc` selects the server with the least outstanding busy conncetions
+`loc` selects the server with the least outstanding busy connections
 
 ### default_role
 ```
@@ -274,15 +335,6 @@ every incoming query to determine if it's a read or a write.
 If it's a read query, we'll direct it to a replica. Otherwise, if it's a write,
 we'll direct it to the primary.
 
-### query_parser_read_write_splitting
-```
-path: pools.<pool_name>.query_parser_read_write_splitting
-default: true
-```
-
-If the query parser is enabled and this setting is enabled, we'll attempt to
-infer the role from the query itself.
-
 ### primary_reads_enabled
 ```
 path: pools.<pool_name>.primary_reads_enabled
@@ -303,18 +355,6 @@ example: '/\* sharding_key: (\d+) \*/'
 Allow sharding commands to be passed as statement comments instead of
 separate commands. If these are unset this functionality is disabled.
 
-### no_shard_specified_behavior
-```
-path: pools.<pool_name>.no_shard_specified_behavior
-default: <UNSET>
-example: "shard_0"
-```
-
-Defines the behavior when no shard is selected in a sharded system.
-`random`: picks a shard at random
-`random_healthy`: picks a shard at random favoring shards with the least number of recent errors
-`shard_<number>`: e.g. shard_0, shard_4, etc. picks a specific shard, everytime
-
 ### sharding_function
 ```
 path: pools.<pool_name>.sharding_function
@@ -331,7 +371,7 @@ Current options:
 ```
 path: pools.<pool_name>.auth_query
 default: <UNSET>
-example: "SELECT usename, passwd FROM pg_shadow WHERE usename='$1'"
+example: "SELECT $1"
 ```
 
 Query to be sent to servers to obtain the hash used for md5 authentication. The connection will be
@@ -385,28 +425,6 @@ default: 3000
 
 Connect timeout can be overwritten in the pool
 
-### dns_cache_enabled
-```
-path: pools.<pool_name>.dns_cache_enabled
-default: <UNSET>
-example: false
-```
-
-When enabled, ip resolutions for server connections specified using hostnames will be cached
-and checked for changes every `dns_max_ttl` seconds. If a change in the host resolution is found
-old ip connections are closed (gracefully) and new connections will start using new ip.
-
-### dns_max_ttl
-```
-path: pools.<pool_name>.dns_max_ttl
-default: <UNSET>
-example: 30
-```
-
-Specifies how often (in seconds) cached ip addresses for servers are rechecked (see `dns_cache_enabled`).
-
-## `pool.<pool_name>.plugins` Section
-
 ## `pools.<pool_name>.users.<user_index>` Section
 
 ### username

From 433bf6eb4ace8909d17f12ab0ba9ef1aa837fc98 Mon Sep 17 00:00:00 2001
From: Mostafa Abdelraouf <mostafa.abdelraouf@instacart.com>
Date: Thu, 7 Sep 2023 13:56:02 -0500
Subject: [PATCH 8/8] clean up code a bit

---
 src/client.rs               | 38 +++++++++-------
 src/config.rs               | 36 +++++++--------
 src/pool.rs                 | 88 +++++++++++++++++++------------------
 src/query_router.rs         |  4 +-
 tests/ruby/sharding_spec.rb |  6 +--
 5 files changed, 89 insertions(+), 83 deletions(-)

diff --git a/src/client.rs b/src/client.rs
index ba035169..4b281121 100644
--- a/src/client.rs
+++ b/src/client.rs
@@ -1009,23 +1009,27 @@ where
 
                 // SET SHARD TO
                 Some((Command::SetShard, _)) => {
-                    // Selected shard is not configured.
-                    if query_router.shard().unwrap_or(0) >= pool.shards() {
-                        // Set the shard back to what it was.
-                        query_router.set_shard(current_shard);
-
-                        error_response(
-                            &mut self.write,
-                            &format!(
-                                "shard {:?} is more than configured {}, staying on shard {:?} (shard numbers start at 0)",
-                                query_router.shard(),
-                                pool.shards(),
-                                current_shard,
-                            ),
-                        )
-                            .await?;
-                    } else {
-                        custom_protocol_response_ok(&mut self.write, "SET SHARD").await?;
+                    match query_router.shard() {
+                        None => (),
+                        Some(selected_shard) => {
+                            if selected_shard >= pool.shards() {
+                                // Bad shard number, send error message to client.
+                                query_router.set_shard(current_shard);
+
+                                error_response(
+                                    &mut self.write,
+                                    &format!(
+                                        "shard {} is not configured {}, staying on shard {:?} (shard numbers start at 0)",
+                                        selected_shard,
+                                        pool.shards(),
+                                        current_shard,
+                                    ),
+                                )
+                                    .await?;
+                            } else {
+                                custom_protocol_response_ok(&mut self.write, "SET SHARD").await?;
+                            }
+                        }
                     }
                     continue;
                 }
diff --git a/src/config.rs b/src/config.rs
index 5814a9fa..0404abc9 100644
--- a/src/config.rs
+++ b/src/config.rs
@@ -559,8 +559,8 @@ pub struct Pool {
     pub shard_id_regex: Option<String>,
     pub regex_search_limit: Option<usize>,
 
-    #[serde(default = "Pool::default_no_shard_specified_behavior")]
-    pub no_shard_specified_behavior: NoShardSpecifiedHandling,
+    #[serde(default = "Pool::default_default_shard")]
+    pub default_shard: DefaultShard,
 
     pub auth_query: Option<String>,
     pub auth_query_user: Option<String>,
@@ -597,8 +597,8 @@ impl Pool {
         PoolMode::Transaction
     }
 
-    pub fn default_no_shard_specified_behavior() -> NoShardSpecifiedHandling {
-        NoShardSpecifiedHandling::default()
+    pub fn default_default_shard() -> DefaultShard {
+        DefaultShard::default()
     }
 
     pub fn default_load_balancing_mode() -> LoadBalancingMode {
@@ -692,8 +692,8 @@ impl Pool {
             None => None,
         };
 
-        match self.no_shard_specified_behavior {
-            NoShardSpecifiedHandling::Shard(shard_number) => {
+        match self.default_shard {
+            DefaultShard::Shard(shard_number) => {
                 if shard_number >= self.shards.len() {
                     error!("Invalid shard {:?}", shard_number);
                     return Err(Error::BadConfig);
@@ -729,7 +729,7 @@ impl Default for Pool {
             sharding_key_regex: None,
             shard_id_regex: None,
             regex_search_limit: Some(1000),
-            no_shard_specified_behavior: Self::default_no_shard_specified_behavior(),
+            default_shard: Self::default_default_shard(),
             auth_query: None,
             auth_query_user: None,
             auth_query_password: None,
@@ -750,28 +750,28 @@ pub struct ServerConfig {
 
 // No Shard Specified handling.
 #[derive(Debug, PartialEq, Clone, Eq, Hash, Copy)]
-pub enum NoShardSpecifiedHandling {
+pub enum DefaultShard {
     Shard(usize),
     Random,
     RandomHealthy,
 }
-impl Default for NoShardSpecifiedHandling {
+impl Default for DefaultShard {
     fn default() -> Self {
-        NoShardSpecifiedHandling::Shard(0)
+        DefaultShard::Shard(0)
     }
 }
-impl serde::Serialize for NoShardSpecifiedHandling {
+impl serde::Serialize for DefaultShard {
     fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
         match self {
-            NoShardSpecifiedHandling::Shard(shard) => {
+            DefaultShard::Shard(shard) => {
                 serializer.serialize_str(&format!("shard_{}", &shard.to_string()))
             }
-            NoShardSpecifiedHandling::Random => serializer.serialize_str("random"),
-            NoShardSpecifiedHandling::RandomHealthy => serializer.serialize_str("random_healthy"),
+            DefaultShard::Random => serializer.serialize_str("random"),
+            DefaultShard::RandomHealthy => serializer.serialize_str("random_healthy"),
         }
     }
 }
-impl<'de> serde::Deserialize<'de> for NoShardSpecifiedHandling {
+impl<'de> serde::Deserialize<'de> for DefaultShard {
     fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
     where
         D: Deserializer<'de>,
@@ -779,12 +779,12 @@ impl<'de> serde::Deserialize<'de> for NoShardSpecifiedHandling {
         let s = String::deserialize(deserializer)?;
         if s.starts_with("shard_") {
             let shard = s[6..].parse::<usize>().map_err(serde::de::Error::custom)?;
-            return Ok(NoShardSpecifiedHandling::Shard(shard));
+            return Ok(DefaultShard::Shard(shard));
         }
 
         match s.as_str() {
-            "random" => Ok(NoShardSpecifiedHandling::Random),
-            "random_healthy" => Ok(NoShardSpecifiedHandling::RandomHealthy),
+            "random" => Ok(DefaultShard::Random),
+            "random_healthy" => Ok(DefaultShard::RandomHealthy),
             _ => Err(serde::de::Error::custom(
                 "invalid value for no_shard_specified_behavior",
             )),
diff --git a/src/pool.rs b/src/pool.rs
index 9e3a193a..18123407 100644
--- a/src/pool.rs
+++ b/src/pool.rs
@@ -19,8 +19,7 @@ use std::time::Instant;
 use tokio::sync::Notify;
 
 use crate::config::{
-    get_config, Address, General, LoadBalancingMode, NoShardSpecifiedHandling, Plugins, PoolMode,
-    Role, User,
+    get_config, Address, DefaultShard, General, LoadBalancingMode, Plugins, PoolMode, Role, User,
 };
 use crate::errors::Error;
 
@@ -143,7 +142,7 @@ pub struct PoolSettings {
     pub shard_id_regex: Option<Regex>,
 
     // What to do when no shard is selected in a sharded system
-    pub no_shard_specified_behavior: NoShardSpecifiedHandling,
+    pub default_shard: DefaultShard,
 
     // Limit how much of each query is searched for a potential shard regex match
     pub regex_search_limit: usize,
@@ -178,7 +177,7 @@ impl Default for PoolSettings {
             sharding_key_regex: None,
             shard_id_regex: None,
             regex_search_limit: 1000,
-            no_shard_specified_behavior: NoShardSpecifiedHandling::Shard(0),
+            default_shard: DefaultShard::Shard(0),
             auth_query: None,
             auth_query_user: None,
             auth_query_password: None,
@@ -490,9 +489,7 @@ impl ConnectionPool {
                             .clone()
                             .map(|regex| Regex::new(regex.as_str()).unwrap()),
                         regex_search_limit: pool_config.regex_search_limit.unwrap_or(1000),
-                        no_shard_specified_behavior: pool_config
-                            .no_shard_specified_behavior
-                            .clone(),
+                        default_shard: pool_config.default_shard.clone(),
                         auth_query: pool_config.auth_query.clone(),
                         auth_query_user: pool_config.auth_query_user.clone(),
                         auth_query_password: pool_config.auth_query_password.clone(),
@@ -618,48 +615,46 @@ impl ConnectionPool {
         role: Option<Role>,         // primary or replica
         client_stats: &ClientStats, // client id
     ) -> Result<(PooledConnection<'_, ServerPool>, Address), Error> {
-        let mut effective_shard_id: Option<usize> = shard;
+        let effective_shard_id = if self.shards() == 1 {
+            // The base, unsharded case
+            Some(0)
+        } else {
+            if !self.valid_shard_id(shard) {
+                // None is valid shard ID so it is safe to unwrap here
+                return Err(Error::InvalidShardId(shard.unwrap()));
+            }
+            shard
+        };
 
-        // The base, unsharded case
-        if self.shards() == 1 {
-            effective_shard_id = Some(0);
-        }
+        let mut candidates = self
+            .addresses
+            .iter()
+            .flatten()
+            .filter(|address| address.role == role)
+            .collect::<Vec<&Address>>();
 
-        let mut sort_by_error_count = false;
-        let mut candidates: Vec<_> = match effective_shard_id {
-            Some(shard_id) => self.addresses[shard_id].iter().collect(),
-            None => match self.settings.no_shard_specified_behavior {
-                NoShardSpecifiedHandling::Random => self.addresses.iter().flatten().collect(),
-                NoShardSpecifiedHandling::RandomHealthy => {
-                    sort_by_error_count = true;
-                    self.addresses.iter().flatten().collect()
+        // We start with a shuffled list of addresses even if we end up resorting
+        // this is meant to avoid hitting instance 0 everytime if the sorting metric
+        // ends up being the same for all instances
+        candidates.shuffle(&mut thread_rng());
+
+        match effective_shard_id {
+            Some(shard_id) => candidates.retain(|address| address.shard == shard_id),
+            None => match self.settings.default_shard {
+                DefaultShard::Shard(shard_id) => {
+                    candidates.retain(|address| address.shard == shard_id)
                 }
-                NoShardSpecifiedHandling::Shard(shard) => {
-                    if shard >= self.shards() {
-                        return Err(Error::InvalidShardId(shard));
-                    } else {
-                        self.addresses[shard].iter().collect()
-                    }
+                DefaultShard::Random => (),
+                DefaultShard::RandomHealthy => {
+                    candidates.sort_by(|a, b| {
+                        b.error_count
+                            .load(Ordering::Relaxed)
+                            .partial_cmp(&a.error_count.load(Ordering::Relaxed))
+                            .unwrap()
+                    });
                 }
             },
         };
-        candidates.retain(|address| address.role == role);
-
-        // We shuffle even if least_outstanding_queries is used to avoid imbalance
-        // in cases where all candidates have more or less the same number of outstanding
-        // queries
-        candidates.shuffle(&mut thread_rng());
-
-        // The branch should only be hit if no shard is specified and we are using
-        // random healthy routing mode
-        if sort_by_error_count {
-            candidates.sort_by(|a, b| {
-                b.error_count
-                    .load(Ordering::Relaxed)
-                    .partial_cmp(&a.error_count.load(Ordering::Relaxed))
-                    .unwrap()
-            });
-        }
 
         if self.settings.load_balancing_mode == LoadBalancingMode::LeastOutstandingConnections {
             candidates.sort_by(|a, b| {
@@ -994,6 +989,13 @@ impl ConnectionPool {
         debug!("{:?} has {:?} busy connections", address, busy);
         return busy;
     }
+
+    fn valid_shard_id(&self, shard: Option<usize>) -> bool {
+        match shard {
+            None => true,
+            Some(shard) => shard < self.shards(),
+        }
+    }
 }
 
 /// Wrapper for the bb8 connection pool.
diff --git a/src/query_router.rs b/src/query_router.rs
index f2f20b93..9d7a106a 100644
--- a/src/query_router.rs
+++ b/src/query_router.rs
@@ -1207,7 +1207,7 @@ mod test {
             ban_time: PoolSettings::default().ban_time,
             sharding_key_regex: None,
             shard_id_regex: None,
-            no_shard_specified_behavior: crate::config::NoShardSpecifiedHandling::Shard(0),
+            default_shard: crate::config::DefaultShard::Shard(0),
             regex_search_limit: 1000,
             auth_query: None,
             auth_query_password: None,
@@ -1285,7 +1285,7 @@ mod test {
             ban_time: PoolSettings::default().ban_time,
             sharding_key_regex: Some(Regex::new(r"/\* sharding_key: (\d+) \*/").unwrap()),
             shard_id_regex: Some(Regex::new(r"/\* shard_id: (\d+) \*/").unwrap()),
-            no_shard_specified_behavior: crate::config::NoShardSpecifiedHandling::Shard(0),
+            default_shard: crate::config::DefaultShard::Shard(0),
             regex_search_limit: 1000,
             auth_query: None,
             auth_query_password: None,
diff --git a/tests/ruby/sharding_spec.rb b/tests/ruby/sharding_spec.rb
index 47bc2a12..746627d1 100644
--- a/tests/ruby/sharding_spec.rb
+++ b/tests/ruby/sharding_spec.rb
@@ -56,7 +56,7 @@
         admin_conn = PG::connect(processes.pgcat.admin_connection_string)
 
         current_configs = processes.pgcat.current_config
-        current_configs["pools"]["sharded_db"]["no_shard_specified_behavior"] = "shard_99"
+        current_configs["pools"]["sharded_db"]["default_shard"] = "shard_99"
 
         processes.pgcat.update_config(current_configs)
 
@@ -91,9 +91,9 @@
           current_configs = processes.pgcat.current_config
           current_configs["pools"]["sharded_db"][config_name] = config_value
           if no_shard_specified_behavior
-            current_configs["pools"]["sharded_db"]["no_shard_specified_behavior"] = no_shard_specified_behavior
+            current_configs["pools"]["sharded_db"]["default_shard"] = no_shard_specified_behavior
           else
-            current_configs["pools"]["sharded_db"].delete("no_shard_specified_behavior")
+            current_configs["pools"]["sharded_db"].delete("default_shard")
           end
 
           processes.pgcat.update_config(current_configs)