Skip to content

Commit e76d720

Browse files
authored
Dont cache prepared statement with errors (#647)
* Fix prepared statement not found when prepared stmt has error * cleanup debug * remove more debug msgs * sure debugged this.. * version bump * add rust tests
1 parent 998cc16 commit e76d720

File tree

7 files changed

+88
-13
lines changed

7 files changed

+88
-13
lines changed

.circleci/run_tests.sh

+7
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,13 @@ popd
119119

120120
start_pgcat "info"
121121

122+
#
123+
# Rust tests
124+
#
125+
cd tests/rust
126+
cargo run
127+
cd ../../
128+
122129
# Admin tests
123130
export PGPASSWORD=admin_pass
124131
psql -U admin_user -e -h 127.0.0.1 -p 6432 -d pgbouncer -c 'SHOW STATS' > /dev/null

Cargo.lock

+1-1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "pgcat"
3-
version = "1.1.2-dev2"
3+
version = "1.1.2-dev4"
44
edition = "2021"
55

66
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

src/client.rs

+30-5
Original file line numberDiff line numberDiff line change
@@ -1149,7 +1149,7 @@ where
11491149
// This reads the first byte without advancing the internal pointer and mutating the bytes
11501150
let code = *message.first().unwrap() as char;
11511151

1152-
trace!("Message: {}", code);
1152+
trace!("Client message: {}", code);
11531153

11541154
match code {
11551155
// Query
@@ -1188,6 +1188,7 @@ where
11881188
};
11891189
}
11901190
}
1191+
11911192
debug!("Sending query to server");
11921193

