Skip to content

Commit d8e2472

Browse files
committed
Add AI powered SQL generation using ClickHouse client's API.
1 parent c9103f8 commit d8e2472

File tree

8 files changed

+353
-12
lines changed

8 files changed

+353
-12
lines changed

chdb/build.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ CMAKE_ARGS="-DCMAKE_BUILD_TYPE=${build_type} -DENABLE_THINLTO=0 -DENABLE_TESTS=0
9595
-DENABLE_KAFKA=1 -DENABLE_LIBPQXX=1 -DENABLE_NATS=0 -DENABLE_AMQPCPP=0 -DENABLE_NURAFT=0 \
9696
-DENABLE_CASSANDRA=0 -DENABLE_ODBC=0 -DENABLE_NLP=0 \
9797
-DENABLE_LDAP=0 \
98+
-DENABLE_CLIENT_AI=1 \
9899
${MYSQL} \
99100
${HDFS} \
100101
-DENABLE_LIBRARIES=0 ${RUST_FEATURES} \
Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
#include "AIQueryProcessor.h"
2+
3+
#include "chdb-internal.h"
4+
5+
#include <pybind11/pybind11.h>
6+
#include <Poco/String.h>
7+
8+
#if USE_CLIENT_AI
9+
# include <Client/AI/AIClientFactory.h>
10+
# include <Client/AI/AISQLGenerator.h>
11+
#endif
12+
13+
#include <cstdlib>
14+
#include <iostream>
15+
#include <stdexcept>
16+
17+
namespace py = pybind11;
18+
19+
#if USE_CLIENT_AI
20+
21+
AIQueryProcessor::AIQueryProcessor(chdb_connection * connection_) : connection(connection_) { }
22+
23+
AIQueryProcessor::~AIQueryProcessor() = default;
24+
25+
namespace
26+
{
27+
DB::AIConfiguration loadAIConfigFromEnv()
28+
{
29+
DB::AIConfiguration config;
30+
31+
if (const char * api_key = std::getenv("AI_API_KEY"))
32+
config.api_key = api_key;
33+
else if (const char * openai_key = std::getenv("OPENAI_API_KEY"))
34+
config.api_key = openai_key;
35+
else if (const char * anthropic_key = std::getenv("ANTHROPIC_API_KEY"))
36+
config.api_key = anthropic_key;
37+
38+
if (const char * base_url = std::getenv("AI_BASE_URL"))
39+
{
40+
config.base_url = base_url;
41+
}
42+
else if (const char * openai_base = std::getenv("OPENAI_API_BASE"))
43+
{
44+
config.base_url = openai_base;
45+
config.provider = "openai";
46+
}
47+
else if (const char * anthropic_base = std::getenv("ANTHROPIC_API_URL"))
48+
{
49+
config.base_url = anthropic_base;
50+
config.provider = "anthropic";
51+
}
52+
53+
if (const char * model = std::getenv("AI_MODEL"))
54+
config.model = model;
55+
if (const char * provider = std::getenv("AI_PROVIDER"))
56+
config.provider = provider;
57+
58+
std::cerr << "[chdb] Using AI config: "
59+
<< "provider=" << (config.provider.empty() ? "<auto>" : config.provider) << ", "
60+
<< "model=" << (config.model.empty() ? "<default>" : config.model) << ", "
61+
<< "base_url=" << (config.base_url.empty() ? "<default>" : config.base_url) << std::endl;
62+
63+
return config;
64+
}
65+
}
66+
67+
std::string AIQueryProcessor::executeQueryForAI(const std::string & query)
68+
{
69+
auto run_query = [this, &query]()
70+
{
71+
return chdb_query_n(*connection, query.data(), query.size(), "TSV", 3);
72+
};
73+
74+
chdb_result * result = nullptr;
75+
if (PyGILState_Check())
76+
{
77+
py::gil_scoped_release release;
78+
result = run_query();
79+
}
80+
else
81+
{
82+
result = run_query();
83+
}
84+
85+
const auto & error_msg = CHDB::chdb_result_error_string(result);
86+
if (!error_msg.empty())
87+
{
88+
std::string msg_copy(error_msg);
89+
chdb_destroy_query_result(result);
90+
throw std::runtime_error(msg_copy);
91+
}
92+
93+
std::string data(chdb_result_buffer(result), chdb_result_length(result));
94+
chdb_destroy_query_result(result);
95+
return data;
96+
}
97+
98+
void AIQueryProcessor::initializeGenerator()
99+
{
100+
if (generator)
101+
return;
102+
103+
DB::AIConfiguration ai_config = loadAIConfigFromEnv();
104+
auto ai_result = DB::AIClientFactory::createClient(ai_config);
105+
106+
if (ai_result.no_configuration_found || !ai_result.client.has_value())
107+
throw std::runtime_error("AI SQL generator is not configured. Set OPENAI_API_KEY or ANTHROPIC_API_KEY to enable it.");
108+
109+
auto query_executor = [this](const std::string & query_text) { return executeQueryForAI(query_text); };
110+
generator = std::make_unique<DB::AISQLGenerator>(ai_config, std::move(ai_result.client.value()), query_executor, std::cerr);
111+
}
112+
113+
std::string AIQueryProcessor::generateSQLFromPrompt(const std::string & prompt)
114+
{
115+
initializeGenerator();
116+
117+
if (!generator)
118+
throw std::runtime_error("AI SQL generator is not configured. Set OPENAI_API_KEY or ANTHROPIC_API_KEY to enable it.");
119+
120+
auto run_generation = [this, &prompt]() { return generator->generateSQL(prompt); };
121+
122+
std::string sql;
123+
if (PyGILState_Check())
124+
{
125+
py::gil_scoped_release release;
126+
sql = run_generation();
127+
}
128+
else
129+
{
130+
sql = run_generation();
131+
}
132+
133+
if (sql.empty())
134+
throw std::runtime_error("AI did not return a SQL query.");
135+
136+
return sql;
137+
}
138+
139+
std::string AIQueryProcessor::preprocess(const std::string & query)
140+
{
141+
std::string prompt;
142+
if (!extractAIPrompt(query, prompt))
143+
return query;
144+
145+
return generateSQLFromPrompt(prompt);
146+
}
147+
148+
#else
149+
150+
AIQueryProcessor::AIQueryProcessor(chdb_connection *) : connection(nullptr) { }
151+
AIQueryProcessor::~AIQueryProcessor() = default;
152+
std::string AIQueryProcessor::executeQueryForAI(const std::string &) { return {}; }
153+
void AIQueryProcessor::initializeGenerator() { }
154+
std::string AIQueryProcessor::generateSQLFromPrompt(const std::string &) { return {}; }
155+
std::string AIQueryProcessor::preprocess(const std::string & query) { return query; }
156+
157+
#endif
158+
159+
bool extractAIPrompt(const std::string & query, std::string & prompt_out)
160+
{
161+
auto trimmed = Poco::trimLeft(query);
162+
if (!trimmed.starts_with("??"))
163+
return false;
164+
165+
auto prompt = Poco::trimLeft(trimmed.substr(2));
166+
if (prompt.empty())
167+
throw std::runtime_error("Please provide a natural language query after ??");
168+
169+
prompt_out = prompt;
170+
return true;
171+
}

