Skip to content

Commit e161f37

Browse files
committed
Plugins!!
1 parent 7d2b695 commit e161f37

12 files changed

+232
-43
lines changed

Cargo.lock

+21
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

+3-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ rand = "0.8"
1414
chrono = "0.4"
1515
sha-1 = "0.10"
1616
toml = "0.7"
17-
serde = "1"
17+
serde = { version = "1", features = ["derive"] }
1818
serde_derive = "1"
1919
regex = "1"
2020
num_cpus = "1"
@@ -42,6 +42,8 @@ fallible-iterator = "0.2"
4242
pin-project = "1"
4343
webpki-roots = "0.23"
4444
rustls = { version = "0.21", features = ["dangerous_configuration"] }
45+
serde_json = "1"
46+
# serde = "*"
4547

4648
[target.'cfg(not(target_env = "msvc"))'.dependencies]
4749
jemallocator = "0.5.0"

pgcat.toml

+3
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,9 @@ admin_username = "admin_user"
7777
# Password to access the virtual administrative database
7878
admin_password = "admin_pass"
7979

80+
# Plugins!!
81+
# plugins = ["pg_table_access", "intercept"]
82+
8083
# pool configs are structured as pool.<pool_name>
8184
# the pool_name is what clients use as database name when connecting.
8285
# For a pool named `sharded_db`, clients access that pool using connection string like

src/admin.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@ use tokio::time::Instant;
1212
use crate::config::{get_config, reload_config, VERSION};
1313
use crate::errors::Error;
1414
use crate::messages::*;
15+
use crate::pool::ClientServerMap;
1516
use crate::pool::{get_all_pools, get_pool};
1617
use crate::stats::{get_client_stats, get_pool_stats, get_server_stats, ClientState, ServerState};
17-
use crate::ClientServerMap;
1818