11931194
self.send_and_receive_loop(
@@ -1320,6 +1321,7 @@ where
13201321
{
13211322
match protocol_data {
13221323
ExtendedProtocolData::Parse { data, metadata } => {
1324+
debug!("Have parse in extended buffer");
13231325
let (parse, hash) = match metadata {
13241326
Some(metadata) => metadata,
13251327
None => {
@@ -1656,11 +1658,25 @@ where
16561658
) -> Result<(), Error> {
16571659
match self.prepared_statements.get(&client_name) {
16581660
Some((parse, hash)) => {
1659-
debug!("Prepared statement `{}` found in cache", parse.name);
1661+
debug!("Prepared statement `{}` found in cache", client_name);
16601662
// In this case we want to send the parse message to the server
16611663
// since pgcat is initiating the prepared statement on this specific server
1662-
self.register_parse_to_server_cache(true, hash, parse, pool, server, address)
1663-
.await?;
1664+
match self
1665+
.register_parse_to_server_cache(true, hash, parse, pool, server, address)
1666+
.await
1667+
{
1668+
Ok(_) => (),
1669+
Err(err) => match err {
1670+
Error::PreparedStatementError => {
1671+
debug!("Removed {} from client cache", client_name);
1672+
self.prepared_statements.remove(&client_name);
1673+
}
1674+
1675+
_ => {
1676+
return Err(err);
1677+
}
1678+
},
1679+
}
16641680
}
16651681

16661682
None => {
@@ -1689,11 +1705,20 @@ where
16891705
// We want to promote this in the pool's LRU
16901706
pool.promote_prepared_statement_hash(hash);
16911707

1708+
debug!("Checking for prepared statement {}", parse.name);
1709+
16921710
if let Err(err) = server
16931711
.register_prepared_statement(parse, should_send_parse_to_server)
16941712
.await
16951713
{
1696-
pool.ban(address, BanReason::MessageSendFailed, Some(&self.stats));
1714+
match err {
1715+
// Don't ban for this.
1716+
Error::PreparedStatementError => (),
1717+
_ => {
1718+
pool.ban(address, BanReason::MessageSendFailed, Some(&self.stats));
1719+
}
1720+
};
1721+
16971722
return Err(err);
16981723
}
16991724

src/errors.rs

+1
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ pub enum Error {
2929
QueryRouterParserError(String),
3030
QueryRouterError(String),
3131
InvalidShardId(usize),
32+
PreparedStatementError,
3233
}
3334

3435
#[derive(Clone, PartialEq, Debug)]

src/server.rs

+40-5
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use lru::LruCache;
77
use once_cell::sync::Lazy;
88
use parking_lot::{Mutex, RwLock};
99
use postgres_protocol::message;
10-
use std::collections::{HashMap, HashSet};
10+
use std::collections::{HashMap, HashSet, VecDeque};
1111
use std::mem;
1212
use std::net::IpAddr;
1313
use std::num::NonZeroUsize;
@@ -325,6 +325,9 @@ pub struct Server {
325325

326326
/// Prepared statements
327327
prepared_statement_cache: Option<LruCache<String, ()>>,
328+
329+
/// Prepared statement being currently registered on the server.
330+
registering_prepared_statement: VecDeque<String>,
328331
}
329332

330333
impl Server {
@@ -827,6 +830,7 @@ impl Server {
827830
NonZeroUsize::new(prepared_statement_cache_size).unwrap(),
828831
)),
829832
},
833+
registering_prepared_statement: VecDeque::new(),
830834
};
831835

832836
return Ok(server);
@@ -956,7 +960,6 @@ impl Server {
956960

957961
// There is no more data available from the server.
958962
self.data_available = false;
959-
960963
break;
961964
}
962965

@@ -966,6 +969,23 @@ impl Server {
966969
self.in_copy_mode = false;
967970
}
968971

972+
// Remove the prepared statement from the cache, it has a syntax error or something else bad happened.
973+
if let Some(prepared_stmt_name) =
974+
self.registering_prepared_statement.pop_front()
975+
{
976+
if let Some(ref mut cache) = self.prepared_statement_cache {
977+
if let Some(_removed) = cache.pop(&prepared_stmt_name) {
978+
debug!(
979+
"Removed {} from prepared statement cache",
980+
prepared_stmt_name
981+
);
982+
} else {
983+
// Shouldn't happen.
984+
debug!("Prepared statement {} was not cached", prepared_stmt_name);
985+
}
986+
}
987+
}
988+
969989
if self.prepared_statement_cache.is_some() {
970990
let error_message = PgErrorMsg::parse(&message)?;
971991
if error_message.message == "cached plan must not change result type" {
@@ -1068,6 +1088,11 @@ impl Server {
10681088
// Buffer until ReadyForQuery shows up, so don't exit the loop yet.
10691089
'c' => (),
10701090

1091+
// Parse complete successfully
1092+
'1' => {
1093+
self.registering_prepared_statement.pop_front();
1094+
}
1095+
10711096
// Anything else, e.g. errors, notices, etc.
10721097
// Keep buffering until ReadyForQuery shows up.
10731098
_ => (),
@@ -1107,7 +1132,7 @@ impl Server {
11071132
has_it
11081133
}
11091134

1110-
pub fn add_prepared_statement_to_cache(&mut self, name: &str) -> Option<String> {
1135+
fn add_prepared_statement_to_cache(&mut self, name: &str) -> Option<String> {
11111136
let cache = match &mut self.prepared_statement_cache {
11121137
Some(cache) => cache,
11131138
None => return None,
@@ -1129,7 +1154,7 @@ impl Server {
11291154
None
11301155
}
11311156

1132-
pub fn remove_prepared_statement_from_cache(&mut self, name: &str) {
1157+
fn remove_prepared_statement_from_cache(&mut self, name: &str) {
11331158
let cache = match &mut self.prepared_statement_cache {
11341159
Some(cache) => cache,
11351160
None => return,
@@ -1145,6 +1170,9 @@ impl Server {
11451170
should_send_parse_to_server: bool,
11461171
) -> Result<(), Error> {
11471172
if !self.has_prepared_statement(&parse.name) {
1173+
self.registering_prepared_statement
1174+
.push_back(parse.name.clone());
1175+
11481176
let mut bytes = BytesMut::new();
11491177

11501178
if should_send_parse_to_server {
@@ -1176,7 +1204,13 @@ impl Server {
11761204
}
11771205
};
11781206

1179-
Ok(())
1207+
// If it's not there, something went bad, I'm guessing bad syntax or permissions error
1208+
// on the server.
1209+
if !self.has_prepared_statement(&parse.name) {
1210+
Err(Error::PreparedStatementError)
1211+
} else {
1212+
Ok(())
1213+
}
11801214
}
11811215

11821216
/// If the server is still inside a transaction.
@@ -1186,6 +1220,7 @@ impl Server {
11861220
self.in_transaction
11871221
}
11881222

1223+
/// Currently copying data from client to server or vice-versa.
11891224
pub fn in_copy_mode(&self) -> bool {
11901225
self.in_copy_mode
11911226
}

tests/rust/src/main.rs

+8-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,14 @@ async fn test_prepared_statements() {
1616
let pool = pool.clone();
1717
let handle = tokio::task::spawn(async move {
1818
for _ in 0..1000 {
19-
sqlx::query("SELECT 1").fetch_all(&pool).await.unwrap();
19+
match sqlx::query("SELECT one").fetch_all(&pool).await {
20+
Ok(_) => (),
21+
Err(err) => {
22+
if err.to_string().contains("prepared statement") {
23+
panic!("prepared statement error: {}", err);
24+
}
25+
}
26+
}
2027
}
2128
});
2229

0 commit comments

Comments
 (0)