Skip to content

Commit 5375bb5

Browse files
authored
Merge pull request #314 from aws-samples/v1.8.0_dev
add bigquery support
2 parents 11b27eb + 6d21153 commit 5375bb5

File tree

4 files changed

+67
-18
lines changed

4 files changed

+67
-18
lines changed

application/nlq/business/connection.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,11 @@ def get_db_url_by_name(cls, conn_name):
5353
conn_config = cls.get_conn_config_by_name(conn_name)
5454
return RelationDatabase.get_db_url_by_connection(conn_config)
5555

56+
@classmethod
57+
def get_db_password_host_by_name(cls, conn_name):
58+
conn_config = cls.get_conn_config_by_name(conn_name)
59+
return RelationDatabase.get_password_host_by_connection(conn_config)
60+
5661
@classmethod
5762
def get_db_type_by_name(cls, conn_name):
5863
conn_config = cls.get_conn_config_by_name(conn_name)

application/nlq/data_access/database.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import json
12

23
import sqlalchemy as db
34
from sqlalchemy import text, Column, inspect
@@ -39,10 +40,11 @@ def get_db_url(cls, db_type, user, password, host, port, db_name):
3940
)
4041
logger.info(f"db_url: {db_url}")
4142
elif db_type == 'bigquery':
43+
password = json.loads(password)
4244
db_url = db.engine.URL.create(
4345
drivername=cls.db_mapping[db_type],
4446
host=host, # BigQuery project. Note: without dataset
45-
query={'credentials_path': password}
47+
query={'credentials_path': json.dumps(password)}
4648
)
4749
else:
4850
db_url = db.engine.URL.create(
@@ -67,7 +69,11 @@ def test_connection(cls, db_type, user, password, host, port, db_name) -> bool:
6769
connect_args={'s3_staging_dir': db_name}
6870
)
6971
else:
70-
engine = db.create_engine(cls.get_db_url(db_type, user, password, host, port, db_name))
72+
if db_type == "bigquery":
73+
password = json.loads(password)
74+
engine = db.create_engine(url=host, credentials_info=password)
75+
else:
76+
engine = db.create_engine(cls.get_db_url(db_type, user, password, host, port, db_name))
7177
connection = engine.connect()
7278
return True
7379
except Exception as e:
@@ -80,7 +86,11 @@ def get_all_schema_names_by_connection(cls, connection: ConnectConfigEntity):
8086
db_type = connection.db_type
8187
db_url = cls.get_db_url(db_type, connection.db_user, connection.db_pwd, connection.db_host, connection.db_port,
8288
connection.db_name)
83-
engine = db.create_engine(db_url)
89+
if db_type == "bigquery":
90+
password = json.loads(connection.db_pwd)
91+
engine = db.create_engine(url=connection.db_host, credentials_info=password)
92+
else:
93+
engine = db.create_engine(db_url)
8494
inspector = inspect(engine)
8595

