Skip to content

Commit 9b7ec8b

Browse files
authored
Merge pull request #229 from aws-samples/v1.5.0_spy
change code for multi chat
2 parents f559171 + 976f428 commit 9b7ec8b

File tree

17 files changed

+570
-66
lines changed

17 files changed

+570
-66
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ The deployment guide here is CDK only. For manual deployment or detailed guide,
66

77
## Introduction
88

9-
A NLQ(Natural Language Query) demo using Amazon Bedrock, Amazon OpenSearch with RAG technique.
9+
A Generative BI demo using Amazon Bedrock, Amazon OpenSearch with RAG technique.
1010

1111
![Screenshot](./assets/aws_architecture.png)
1212
*Reference Architecture on AWS*

application/api/main.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import logging
55
from nlq.business.profile import ProfileManagement
66
from .enum import ContentEnum
7-
from .schemas import Question, Answer, Option, CustomQuestion, FeedBackInput
7+
from .schemas import Question, Answer, Option, CustomQuestion, FeedBackInput, HistoryRequest
88
from . import service
99
from nlq.business.nlq_chain import NLQChain
1010
from dotenv import load_dotenv
@@ -38,6 +38,13 @@ def ask(question: Question):
3838
return service.ask(question)
3939

4040

41+
@router.post("/get_history_by_user_profile")
42+
def get_history_by_user_profile(history_request : HistoryRequest):
43+
user_id = history_request.user_id
44+
profile_name = history_request.profile_name
45+
return service.get_history_by_user_profile(user_id, profile_name)
46+
47+
4148
@router.post("/user_feedback")
4249
def user_feedback(input_data: FeedBackInput):
4350
feedback_type = input_data.feedback_type

application/api/schemas.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any
1+
from typing import Any, Union
22
from pydantic import BaseModel
33

44

@@ -17,7 +17,7 @@ class Question(BaseModel):
1717
top_p: float = 0.9
1818
max_tokens: int = 2048
1919
temperature: float = 0.01
20-
context_window: int = 3
20+
context_window: int = 5
2121
session_id: str = "-1"
2222
user_id: str = "admin"
2323

@@ -28,6 +28,11 @@ class Example(BaseModel):
2828
answer: str
2929

3030

31+
class HistoryRequest(BaseModel):
32+
user_id: str
33+
profile_name: str
34+
35+
3136
class QueryEntity(BaseModel):
3237
query: str
3338
sql: str
@@ -80,10 +85,30 @@ class AgentSearchResult(BaseModel):
8085
agent_summary: str
8186

8287

88+
class AskReplayResult(BaseModel):
89+
query_rewrite: str
90+
91+
8392
class Answer(BaseModel):
8493
query: str
94+
query_rewrite: str = ""
8595
query_intent: str
8696
knowledge_search_result: KnowledgeSearchResult
8797
sql_search_result: SQLSearchResult
8898
agent_search_result: AgentSearchResult
99+
ask_rewrite_result: AskReplayResult
89100
suggested_question: list[str]
101+
102+
103+
class Message(BaseModel):
104+
type: str
105+
content: Union[str, Answer]
106+
107+
108+
class HistoryMessage(BaseModel):
109+
session_id: str
110+
messages: list[Message]
111+
112+
113+
class ChatHistory(BaseModel):
114+
messages: list[HistoryMessage]

application/api/service.py

Lines changed: 122 additions & 30 deletions
Large diffs are not rendered by default.

application/main.py

Lines changed: 37 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,68 @@
1+
import json
2+
import logging
3+
14
from fastapi import FastAPI, status
25
from fastapi.staticfiles import StaticFiles
36
from fastapi.responses import RedirectResponse
47
from api.exception_handler import biz_exception
58
from api.main import router
69
from fastapi.middleware.cors import CORSMiddleware
710
from api import service
8-
from api.schemas import Option
11+
from api.schemas import Option, Message
12+
from nlq.business.log_store import LogManagement
13+
from utils.tool import set_share_data, get_share_data
914

15+
MAX_CHAT_WINDOW_SIZE = 10 * 2
1016
app = FastAPI(title='GenBI')
1117

