Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 62 additions & 7 deletions rag/llm/embedding_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ def encode(self, texts: list[str|bytes], task="retrieval.passage"):
data = {"model": self.model_name, "input": input[i : i + batch_size]}
if "v4" in self.model_name:
data["return_multivector"] = True

if "v3" in self.model_name or "v4" in self.model_name:
data['task'] = task
data['truncate'] = True
Expand All @@ -391,7 +391,7 @@ def encode(self, texts: list[str|bytes], task="retrieval.passage"):
if data.get("return_multivector", False): # v4
token_embs = np.asarray(d['embeddings'], dtype=np.float32)
chunk_emb = token_embs.mean(axis=0)

else:
# v2/v3
chunk_emb = np.asarray(d['embedding'], dtype=np.float32)
Expand Down Expand Up @@ -481,7 +481,7 @@ def __init__(self, key, model_name, **kwargs):
self.model_name = model_name
self.is_amazon = self.model_name.split(".")[0] == "amazon"
self.is_cohere = self.model_name.split(".")[0] == "cohere"

if mode == "access_key_secret":
self.bedrock_ak = key.get("bedrock_ak")
self.bedrock_sk = key.get("bedrock_sk")
Expand Down Expand Up @@ -885,15 +885,70 @@ def encode_queries(self, text: str):
raise Exception(f"Error: {response.status_code} - {response.text}")


class VolcEngineEmbed(OpenAIEmbed):
class VolcEngineEmbed(Base):
_FACTORY_NAME = "VolcEngine"

def __init__(self, key, model_name, base_url="https://ark.cn-beijing.volces.com/api/v3"):
if not base_url:
base_url = "https://ark.cn-beijing.volces.com/api/v3"
ark_api_key = json.loads(key).get("ark_api_key", "")
model_name = json.loads(key).get("ep_id", "") + json.loads(key).get("endpoint_id", "")
super().__init__(ark_api_key, model_name, base_url)
self.base_url = base_url

cfg = json.loads(key)
self.ark_api_key = cfg.get("ark_api_key", "")
self.model_name = model_name

@staticmethod
def _extract_embedding(result: dict) -> list[float]:
if not isinstance(result, dict):
raise TypeError(f"Unexpected response type: {type(result)}")

data = result.get("data")
if data is None:
raise KeyError("Missing 'data' in response")

if isinstance(data, list):
if not data:
raise ValueError("Empty 'data' in response")
item = data[0]
elif isinstance(data, dict):
item = data
else:
raise TypeError(f"Unexpected 'data' type: {type(data)}")

if not isinstance(item, dict):
raise TypeError("Unexpected item shape in 'data'")
if "embedding" not in item:
raise KeyError("Missing 'embedding' in response item")
return item["embedding"]

def _encode_texts(self, texts: list[str]):
from common.http_client import sync_request

url = f"{self.base_url}/embeddings/multimodal"
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.ark_api_key}"}

ress: list[list[float]] = []
total_tokens = 0
for text in texts:
request_body = {"model": self.model_name, "input": [{"type": "text", "text": text}]}
response = sync_request(method="POST", url=url, headers=headers, json=request_body, timeout=60)
if response.status_code != 200:
raise Exception(f"Error: {response.status_code} - {response.text}")
result = response.json()
try:
ress.append(self._extract_embedding(result))
total_tokens += total_token_count_from_response(result)
except Exception as _e:
log_exception(_e)

return np.array(ress), total_tokens

def encode(self, texts: list):
return self._encode_texts(texts)

def encode_queries(self, text: str):
embeddings, tokens = self._encode_texts([text])
return embeddings[0], tokens


class GPUStackEmbed(OpenAIEmbed):
Expand Down