Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ The deployment guide here is CDK only. For manual deployment or detailed guide,

## Introduction

A NLQ(Natural Language Query) demo using Amazon Bedrock, Amazon OpenSearch with RAG technique.
A Generative BI demo using Amazon Bedrock, Amazon OpenSearch with RAG technique.

![Screenshot](./assets/aws_architecture.png)
*Reference Architecture on AWS*
Expand Down
9 changes: 8 additions & 1 deletion application/api/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import logging
from nlq.business.profile import ProfileManagement
from .enum import ContentEnum
from .schemas import Question, Answer, Option, CustomQuestion, FeedBackInput
from .schemas import Question, Answer, Option, CustomQuestion, FeedBackInput, HistoryRequest
from . import service
from nlq.business.nlq_chain import NLQChain
from dotenv import load_dotenv
Expand Down Expand Up @@ -38,6 +38,13 @@ def ask(question: Question):
return service.ask(question)


@router.post("/get_history_by_user_profile")
def get_history_by_user_profile(history_request : HistoryRequest):
user_id = history_request.user_id
profile_name = history_request.profile_name
return service.get_history_by_user_profile(user_id, profile_name)


@router.post("/user_feedback")
def user_feedback(input_data: FeedBackInput):
feedback_type = input_data.feedback_type
Expand Down
29 changes: 27 additions & 2 deletions application/api/schemas.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any
from typing import Any, Union
from pydantic import BaseModel


Expand All @@ -17,7 +17,7 @@ class Question(BaseModel):
top_p: float = 0.9
max_tokens: int = 2048
temperature: float = 0.01
context_window: int = 3
context_window: int = 5
session_id: str = "-1"
user_id: str = "admin"

Expand All @@ -28,6 +28,11 @@ class Example(BaseModel):
answer: str


class HistoryRequest(BaseModel):
user_id: str
profile_name: str


class QueryEntity(BaseModel):
query: str
sql: str
Expand Down Expand Up @@ -80,10 +85,30 @@ class AgentSearchResult(BaseModel):
agent_summary: str


class AskReplayResult(BaseModel):
query_rewrite: str


class Answer(BaseModel):
query: str
query_rewrite: str = ""
query_intent: str
knowledge_search_result: KnowledgeSearchResult
sql_search_result: SQLSearchResult
agent_search_result: AgentSearchResult
ask_rewrite_result: AskReplayResult
suggested_question: list[str]


class Message(BaseModel):
type: str
content: Union[str, Answer]


class HistoryMessage(BaseModel):
session_id: str
messages: list[Message]


class ChatHistory(BaseModel):
messages: list[HistoryMessage]
152 changes: 122 additions & 30 deletions application/api/service.py

Large diffs are not rendered by default.

44 changes: 37 additions & 7 deletions application/main.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,68 @@
import json
import logging

from fastapi import FastAPI, status
from fastapi.staticfiles import StaticFiles
from fastapi.responses import RedirectResponse
from api.exception_handler import biz_exception
from api.main import router
from fastapi.middleware.cors import CORSMiddleware
from api import service
from api.schemas import Option
from api.schemas import Option, Message
from nlq.business.log_store import LogManagement
from utils.tool import set_share_data, get_share_data

MAX_CHAT_WINDOW_SIZE = 10 * 2
app = FastAPI(title='GenBI')

# 配置CORS中间件
app.add_middleware(
CORSMiddleware,
allow_origins=['*'], # 允许所有源访问,可以根据需求进行修改
allow_credentials=True, # 允许发送凭据(如Cookie)
allow_methods=['*'], # 允许所有HTTP方法
allow_headers=['*'], # 允许所有请求头
allow_origins=['*'],
allow_credentials=True,
allow_methods=['*'],
allow_headers=['*'],
)

# Global exception capture
biz_exception(app)
app.mount("/static", StaticFiles(directory="static"), name="static")
app.include_router(router)


# changed from "/" to "/test" to avoid health check fails in ECS
@app.get("/test", status_code=status.HTTP_302_FOUND)
def index():
return RedirectResponse("static/WebSocket.html")


# health check
@app.get("/")
def health():
return {"status": "ok"}


@app.get("/option", response_model=Option)
def option():
return service.get_option()
return service.get_option()


@app.on_event("startup")
def set_history_in_share():
logging.info("Setting history in share data")
history_list = LogManagement.get_all_history()
chat_history_session = {}
for item in history_list:
session_id = item['session_id']
if session_id not in chat_history_session:
chat_history_session[session_id] = []
log_info = item['log_info']
query = item['query']
human_message = Message(type="human", content=query)
bot_message = Message(type="AI", content=json.loads(log_info))
chat_history_session[session_id].append(human_message)
chat_history_session[session_id].append(bot_message)