12-
# 配置CORS中间件
1318
app.add_middleware(
1419
CORSMiddleware,
15-
allow_origins=['*'], # 允许所有源访问,可以根据需求进行修改
16-
allow_credentials=True, # 允许发送凭据(如Cookie)
17-
allow_methods=['*'], # 允许所有HTTP方法
18-
allow_headers=['*'], # 允许所有请求头
20+
allow_origins=['*'],
21+
allow_credentials=True,
22+
allow_methods=['*'],
23+
allow_headers=['*'],
1924
)
2025

2126
# Global exception capture
2227
biz_exception(app)
2328
app.mount("/static", StaticFiles(directory="static"), name="static")
2429
app.include_router(router)
2530

31+
2632
# changed from "/" to "/test" to avoid health check fails in ECS
2733
@app.get("/test", status_code=status.HTTP_302_FOUND)
2834
def index():
2935
return RedirectResponse("static/WebSocket.html")
3036

37+
3138
# health check
3239
@app.get("/")
3340
def health():
3441
return {"status": "ok"}
3542

43+
3644
@app.get("/option", response_model=Option)
3745
def option():
38-
return service.get_option()
46+
return service.get_option()
47+
48+
49+
@app.on_event("startup")
50+
def set_history_in_share():
51+
logging.info("Setting history in share data")
52+
history_list = LogManagement.get_all_history()
53+
chat_history_session = {}
54+
for item in history_list:
55+
session_id = item['session_id']
56+
if session_id not in chat_history_session:
57+
chat_history_session[session_id] = []
58+
log_info = item['log_info']
59+
query = item['query']
60+
human_message = Message(type="human", content=query)
61+
bot_message = Message(type="AI", content=json.loads(log_info))
62+
chat_history_session[session_id].append(human_message)
63+
chat_history_session[session_id].append(bot_message)
64+
65+
for key, value in chat_history_session.items():
66+
value = value[-MAX_CHAT_WINDOW_SIZE:]
67+
set_share_data(key, value)
68+
logging.info("Setting history in share data done")

application/nlq/business/log_store.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,18 @@ class LogManagement:
99
query_log_dao = DynamoQueryLogDao()
1010

1111
@classmethod
12-
def add_log_to_database(cls, log_id, user_id, session_id, profile_name, sql, query, intent, log_info, time_str):
12+
def add_log_to_database(cls, log_id, user_id, session_id, profile_name, sql, query, intent, log_info, time_str,
13+
log_type="normal_log", ):
1314
cls.query_log_dao.add_log(log_id=log_id, profile_name=profile_name, user_id=user_id, session_id=session_id,
14-
sql=sql, query=query, intent=intent, log_info=log_info, time_str=time_str)
15+
sql=sql, query=query, intent=intent, log_info=log_info, log_type=log_type,
16+
time_str=time_str)
17+
18+
@classmethod
19+
def get_history(cls, user_id, profile_name):
20+
history_list = cls.query_log_dao.get_history_by_user_profile(user_id, profile_name)
21+
return history_list
22+
23+
@classmethod
24+
def get_all_history(cls):
25+
history_list = cls.query_log_dao.get_all_history()
26+
return history_list

application/nlq/business/vector_store.py

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -86,16 +86,30 @@ def add_sample(cls, profile_name, question, answer):
8686
logger.info('Sample added')
8787

