Skip to content

Commit f76685a

Browse files
Add a new embedding MosecEmbedding (#182)
* Add a new embedding MosecEmbedding. Signed-off-by: Jincheng Miao <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Jincheng Miao <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 61ead43 commit f76685a

File tree

11 files changed

+434
-0
lines changed

11 files changed

+434
-0
lines changed
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# build Mosec endpoint docker image
2+
3+
```
4+
docker build --build-arg http_proxy=$http_proxy --build-arg https_proxy=$https_proxy -t langchain-mosec:latest -f comps/embeddings/langchain-mosec/mosec-docker/Dockerfile .
5+
```
6+
7+
# build embedding microservice docker image
8+
9+
```
10+
docker build --build-arg http_proxy=$http_proxy --build-arg https_proxy=$https_proxy -t opea/embedding-langchain-mosec:latest -f comps/embeddings/langchain-mosec/docker/Dockerfile .
11+
```
12+
13+
# launch Mosec endpoint docker container
14+
15+
```
16+
docker run -d --name="embedding-langchain-mosec-endpoint" -p 6001:8000 langchain-mosec:latest
17+
```
18+
19+
# launch embedding microservice docker container
20+
21+
```
22+
export MOSEC_EMBEDDING_ENDPOINT=http://127.0.0.1:6001
23+
docker run -d --name="embedding-langchain-mosec-server" -e http_proxy=$http_proxy -e https_proxy=$https_proxy -p 6000:6000 --ipc=host -e MOSEC_EMBEDDING_ENDPOINT=$MOSEC_EMBEDDING_ENDPOINT opea/embedding-langchain-mosec:latest
24+
```
25+
26+
# run client test
27+
28+
```
29+
curl localhost:6000/v1/embeddings \
30+
-X POST \
31+
-d '{"text":"Hello, world!"}' \
32+
-H 'Content-Type: application/json'
33+
```
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# Copyright (C) 2024 Intel Corporation
2+
# SPDX-License-Identifier: Apache-2.0
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
2+
# Copyright (C) 2024 Intel Corporation
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
FROM langchain/langchain:latest
6+
7+
RUN apt-get update -y && apt-get install -y --no-install-recommends --fix-missing \
8+
libgl1-mesa-glx \
9+
libjemalloc-dev \
10+
vim
11+
12+
RUN useradd -m -s /bin/bash user && \
13+
mkdir -p /home/user && \
14+
chown -R user /home/user/
15+
16+
USER user
17+
18+
COPY comps /home/user/comps
19+
20+
RUN pip install --no-cache-dir --upgrade pip && \
21+
pip install --no-cache-dir -r /home/user/comps/embeddings/langchain-mosec/requirements.txt
22+
23+
ENV PYTHONPATH=$PYTHONPATH:/home/user
24+
25+
WORKDIR /home/user/comps/embeddings/langchain-mosec
26+
27+
ENTRYPOINT ["python", "embedding_mosec.py"]
28+
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# Copyright (C) 2024 Intel Corporation
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
version: "3.8"
5+
6+
services:
7+
embedding:
8+
image: opea/embedding-langchain-mosec:latest
9+
container_name: embedding-langchain-mosec-server
10+
ports:
11+
- "6000:6000"
12+
ipc: host
13+
environment:
14+
http_proxy: ${http_proxy}
15+
https_proxy: ${https_proxy}
16+
MOSEC_EMBEDDING_ENDPOINT: ${MOSEC_EMBEDDING_ENDPOINT}
17+
LANGCHAIN_API_KEY: ${LANGCHAIN_API_KEY}
18+
restart: unless-stopped
19+
20+
networks:
21+
default:
22+
driver: bridge
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
# Copyright (C) 2024 Intel Corporation
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
import os
5+
import time
6+
from typing import List, Optional
7+
8+
from langchain_community.embeddings import OpenAIEmbeddings
9+
from langsmith import traceable
10+
11+
from comps import (
12+
EmbedDoc768,
13+
ServiceType,
14+
TextDoc,
15+
opea_microservices,
16+
register_microservice,
17+
register_statistics,
18+
statistics_dict,
19+
)
20+
21+
22+
class MosecEmbeddings(OpenAIEmbeddings):
23+
def _get_len_safe_embeddings(
24+
self, texts: List[str], *, engine: str, chunk_size: Optional[int] = None
25+
) -> List[List[float]]:
26+
_chunk_size = chunk_size or self.chunk_size
27+
batched_embeddings: List[List[float]] = []
28+
response = self.client.create(input=texts, **self._invocation_params)
29+
if not isinstance(response, dict):
30+
response = response.model_dump()
31+
batched_embeddings.extend(r["embedding"] for r in response["data"])
32+
33+
_cached_empty_embedding: Optional[List[float]] = None
34+
35+
def empty_embedding() -> List[float]:
36+
nonlocal _cached_empty_embedding
37+
if _cached_empty_embedding is None:
38+
average_embedded = self.client.create(input="", **self._invocation_params)
39+
if not isinstance(average_embedded, dict):
40+
average_embedded = average_embedded.model_dump()
41+
_cached_empty_embedding = average_embedded["data"][0]["embedding"]
42+
return _cached_empty_embedding
43+
44+
return [e if e is not None else empty_embedding() for e in batched_embeddings]
45+
46+
47+
@register_microservice(
48+
name="opea_service@embedding_mosec",
49+
service_type=ServiceType.EMBEDDING,
50+
endpoint="/v1/embeddings",
51+
host="0.0.0.0",
52+
port=6000,
53+
input_datatype=TextDoc,
54+
output_datatype=EmbedDoc768,
55+
)
56+
@traceable(run_type="embedding")
57+
@register_statistics(names=["opea_service@embedding_mosec"])
58+
def embedding(input: TextDoc) -> EmbedDoc768:
59+
start = time.time()
60+
embed_vector = embeddings.embed_query(input.text)
61+
embed_vector = embed_vector[:768] # Keep only the first 768 elements
62+
res = EmbedDoc768(text=input.text, embedding=embed_vector)
63+
statistics_dict["opea_service@embedding_mosec"].append_latency(time.time() - start, None)
64+
return res
65+
66+
67+
if __name__ == "__main__":
68+
MOSEC_EMBEDDING_ENDPOINT = os.environ.get("MOSEC_EMBEDDING_ENDPOINT", "http://127.0.0.1:8080")
69+
os.environ["OPENAI_API_BASE"] = MOSEC_EMBEDDING_ENDPOINT
70+
os.environ["OPENAI_API_KEY"] = "Dummy key"
71+
MODEL_ID = "/root/bge-large-zh"
72+
embeddings = MosecEmbeddings(model=MODEL_ID)
73+
print("Mosec Embedding initialized.")
74+
opea_microservices["opea_service@embedding_mosec"].start()
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# Copyright (C) 2024 Intel Corporation
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
From ubuntu:22.04
5+
ARG DEBIAN_FRONTEND=noninteractive
6+
7+
ENV GLIBC_TUNABLES glibc.cpu.x86_shstk=permissive
8+
9+
COPY comps /root/comps
10+
11+
RUN apt update && apt install -y python3 python3-pip
12+
RUN pip3 install torch==2.2.2 torchvision --index-url https://download.pytorch.org/whl/cpu
13+
RUN pip3 install intel-extension-for-pytorch==2.2.0
14+
RUN pip3 install transformers
15+
RUN pip3 install llmspec mosec
16+
17+
RUN cd /root/ && export HF_ENDPOINT=https://hf-mirror.com && huggingface-cli download --resume-download BAAI/bge-large-zh --local-dir /root/bge-large-zh
18+
19+
ENV EMB_MODEL="/root/bge-large-zh/"
20+
21+
WORKDIR /root/comps/embeddings/langchain-mosec/mosec-docker
22+
23+
CMD ["python3", "server-ipex.py"]
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# Embedding Server
2+
3+
## 1. Introduction
4+
5+
This service has an OpenAI compatible restful API to extract text features.
6+
It is dedicated to be used on Xeon to accelerate embedding model serving.
7+
Currently the local model is BGE-large-zh.
8+
9+
## 2. Quick Start
10+
11+
### 2.1 Build Docker image
12+
13+
```shell
14+
docker build -t embedding:latest .
15+
```
16+
17+
### 2.2 Launch server
18+
19+
```shell
20+
docker run -itd -p 8000:8000 embedding:latest
21+
```
22+
23+
### 2.3 Client test
24+
25+
- Restful API by curl
26+
27+
```shell
28+
curl -X POST http://127.0.0.1:8000/v1/embeddings -H "Content-Type: application/json" -d '{ "model": "/root/bge-large-zh/", "input": "hello world"}'
29+
```
30+
31+
- generate embedding from python
32+
33+
```python
34+
DEFAULT_MODEL = "/root/bge-large-zh/"
35+
SERVICE_URL = "http://127.0.0.1:8000"
36+
INPUT_STR = "Hello world!"
37+
38+
client = Client(api_key="fake", base_url=SERVICE_URL)
39+
emb = client.embeddings.create(
40+
model=DEFAULT_MODEL,
41+
input=INPUT_STR,
42+
)
43+
```
Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
# Copyright (C) 2024 Intel Corporation
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
import base64
5+
import os
6+
from typing import List, Union
7+
8+
import intel_extension_for_pytorch as ipex
9+
import numpy as np
10+
import torch # type: ignore
11+
import torch.nn.functional as F # type: ignore
12+
import transformers # type: ignore
13+
from llmspec import EmbeddingData, EmbeddingRequest, EmbeddingResponse, TokenUsage
14+
from mosec import ClientError, Runtime, Server, Worker
15+
16+
DEFAULT_MODEL = "/root/bge-large-zh/"
17+
18+
19+
class Embedding(Worker):
20+
def __init__(self):
21+
self.model_name = os.environ.get("EMB_MODEL", DEFAULT_MODEL)
22+
self.tokenizer = transformers.AutoTokenizer.from_pretrained(self.model_name)
23+
self.model = transformers.AutoModel.from_pretrained(self.model_name)
24+
self.device = torch.cuda.current_device() if torch.cuda.is_available() else "cpu"
25+
26+
self.model = self.model.to(self.device)
27+
self.model.eval()
28+
29+
# jit trace model
30+
self.model = ipex.optimize(self.model, dtype=torch.bfloat16)
31+
vocab_size = self.model.config.vocab_size
32+
batch_size = 16
33+
seq_length = 512
34+
d = torch.randint(vocab_size, size=[batch_size, seq_length])
35+
t = torch.randint(0, 1, size=[batch_size, seq_length])
36+
m = torch.randint(1, 2, size=[batch_size, seq_length])
37+
self.model = torch.jit.trace(self.model, [d, t, m], check_trace=False, strict=False)
38+
self.model = torch.jit.freeze(self.model)
39+
self.model(d, t, m)
40+
41+
def get_embedding_with_token_count(self, sentences: Union[str, List[Union[str, List[int]]]]):
42+
# Mean Pooling - Take attention mask into account for correct averaging
43+
def mean_pooling(model_output, attention_mask):
44+
# First element of model_output contains all token embeddings
45+
token_embeddings = model_output["last_hidden_state"]
46+
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
47+
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(
48+
input_mask_expanded.sum(1), min=1e-9
49+
)
50+
51+
# Tokenize sentences
52+
# TODO: support `List[List[int]]` input
53+
encoded_input = self.tokenizer(sentences, padding=True, truncation=True, return_tensors="pt")
54+
inputs = encoded_input.to(self.device)
55+
token_count = inputs["attention_mask"].sum(dim=1).tolist()
56+
# Compute token embeddings
57+
model_output = self.model(**inputs)
58+
# Perform pooling
59+
sentence_embeddings = mean_pooling(model_output, inputs["attention_mask"])
60+
# Normalize embeddings
61+
sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)
62+
63+
return token_count, sentence_embeddings
64+
65+
def deserialize(self, data: bytes) -> EmbeddingRequest:
66+
return EmbeddingRequest.from_bytes(data)
67+
68+
def serialize(self, data: EmbeddingResponse) -> bytes:
69+
return data.to_json()
70+
71+
def forward(self, data: List[EmbeddingRequest]) -> List[EmbeddingResponse]:
72+
inputs = []
73+
inputs_lens = []
74+
for d in data:
75+
inputs.extend(d.input if isinstance(d.input, list) else [d.input])
76+
inputs_lens.append(len(d.input) if isinstance(d.input, list) else 1)
77+
token_cnt, embeddings = self.get_embedding_with_token_count(inputs)
78+
79+
embeddings = embeddings.detach()
80+
if self.device != "cpu":
81+
embeddings = embeddings.cpu()
82+
embeddings = embeddings.numpy()
83+
embeddings = [emb.tolist() for emb in embeddings]
84+
85+
resp = []
86+
emb_idx = 0
87+
for lens in inputs_lens:
88+
token_count = sum(token_cnt[emb_idx : emb_idx + lens])
89+
resp.append(
90+
EmbeddingResponse(
91+
data=[
92+
EmbeddingData(embedding=emb, index=i)
93+
for i, emb in enumerate(embeddings[emb_idx : emb_idx + lens])
94+
],
95+
model=self.model_name,
96+
usage=TokenUsage(
97+
prompt_tokens=token_count,
98+
# No completions performed, only embeddings generated.
99+
completion_tokens=0,
100+
total_tokens=token_count,
101+
),
102+
)
103+
)
104+
emb_idx += lens
105+
return resp
106+
107+
108+
if __name__ == "__main__":
109+
MAX_BATCH_SIZE = int(os.environ.get("MAX_BATCH_SIZE", 128))
110+
MAX_WAIT_TIME = int(os.environ.get("MAX_WAIT_TIME", 10))
111+
server = Server()
112+
emb = Runtime(Embedding, max_batch_size=MAX_BATCH_SIZE, max_wait_time=MAX_WAIT_TIME)
113+
server.register_runtime(
114+
{
115+
"/v1/embeddings": [emb],
116+
"/embeddings": [emb],
117+
}
118+
)
119+
server.run()
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# Copyright (C) 2024 Intel Corporation
2+
# SPDX-License-Identifier: Apache-2.0
3+
"""OpenAI embedding client example."""
4+
5+
from openai import Client
6+
7+
DEFAULT_MODEL = "/root/bge-large-zh/"
8+
SERVICE_URL = "http://127.0.0.1:8000"
9+
INPUT_STR = "Hello world!"
10+
11+
client = Client(api_key="fake", base_url=SERVICE_URL)
12+
emb = client.embeddings.create(
13+
model=DEFAULT_MODEL,
14+
input=INPUT_STR,
15+
)
16+
17+
print(len(emb.data)) # type: ignore
18+
print(emb.data[0].embedding) # type: ignore
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
docarray[full]
2+
fastapi
3+
langchain
4+
langchain_community
5+
openai
6+
opentelemetry-api
7+
opentelemetry-exporter-otlp
8+
opentelemetry-sdk
9+
shortuuid

0 commit comments

Comments
 (0)