diff --git a/.circleci/pgcat.toml b/.circleci/pgcat.toml index 377680a0..8b87aa03 100644 --- a/.circleci/pgcat.toml +++ b/.circleci/pgcat.toml @@ -74,6 +74,10 @@ default_role = "any" # we'll direct it to the primary. query_parser_enabled = true +# If the query parser is enabled and this setting is enabled, we'll attempt to +# infer the role from the query itself. +query_parser_read_write_splitting = true + # If the query parser is enabled and this setting is enabled, the primary will be part of the pool of databases used for # load balancing of read queries. Otherwise, the primary will only be used for write # queries. The primary can always be explicitely selected with our custom protocol. @@ -134,6 +138,7 @@ database = "shard2" pool_mode = "session" default_role = "primary" query_parser_enabled = true +query_parser_read_write_splitting = true primary_reads_enabled = true sharding_function = "pg_bigint_hash" diff --git a/examples/docker/pgcat.toml b/examples/docker/pgcat.toml index 5fd929de..cfd94a1a 100644 --- a/examples/docker/pgcat.toml +++ b/examples/docker/pgcat.toml @@ -71,6 +71,10 @@ default_role = "any" # we'll direct it to the primary. query_parser_enabled = true +# If the query parser is enabled and this setting is enabled, we'll attempt to +# infer the role from the query itself. +query_parser_read_write_splitting = true + # If the query parser is enabled and this setting is enabled, the primary will be part of the pool of databases used for # load balancing of read queries. Otherwise, the primary will only be used for write # queries. The primary can always be explicitly selected with our custom protocol. diff --git a/pgcat.toml b/pgcat.toml index 3e8801b6..2c4441ba 100644 --- a/pgcat.toml +++ b/pgcat.toml @@ -162,6 +162,10 @@ default_role = "any" # we'll direct it to the primary. query_parser_enabled = true +# If the query parser is enabled and this setting is enabled, we'll attempt to +# infer the role from the query itself. +query_parser_read_write_splitting = true + # If the query parser is enabled and this setting is enabled, the primary will be part of the pool of databases used for # load balancing of read queries. Otherwise, the primary will only be used for write # queries. The primary can always be explicitly selected with our custom protocol. diff --git a/src/client.rs b/src/client.rs index 7d5e9798..4f5e6c96 100644 --- a/src/client.rs +++ b/src/client.rs @@ -774,6 +774,9 @@ where let mut prepared_statement = None; let mut will_prepare = false; + let client_identifier = + ClientIdentifier::new(&self.application_name, &self.username, &self.pool_name); + // Our custom protocol loop. // We expect the client to either start a transaction with regular queries // or issue commands for our sharding and server selection protocol. @@ -812,6 +815,21 @@ where message_result = read_message(&mut self.read) => message_result? }; + // Handle admin database queries. + if self.admin { + debug!("Handling admin command"); + handle_admin(&mut self.write, message, self.client_server_map.clone()).await?; + continue; + } + + // Get a pool instance referenced by the most up-to-date + // pointer. This ensures we always read the latest config + // when starting a query. + let mut pool = self.get_pool().await?; + query_router.update_pool_settings(pool.settings.clone()); + + let mut initial_parsed_ast = None; + match message[0] as char { // Buffer extended protocol messages even if we do not have // a server connection yet. Hopefully, when we get the S message @@ -841,24 +859,34 @@ where 'Q' => { if query_router.query_parser_enabled() { - if let Ok(ast) = QueryRouter::parse(&message) { - let plugin_result = query_router.execute_plugins(&ast).await; + match query_router.parse(&message) { + Ok(ast) => { + let plugin_result = query_router.execute_plugins(&ast).await; - match plugin_result { - Ok(PluginOutput::Deny(error)) => { - error_response(&mut self.write, &error).await?; - continue; - } + match plugin_result { + Ok(PluginOutput::Deny(error)) => { + error_response(&mut self.write, &error).await?; + continue; + } - Ok(PluginOutput::Intercept(result)) => { - write_all(&mut self.write, result).await?; - continue; - } + Ok(PluginOutput::Intercept(result)) => { + write_all(&mut self.write, result).await?; + continue; + } - _ => (), - }; + _ => (), + }; + + let _ = query_router.infer(&ast); - let _ = query_router.infer(&ast); + initial_parsed_ast = Some(ast); + } + Err(error) => { + warn!( + "Query parsing error: {} (client: {})", + error, client_identifier + ); + } } } } @@ -872,13 +900,21 @@ where self.buffer.put(&message[..]); if query_router.query_parser_enabled() { - if let Ok(ast) = QueryRouter::parse(&message) { - if let Ok(output) = query_router.execute_plugins(&ast).await { - plugin_output = Some(output); - } + match query_router.parse(&message) { + Ok(ast) => { + if let Ok(output) = query_router.execute_plugins(&ast).await { + plugin_output = Some(output); + } - let _ = query_router.infer(&ast); - } + let _ = query_router.infer(&ast); + } + Err(error) => { + warn!( + "Query parsing error: {} (client: {})", + error, client_identifier + ); + } + }; } continue; @@ -922,13 +958,6 @@ where _ => (), } - // Handle admin database queries. - if self.admin { - debug!("Handling admin command"); - handle_admin(&mut self.write, message, self.client_server_map.clone()).await?; - continue; - } - // Check on plugin results. match plugin_output { Some(PluginOutput::Deny(error)) => { @@ -941,11 +970,6 @@ where _ => (), }; - // Get a pool instance referenced by the most up-to-date - // pointer. This ensures we always read the latest config - // when starting a query. - let mut pool = self.get_pool().await?; - // Check if the pool is paused and wait until it's resumed. if pool.wait_paused().await { // Refresh pool information, something might have changed. @@ -1165,6 +1189,9 @@ where None => { trace!("Waiting for message inside transaction or in session mode"); + // This is not an initial message so discard the initial_parsed_ast + initial_parsed_ast.take(); + match tokio::time::timeout( idle_client_timeout_duration, read_message(&mut self.read), @@ -1221,7 +1248,22 @@ where // Query 'Q' => { if query_router.query_parser_enabled() { - if let Ok(ast) = QueryRouter::parse(&message) { + // We don't want to parse again if we already parsed it as the initial message + let ast = match initial_parsed_ast { + Some(_) => Some(initial_parsed_ast.take().unwrap()), + None => match query_router.parse(&message) { + Ok(ast) => Some(ast), + Err(error) => { + warn!( + "Query parsing error: {} (client: {})", + error, client_identifier + ); + None + } + }, + }; + + if let Some(ast) = ast { let plugin_result = query_router.execute_plugins(&ast).await; match plugin_result { @@ -1237,8 +1279,6 @@ where _ => (), }; - - let _ = query_router.infer(&ast); } } debug!("Sending query to server"); @@ -1290,7 +1330,7 @@ where } if query_router.query_parser_enabled() { - if let Ok(ast) = QueryRouter::parse(&message) { + if let Ok(ast) = query_router.parse(&message) { if let Ok(output) = query_router.execute_plugins(&ast).await { plugin_output = Some(output); } diff --git a/src/config.rs b/src/config.rs index 9228b9bb..0e4b8c7d 100644 --- a/src/config.rs +++ b/src/config.rs @@ -511,6 +511,11 @@ pub struct Pool { #[serde(default)] // False pub query_parser_enabled: bool, + pub query_parser_max_length: Option, + + #[serde(default)] // False + pub query_parser_read_write_splitting: bool, + #[serde(default)] // False pub primary_reads_enabled: bool, @@ -627,6 +632,18 @@ impl Pool { } } + if self.query_parser_read_write_splitting && !self.query_parser_enabled { + error!( + "query_parser_read_write_splitting is only valid when query_parser_enabled is true" + ); + return Err(Error::BadConfig); + } + + if self.plugins.is_some() && !self.query_parser_enabled { + error!("plugins are only valid when query_parser_enabled is true"); + return Err(Error::BadConfig); + } + self.automatic_sharding_key = match &self.automatic_sharding_key { Some(key) => { // No quotes in the key so we don't have to compare quoted @@ -663,6 +680,8 @@ impl Default for Pool { users: BTreeMap::default(), default_role: String::from("any"), query_parser_enabled: false, + query_parser_max_length: None, + query_parser_read_write_splitting: false, primary_reads_enabled: false, sharding_function: ShardingFunction::PgBigintHash, automatic_sharding_key: None, @@ -914,6 +933,17 @@ impl From<&Config> for std::collections::HashMap { format!("pools.{}.query_parser_enabled", pool_name), pool.query_parser_enabled.to_string(), ), + ( + format!("pools.{}.query_parser_max_length", pool_name), + match pool.query_parser_max_length { + Some(max_length) => max_length.to_string(), + None => String::from("unlimited"), + }, + ), + ( + format!("pools.{}.query_parser_read_write_splitting", pool_name), + pool.query_parser_read_write_splitting.to_string(), + ), ( format!("pools.{}.default_role", pool_name), pool.default_role.clone(), @@ -1096,6 +1126,15 @@ impl Config { "[pool: {}] Query router: {}", pool_name, pool_config.query_parser_enabled ); + + info!( + "[pool: {}] Query parser max length: {:?}", + pool_name, pool_config.query_parser_max_length + ); + info!( + "[pool: {}] Infer role from query: {}", + pool_name, pool_config.query_parser_read_write_splitting + ); info!( "[pool: {}] Number of shards: {}", pool_name, diff --git a/src/pool.rs b/src/pool.rs index b9293521..dddb3ebe 100644 --- a/src/pool.rs +++ b/src/pool.rs @@ -111,6 +111,12 @@ pub struct PoolSettings { // Enable/disable query parser. pub query_parser_enabled: bool, + // Max length of query the parser will parse. + pub query_parser_max_length: Option, + + // Infer role + pub query_parser_read_write_splitting: bool, + // Read from the primary as well or not. pub primary_reads_enabled: bool, @@ -157,6 +163,8 @@ impl Default for PoolSettings { db: String::default(), default_role: None, query_parser_enabled: false, + query_parser_max_length: None, + query_parser_read_write_splitting: false, primary_reads_enabled: true, sharding_function: ShardingFunction::PgBigintHash, automatic_sharding_key: None, @@ -456,6 +464,9 @@ impl ConnectionPool { _ => unreachable!(), }, query_parser_enabled: pool_config.query_parser_enabled, + query_parser_max_length: pool_config.query_parser_max_length, + query_parser_read_write_splitting: pool_config + .query_parser_read_write_splitting, primary_reads_enabled: pool_config.primary_reads_enabled, sharding_function: pool_config.sharding_function, automatic_sharding_key: pool_config.automatic_sharding_key.clone(), diff --git a/src/query_router.rs b/src/query_router.rs index 126b8138..9676a26f 100644 --- a/src/query_router.rs +++ b/src/query_router.rs @@ -331,11 +331,23 @@ impl QueryRouter { Some((command, value)) } - pub fn parse(message: &BytesMut) -> Result, Error> { + pub fn parse(&self, message: &BytesMut) -> Result, Error> { let mut message_cursor = Cursor::new(message); let code = message_cursor.get_u8() as char; - let _len = message_cursor.get_i32() as usize; + let len = message_cursor.get_i32() as usize; + + match self.pool_settings.query_parser_max_length { + Some(max_length) => { + if len > max_length { + return Err(Error::QueryRouterParserError(format!( + "Query too long for parser: {} > {}", + len, max_length + ))); + } + } + None => (), + }; let query = match code { // Query @@ -372,6 +384,10 @@ impl QueryRouter { /// Try to infer which server to connect to based on the contents of the query. pub fn infer(&mut self, ast: &Vec) -> Result<(), Error> { + if !self.pool_settings.query_parser_read_write_splitting { + return Ok(()); // Nothing to do + } + debug!("Inferring role"); if ast.is_empty() { @@ -433,6 +449,10 @@ impl QueryRouter { /// N.B.: Only supports anonymous prepared statements since we don't /// keep a cache of them in PgCat. pub fn infer_shard_from_bind(&mut self, message: &BytesMut) -> bool { + if !self.pool_settings.query_parser_read_write_splitting { + return false; // Nothing to do + } + debug!("Parsing bind message"); let mut message_cursor = Cursor::new(message); @@ -910,6 +930,7 @@ mod test { fn test_infer_replica() { QueryRouter::setup(); let mut qr = QueryRouter::new(); + qr.pool_settings.query_parser_read_write_splitting = true; assert!(qr.try_execute_command(&simple_query("SET SERVER ROLE TO 'auto'")) != None); assert!(qr.query_parser_enabled()); @@ -925,7 +946,7 @@ mod test { for query in queries { // It's a recognized query - assert!(qr.infer(&QueryRouter::parse(&query).unwrap()).is_ok()); + assert!(qr.infer(&qr.parse(&query).unwrap()).is_ok()); assert_eq!(qr.role(), Some(Role::Replica)); } } @@ -934,6 +955,7 @@ mod test { fn test_infer_primary() { QueryRouter::setup(); let mut qr = QueryRouter::new(); + qr.pool_settings.query_parser_read_write_splitting = true; let queries = vec![ simple_query("UPDATE items SET name = 'pumpkin' WHERE id = 5"), @@ -944,7 +966,7 @@ mod test { for query in queries { // It's a recognized query - assert!(qr.infer(&QueryRouter::parse(&query).unwrap()).is_ok()); + assert!(qr.infer(&qr.parse(&query).unwrap()).is_ok()); assert_eq!(qr.role(), Some(Role::Primary)); } } @@ -956,7 +978,7 @@ mod test { let query = simple_query("SELECT * FROM items WHERE id = 5"); assert!(qr.try_execute_command(&simple_query("SET PRIMARY READS TO on")) != None); - assert!(qr.infer(&QueryRouter::parse(&query).unwrap()).is_ok()); + assert!(qr.infer(&qr.parse(&query).unwrap()).is_ok()); assert_eq!(qr.role(), None); } @@ -964,6 +986,8 @@ mod test { fn test_infer_parse_prepared() { QueryRouter::setup(); let mut qr = QueryRouter::new(); + qr.pool_settings.query_parser_read_write_splitting = true; + qr.try_execute_command(&simple_query("SET SERVER ROLE TO 'auto'")); assert!(qr.try_execute_command(&simple_query("SET PRIMARY READS TO off")) != None); @@ -976,7 +1000,7 @@ mod test { res.put(prepared_stmt); res.put_i16(0); - assert!(qr.infer(&QueryRouter::parse(&res).unwrap()).is_ok()); + assert!(qr.infer(&qr.parse(&res).unwrap()).is_ok()); assert_eq!(qr.role(), Some(Role::Replica)); } @@ -1132,6 +1156,8 @@ mod test { fn test_enable_query_parser() { QueryRouter::setup(); let mut qr = QueryRouter::new(); + qr.pool_settings.query_parser_read_write_splitting = true; + let query = simple_query("SET SERVER ROLE TO 'auto'"); assert!(qr.try_execute_command(&simple_query("SET PRIMARY READS TO off")) != None); @@ -1140,11 +1166,11 @@ mod test { assert_eq!(qr.role(), None); let query = simple_query("INSERT INTO test_table VALUES (1)"); - assert!(qr.infer(&QueryRouter::parse(&query).unwrap()).is_ok()); + assert!(qr.infer(&qr.parse(&query).unwrap()).is_ok()); assert_eq!(qr.role(), Some(Role::Primary)); let query = simple_query("SELECT * FROM test_table"); - assert!(qr.infer(&QueryRouter::parse(&query).unwrap()).is_ok()); + assert!(qr.infer(&qr.parse(&query).unwrap()).is_ok()); assert_eq!(qr.role(), Some(Role::Replica)); assert!(qr.query_parser_enabled()); @@ -1164,6 +1190,8 @@ mod test { user: crate::config::User::default(), default_role: Some(Role::Replica), query_parser_enabled: true, + query_parser_max_length: None, + query_parser_read_write_splitting: true, primary_reads_enabled: false, sharding_function: ShardingFunction::PgBigintHash, automatic_sharding_key: Some(String::from("test.id")), @@ -1208,18 +1236,18 @@ mod test { let mut qr = QueryRouter::new(); assert!(qr - .infer(&QueryRouter::parse(&simple_query("BEGIN; SELECT 1; COMMIT;")).unwrap()) + .infer(&qr.parse(&simple_query("BEGIN; SELECT 1; COMMIT;")).unwrap()) .is_ok()); assert_eq!(qr.role(), Role::Primary); assert!(qr - .infer(&QueryRouter::parse(&simple_query("SELECT 1; SELECT 2;")).unwrap()) + .infer(&qr.parse(&simple_query("SELECT 1; SELECT 2;")).unwrap()) .is_ok()); assert_eq!(qr.role(), Role::Replica); assert!(qr .infer( - &QueryRouter::parse(&simple_query( + &qr.parse(&simple_query( "SELECT 123; INSERT INTO t VALUES (5); SELECT 1;" )) .unwrap() @@ -1239,6 +1267,8 @@ mod test { user: crate::config::User::default(), default_role: Some(Role::Replica), query_parser_enabled: true, + query_parser_max_length: None, + query_parser_read_write_splitting: true, primary_reads_enabled: false, sharding_function: ShardingFunction::PgBigintHash, automatic_sharding_key: None, @@ -1284,15 +1314,19 @@ mod test { let mut qr = QueryRouter::new(); qr.pool_settings.automatic_sharding_key = Some("data.id".to_string()); qr.pool_settings.shards = 3; + qr.pool_settings.query_parser_read_write_splitting = true; assert!(qr - .infer(&QueryRouter::parse(&simple_query("SELECT * FROM data WHERE id = 5")).unwrap()) + .infer( + &qr.parse(&simple_query("SELECT * FROM data WHERE id = 5")) + .unwrap(), + ) .is_ok()); assert_eq!(qr.shard(), 2); assert!(qr .infer( - &QueryRouter::parse(&simple_query( + &qr.parse(&simple_query( "SELECT one, two, three FROM public.data WHERE id = 6" )) .unwrap() @@ -1302,7 +1336,7 @@ mod test { assert!(qr .infer( - &QueryRouter::parse(&simple_query( + &qr.parse(&simple_query( "SELECT * FROM data INNER JOIN t2 ON data.id = 5 AND t2.data_id = data.id @@ -1317,7 +1351,7 @@ mod test { // in the query. assert!(qr .infer( - &QueryRouter::parse(&simple_query( + &qr.parse(&simple_query( "SELECT * FROM t2 INNER JOIN data ON id = 6 AND data.id = t2.data_id" )) .unwrap() @@ -1327,7 +1361,7 @@ mod test { assert!(qr .infer( - &QueryRouter::parse(&simple_query( + &qr.parse(&simple_query( r#"SELECT * FROM "public"."data" WHERE "id" = 6"# )) .unwrap() @@ -1337,7 +1371,7 @@ mod test { assert!(qr .infer( - &QueryRouter::parse(&simple_query( + &qr.parse(&simple_query( r#"SELECT * FROM "public"."data" WHERE "data"."id" = 5"# )) .unwrap() @@ -1349,7 +1383,7 @@ mod test { qr.pool_settings.automatic_sharding_key = Some("*.unique_enough_column_name".to_string()); assert!(qr .infer( - &QueryRouter::parse(&simple_query( + &qr.parse(&simple_query( "SELECT * FROM table_x WHERE unique_enough_column_name = 6" )) .unwrap() @@ -1359,7 +1393,7 @@ mod test { assert!(qr .infer( - &QueryRouter::parse(&simple_query("SELECT * FROM table_y WHERE another_key = 5")) + &qr.parse(&simple_query("SELECT * FROM table_y WHERE another_key = 5")) .unwrap() ) .is_ok()); @@ -1385,10 +1419,9 @@ mod test { let mut qr = QueryRouter::new(); qr.pool_settings.automatic_sharding_key = Some("data.id".to_string()); qr.pool_settings.shards = 3; + qr.pool_settings.query_parser_read_write_splitting = true; - assert!(qr - .infer(&QueryRouter::parse(&simple_query(stmt)).unwrap()) - .is_ok()); + assert!(qr.infer(&qr.parse(&simple_query(stmt)).unwrap()).is_ok()); assert_eq!(qr.placeholders.len(), 1); assert!(qr.infer_shard_from_bind(&bind)); @@ -1419,7 +1452,7 @@ mod test { qr.update_pool_settings(pool_settings); let query = simple_query("SELECT * FROM pg_database"); - let ast = QueryRouter::parse(&query).unwrap(); + let ast = qr.parse(&query).unwrap(); let res = qr.execute_plugins(&ast).await; @@ -1437,7 +1470,7 @@ mod test { let qr = QueryRouter::new(); let query = simple_query("SELECT * FROM pg_database"); - let ast = QueryRouter::parse(&query).unwrap(); + let ast = qr.parse(&query).unwrap(); let res = qr.execute_plugins(&ast).await; diff --git a/tests/ruby/helpers/pgcat_helper.rb b/tests/ruby/helpers/pgcat_helper.rb index 7a5bd71d..9b764d87 100644 --- a/tests/ruby/helpers/pgcat_helper.rb +++ b/tests/ruby/helpers/pgcat_helper.rb @@ -34,6 +34,7 @@ def self.three_shard_setup(pool_name, pool_size, pool_mode="transaction", lb_mod "load_balancing_mode" => lb_mode, "primary_reads_enabled" => true, "query_parser_enabled" => true, + "query_parser_read_write_splitting" => true, "automatic_sharding_key" => "data.id", "sharding_function" => "pg_bigint_hash", "shards" => {