1919
pub fn generate_server_info_for_admin() -> BytesMut {
2020
let mut server_info = BytesMut::new();

src/client.rs

+83
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ use crate::auth_passthrough::refetch_auth_hash;
1616
use crate::config::{get_config, get_idle_client_in_transaction_timeout, Address, PoolMode};
1717
use crate::constants::*;
1818
use crate::messages::*;
19+
use crate::plugins::PluginOutput;
1920
use crate::pool::{get_pool, ClientServerMap, ConnectionPool};
2021
use crate::query_router::{Command, QueryRouter};
2122
use crate::server::Server;
@@ -765,6 +766,9 @@ where
765766

766767
self.stats.register(self.stats.clone());
767768

769+
// Error returned by one of the plugins.
770+
let mut plugin_output = None;
771+
768772
// Our custom protocol loop.
769773
// We expect the client to either start a transaction with regular queries
770774
// or issue commands for our sharding and server selection protocol.
@@ -816,6 +820,22 @@ where
816820
'Q' => {
817821
if query_router.query_parser_enabled() {
818822
if let Ok(ast) = QueryRouter::parse(&message) {
823+
let plugin_result = query_router.execute_plugins(&ast).await;
824+
825+
match plugin_result {
826+
Ok(PluginOutput::Deny(error)) => {
827+
error_response(&mut self.write, &error).await?;
828+
continue;
829+
}
830+
831+
Ok(PluginOutput::Intercept(result)) => {
832+
write_all(&mut self.write, result).await?;
833+
continue;
834+
}
835+
836+
_ => (),
837+
};
838+
819839
let _ = query_router.infer(&ast);
820840
}
821841
}
@@ -826,6 +846,10 @@ where
826846

827847
if query_router.query_parser_enabled() {
828848
if let Ok(ast) = QueryRouter::parse(&message) {
849+
if let Ok(output) = query_router.execute_plugins(&ast).await {
850+
plugin_output = Some(output);
851+
}
852+
829853
let _ = query_router.infer(&ast);
830854
}
831855
}
@@ -861,6 +885,18 @@ where
861885
continue;
862886
}
863887

888+
// Check on plugin results.
889+
match plugin_output {
890+
Some(PluginOutput::Deny(error)) => {
891+
self.buffer.clear();
892+
error_response(&mut self.write, &error).await?;
893+
plugin_output = None;
894+
continue;
895+
}
896+
897+
_ => (),
898+
};
899+
864900
// Get a pool instance referenced by the most up-to-date
865901
// pointer. This ensures we always read the latest config
866902
// when starting a query.
@@ -1089,6 +1125,27 @@ where
10891125
match code {
10901126
// Query
10911127
'Q' => {
1128+
if query_router.query_parser_enabled() {
1129+
if let Ok(ast) = QueryRouter::parse(&message) {
1130+
let plugin_result = query_router.execute_plugins(&ast).await;
1131+
1132+
match plugin_result {
1133+
Ok(PluginOutput::Deny(error)) => {
1134+
error_response(&mut self.write, &error).await?;
1135+
continue;
1136+
}
1137+
1138+
Ok(PluginOutput::Intercept(result)) => {
1139+
write_all(&mut self.write, result).await?;
1140+
continue;
1141+
}
1142+
1143+
_ => (),
1144+
};
1145+
1146+
let _ = query_router.infer(&ast);
1147+
}
1148+
}
10921149
debug!("Sending query to server");
10931150

10941151
self.send_and_receive_loop(
@@ -1128,6 +1185,14 @@ where
11281185
// Parse
11291186
// The query with placeholders is here, e.g. `SELECT * FROM users WHERE email = $1 AND active = $2`.
11301187
'P' => {
1188+
if query_router.query_parser_enabled() {
1189+
if let Ok(ast) = QueryRouter::parse(&message) {
1190+
if let Ok(output) = query_router.execute_plugins(&ast).await {
1191+
plugin_output = Some(output);
1192+
}
1193+
}
1194+
}
1195+
11311196
self.buffer.put(&message[..]);
11321197
}
11331198

@@ -1159,6 +1224,24 @@ where
11591224
'S' => {
11601225
debug!("Sending query to server");
11611226

1227+
match plugin_output {
1228+
Some(PluginOutput::Deny(error)) => {
1229+
error_response(&mut self.write, &error).await?;
1230+
plugin_output = None;
1231+
self.buffer.clear();
1232+
continue;
1233+
}
1234+
1235+
Some(PluginOutput::Intercept(result)) => {
1236+
write_all(&mut self.write, result).await?;
1237+
plugin_output = None;
1238+
self.buffer.clear();
1239+
continue;
1240+
}
1241+
1242+
_ => (),
1243+
};
1244+
11621245
self.buffer.put(&message[..]);
11631246

11641247
let first_message_code = (*self.buffer.get(0).unwrap_or(&0)) as char;

src/config.rs

+3
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,8 @@ pub struct General {
295295
pub auth_query: Option<String>,
296296
pub auth_query_user: Option<String>,
297297
pub auth_query_password: Option<String>,
298+
299+
pub query_router_plugins: Option<Vec<String>>,
298300
}
299301

300302
impl General {
@@ -389,6 +391,7 @@ impl Default for General {
389391
auth_query_user: None,
390392
auth_query_password: None,
391393
server_lifetime: 1000 * 3600 * 24, // 24 hours,
394+
query_router_plugins: None,
392395
}
393396
}
394397
}

src/errors.rs

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
//! Errors.
22
33
/// Various errors.
4-
#[derive(Debug, PartialEq)]
4+
#[derive(Debug, PartialEq, Clone)]
55
pub enum Error {
66
SocketError(String),
77
ClientSocketError(String, ClientIdentifier),
@@ -25,7 +25,9 @@ pub enum Error {
2525
AuthPassthroughError(String),
2626
UnsupportedStatement,
2727
QueryRouterParserError(String),
28+
PermissionDenied(String),
2829
PermissionDeniedTable(String),
30+
QueryDenied(String),
2931
}
3032

3133
#[derive(Clone, PartialEq, Debug)]

src/lib.rs

+5
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,16 @@
1+
pub mod admin;
12
pub mod auth_passthrough;
3+
pub mod client;
24
pub mod config;
35
pub mod constants;
46
pub mod errors;
57
pub mod messages;
68
pub mod mirrors;
79
pub mod multi_logger;
10+
pub mod plugins;
811
pub mod pool;
12+
pub mod prometheus;
13+
pub mod query_router;
914
pub mod scram;
1015
pub mod server;
1116
pub mod sharding;

src/main.rs

+10-28
Original file line numberDiff line numberDiff line change
@@ -60,36 +60,18 @@ use std::str::FromStr;
6060
use std::sync::Arc;
6161
use tokio::sync::broadcast;
6262

63-
mod admin;
64-
mod auth_passthrough;
65-
mod client;
66-
mod config;
67-
mod constants;
68-
mod errors;
69-
mod messages;
70-
mod mirrors;
71-
mod multi_logger;
72-
mod pool;
73-
mod prometheus;
74-
mod query_router;
75-
mod scram;
76-
mod server;
77-
mod sharding;
78-
mod stats;
79-
mod tls;
80-
81-
use crate::config::{get_config, reload_config, VERSION};
82-
use crate::messages::configure_socket;
83-
use crate::pool::{ClientServerMap, ConnectionPool};
84-
use crate::prometheus::start_metric_server;
85-
use crate::stats::{Collector, Reporter, REPORTER};
63+
use pgcat::config::{get_config, reload_config, VERSION};
64+
use pgcat::messages::configure_socket;
65+
use pgcat::pool::{ClientServerMap, ConnectionPool};
66+
use pgcat::prometheus::start_metric_server;
67+
use pgcat::stats::{Collector, Reporter, REPORTER};
8668

8769
fn main() -> Result<(), Box<dyn std::error::Error>> {
88-
multi_logger::MultiLogger::init().unwrap();
70+
pgcat::multi_logger::MultiLogger::init().unwrap();
8971

9072
info!("Welcome to PgCat! Meow. (Version {})", VERSION);
9173

92-
if !query_router::QueryRouter::setup() {
74+
if !pgcat::query_router::QueryRouter::setup() {
9375
error!("Could not setup query router");
9476
std::process::exit(exitcode::CONFIG);
9577
}
@@ -107,7 +89,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
10789
let runtime = Builder::new_multi_thread().worker_threads(1).build()?;
10890

10991
runtime.block_on(async {
110-
match config::parse(&config_file).await {
92+
match pgcat::config::parse(&config_file).await {
11193
Ok(_) => (),
11294
Err(err) => {
11395
error!("Config parse error: {:?}", err);
@@ -295,7 +277,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
295277
tokio::task::spawn(async move {
296278
let start = chrono::offset::Utc::now().naive_utc();
297279

298-
match client::client_entrypoint(
280+
match pgcat::client::client_entrypoint(
299281
socket,
300282
client_server_map,
301283
shutdown_rx,
@@ -326,7 +308,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
326308

327309
Err(err) => {
328310
match err {
329-
errors::Error::ClientBadStartup => debug!("Client disconnected with error {:?}", err),
311+
pgcat::errors::Error::ClientBadStartup => debug!("Client disconnected with error {:?}", err),
330312
_ => warn!("Client disconnected with error {:?}", err),
331313
}
332314

0 commit comments

Comments
 (0)