8888
@classmethod
89-
def add_entity_sample(cls, profile_name, entity, comment):
89+
def add_entity_sample(cls, profile_name, entity, comment, entity_type="metrics", entity_info_dict=None):
90+
if entity_type == "metrics" or entity_info_dict is None:
91+
entity_table_info = []
92+
else:
93+
entity_table_info = [entity_info_dict]
9094
logger.info(f'add sample entity: {entity} to profile {profile_name}')
9195
if SAGEMAKER_ENDPOINT_EMBEDDING is not None and SAGEMAKER_ENDPOINT_EMBEDDING != "":
9296
embedding = cls.create_vector_embedding_with_sagemaker(entity)
9397
else:
9498
embedding = cls.create_vector_embedding_with_bedrock(entity)
95-
has_same_sample = cls.search_same_query(profile_name, 1, opensearch_info['ner_index'], embedding)
96-
if has_same_sample:
97-
logger.info(f'delete sample sample entity: {entity} to profile {profile_name}')
98-
if cls.opensearch_dao.add_entity_sample(opensearch_info['ner_index'], profile_name, entity, comment, embedding):
99+
if entity_type == "metrics":
100+
has_same_sample = cls.search_same_query(profile_name, 1, opensearch_info['ner_index'], embedding)
101+
if has_same_sample:
102+
logger.info(f'delete sample sample entity: {entity} to profile {profile_name}')
103+
else:
104+
same_dimension_value = cls.search_same_dimension_entity(profile_name, 1, opensearch_info['ner_index'],
105+
embedding)
106+
if len(same_dimension_value) > 0:
107+
for item in same_dimension_value:
108+
entity_table_info.append(item)
109+
logger.info("entity_table_info: " + str(entity_table_info))
110+
has_same_sample = cls.search_same_query(profile_name, 1, opensearch_info['ner_index'], embedding)
111+
if cls.opensearch_dao.add_entity_sample(opensearch_info['ner_index'], profile_name, entity, comment, embedding,
112+
entity_type, entity_table_info):
99113
logger.info('Sample added')
100114

101115
@classmethod
@@ -108,7 +122,8 @@ def add_agent_cot_sample(cls, profile_name, entity, comment):
108122
has_same_sample = cls.search_same_query(profile_name, 1, opensearch_info['agent_index'], embedding)
109123
if has_same_sample:
110124
logger.info(f'delete agent sample sample query: {entity} to profile {profile_name}')
111-
if cls.opensearch_dao.add_agent_cot_sample(opensearch_info['agent_index'], profile_name, entity, comment, embedding):
125+
if cls.opensearch_dao.add_agent_cot_sample(opensearch_info['agent_index'], profile_name, entity, comment,
126+
embedding):
112127
logger.info('Sample added')
113128

114129
@classmethod
@@ -192,3 +207,15 @@ def search_same_query(cls, profile_name, top_k, index_name, embedding):
192207
else:
193208
return False
194209
return False
210+
211+
@classmethod
212+
def search_same_dimension_entity(cls, profile_name, top_k, index_name, embedding):
213+
search_res = cls.search_sample_with_embedding(profile_name, top_k, index_name, embedding)
214+
same_dimension_value = []
215+
if len(search_res) > 0:
216+
similarity_sample = search_res[0]
217+
similarity_score = similarity_sample["_score"]
218+
if similarity_score == 1.0:
219+
if index_name == opensearch_info['ner_index']:
220+
same_dimension_value = similarity_sample["_source"]["entity_table_info"]
221+
return same_dimension_value

application/nlq/data_access/dynamo_query_log.py

Lines changed: 68 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import boto3
55
from botocore.exceptions import ClientError
6+
from boto3.dynamodb.conditions import Key
67

78
logger = logging.getLogger(__name__)
89

@@ -12,7 +13,7 @@
1213

1314

1415
class DynamoQueryLogEntity:
15-
def __init__(self, log_id, profile_name, user_id, session_id, sql, query, intent, log_info, time_str):
16+
def __init__(self, log_id, profile_name, user_id, session_id, sql, query, intent, log_info, log_type, time_str):
1617
self.log_id = log_id
1718
self.profile_name = profile_name
1819
self.user_id = user_id
@@ -21,6 +22,7 @@ def __init__(self, log_id, profile_name, user_id, session_id, sql, query, intent
2122
self.query = query
2223
self.intent = intent
2324
self.log_info = log_info
25+
self.log_type = log_type
2426
self.time_str = time_str
2527

2628
def to_dict(self):
@@ -34,6 +36,7 @@ def to_dict(self):
3436
'query': self.query,
3537
'intent': self.intent,
3638
'log_info': self.log_info,
39+
'log_type': self.log_type,
3740
'time_str': self.time_str
3841
}
3942

@@ -113,6 +116,68 @@ def add(self, entity):
113116
def update(self, entity):
114117
self.table.put_item(Item=entity.to_dict())
115118

