33import boto3
44import json
55from nlq .data_access .opensearch import OpenSearchDao
6- from utils .env_var import BEDROCK_REGION , AOS_HOST , AOS_PORT , AOS_USER , AOS_PASSWORD , opensearch_info
6+ from utils .env_var import BEDROCK_REGION , AOS_HOST , AOS_PORT , AOS_USER , AOS_PASSWORD , opensearch_info , \
7+ SAGEMAKER_ENDPOINT_EMBEDDING
78from utils .env_var import bedrock_ak_sk_info
9+ from utils .llm import invoke_model_sagemaker_endpoint
810
911logger = logging .getLogger (__name__ )
1012
@@ -73,7 +75,10 @@ def get_all_agent_cot_samples(cls, profile_name):
7375 @classmethod
7476 def add_sample (cls , profile_name , question , answer ):
7577 logger .info (f'add sample question: { question } to profile { profile_name } ' )
76- embedding = cls .create_vector_embedding_with_bedrock (question )
78+ if SAGEMAKER_ENDPOINT_EMBEDDING is not None and SAGEMAKER_ENDPOINT_EMBEDDING != "" :
79+ embedding = cls .create_vector_embedding_with_sagemaker (question )
80+ else :
81+ embedding = cls .create_vector_embedding_with_bedrock (question )
7782 has_same_sample = cls .search_same_query (profile_name , 1 , opensearch_info ['sql_index' ], embedding )
7883 if has_same_sample :
7984 logger .info (f'delete sample sample entity: { question } to profile { profile_name } ' )
@@ -83,7 +88,10 @@ def add_sample(cls, profile_name, question, answer):
8388 @classmethod
8489 def add_entity_sample (cls , profile_name , entity , comment ):
8590 logger .info (f'add sample entity: { entity } to profile { profile_name } ' )
86- embedding = cls .create_vector_embedding_with_bedrock (entity )
91+ if SAGEMAKER_ENDPOINT_EMBEDDING is not None and SAGEMAKER_ENDPOINT_EMBEDDING != "" :
92+ embedding = cls .create_vector_embedding_with_sagemaker (entity )
93+ else :
94+ embedding = cls .create_vector_embedding_with_bedrock (entity )
8795 has_same_sample = cls .search_same_query (profile_name , 1 , opensearch_info ['ner_index' ], embedding )
8896 if has_same_sample :
8997 logger .info (f'delete sample sample entity: { entity } to profile { profile_name } ' )
@@ -93,7 +101,10 @@ def add_entity_sample(cls, profile_name, entity, comment):
93101 @classmethod
94102 def add_agent_cot_sample (cls , profile_name , entity , comment ):
95103 logger .info (f'add agent sample query: { entity } to profile { profile_name } ' )
96- embedding = cls .create_vector_embedding_with_bedrock (entity )
104+ if SAGEMAKER_ENDPOINT_EMBEDDING is not None and SAGEMAKER_ENDPOINT_EMBEDDING != "" :
105+ embedding = cls .create_vector_embedding_with_sagemaker (entity )
106+ else :
107+ embedding = cls .create_vector_embedding_with_bedrock (entity )
97108 has_same_sample = cls .search_same_query (profile_name , 1 , opensearch_info ['agent_index' ], embedding )
98109 if has_same_sample :
99110 logger .info (f'delete agent sample sample query: { entity } to profile { profile_name } ' )
@@ -118,9 +129,17 @@ def create_vector_embedding_with_bedrock(cls, text):
118129 return embedding
119130
120131 @classmethod
121- def create_vector_embedding_with_sagemaker (cls ):
122- # to do
123- pass
132+ def create_vector_embedding_with_sagemaker (cls , text ):
133+ try :
134+ model_kwargs = {}
135+ model_kwargs ["batch_size" ] = 12
136+ model_kwargs ["max_length" ] = 512
137+ model_kwargs ["return_type" ] = "dense"
138+ body = json .dumps ({"inputs" : [text ], ** model_kwargs })
139+ embeddings = invoke_model_sagemaker_endpoint (SAGEMAKER_ENDPOINT_EMBEDDING , body )
140+ return embeddings
141+ except Exception as e :
142+ logger .error (f'create_vector_embedding_with_sagemaker is error { e } ' )
124143
125144 @classmethod
126145 def delete_sample (cls , profile_name , doc_id ):
0 commit comments