Skip to content

Allow configuring routing decision when no shard is selected #578

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Sep 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ jobs:
# Specify the execution environment. You can specify an image from Dockerhub or use one of our Convenience Images from CircleCI's Developer Hub.
# See: https://circleci.com/docs/2.0/configuration-reference/#docker-machine-macos-windows-executor
docker:
- image: ghcr.io/levkk/pgcat-ci:1.67
- image: ghcr.io/postgresml/pgcat-ci:latest
environment:
RUST_LOG: info
LLVM_PROFILE_FILE: /tmp/pgcat-%m-%p.profraw
Expand Down
6 changes: 6 additions & 0 deletions pgcat.toml
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,12 @@ primary_reads_enabled = true
# shard_id_regex = '/\* shard_id: (\d+) \*/'
# regex_search_limit = 1000 # only look at the first 1000 characters of SQL statements

# Defines the behavior when no shard is selected in a sharded system.
# `random`: picks a shard at random
# `random_healthy`: picks a shard at random favoring shards with the least number of recent errors
# `shard_<number>`: e.g. shard_0, shard_4, etc. picks a specific shard, everytime
# no_shard_specified_behavior = "shard_0"

# So what if you wanted to implement a different hashing function,
# or you've already built one and you want this pooler to use it?
# Current options:
Expand Down
47 changes: 27 additions & 20 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1009,23 +1009,27 @@ where

// SET SHARD TO
Some((Command::SetShard, _)) => {
// Selected shard is not configured.
if query_router.shard() >= pool.shards() {
// Set the shard back to what it was.
query_router.set_shard(current_shard);

error_response(
&mut self.write,
&format!(
"shard {} is more than configured {}, staying on shard {} (shard numbers start at 0)",
query_router.shard(),
pool.shards(),
current_shard,
),
)
.await?;
} else {
custom_protocol_response_ok(&mut self.write, "SET SHARD").await?;
match query_router.shard() {
None => (),
Some(selected_shard) => {
if selected_shard >= pool.shards() {
// Bad shard number, send error message to client.
query_router.set_shard(current_shard);

error_response(
&mut self.write,
&format!(
"shard {} is not configured {}, staying on shard {:?} (shard numbers start at 0)",
selected_shard,
pool.shards(),
current_shard,
),
)
.await?;
} else {
custom_protocol_response_ok(&mut self.write, "SET SHARD").await?;
}
}
}
continue;
}
Expand Down Expand Up @@ -1093,8 +1097,11 @@ where
self.buffer.clear();
}

error_response(&mut self.write, "could not get connection from the pool")
.await?;
error_response(
&mut self.write,
format!("could not get connection from the pool - {}", err).as_str(),
)
.await?;

error!(
"Could not get connection from pool: \
Expand Down Expand Up @@ -1234,7 +1241,7 @@ where
{{ \
pool_name: {}, \
username: {}, \
shard: {}, \
shard: {:?}, \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Debug output going to the client here.

role: \"{:?}\" \
}}",
self.pool_name,
Expand Down
81 changes: 81 additions & 0 deletions src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,14 @@ use arc_swap::ArcSwap;
use log::{error, info};
use once_cell::sync::Lazy;
use regex::Regex;
use serde::{Deserializer, Serializer};
use serde_derive::{Deserialize, Serialize};

use std::collections::hash_map::DefaultHasher;
use std::collections::{BTreeMap, HashMap, HashSet};
use std::hash::{Hash, Hasher};
use std::path::Path;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use tokio::fs::File;
use tokio::io::AsyncReadExt;
Expand Down Expand Up @@ -101,6 +104,9 @@ pub struct Address {

/// Address stats
pub stats: Arc<AddressStats>,

/// Number of errors encountered since last successful checkout
pub error_count: Arc<AtomicU64>,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we show this in admin? This would be a really cool metric to expose.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't be too difficult to expose it. It gets reset with every successful checkout so I am not sure how useful it would be.

Copy link
Contributor

@levkk levkk Sep 8, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see. Not so useful then, unless the replica is banned? Could be useful to see how many errors it took before it got successfully banned. Alternatively, we don't reset that number? If it's an i64, it can be incremented ...forever without worrying about overflows.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought about not resetting it too but that meant errors that occurred 2 days ago can affect routing decisions today

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point.

}

impl Default for Address {
Expand All @@ -118,6 +124,7 @@ impl Default for Address {
pool_name: String::from("pool_name"),
mirrors: Vec::new(),
stats: Arc::new(AddressStats::default()),
error_count: Arc::new(AtomicU64::new(0)),
}
}
}
Expand Down Expand Up @@ -182,6 +189,18 @@ impl Address {
),
}
}

