Skip to content

Commit 71614b8

Browse files
author
Pinyu Su
committed
add big query
1 parent a1a5f24 commit 71614b8

File tree

6 files changed

+35
-10
lines changed

6 files changed

+35
-10
lines changed

application/nlq/data_access/database.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@ class RelationDatabase():
1515
'starrocks': 'starrocks',
1616
'clickhouse': 'clickhouse',
1717
'hive': 'hive',
18-
'athena': 'awsathena+rest'
18+
'athena': 'awsathena+rest',
19+
'bigquery': 'bigquery',
1920
# Add more mappings here for other databases
2021
}
2122

@@ -37,6 +38,12 @@ def get_db_url(cls, db_type, user, password, host, port, db_name):
3738
query={'s3_staging_dir': db_name}
3839
)
3940
logger.info(f"db_url: {db_url}")
41+
elif db_type == 'bigquery':
42+
db_url = db.engine.URL.create(
43+
drivername=cls.db_mapping[db_type],
44+
host=host, # BigQuery project. Note: without dataset
45+
query={'credentials_path': password}
46+
)
4047
else:
4148
db_url = db.engine.URL.create(
4249
drivername=cls.db_mapping[db_type],
@@ -79,7 +86,7 @@ def get_all_schema_names_by_connection(cls, connection: ConnectConfigEntity):
7986
if db_type == 'postgresql':
8087
schemas = [schema for schema in inspector.get_schema_names() if
8188
schema not in ('pg_catalog', 'information_schema', 'public')]
82-
elif db_type in ('redshift', 'mysql', 'starrocks', 'clickhouse', 'hive', 'athena'):
89+
elif db_type in ('redshift', 'mysql', 'starrocks', 'clickhouse', 'hive', 'athena', 'bigquery'):
8390
schemas = inspector.get_schema_names()
8491
else:
8592
raise ValueError("Unsupported database type")
@@ -100,10 +107,14 @@ def get_metadata_by_connection(cls, connection, schemas):
100107
engine = db.create_engine(db_url)
101108
# connection = engine.connect()
102109
metadata = db.MetaData()
103-
for s in schemas:
104-
metadata.reflect(bind=engine, schema=s, views=True)
105-
# metadata.reflect(bind=engine)
106-
return metadata
110+
if connection.db_type == 'bigquery':
111+
metadata.reflect(bind=engine)
112+
return metadata
113+
else:
114+
for s in schemas:
115+
metadata.reflect(bind=engine, schema=s, views=True)
116+
# metadata.reflect(bind=engine)
117+
return metadata
107118

108119
@classmethod
109120
def get_table_definition_by_connection(cls, connection: ConnectConfigEntity, schemas, table_names):

application/pages/2_🪙_Data_Connection_Management.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414
'starrocks': 'StarRocks',
1515
'clickhouse': 'Clickhouse',
1616
'hive': 'Hive',
17-
'athena': 'Athena'
17+
'athena': 'Athena',
18+
'bigquery': 'BigQuery'
1819
}
1920

2021

application/requirements-api.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,4 +23,5 @@ sqlalchemy-redshift~=0.8.14
2323
numpy==1.26.4
2424
pyhive==0.7.0
2525
thrift==0.20.0
26-
thrift-sasl==0.4.3
26+
thrift-sasl==0.4.3
27+
sqlalchemy-bigquery==1.11.0

application/requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,4 +22,5 @@ sqlalchemy-redshift~=0.8.14
2222
numpy==1.26.4
2323
pyhive==0.7.0
2424
thrift==0.20.0
25-
thrift-sasl==0.4.3
25+
thrift-sasl==0.4.3
26+
sqlalchemy-bigquery==1.11.0

application/utils/prompt.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,3 +245,12 @@
245245
The data is:{data}
246246
247247
"""
248+
249+
250+
BIGQUERY_DIALECT_PROMPT_CLAUDE3 = """
251+
You are a data analysis expert and proficient in Google BigQuery. Given an input question, first create a syntactically correct BigQuery SQL query to run.
252+
Unless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per BigQuery.
253+
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Use backticks (`) to denote table and column names as delimited identifiers.
254+
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
255+
Pay attention to use CURRENT_DATE() function to get the current date, if the question involves "today". Aside from giving the SQL answer, concisely explain yourself after giving the answer in the same language as the question.
256+
""".format(top_k=TOP_K)

application/utils/prompts/generate_prompt.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from utils.logging import getLogger
22
from utils.prompt import POSTGRES_DIALECT_PROMPT_CLAUDE3, MYSQL_DIALECT_PROMPT_CLAUDE3, \
33
DEFAULT_DIALECT_PROMPT, AGENT_COT_EXAMPLE, AWS_REDSHIFT_DIALECT_PROMPT_CLAUDE3, STARROCKS_DIALECT_PROMPT_CLAUDE3, \
4-
CLICKHOUSE_DIALECT_PROMPT_CLAUDE3, HIVE_DIALECT_PROMPT_CLAUDE3
4+
CLICKHOUSE_DIALECT_PROMPT_CLAUDE3, HIVE_DIALECT_PROMPT_CLAUDE3, BIGQUERY_DIALECT_PROMPT_CLAUDE3
55
from utils.prompts import guidance_prompt
66
from utils.prompts import table_prompt
77

@@ -2209,6 +2209,8 @@ def generate_llm_prompt(ddl, hints, prompt_map, search_box, sql_examples=None, n
22092209
dialect_prompt = CLICKHOUSE_DIALECT_PROMPT_CLAUDE3
22102210
elif dialect == 'hive':
22112211
dialect_prompt = HIVE_DIALECT_PROMPT_CLAUDE3
2212+
elif dialect == 'bigquery':
2213+
dialect_prompt = BIGQUERY_DIALECT_PROMPT_CLAUDE3
22122214
else:
22132215
dialect_prompt = DEFAULT_DIALECT_PROMPT
22142216

0 commit comments

Comments
 (0)