for key, value in chat_history_session.items():
value = value[-MAX_CHAT_WINDOW_SIZE:]
set_share_data(key, value)
logging.info("Setting history in share data done")
16 changes: 14 additions & 2 deletions application/nlq/business/log_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,18 @@ class LogManagement:
query_log_dao = DynamoQueryLogDao()

@classmethod
def add_log_to_database(cls, log_id, user_id, session_id, profile_name, sql, query, intent, log_info, time_str):
def add_log_to_database(cls, log_id, user_id, session_id, profile_name, sql, query, intent, log_info, time_str,
log_type="normal_log", ):
cls.query_log_dao.add_log(log_id=log_id, profile_name=profile_name, user_id=user_id, session_id=session_id,
sql=sql, query=query, intent=intent, log_info=log_info, time_str=time_str)
sql=sql, query=query, intent=intent, log_info=log_info, log_type=log_type,
time_str=time_str)

@classmethod
def get_history(cls, user_id, profile_name):
history_list = cls.query_log_dao.get_history_by_user_profile(user_id, profile_name)
return history_list

@classmethod
def get_all_history(cls):
history_list = cls.query_log_dao.get_all_history()
return history_list
39 changes: 33 additions & 6 deletions application/nlq/business/vector_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,16 +86,30 @@ def add_sample(cls, profile_name, question, answer):
logger.info('Sample added')

@classmethod
def add_entity_sample(cls, profile_name, entity, comment):
def add_entity_sample(cls, profile_name, entity, comment, entity_type="metrics", entity_info_dict=None):
if entity_type == "metrics" or entity_info_dict is None:
entity_table_info = []
else:
entity_table_info = [entity_info_dict]
logger.info(f'add sample entity: {entity} to profile {profile_name}')
if SAGEMAKER_ENDPOINT_EMBEDDING is not None and SAGEMAKER_ENDPOINT_EMBEDDING != "":
embedding = cls.create_vector_embedding_with_sagemaker(entity)
else:
embedding = cls.create_vector_embedding_with_bedrock(entity)
has_same_sample = cls.search_same_query(profile_name, 1, opensearch_info['ner_index'], embedding)
if has_same_sample:
logger.info(f'delete sample sample entity: {entity} to profile {profile_name}')
if cls.opensearch_dao.add_entity_sample(opensearch_info['ner_index'], profile_name, entity, comment, embedding):
if entity_type == "metrics":
has_same_sample = cls.search_same_query(profile_name, 1, opensearch_info['ner_index'], embedding)
if has_same_sample:
logger.info(f'delete sample sample entity: {entity} to profile {profile_name}')
else:
same_dimension_value = cls.search_same_dimension_entity(profile_name, 1, opensearch_info['ner_index'],
embedding)
if len(same_dimension_value) > 0:
for item in same_dimension_value:
entity_table_info.append(item)
logger.info("entity_table_info: " + str(entity_table_info))
has_same_sample = cls.search_same_query(profile_name, 1, opensearch_info['ner_index'], embedding)
if cls.opensearch_dao.add_entity_sample(opensearch_info['ner_index'], profile_name, entity, comment, embedding,
entity_type, entity_table_info):
logger.info('Sample added')

@classmethod
Expand All @@ -108,7 +122,8 @@ def add_agent_cot_sample(cls, profile_name, entity, comment):
has_same_sample = cls.search_same_query(profile_name, 1, opensearch_info['agent_index'], embedding)
if has_same_sample:
logger.info(f'delete agent sample sample query: {entity} to profile {profile_name}')
if cls.opensearch_dao.add_agent_cot_sample(opensearch_info['agent_index'], profile_name, entity, comment, embedding):
if cls.opensearch_dao.add_agent_cot_sample(opensearch_info['agent_index'], profile_name, entity, comment,
embedding):
logger.info('Sample added')

@classmethod
Expand Down Expand Up @@ -192,3 +207,15 @@ def search_same_query(cls, profile_name, top_k, index_name, embedding):
else:
return False
return False

@classmethod
def search_same_dimension_entity(cls, profile_name, top_k, index_name, embedding):
search_res = cls.search_sample_with_embedding(profile_name, top_k, index_name, embedding)
same_dimension_value = []
if len(search_res) > 0:
similarity_sample = search_res[0]
similarity_score = similarity_sample["_score"]
if similarity_score == 1.0:
if index_name == opensearch_info['ner_index']:
same_dimension_value = similarity_sample["_source"]["entity_table_info"]
return same_dimension_value
71 changes: 68 additions & 3 deletions application/nlq/data_access/dynamo_query_log.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import boto3
from botocore.exceptions import ClientError
from boto3.dynamodb.conditions import Key

