diff --git a/src/query_router.rs b/src/query_router.rs index 9d7a106a..189f2dcc 100644 --- a/src/query_router.rs +++ b/src/query_router.rs @@ -4,10 +4,10 @@ use bytes::{Buf, BytesMut}; use log::{debug, error}; use once_cell::sync::OnceCell; use regex::{Regex, RegexSet}; -use sqlparser::ast::Statement::{Query, StartTransaction}; +use sqlparser::ast::Statement::{Delete, Insert, Query, StartTransaction, Update}; use sqlparser::ast::{ - BinaryOperator, Expr, Ident, JoinConstraint, JoinOperator, SetExpr, Statement, TableFactor, - Value, + Assignment, BinaryOperator, Expr, Ident, JoinConstraint, JoinOperator, SetExpr, Statement, + TableFactor, TableWithJoins, Value, }; use sqlparser::dialect::PostgreSqlDialect; use sqlparser::parser::Parser; @@ -403,6 +403,9 @@ impl QueryRouter { return Err(Error::QueryRouterParserError("empty query".into())); } + let mut visited_write_statement = false; + let mut prev_inferred_shard = None; + for q in ast { match q { // All transactions go to the primary, probably a write. @@ -420,29 +423,38 @@ impl QueryRouter { // or discard shard selection. If they point to the same shard though, // we can let them through as-is. // This is basically building a database now :) - match self.infer_shard(query) { - Some(shard) => { - self.active_shard = Some(shard); - debug!("Automatically using shard: {:?}", self.active_shard); - } - - None => (), - }; + let inferred_shard = self.infer_shard(query); + self.handle_inferred_shard(inferred_shard, &mut prev_inferred_shard)?; } None => (), }; - self.active_role = match self.primary_reads_enabled() { - false => Some(Role::Replica), // If primary should not be receiving reads, use a replica. - true => None, // Any server role is fine in this case. + // If we already visited a write statement, we should be going to the primary. + if !visited_write_statement { + self.active_role = match self.primary_reads_enabled() { + false => Some(Role::Replica), // If primary should not be receiving reads, use a replica. + true => None, // Any server role is fine in this case. + } } } // Likely a write _ => { + match &self.pool_settings.automatic_sharding_key { + Some(_) => { + // TODO: similar to the above, if we have multiple queries in the + // same message, we can either split them and execute them individually + // or discard shard selection. If they point to the same shard though, + // we can let them through as-is. + let inferred_shard = self.infer_shard_on_write(q)?; + self.handle_inferred_shard(inferred_shard, &mut prev_inferred_shard)?; + } + + None => (), + }; + visited_write_statement = true; self.active_role = Some(Role::Primary); - break; } }; } @@ -450,6 +462,208 @@ impl QueryRouter { Ok(()) } + fn handle_inferred_shard( + &mut self, + inferred_shard: Option, + prev_inferred_shard: &mut Option, + ) -> Result<(), Error> { + match inferred_shard { + Some(shard) => { + if let Some(prev_shard) = *prev_inferred_shard { + if prev_shard != shard { + debug!("Found more than one shard in the query, not supported yet"); + return Err(Error::QueryRouterParserError( + "multiple shards in query".into(), + )); + } + } + *prev_inferred_shard = Some(shard); + self.active_shard = Some(shard); + debug!("Automatically using shard: {:?}", self.active_shard); + } + + None => (), + }; + Ok(()) + } + + fn infer_shard_on_write(&mut self, q: &Statement) -> Result, Error> { + let mut exprs = Vec::new(); + + // Collect all table names from the query. + let mut table_names = Vec::new(); + + match q { + Insert { + or, + into: _, + table_name, + columns, + overwrite: _, + source, + partitioned, + after_columns, + table: _, + on: _, + returning: _, + } => { + // Not supported in postgres. + assert!(or.is_none()); + assert!(partitioned.is_none()); + assert!(after_columns.is_empty()); + + Self::process_table(table_name, &mut table_names); + Self::process_query(&*source, &mut exprs, &mut table_names, &Some(columns)); + } + Delete { + tables, + from, + using, + selection, + returning: _, + } => { + if let Some(expr) = selection { + exprs.push(expr.clone()); + } + + // Multi tables delete are not supported in postgres. + assert!(tables.is_empty()); + + Self::process_tables_with_join(&from, &mut exprs, &mut table_names); + if let Some(using_tbl_with_join) = using { + Self::process_tables_with_join( + using_tbl_with_join, + &mut exprs, + &mut table_names, + ); + } + Self::process_selection(selection, &mut exprs); + } + Update { + table, + assignments, + from, + selection, + returning: _, + } => { + Self::process_table_with_join(table, &mut exprs, &mut table_names); + if let Some(from_tbl) = from { + Self::process_table_with_join(from_tbl, &mut exprs, &mut table_names); + } + Self::process_selection(selection, &mut exprs); + self.assignment_parser(assignments)?; + } + _ => { + return Ok(None); + } + }; + + Ok(self.infer_shard_from_exprs(exprs, table_names)) + } + + fn process_query( + query: &sqlparser::ast::Query, + exprs: &mut Vec, + table_names: &mut Vec>, + columns: &Option<&Vec>, + ) { + match &*query.body { + SetExpr::Query(query) => { + Self::process_query(&*query, exprs, table_names, columns); + } + + // SELECT * FROM ... + // We understand that pretty well. + SetExpr::Select(select) => { + Self::process_tables_with_join(&select.from, exprs, table_names); + + // Parse the actual "FROM ..." + Self::process_selection(&select.selection, exprs); + } + + SetExpr::Values(values) => { + if let Some(cols) = columns { + for row in values.rows.iter() { + for (i, expr) in row.iter().enumerate() { + if cols.len() > i { + exprs.push(Expr::BinaryOp { + left: Box::new(Expr::Identifier(cols[i].clone())), + op: BinaryOperator::Eq, + right: Box::new(expr.clone()), + }); + } + } + } + } + } + _ => (), + }; + } + + fn process_selection(selection: &Option, exprs: &mut Vec) { + match selection { + Some(selection) => { + exprs.push(selection.clone()); + } + + None => (), + }; + } + + fn process_tables_with_join( + tables: &Vec, + exprs: &mut Vec, + table_names: &mut Vec>, + ) { + for table in tables.iter() { + Self::process_table_with_join(table, exprs, table_names); + } + } + + fn process_table_with_join( + table: &TableWithJoins, + exprs: &mut Vec, + table_names: &mut Vec>, + ) { + match &table.relation { + TableFactor::Table { name, .. } => { + Self::process_table(name, table_names); + } + + _ => (), + }; + + // Get table names from all the joins. + for join in table.joins.iter() { + match &join.relation { + TableFactor::Table { name, .. } => { + Self::process_table(name, table_names); + } + + _ => (), + }; + + // We can filter results based on join conditions, e.g. + // SELECT * FROM t INNER JOIN B ON B.sharding_key = 5; + match &join.join_operator { + JoinOperator::Inner(inner_join) => match &inner_join { + JoinConstraint::On(expr) => { + // Parse the selection criteria later. + exprs.push(expr.clone()); + } + + _ => (), + }, + + _ => (), + }; + } + } + + fn process_table(name: &sqlparser::ast::ObjectName, table_names: &mut Vec>) { + table_names.push(name.0.clone()) + } + /// Parse the shard number from the Bind message /// which contains the arguments for a prepared statement. /// @@ -592,6 +806,33 @@ impl QueryRouter { } } + /// An `assignments` exists in the `UPDATE` statements. This parses the assignments and makes + /// sure that we are not updating the sharding key. It's not supported yet. + fn assignment_parser(&self, assignments: &Vec) -> Result<(), Error> { + let sharding_key = self + .pool_settings + .automatic_sharding_key + .as_ref() + .unwrap() + .split(".") + .map(|ident| Ident::new(ident.to_lowercase())) + .collect::>(); + + // Sharding key must be always fully qualified + assert_eq!(sharding_key.len(), 2); + + for a in assignments { + if sharding_key[0].value == "*" { + if sharding_key[1].value == a.id.last().unwrap().value.to_lowercase() { + return Err(Error::QueryRouterParserError( + "Sharding key cannot be updated.".into(), + )); + } + } + } + Ok(()) + } + /// A `selection` is the `WHERE` clause. This parses /// the clause and extracts the sharding key, if present. fn selection_parser(&self, expr: &Expr, table_names: &Vec>) -> Vec { @@ -604,7 +845,7 @@ impl QueryRouter { .as_ref() .unwrap() .split(".") - .map(|ident| Ident::new(ident)) + .map(|ident| Ident::new(ident.to_lowercase())) .collect::>(); // Sharding key must be always fully qualified @@ -620,7 +861,7 @@ impl QueryRouter { Expr::Identifier(ident) => { // Only if we're dealing with only one table // and there is no ambiguity - if &ident.value == &sharding_key[1].value { + if &ident.value.to_lowercase() == &sharding_key[1].value { // Sharding key is unique enough, don't worry about // table names. if &sharding_key[0].value == "*" { @@ -633,13 +874,13 @@ impl QueryRouter { // SELECT * FROM t WHERE sharding_key = 5 // Make sure the table name from the sharding key matches // the table name from the query. - found = &sharding_key[0].value == &table[0].value; + found = &sharding_key[0].value == &table[0].value.to_lowercase(); } else if table.len() == 2 { // Table name is fully qualified with the schema: e.g. // SELECT * FROM public.t WHERE sharding_key = 5 // Ignore the schema (TODO: at some point, we want schema support) // and use the table name only. - found = &sharding_key[0].value == &table[1].value; + found = &sharding_key[0].value == &table[1].value.to_lowercase(); } else { debug!("Got table name with more than two idents, which is not possible"); } @@ -651,8 +892,9 @@ impl QueryRouter { // The key is fully qualified in the query, // it will exist or Postgres will throw an error. if idents.len() == 2 { - found = &sharding_key[0].value == &idents[0].value - && &sharding_key[1].value == &idents[1].value; + found = (&sharding_key[0].value == "*" + || &sharding_key[0].value == &idents[0].value.to_lowercase()) + && &sharding_key[1].value == &idents[1].value.to_lowercase(); } // TODO: key can have schema as well, e.g. public.data.id (len == 3) } @@ -705,100 +947,48 @@ impl QueryRouter { /// Try to figure out which shard the query should go to. fn infer_shard(&mut self, query: &sqlparser::ast::Query) -> Option { - let mut shards = BTreeSet::new(); let mut exprs = Vec::new(); - match &*query.body { - SetExpr::Query(query) => { - match self.infer_shard(&*query) { - Some(shard) => { - shards.insert(shard); - } - None => (), - }; - } + // Collect all table names from the query. + let mut table_names = Vec::new(); - // SELECT * FROM ... - // We understand that pretty well. - SetExpr::Select(select) => { - // Collect all table names from the query. - let mut table_names = Vec::new(); - - for table in select.from.iter() { - match &table.relation { - TableFactor::Table { name, .. } => { - table_names.push(name.0.clone()); - } - - _ => (), - }; + Self::process_query(query, &mut exprs, &mut table_names, &None); + self.infer_shard_from_exprs(exprs, table_names) + } - // Get table names from all the joins. - for join in table.joins.iter() { - match &join.relation { - TableFactor::Table { name, .. } => { - table_names.push(name.0.clone()); - } + fn infer_shard_from_exprs( + &mut self, + exprs: Vec, + table_names: Vec>, + ) -> Option { + let mut shards = BTreeSet::new(); - _ => (), - }; + let sharder = Sharder::new( + self.pool_settings.shards, + self.pool_settings.sharding_function, + ); - // We can filter results based on join conditions, e.g. - // SELECT * FROM t INNER JOIN B ON B.sharding_key = 5; - match &join.join_operator { - JoinOperator::Inner(inner_join) => match &inner_join { - JoinConstraint::On(expr) => { - // Parse the selection criteria later. - exprs.push(expr.clone()); - } + // Look for sharding keys in either the join condition + // or the selection. + for expr in exprs.iter() { + let sharding_keys = self.selection_parser(expr, &table_names); - _ => (), - }, + // TODO: Add support for prepared statements here. + // This should just give us the position of the value in the `B` message. - _ => (), - }; + for value in sharding_keys { + match value { + ShardingKey::Value(value) => { + let shard = sharder.shard(value); + shards.insert(shard); } - } - // Parse the actual "FROM ..." - match &select.selection { - Some(selection) => { - exprs.push(selection.clone()); + ShardingKey::Placeholder(position) => { + self.placeholders.push(position); } - - None => (), }; - - let sharder = Sharder::new( - self.pool_settings.shards, - self.pool_settings.sharding_function, - ); - - // Look for sharding keys in either the join condition - // or the selection. - for expr in exprs.iter() { - let sharding_keys = self.selection_parser(expr, &table_names); - - // TODO: Add support for prepared statements here. - // This should just give us the position of the value in the `B` message. - - for value in sharding_keys { - match value { - ShardingKey::Value(value) => { - let shard = sharder.shard(value); - shards.insert(shard); - } - - ShardingKey::Placeholder(position) => { - self.placeholders.push(position); - } - }; - } - } } - _ => (), - }; - + } match shards.len() { // Didn't find a sharding key, you're on your own. 0 => { @@ -1414,6 +1604,221 @@ mod test { assert_eq!(qr.shard().unwrap(), 0); } + fn auto_shard_wrapper(qry: &str, should_succeed: bool) -> Option { + let mut qr = QueryRouter::new(); + qr.pool_settings.automatic_sharding_key = Some("*.w_id".to_string()); + qr.pool_settings.shards = 3; + qr.pool_settings.query_parser_read_write_splitting = true; + assert_eq!(qr.shard(), None); + let infer_res = qr.infer(&qr.parse(&simple_query(qry)).unwrap()); + assert_eq!(infer_res.is_ok(), should_succeed); + qr.shard() + } + + fn auto_shard(qry: &str) -> Option { + auto_shard_wrapper(qry, true) + } + + fn auto_shard_fails(qry: &str) -> Option { + auto_shard_wrapper(qry, false) + } + + #[test] + fn test_automatic_sharding_insert_update_delete() { + QueryRouter::setup(); + + assert_eq!( + auto_shard_fails( + "UPDATE ORDERS SET w_id = 3 WHERE O_ID = 3 AND O_D_ID = 3 AND W_ID = 5" + ), + None + ); + + assert_eq!( + auto_shard_fails( + "UPDATE ORDERS o SET o.W_ID = 3 WHERE o.O_ID = 3 AND o.O_D_ID = 3 AND o.W_ID = 5" + ), + None + ); + + assert_eq!( + auto_shard( + "UPDATE ORDERS o SET o.O_CARRIER_ID = 3 WHERE o.O_ID = 3 AND o.O_D_ID = 3 AND o.W_ID = 5" + ), + Some(2) + ); + } + + #[test] + fn test_automatic_sharding_key_tpcc() { + QueryRouter::setup(); + + assert_eq!(auto_shard("SELECT * FROM my_tbl WHERE w_id = 5"), Some(2)); + assert_eq!( + auto_shard("BEGIN TRANSACTION ISOLATION LEVEL REPEATABLE READ"), + None + ); + assert_eq!(auto_shard("COMMIT"), None); + assert_eq!(auto_shard("ROLLBACK"), None); + + assert_eq!(auto_shard("SELECT NO_O_ID FROM NEW_ORDER WHERE NO_D_ID = 7 AND W_ID = 5 AND NO_O_ID > 3 LIMIT 3"), Some(2)); + assert_eq!(auto_shard("SELECT NO_O_ID FROM NEW_ORDER no WHERE no.NO_D_ID = 7 AND no.W_ID = 5 AND no.NO_O_ID > 3 LIMIT 3"), Some(2)); + + assert_eq!( + auto_shard("DELETE FROM NEW_ORDER WHERE NO_D_ID = 7 AND W_ID = 5 AND NO_O_ID = 3"), + Some(2) + ); + + assert_eq!( + auto_shard("SELECT O_C_ID FROM ORDERS WHERE O_ID = 3 AND O_D_ID = 3 AND W_ID = 5"), + Some(2) + ); + assert_eq!( + auto_shard( + "UPDATE ORDERS SET O_CARRIER_ID = 3 WHERE O_ID = 3 AND O_D_ID = 3 AND W_ID = 5" + ), + Some(2) + ); + + assert_eq!( + auto_shard("UPDATE ORDER_LINE SET OL_DELIVERY_D = 3 WHERE OL_O_ID = 3 AND OL_D_ID = 3 AND W_ID = 5"), + Some(2) + ); + + assert_eq!( + auto_shard("SELECT SUM(OL_AMOUNT) FROM ORDER_LINE WHERE OL_O_ID = 3 AND OL_D_ID = 3 AND W_ID = 5"), + Some(2) + ); + + assert_eq!( + auto_shard("UPDATE CUSTOMER SET C_BALANCE = C_BALANCE + 3 WHERE C_ID = 3 AND C_D_ID = 3 AND W_ID = 5"), + Some(2) + ); + + assert_eq!( + auto_shard("SELECT W_TAX FROM WAREHOUSE WHERE W_ID = 5"), + Some(2) + ); + assert_eq!( + auto_shard("SELECT D_TAX, D_NEXT_O_ID FROM DISTRICT WHERE D_ID = 3 AND W_ID = 5"), + Some(2) + ); + assert_eq!( + auto_shard("UPDATE DISTRICT SET D_NEXT_O_ID = 3 WHERE D_ID = 3 AND W_ID = 5"), + Some(2) + ); + assert_eq!( + auto_shard("SELECT C_DISCOUNT, C_LAST, C_CREDIT FROM CUSTOMER WHERE W_ID = 5 AND C_D_ID = 3 AND C_ID = 3"), + Some(2) + ); + assert_eq!( + auto_shard("INSERT INTO ORDERS (O_ID, O_D_ID, W_ID, O_C_ID, O_ENTRY_D, O_CARRIER_ID, O_OL_CNT, O_ALL_LOCAL) VALUES (3, 3, 5, 3, 3, 3, 3, 3)"), + Some(2) + ); + assert_eq!( + auto_shard("INSERT INTO NEW_ORDER (NO_O_ID, NO_D_ID, W_ID) VALUES (3, 3, 5)"), + Some(2) + ); + assert_eq!( + auto_shard("SELECT I_PRICE, I_NAME, I_DATA FROM ITEM WHERE I_ID = 3"), + None + ); + assert_eq!( + auto_shard("SELECT S_QUANTITY, S_DATA, S_YTD, S_ORDER_CNT, S_REMOTE_CNT, S_DIST_03 FROM STOCK WHERE S_I_ID = 3 AND W_ID = 5"), + Some(2) + ); + assert_eq!( + auto_shard("UPDATE STOCK SET S_QUANTITY = 3, S_YTD = 3, S_ORDER_CNT = 3, S_REMOTE_CNT = 3 WHERE S_I_ID = 3 AND W_ID = 5"), + Some(2) + ); + assert_eq!( + auto_shard("INSERT INTO ORDER_LINE (OL_O_ID, OL_D_ID, W_ID, OL_NUMBER, OL_I_ID, OL_SUPPLY_W_ID, OL_DELIVERY_D, OL_QUANTITY, OL_AMOUNT, OL_DIST_INFO) VALUES (3, 3, 5, 3, 3, 3, 3, 3, 3, 3)"), + Some(2) + ); + assert_eq!( + auto_shard("SELECT C_ID, C_FIRST, C_MIDDLE, C_LAST, C_BALANCE FROM CUSTOMER WHERE W_ID = 5 AND C_D_ID = 3 AND C_ID = 3"), + Some(2) + ); + assert_eq!( + auto_shard("SELECT C_ID, C_FIRST, C_MIDDLE, C_LAST, C_BALANCE FROM CUSTOMER WHERE W_ID = 5 AND C_D_ID = 3 AND C_LAST = '3' ORDER BY C_FIRST"), + Some(2) + ); + assert_eq!( + auto_shard("SELECT O_ID, O_CARRIER_ID, O_ENTRY_D FROM ORDERS WHERE W_ID = 5 AND O_D_ID = 3 AND O_C_ID = 3 ORDER BY O_ID DESC LIMIT 3"), + Some(2) + ); + assert_eq!( + auto_shard("SELECT OL_SUPPLY_W_ID, OL_I_ID, OL_QUANTITY, OL_AMOUNT, OL_DELIVERY_D FROM ORDER_LINE WHERE W_ID = 5 AND OL_D_ID = 3 AND OL_O_ID = 3"), + Some(2) + ); + assert_eq!( + auto_shard("SELECT W_NAME, W_STREET_1, W_STREET_2, W_CITY, W_STATE, W_ZIP FROM WAREHOUSE WHERE W_ID = 5"), + Some(2) + ); + assert_eq!( + auto_shard("UPDATE WAREHOUSE SET W_YTD = W_YTD + 3 WHERE W_ID = 5"), + Some(2) + ); + assert_eq!( + auto_shard("SELECT D_NAME, D_STREET_1, D_STREET_2, D_CITY, D_STATE, D_ZIP FROM DISTRICT WHERE W_ID = 5 AND D_ID = 3"), + Some(2) + ); + assert_eq!( + auto_shard("UPDATE DISTRICT SET D_YTD = D_YTD + 3 WHERE W_ID = 5 AND D_ID = 3"), + Some(2) + ); + assert_eq!( + auto_shard("SELECT C_ID, C_FIRST, C_MIDDLE, C_LAST, C_STREET_1, C_STREET_2, C_CITY, C_STATE, C_ZIP, C_PHONE, C_SINCE, C_CREDIT, C_CREDIT_LIM, C_DISCOUNT, C_BALANCE, C_YTD_PAYMENT, C_PAYMENT_CNT, C_DATA FROM CUSTOMER WHERE W_ID = 5 AND C_D_ID = 3 AND C_ID = 3"), + Some(2) + ); + assert_eq!( + auto_shard("SELECT C_ID, C_FIRST, C_MIDDLE, C_LAST, C_STREET_1, C_STREET_2, C_CITY, C_STATE, C_ZIP, C_PHONE, C_SINCE, C_CREDIT, C_CREDIT_LIM, C_DISCOUNT, C_BALANCE, C_YTD_PAYMENT, C_PAYMENT_CNT, C_DATA FROM CUSTOMER WHERE W_ID = 5 AND C_D_ID = 3 AND C_LAST = '3' ORDER BY C_FIRST"), + Some(2) + ); + assert_eq!( + auto_shard("UPDATE CUSTOMER SET C_BALANCE = 3, C_YTD_PAYMENT = 3, C_PAYMENT_CNT = 3, C_DATA = 3 WHERE W_ID = 5 AND C_D_ID = 3 AND C_ID = 3"), + Some(2) + ); + assert_eq!( + auto_shard("UPDATE CUSTOMER SET C_BALANCE = 3, C_YTD_PAYMENT = 3, C_PAYMENT_CNT = 3 WHERE W_ID = 5 AND C_D_ID = 3 AND C_ID = 3"), + Some(2) + ); + + assert_eq!(auto_shard("INSERT INTO HISTORY (H_C_ID, H_C_D_ID, H_C_W_ID, H_D_ID, W_ID, H_DATE, H_AMOUNT, H_DATA) VALUES (3, 3, 5, 3, 5, 3, 3, 3)"), Some(2)); + assert_eq!( + auto_shard("SELECT D_NEXT_O_ID FROM DISTRICT WHERE W_ID = 5 AND D_ID = 3"), + Some(2) + ); + assert_eq!( + auto_shard( + "SELECT COUNT(DISTINCT(OL_I_ID)) FROM ORDER_LINE, STOCK + WHERE ORDER_LINE.W_ID = 5 + AND OL_D_ID = 3 + AND OL_O_ID < 3 + AND OL_O_ID >= 3 + AND STOCK.W_ID = 5 + AND S_I_ID = OL_I_ID + AND S_QUANTITY < 3" + ), + Some(2) + ); + + // This is a distributed query and contains two shards + assert_eq!( + auto_shard( + "SELECT COUNT(DISTINCT(OL_I_ID)) FROM ORDER_LINE, STOCK + WHERE ORDER_LINE.W_ID = 5 + AND OL_D_ID = 3 + AND OL_O_ID < 3 + AND OL_O_ID >= 3 + AND STOCK.W_ID = 7 + AND S_I_ID = OL_I_ID + AND S_QUANTITY < 3" + ), + None + ); + } + #[test] fn test_prepared_statements() { let stmt = "SELECT * FROM data WHERE id = $1";