8696
if db_type == 'postgresql':
@@ -104,7 +114,11 @@ def get_all_tables_by_connection(cls, connection: ConnectConfigEntity, schemas=N
104114
def get_metadata_by_connection(cls, connection, schemas):
105115
db_url = cls.get_db_url(connection.db_type, connection.db_user, connection.db_pwd, connection.db_host,
106116
connection.db_port, connection.db_name)
107-
engine = db.create_engine(db_url)
117+
if connection.db_type == "bigquery":
118+
password = json.loads(connection.db_pwd)
119+
engine = db.create_engine(url=connection.db_host, credentials_info=password)
120+
else:
121+
engine = db.create_engine(db_url)
108122
# connection = engine.connect()
109123
metadata = db.MetaData()
110124
if connection.db_type == 'bigquery':
@@ -179,3 +193,7 @@ def get_db_url_by_connection(cls, connection: ConnectConfigEntity):
179193
db_url = cls.get_db_url(connection.db_type, connection.db_user, connection.db_pwd, connection.db_host,
180194
connection.db_port, connection.db_name)
181195
return db_url
196+
197+
@classmethod
198+
def get_password_host_by_connection(cls, connection: ConnectConfigEntity):
199+
return connection.db_pwd, connection.db_host

application/pages/2_🪙_Data_Connection_Management.py

Lines changed: 32 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
1+
import json
2+
13
import streamlit as st
24
from dotenv import load_dotenv
35
from nlq.business.connection import ConnectionManagement
46
from nlq.data_access.database import RelationDatabase
7+
from utils.logging import getLogger
58
from utils.navigation import make_sidebar
69

7-
10+
logger = getLogger()
811
# global variables
912

1013
db_type_mapping = {
@@ -78,23 +81,39 @@ def main():
7881
db_type = db_type.lower() # Convert to lowercase for matching with db_mapping keys
7982
if db_type == 'athena':
8083
st.info("Please enter S3 staging directory in the database name field. You can leave other fields empty. Please also make sure that IAM role is able to access Athena and S3.")
81-
host = st.text_input("Enter host")
82-
port = st.text_input("Enter port")
83-
user = st.text_input("Enter username")
84-
password = st.text_input("Enter password", type="password")
85-
db_name = st.text_input("Enter database name")
86-
comment = st.text_input("Enter comment")
87-
88-
test_connection_view(db_type, user, password, host, port, db_name)
8984

90-
if st.button('Add Connection', type='primary'):
91-
if db_name == '':
92-
st.error("Database name is required!")
93-
else:
85+
if db_type == "bigquery":
86+
host = st.text_input("Enter host")
87+
password = st.text_area("Credentials Info", height=200, placeholder="Paste your credentials info here")
88+
# credentials_info = json.loads(credentials_info)
89+
port = ""
90+
user = ""
91+
db_name = ""
92+
comment = st.text_input("Enter comment")
93+
test_connection_view(db_type, user, password, host, port, db_name)
94+
if st.button('Add Connection', type='primary'):
9495
ConnectionManagement.add_connection(connection_name, db_type, host, port, user, password, db_name, comment)
9596
st.success(f"{connection_name} added successfully!")
9697
st.session_state.new_connection_mode = False
9798

99+
else:
100+
host = st.text_input("Enter host")
101+
port = st.text_input("Enter port")
102+
user = st.text_input("Enter username")
103+
password = st.text_input("Enter password", type="password")
104+
db_name = st.text_input("Enter database name")
105+
comment = st.text_input("Enter comment")
106+
107+
test_connection_view(db_type, user, password, host, port, db_name)
108+
109+
if st.button('Add Connection', type='primary'):
110+
if db_name == '':
111+
st.error("Database name is required!")
112+
else:
113+
ConnectionManagement.add_connection(connection_name, db_type, host, port, user, password, db_name, comment)
114+
st.success(f"{connection_name} added successfully!")
115+
st.session_state.new_connection_mode = False
116+
98117
elif st.session_state.update_connection_mode:
99118
st.subheader("Update Database Connection")
100119
current_conn = st.session_state.current_connection

application/utils/apis.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import json
2+
13
import sqlalchemy as db
24
from sqlalchemy import text
35
from utils.env_var import RDS_MYSQL_HOST, RDS_MYSQL_PORT, RDS_MYSQL_USERNAME, RDS_MYSQL_PASSWORD, RDS_MYSQL_DBNAME, RDS_PQ_SCHEMA
@@ -88,7 +90,12 @@ def get_sql_result_tool(profile, sql):
8890
RDS_MYSQL_DBNAME=RDS_MYSQL_DBNAME,
8991
))
9092
else:
91-
engine = db.create_engine(p_db_url)
93+
if profile['db_type'] == "bigquery":
94+
password, host = ConnectionManagement.get_db_password_host_by_name(profile['conn_name'])
95+
password = json.loads(password)
96+
engine = db.create_engine(url=host, credentials_info=password)
97+
else:
98+
engine = db.create_engine(p_db_url)
9299
with engine.connect() as connection:
93100
logger.info(f'{sql=}')
94101
executed_result_df = pd.read_sql_query(text(sql), connection)

0 commit comments

Comments
 (0)