Skip to content

Commit 34370bb

Browse files
committed
add segamaker env
1 parent 3a3048c commit 34370bb

File tree

2 files changed

+33
-7
lines changed

2 files changed

+33
-7
lines changed

application/nlq/business/vector_store.py

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@
33
import boto3
44
import json
55
from nlq.data_access.opensearch import OpenSearchDao
6-
from utils.env_var import BEDROCK_REGION, AOS_HOST, AOS_PORT, AOS_USER, AOS_PASSWORD, opensearch_info
6+
from utils.env_var import BEDROCK_REGION, AOS_HOST, AOS_PORT, AOS_USER, AOS_PASSWORD, opensearch_info, \
7+
SAGEMAKER_ENDPOINT_EMBEDDING
78
from utils.env_var import bedrock_ak_sk_info
9+
from utils.llm import invoke_model_sagemaker_endpoint
810

911
logger = logging.getLogger(__name__)
1012

@@ -73,7 +75,10 @@ def get_all_agent_cot_samples(cls, profile_name):
7375
@classmethod
7476
def add_sample(cls, profile_name, question, answer):
7577
logger.info(f'add sample question: {question} to profile {profile_name}')
76-
embedding = cls.create_vector_embedding_with_bedrock(question)
78+
if SAGEMAKER_ENDPOINT_EMBEDDING is not None and SAGEMAKER_ENDPOINT_EMBEDDING != "":
79+
embedding = cls.create_vector_embedding_with_sagemaker(question)
80+
else:
81+
embedding = cls.create_vector_embedding_with_bedrock(question)
7782
has_same_sample = cls.search_same_query(profile_name, 1, opensearch_info['sql_index'], embedding)
7883
if has_same_sample:
7984
logger.info(f'delete sample sample entity: {question} to profile {profile_name}')
@@ -83,7 +88,10 @@ def add_sample(cls, profile_name, question, answer):
8388
@classmethod
8489
def add_entity_sample(cls, profile_name, entity, comment):
8590
logger.info(f'add sample entity: {entity} to profile {profile_name}')
86-
embedding = cls.create_vector_embedding_with_bedrock(entity)
91+
if SAGEMAKER_ENDPOINT_EMBEDDING is not None and SAGEMAKER_ENDPOINT_EMBEDDING != "":
92+
embedding = cls.create_vector_embedding_with_sagemaker(entity)
93+
else:
94+
embedding = cls.create_vector_embedding_with_bedrock(entity)
8795
has_same_sample = cls.search_same_query(profile_name, 1, opensearch_info['ner_index'], embedding)
8896
if has_same_sample:
8997
logger.info(f'delete sample sample entity: {entity} to profile {profile_name}')
@@ -93,7 +101,10 @@ def add_entity_sample(cls, profile_name, entity, comment):
93101
@classmethod
94102
def add_agent_cot_sample(cls, profile_name, entity, comment):
95103
logger.info(f'add agent sample query: {entity} to profile {profile_name}')
96-
embedding = cls.create_vector_embedding_with_bedrock(entity)
104+
if SAGEMAKER_ENDPOINT_EMBEDDING is not None and SAGEMAKER_ENDPOINT_EMBEDDING != "":
105+
embedding = cls.create_vector_embedding_with_sagemaker(entity)
106+
else:
107+
embedding = cls.create_vector_embedding_with_bedrock(entity)
97108
has_same_sample = cls.search_same_query(profile_name, 1, opensearch_info['agent_index'], embedding)
98109
if has_same_sample:
99110
logger.info(f'delete agent sample sample query: {entity} to profile {profile_name}')
@@ -118,9 +129,17 @@ def create_vector_embedding_with_bedrock(cls, text):
118129
return embedding
119130

120131
@classmethod
121-
def create_vector_embedding_with_sagemaker(cls):
122-
# to do
123-
pass
132+
def create_vector_embedding_with_sagemaker(cls, text):
133+
try:
134+
model_kwargs = {}
135+
model_kwargs["batch_size"] = 12
136+
model_kwargs["max_length"] = 512
137+
model_kwargs["return_type"] = "dense"
138+
body = json.dumps({"inputs": [text], **model_kwargs})
139+
embeddings = invoke_model_sagemaker_endpoint(SAGEMAKER_ENDPOINT_EMBEDDING, body)
140+
return embeddings
141+
except Exception as e:
142+
logger.error(f'create_vector_embedding_with_sagemaker is error {e}')
124143

125144
@classmethod
126145
def delete_sample(cls, profile_name, doc_id):

application/utils/env_var.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,13 @@
4646

4747
SAGEMAKER_ENDPOINT_EMBEDDING = os.getenv('SAGEMAKER_ENDPOINT_EMBEDDING', '')
4848

49+
SAGEMAKER_ENDPOINT_SQL = os.getenv('SAGEMAKER_ENDPOINT_SQL', '')
50+
51+
SAGEMAKER_EMBEDDING_REGION = os.getenv('SAGEMAKER_EMBEDDING_REGION', '')
52+
53+
SAGEMAKER_SQL_REGION = os.getenv('SAGEMAKER_SQL_REGION', '')
54+
55+
4956
def get_opensearch_parameter():
5057
try:
5158
session = boto3.session.Session()

0 commit comments

Comments
 (0)