programs/local/AIQueryProcessor.h

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
#pragma once
2+
3+
#include "chdb.h"
4+
#include <Client/AI/AISQLGenerator.h>
5+
6+
#include <memory>
7+
#include <string>
8+
9+
/// AI query processor that handles "??" prompts and delegates to AISQLGenerator.
10+
class AIQueryProcessor
11+
{
12+
public:
13+
explicit AIQueryProcessor(chdb_connection * connection_);
14+
~AIQueryProcessor();
15+
16+
/// If query starts with "??", generate SQL via AI; otherwise return original.
17+
std::string preprocess(const std::string & query);
18+
19+
private:
20+
chdb_connection * connection;
21+
std::unique_ptr<DB::AISQLGenerator> generator;
22+
23+
std::string executeQueryForAI(const std::string & query);
24+
std::string generateSQLFromPrompt(const std::string & prompt);
25+
void initializeGenerator();
26+
};
27+
28+
/// Utility to extract prompt after "??" or return false if not present.
29+
bool extractAIPrompt(const std::string & query, std::string & prompt_out);

programs/local/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ if (USE_PYTHON)
4444
PandasDataFrame.cpp
4545
PandasDataFrameBuilder.cpp
4646
PandasScan.cpp
47+
AIQueryProcessor.cpp
4748
PyArrowStreamFactory.cpp
4849
PyArrowTable.cpp
4950
PybindWrapper.cpp