logger = logging.getLogger(__name__)

Expand All @@ -12,7 +13,7 @@


class DynamoQueryLogEntity:
def __init__(self, log_id, profile_name, user_id, session_id, sql, query, intent, log_info, time_str):
def __init__(self, log_id, profile_name, user_id, session_id, sql, query, intent, log_info, log_type, time_str):
self.log_id = log_id
self.profile_name = profile_name
self.user_id = user_id
Expand All @@ -21,6 +22,7 @@ def __init__(self, log_id, profile_name, user_id, session_id, sql, query, intent
self.query = query
self.intent = intent
self.log_info = log_info
self.log_type = log_type
self.time_str = time_str

def to_dict(self):
Expand All @@ -34,6 +36,7 @@ def to_dict(self):
'query': self.query,
'intent': self.intent,
'log_info': self.log_info,
'log_type': self.log_type,
'time_str': self.time_str
}

Expand Down Expand Up @@ -113,6 +116,68 @@ def add(self, entity):
def update(self, entity):
self.table.put_item(Item=entity.to_dict())

def add_log(self, log_id, profile_name, user_id, session_id, sql, query, intent, log_info, time_str):
entity = DynamoQueryLogEntity(log_id, profile_name, user_id, session_id, sql, query, intent, log_info, time_str)
def add_log(self, log_id, profile_name, user_id, session_id, sql, query, intent, log_info, log_type, time_str):
entity = DynamoQueryLogEntity(log_id, profile_name, user_id, session_id, sql, query, intent, log_info, log_type, time_str)
self.add(entity)

def get_history_by_user_profile(self, user_id, profile_name):
try:
# First, we need to scan the table to find all items for the user and profile
response = self.table.scan(
FilterExpression=Key('user_id').eq(user_id) & Key('profile_name').eq(profile_name) & Key('log_type').eq("chat_history")
)

items = response['Items']

# DynamoDB might not return all items in a single response if the data set is large
while 'LastEvaluatedKey' in response:
response = self.table.scan(
FilterExpression=Key('user_id').eq(user_id) & Key('profile_name').eq(profile_name) & Key('log_type').eq("chat_history"),
ExclusiveStartKey=response['LastEvaluatedKey']
)
items.extend(response['Items'])

# Sort the items by time_str to get them in chronological order
sorted_items = sorted(items, key=lambda x: x['time_str'])

return sorted_items

except ClientError as err:
logger.error(
"Couldn't get history for user %s and profile %s. Here's why: %s: %s",
user_id,
profile_name,
err.response["Error"]["Code"],
err.response["Error"]["Message"],
)
return []

def get_all_history(self):
try:
# First, we need to scan the table to find all items for the user and profile
response = self.table.scan(
FilterExpression=Key('log_type').eq("chat_history")
)

items = response['Items']

# DynamoDB might not return all items in a single response if the data set is large
while 'LastEvaluatedKey' in response:
response = self.table.scan(
FilterExpression=Key('log_type').eq("chat_history"),
ExclusiveStartKey=response['LastEvaluatedKey']
)
items.extend(response['Items'])

# Sort the items by time_str to get them in chronological order
sorted_items = sorted(items, key=lambda x: x['time_str'])

return sorted_items

except ClientError as err:
logger.error(
"Couldn't get history Here's why: %s: %s",
err.response["Error"]["Code"],
err.response["Error"]["Message"],
)
return []
21 changes: 19 additions & 2 deletions application/nlq/data_access/opensearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,13 +162,30 @@ def add_sample(self, index_name, profile_name, question, answer, embedding):
success, failed = put_bulk_in_opensearch([record], self.opensearch_client)
return success == 1

def add_entity_sample(self, index_name, profile_name, entity, comment, embedding):
def add_entity_sample(self, index_name, profile_name, entity, comment, embedding, entity_type="", entity_table_info=[]):
entity_count = len(entity_table_info)
comment_value = []
item_comment_format = "{entity} is located in table {table_name}, column {column_name}, the dimension value is {value}."
if entity_type == "dimension":
if entity_count > 0:
for item in entity_table_info:
table_name = item["table_name"]
column_name = item["column_name"]
value = item["value"]
comment_format = item_comment_format.format(entity=entity, table_name=table_name,
column_name=column_name, value=value)
comment_value.append(comment_format)
comment = ";".join(comment_value)

record = {
'_index': index_name,
'entity': entity,
'comment': comment,
'profile': profile_name,
'vector_field': embedding
'vector_field': embedding,
'entity_type': entity_type,
'entity_count': entity_count,
'entity_table_info': entity_table_info
}

success, failed = put_bulk_in_opensearch([record], self.opensearch_client)
Expand Down
Loading