116-
def add_log(self, log_id, profile_name, user_id, session_id, sql, query, intent, log_info, time_str):
117-
entity = DynamoQueryLogEntity(log_id, profile_name, user_id, session_id, sql, query, intent, log_info, time_str)
119+
def add_log(self, log_id, profile_name, user_id, session_id, sql, query, intent, log_info, log_type, time_str):
120+
entity = DynamoQueryLogEntity(log_id, profile_name, user_id, session_id, sql, query, intent, log_info, log_type, time_str)
118121
self.add(entity)
122+
123+
def get_history_by_user_profile(self, user_id, profile_name):
124+
try:
125+
# First, we need to scan the table to find all items for the user and profile
126+
response = self.table.scan(
127+
FilterExpression=Key('user_id').eq(user_id) & Key('profile_name').eq(profile_name) & Key('log_type').eq("chat_history")
128+
)
129+
130+
items = response['Items']
131+
132+
# DynamoDB might not return all items in a single response if the data set is large
133+
while 'LastEvaluatedKey' in response:
134+
response = self.table.scan(
135+
FilterExpression=Key('user_id').eq(user_id) & Key('profile_name').eq(profile_name) & Key('log_type').eq("chat_history"),
136+
ExclusiveStartKey=response['LastEvaluatedKey']
137+
)
138+
items.extend(response['Items'])
139+
140+
# Sort the items by time_str to get them in chronological order
141+
sorted_items = sorted(items, key=lambda x: x['time_str'])
142+
143+
return sorted_items
144+
145+
except ClientError as err:
146+
logger.error(
147+
"Couldn't get history for user %s and profile %s. Here's why: %s: %s",
148+
user_id,
149+
profile_name,
150+
err.response["Error"]["Code"],
151+
err.response["Error"]["Message"],
152+
)
153+
return []
154+
155+
def get_all_history(self):
156+
try:
157+
# First, we need to scan the table to find all items for the user and profile
158+
response = self.table.scan(
159+
FilterExpression=Key('log_type').eq("chat_history")
160+
)
161+
162+
items = response['Items']
163+
164+
# DynamoDB might not return all items in a single response if the data set is large
165+
while 'LastEvaluatedKey' in response:
166+
response = self.table.scan(
167+
FilterExpression=Key('log_type').eq("chat_history"),
168+
ExclusiveStartKey=response['LastEvaluatedKey']
169+
)
170+
items.extend(response['Items'])
171+
172+
# Sort the items by time_str to get them in chronological order
173+
sorted_items = sorted(items, key=lambda x: x['time_str'])
174+
175+
return sorted_items
176+
177+
except ClientError as err:
178+
logger.error(
179+
"Couldn't get history Here's why: %s: %s",
180+
err.response["Error"]["Code"],
181+
err.response["Error"]["Message"],
182+
)
183+
return []

application/nlq/data_access/opensearch.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -162,13 +162,30 @@ def add_sample(self, index_name, profile_name, question, answer, embedding):
162162
success, failed = put_bulk_in_opensearch([record], self.opensearch_client)
163163
return success == 1
164164

165-
def add_entity_sample(self, index_name, profile_name, entity, comment, embedding):
165+
def add_entity_sample(self, index_name, profile_name, entity, comment, embedding, entity_type="", entity_table_info=[]):
166+
entity_count = len(entity_table_info)
167+
comment_value = []
168+
item_comment_format = "{entity} is located in table {table_name}, column {column_name}, the dimension value is {value}."
169+
if entity_type == "dimension":
170+
if entity_count > 0:
171+
for item in entity_table_info:
172+
table_name = item["table_name"]
173+
column_name = item["column_name"]
174+
value = item["value"]
175+
comment_format = item_comment_format.format(entity=entity, table_name=table_name,
176+
column_name=column_name, value=value)
177+
comment_value.append(comment_format)
178+
comment = ";".join(comment_value)
179+
166180
record = {
167181
'_index': index_name,
168182
'entity': entity,
169183
'comment': comment,
170184
'profile': profile_name,
171-
'vector_field': embedding
185+
'vector_field': embedding,
186+
'entity_type': entity_type,
187+
'entity_count': entity_count,
188+
'entity_table_info': entity_table_info
172189
}
173190

174191
success, failed = put_bulk_in_opensearch([record], self.opensearch_client)

0 commit comments

Comments
 (0)