programs/local/LocalChdb.cpp

Lines changed: 82 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include "LocalChdb.h"
2+
#include "AIQueryProcessor.h"
23
#include "chdb-internal.h"
34
#include "PandasDataFrameBuilder.h"
45
#include "ChunkCollectorOutputFormat.h"
@@ -9,10 +10,15 @@
910
#include <pybind11/pybind11.h>
1011
#include <Poco/String.h>
1112
#include <Common/logger_useful.h>
13+
#include <Common/quoteString.h>
1214
#if USE_JEMALLOC
1315
# include <Common/memory.h>
1416
#endif
1517

18+
#include <iostream>
19+
#include <sstream>
20+
#include <stdexcept>
21+
1622
namespace py = pybind11;
1723

1824
extern bool inside_main = true;
@@ -94,6 +100,27 @@ memoryview_wrapper * query_result::get_memview()
94100
return new memoryview_wrapper(this->result_wrapper);
95101
}
96102

103+
#if USE_CLIENT_AI
104+
namespace
105+
{
106+
107+
bool extractAIPrompt(const std::string & query, std::string & prompt_out)
108+
{
109+
auto trimmed = Poco::trimLeft(query);
110+
if (!trimmed.starts_with("??"))
111+
return false;
112+
113+
auto prompt = Poco::trimLeft(trimmed.substr(2));
114+
if (prompt.empty())
115+
throw std::runtime_error("Please provide a natural language query after ??");
116+
117+
prompt_out = prompt;
118+
return true;
119+
}
120+
121+
}
122+
#endif
123+
97124

98125
// Parse SQLite-style connection string
99126
std::pair<std::string, std::map<std::string, std::string>> connection_wrapper::parse_connection_string(const std::string & conn_str)
@@ -220,6 +247,27 @@ connection_wrapper::build_clickhouse_args(const std::string & path, const std::m
220247
return argv;
221248
}
222249

250+
std::string connection_wrapper::preprocessQuery(const std::string & query_str)
251+
{
252+
#if USE_CLIENT_AI
253+
try
254+
{
255+
if (!ai_processor)
256+
ai_processor = std::make_unique<AIQueryProcessor>(conn);
257+
return ai_processor->preprocess(query_str);
258+
}
259+
catch (const std::exception & e)
260+
{
261+
throw std::runtime_error(std::string("AI SQL generation failed: ") + e.what());
262+
}
263+
#else
264+
auto trimmed = Poco::trimLeft(query_str);
265+
if (trimmed.starts_with("??"))
266+
throw std::runtime_error("AI SQL generation is not available in this build. Rebuild with USE_CLIENT_AI enabled.");
267+
return query_str;
268+
#endif
269+
}
270+
223271
connection_wrapper::connection_wrapper(const std::string & conn_str)
224272
{
225273
auto [path, params] = parse_connection_string(conn_str);
@@ -263,15 +311,37 @@ void connection_wrapper::commit()
263311
// do nothing
264312
}
265313

