diff --git a/Cargo.lock b/Cargo.lock index 5eb570f..6cd635a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1467,6 +1467,7 @@ dependencies = [ "chrono", "datafusion", "futures", + "getset", "log", "pgwire", "postgres-types", @@ -1775,6 +1776,18 @@ dependencies = [ "wasi 0.14.2+wasi-0.2.4", ] +[[package]] +name = "getset" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f3586f256131df87204eb733da72e3d3eb4f343c639f4b7be279ac7c48baeafe" +dependencies = [ + "proc-macro-error2", + "proc-macro2", + "quote", + "syn 2.0.101", +] + [[package]] name = "gimli" version = "0.31.1" @@ -2722,6 +2735,28 @@ dependencies = [ "version_check", ] +[[package]] +name = "proc-macro-error-attr2" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "96de42df36bb9bba5542fe9f1a054b8cc87e172759a1868aa05c1f3acc89dfc5" +dependencies = [ + "proc-macro2", + "quote", +] + +[[package]] +name = "proc-macro-error2" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "11ec05c52be0a07b08061f7dd003e7d7092e0472bc731b4af7bb1ef876109802" +dependencies = [ + "proc-macro-error-attr2", + "proc-macro2", + "quote", + "syn 2.0.101", +] + [[package]] name = "proc-macro2" version = "1.0.95" diff --git a/datafusion-postgres-cli/src/main.rs b/datafusion-postgres-cli/src/main.rs index acdee6d..a9bfe4a 100644 --- a/datafusion-postgres-cli/src/main.rs +++ b/datafusion-postgres-cli/src/main.rs @@ -1,13 +1,9 @@ -use std::sync::Arc; - use datafusion::execution::options::{ ArrowReadOptions, AvroReadOptions, CsvReadOptions, NdJsonReadOptions, ParquetReadOptions, }; use datafusion::prelude::SessionContext; -use datafusion_postgres::{DfSessionService, HandlerFactory}; // Assuming the crate name is `datafusion_postgres` -use pgwire::tokio::process_socket; +use datafusion_postgres::{serve, ServerOptions}; // Assuming the crate name is `datafusion_postgres` use structopt::StructOpt; -use tokio::net::TcpListener; #[derive(Debug, StructOpt)] #[structopt( @@ -103,33 +99,13 @@ async fn main() -> Result<(), Box> { println!("Loaded {} as table {}", table_path, table_name); } - // Get the first catalog name from the session context - let catalog_name = session_context - .catalog_names() // Fixed: Removed .catalog_list() - .first() - .cloned(); - - // Create the handler factory with the session context and catalog name - let factory = Arc::new(HandlerFactory(Arc::new(DfSessionService::new( - session_context, - catalog_name, - )))); + let server_options = ServerOptions::new() + .with_host(opts.host) + .with_port(opts.port); - // Bind to the specified host and port - let server_addr = format!("{}:{}", opts.host, opts.port); - let listener = TcpListener::bind(&server_addr).await?; - println!("Listening on {}", server_addr); + serve(session_context, &server_options) + .await + .map_err(|e| format!("Failed to run server: {}", e))?; - // Accept incoming connections - loop { - let (socket, addr) = listener.accept().await?; - let factory_ref = factory.clone(); - println!("Accepted connection from {}", addr); - - tokio::spawn(async move { - if let Err(e) = process_socket(socket, None, factory_ref).await { - eprintln!("Error processing socket: {}", e); - } - }); - } + Ok(()) } diff --git a/datafusion-postgres/Cargo.toml b/datafusion-postgres/Cargo.toml index 5f9aec4..b66bb72 100644 --- a/datafusion-postgres/Cargo.toml +++ b/datafusion-postgres/Cargo.toml @@ -21,8 +21,9 @@ bytes = "1.10.1" chrono = { version = "0.4", features = ["std"] } datafusion = { workspace = true } futures = "0.3" +getset = "0.1" log = "0.4" pgwire = { workspace = true } postgres-types = "0.2" rust_decimal = { version = "1.37", features = ["db-postgres"] } -tokio = { version = "1.45", features = ["sync"] } +tokio = { version = "1.45", features = ["sync", "net"] } diff --git a/datafusion-postgres/src/lib.rs b/datafusion-postgres/src/lib.rs index b677a2a..7f16ba2 100644 --- a/datafusion-postgres/src/lib.rs +++ b/datafusion-postgres/src/lib.rs @@ -4,3 +4,69 @@ mod handlers; mod information_schema; pub use handlers::{DfSessionService, HandlerFactory, Parser}; + +use std::sync::Arc; + +use datafusion::prelude::SessionContext; +use getset::{Getters, Setters, WithSetters}; +use pgwire::tokio::process_socket; +use tokio::net::TcpListener; + +#[derive(Getters, Setters, WithSetters)] +#[getset(get = "pub", set = "pub", set_with = "pub")] +pub struct ServerOptions { + host: String, + port: u16, +} + +impl ServerOptions { + pub fn new() -> ServerOptions { + ServerOptions::default() + } +} + +impl Default for ServerOptions { + fn default() -> Self { + ServerOptions { + host: "127.0.0.1".to_string(), + port: 5432, + } + } +} + +/// Serve the Datafusion `SessionContext` with Postgres protocol. +pub async fn serve( + session_context: SessionContext, + opts: &ServerOptions, +) -> Result<(), std::io::Error> { + // Get the first catalog name from the session context + let catalog_name = session_context + .catalog_names() // Fixed: Removed .catalog_list() + .first() + .cloned(); + + // Create the handler factory with the session context and catalog name + let factory = Arc::new(HandlerFactory(Arc::new(DfSessionService::new( + session_context, + catalog_name, + )))); + + // Bind to the specified host and port + let server_addr = format!("{}:{}", opts.host, opts.port); + let listener = TcpListener::bind(&server_addr).await?; + println!("Listening on {}", server_addr); + + // Accept incoming connections + loop { + if let Ok((socket, addr)) = listener.accept().await { + let factory_ref = factory.clone(); + println!("Accepted connection from {}", addr); + + tokio::spawn(async move { + if let Err(e) = process_socket(socket, None, factory_ref).await { + eprintln!("Error processing socket: {}", e); + } + }); + }; + } +}