pub fn error_count(&self) -> u64 {
self.error_count.load(Ordering::Relaxed)
}

pub fn increment_error_count(&self) {
self.error_count.fetch_add(1, Ordering::Relaxed);
}

pub fn reset_error_count(&self) {
self.error_count.store(0, Ordering::Relaxed);
}
}

/// PostgreSQL user.
Expand Down Expand Up @@ -540,6 +559,9 @@ pub struct Pool {
pub shard_id_regex: Option<String>,
pub regex_search_limit: Option<usize>,

#[serde(default = "Pool::default_default_shard")]
pub default_shard: DefaultShard,

pub auth_query: Option<String>,
pub auth_query_user: Option<String>,
pub auth_query_password: Option<String>,
Expand Down Expand Up @@ -575,6 +597,10 @@ impl Pool {
PoolMode::Transaction
}

pub fn default_default_shard() -> DefaultShard {
DefaultShard::default()
}

pub fn default_load_balancing_mode() -> LoadBalancingMode {
LoadBalancingMode::Random
}
Expand Down Expand Up @@ -666,6 +692,16 @@ impl Pool {
None => None,
};

match self.default_shard {
DefaultShard::Shard(shard_number) => {
if shard_number >= self.shards.len() {
error!("Invalid shard {:?}", shard_number);
return Err(Error::BadConfig);
}
}
_ => (),
}

for (_, user) in &self.users {
user.validate()?;
}
Expand Down Expand Up @@ -693,6 +729,7 @@ impl Default for Pool {
sharding_key_regex: None,
shard_id_regex: None,
regex_search_limit: Some(1000),
default_shard: Self::default_default_shard(),
auth_query: None,
auth_query_user: None,
auth_query_password: None,
Expand All @@ -711,6 +748,50 @@ pub struct ServerConfig {
pub role: Role,
}

// No Shard Specified handling.
#[derive(Debug, PartialEq, Clone, Eq, Hash, Copy)]
pub enum DefaultShard {
Shard(usize),
Random,
RandomHealthy,
}
impl Default for DefaultShard {
fn default() -> Self {
DefaultShard::Shard(0)
}
}
impl serde::Serialize for DefaultShard {
fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
match self {
DefaultShard::Shard(shard) => {
serializer.serialize_str(&format!("shard_{}", &shard.to_string()))
}
DefaultShard::Random => serializer.serialize_str("random"),
DefaultShard::RandomHealthy => serializer.serialize_str("random_healthy"),
}
}
}
impl<'de> serde::Deserialize<'de> for DefaultShard {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
if s.starts_with("shard_") {
let shard = s[6..].parse::<usize>().map_err(serde::de::Error::custom)?;
return Ok(DefaultShard::Shard(shard));
}

match s.as_str() {
"random" => Ok(DefaultShard::Random),
"random_healthy" => Ok(DefaultShard::RandomHealthy),
_ => Err(serde::de::Error::custom(
"invalid value for no_shard_specified_behavior",
)),
}
}
}

#[derive(Clone, PartialEq, Serialize, Deserialize, Debug, Hash, Eq)]
pub struct MirrorServerConfig {
pub host: String,
Expand Down
1 change: 1 addition & 0 deletions src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ pub enum Error {
UnsupportedStatement,
QueryRouterParserError(String),
QueryRouterError(String),
InvalidShardId(usize),
}

#[derive(Clone, PartialEq, Debug)]
Expand Down
Loading