314+
static bool isAIGenSqlQuery(const std::string & original, const std::string & processed)
315+
{
316+
#if USE_CLIENT_AI
317+
auto trimmed = Poco::trimLeft(original);
318+
return trimmed.starts_with("??") && processed != original;
319+
#else
320+
return false;
321+
#endif
322+
}
323+
266324
query_result * connection_wrapper::query(const std::string & query_str, const std::string & format)
267325
{
268326
if (Poco::toLower(format) == "dataframe")
269327
throw std::runtime_error("Unsupported output format dataframe, please use 'query_df' function");
270328

271-
CHDB::cachePythonTablesFromQuery(reinterpret_cast<chdb_conn *>(*conn), query_str);
329+
auto query = preprocessQuery(query_str);
330+
331+
#if USE_CLIENT_AI
332+
if (isAIGenSqlQuery(query_str, query))
333+
{
334+
// Return generated SQL as plain text without executing
335+
auto sql_literal = "SELECT " + DB::quoteString(query) + " AS query FORMAT Raw";
336+
auto * result = chdb_query_n(*conn, sql_literal.data(), sql_literal.size(), "Raw", 3);
337+
return new query_result(result, false);
338+
}
339+
#endif
340+
341+
CHDB::cachePythonTablesFromQuery(reinterpret_cast<chdb_conn *>(*conn), query);
272342
py::gil_scoped_release release;
273343

274-
auto * result = chdb_query_n(*conn, query_str.data(), query_str.size(), format.data(), format.size());
344+
auto * result = chdb_query_n(*conn, query.data(), query.size(), format.data(), format.size());
275345

276346
const auto & error_msg = CHDB::chdb_result_error_string(result);
277347
if (!error_msg.empty())
@@ -291,12 +361,14 @@ py::object connection_wrapper::query_df(const std::string & query_str)
291361
chdb_result * result = nullptr;
292362
CHDB::ChunkQueryResult * chunk_result = nullptr;
293363

294-
CHDB::cachePythonTablesFromQuery(reinterpret_cast<chdb_conn *>(*conn), query_str);
364+
auto query = preprocessQuery(query_str);
365+
366+
CHDB::cachePythonTablesFromQuery(reinterpret_cast<chdb_conn *>(*conn), query);
295367

296368
{
297369
py::gil_scoped_release release;
298370

299-
result = chdb_query_n(*conn, query_str.data(), query_str.size(), format.data(), format.size());
371+
result = chdb_query_n(*conn, query.data(), query.size(), format.data(), format.size());
300372

301373
const auto & error_msg = CHDB::chdb_result_error_string(result);
302374
if (!error_msg.empty())
@@ -319,9 +391,10 @@ py::object connection_wrapper::query_df(const std::string & query_str)
319391

320392
streaming_query_result * connection_wrapper::send_query(const std::string & query_str, const std::string & format)
321393
{
322-
CHDB::cachePythonTablesFromQuery(reinterpret_cast<chdb_conn *>(*conn), query_str);
394+
auto query = preprocessQuery(query_str);
395+
CHDB::cachePythonTablesFromQuery(reinterpret_cast<chdb_conn *>(*conn), query);
323396
py::gil_scoped_release release;
324-
auto * result = chdb_stream_query_n(*conn, query_str.data(), query_str.size(), format.data(), format.size());
397+
auto * result = chdb_stream_query_n(*conn, query.data(), query.size(), format.data(), format.size());
325398
const auto & error_msg = CHDB::chdb_result_error_string(result);
326399
if (!error_msg.empty())
327400
{
@@ -397,10 +470,11 @@ void connection_wrapper::streaming_cancel_query(streaming_query_result * streami
397470
void cursor_wrapper::execute(const std::string & query_str)
398471
{
399472
release_result();
400-
CHDB::cachePythonTablesFromQuery(reinterpret_cast<chdb_conn *>(conn->get_conn()), query_str);
473+
auto query = conn->preprocessQuery(query_str);
474+
CHDB::cachePythonTablesFromQuery(reinterpret_cast<chdb_conn *>(conn->get_conn()), query);
401475
// Use JSONCompactEachRowWithNamesAndTypes format for better type support
402476
py::gil_scoped_release release;
403-
current_result = chdb_query_n(conn->get_conn(), query_str.data(), query_str.size(), CURSOR_DEFAULT_FORMAT, CURSOR_DEFAULT_FORMAT_LEN);
477+
current_result = chdb_query_n(conn->get_conn(), query.data(), query.size(), CURSOR_DEFAULT_FORMAT, CURSOR_DEFAULT_FORMAT_LEN);
404478
}
405479

406480

0 commit comments

Comments
 (0)