Skip to content

Commit 2a53e25

Browse files
srinarayan-srikanthanpre-commit-ci[bot]XuehaoSunletonghanZailiWang
authored
adding embedding support for CLIP based models for VideoRAGQnA example for v0.9 (#538)
* clip embedding support Signed-off-by: srinarayan-srikanthan <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: srinarayan-srikanthan <[email protected]> * test script for embedding Signed-off-by: srinarayan-srikanthan <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: srinarayan-srikanthan <[email protected]> * fix freeze workflow (#522) Signed-off-by: Sun, Xuehao <[email protected]> Signed-off-by: srinarayan-srikanthan <[email protected]> * Fix Dataprep Potential Error in get_file (#540) * fix get file error & refine logs Signed-off-by: letonghan <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: letonghan <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: srinarayan-srikanthan <[email protected]> * Support SearchedDoc input type in LLM for No Rerank Pipeline (#541) Signed-off-by: letonghan <[email protected]> Signed-off-by: srinarayan-srikanthan <[email protected]> * Add dependency for pdf2image and OCR processing (#421) Signed-off-by: srinarayan-srikanthan <[email protected]> * Add local_embedding return 768 length to align with chatqna example (#313) Signed-off-by: Chendi.Xue <[email protected]> Signed-off-by: srinarayan-srikanthan <[email protected]> * add telemetry doc (#536) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: srinarayan-srikanthan <[email protected]> * Add video-llama LVM microservice under lvms (#495) Signed-off-by: BaoHuiling <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: srinarayan-srikanthan <[email protected]> * Fix the data load issue for structured files (#505) Signed-off-by: XuhuiRen <[email protected]> Signed-off-by: srinarayan-srikanthan <[email protected]> * Add finetuning component (#502) Signed-off-by: Xinyu Ye <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: lkk <[email protected]> Co-authored-by: test <[email protected]> Co-authored-by: root <[email protected]> Co-authored-by: Letong Han <[email protected]> Signed-off-by: srinarayan-srikanthan <[email protected]> * add torchvision into requirements (#546) Signed-off-by: chensuyue <[email protected]> Signed-off-by: srinarayan-srikanthan <[email protected]> * Use Gaudi base images from Dockerhub (#526) * Use Gaudi base images from Dockerhub Signed-off-by: Abolfazl Shahbazi <[email protected]> * Fixing the malformed tag Signed-off-by: Abolfazl Shahbazi <[email protected]> * fix another malformed tag Signed-off-by: Abolfazl Shahbazi <[email protected]> --------- Signed-off-by: Abolfazl Shahbazi <[email protected]> Signed-off-by: srinarayan-srikanthan <[email protected]> * Add toxicity detection microservice (#338) * Add toxicity detection microservice Signed-off-by: Qun Gao <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Modification to toxicity plugin PR (#432) * changed microservice to use Service.GUARDRAILS and input/output to TextDoc Signed-off-by: Tyler Wilbers <[email protected]> * simplify dockerfile to use langchain Signed-off-by: Tyler Wilbers <[email protected]> * sort requirements Signed-off-by: Tyler Wilbers <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Tyler Wilbers <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * Minor SPDX header update (#434) Signed-off-by: Abolfazl Shahbazi <[email protected]> * Remove 'langsmith' per code review (#534) Signed-off-by: Abolfazl Shahbazi <[email protected]> * Add toxicity detection microservices with E2E testing Signed-off-by: Qun Gao <[email protected]> --------- Signed-off-by: Qun Gao <[email protected]> Signed-off-by: Tyler Wilbers <[email protected]> Signed-off-by: Abolfazl Shahbazi <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Abolfazl Shahbazi <[email protected]> Co-authored-by: Tyler W <[email protected]> Signed-off-by: srinarayan-srikanthan <[email protected]> * rename script and use 5xxx Signed-off-by: BaoHuiling <[email protected]> Signed-off-by: srinarayan-srikanthan <[email protected]> * add proxy for build Signed-off-by: BaoHuiling <[email protected]> Signed-off-by: srinarayan-srikanthan <[email protected]> * fixed commit issues Signed-off-by: srinarayan-srikanthan <[email protected]> * Fix docarray constraint Signed-off-by: srinarayan-srikanthan <[email protected]> * updated docarray Signed-off-by: srinarayan-srikanthan <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: srinarayan-srikanthan <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * rm telemetry which cause error in mega Signed-off-by: BaoHuiling <[email protected]> * renamed dirs Signed-off-by: srinarayan-srikanthan <[email protected]> * renamed test Signed-off-by: srinarayan-srikanthan <[email protected]> --------- Signed-off-by: srinarayan-srikanthan <[email protected]> Signed-off-by: Sun, Xuehao <[email protected]> Signed-off-by: letonghan <[email protected]> Signed-off-by: Chendi.Xue <[email protected]> Signed-off-by: BaoHuiling <[email protected]> Signed-off-by: XuhuiRen <[email protected]> Signed-off-by: Xinyu Ye <[email protected]> Signed-off-by: chensuyue <[email protected]> Signed-off-by: Abolfazl Shahbazi <[email protected]> Signed-off-by: Qun Gao <[email protected]> Signed-off-by: Tyler Wilbers <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Sun, Xuehao <[email protected]> Co-authored-by: Letong Han <[email protected]> Co-authored-by: Zaili Wang <[email protected]> Co-authored-by: Chendi.Xue <[email protected]> Co-authored-by: Sihan Chen <[email protected]> Co-authored-by: Huiling Bao <[email protected]> Co-authored-by: XuhuiRen <[email protected]> Co-authored-by: XinyuYe-Intel <[email protected]> Co-authored-by: lkk <[email protected]> Co-authored-by: test <[email protected]> Co-authored-by: root <[email protected]> Co-authored-by: chen, suyue <[email protected]> Co-authored-by: Abolfazl Shahbazi <[email protected]> Co-authored-by: qgao007 <[email protected]> Co-authored-by: Tyler W <[email protected]>
1 parent 8325d5d commit 2a53e25

File tree

22 files changed

+409
-0
lines changed

22 files changed

+409
-0
lines changed

comps/cores/proto/docarray.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ class EmbedDoc(BaseDoc):
6464
fetch_k: int = 20
6565
lambda_mult: float = 0.5
6666
score_threshold: float = 0.2
67+
constraints: Optional[Union[Dict[str, Any], None]] = None
6768

6869

6970
class EmbedMultimodalDoc(EmbedDoc):

comps/dataprep/milvus/prepare_doc_milvus.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ def ingest_data_to_milvus(doc_path: DocPath, embedder):
133133
)
134134

135135
content = document_loader(path)
136+
136137
if logflag:
137138
logger.info("[ ingest data ] file content loaded")
138139

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# Multimodal CLIP Embeddings Microservice
2+
3+
The Multimodal CLIP Embedding Microservice is designed to efficiently convert textual strings and images into vectorized embeddings, facilitating seamless integration into various machine learning and data processing workflows. This service utilizes advanced algorithms to generate high-quality embeddings that capture the semantic essence of the input text and images, making it ideal for applications in multi-modal data processing, information retrieval, and similar fields.
4+
5+
Key Features:
6+
7+
**High Performance**: Optimized for quick and reliable conversion of textual data and image inputs into vector embeddings.
8+
9+
**Scalability**: Built to handle high volumes of requests simultaneously, ensuring robust performance even under heavy loads.
10+
11+
**Ease of Integration**: Provides a simple and intuitive API, allowing for straightforward integration into existing systems and workflows.
12+
13+
**Customizable**: Supports configuration and customization to meet specific use case requirements, including different embedding models and preprocessing techniques.
14+
15+
Users are albe to configure and build embedding-related services according to their actual needs.
16+
17+
## 🚀1. Start Microservice with Docker
18+
19+
### 1.1 Build Docker Image
20+
21+
#### Build Langchain Docker
22+
23+
```bash
24+
cd ../../..
25+
docker build -t opea/embedding-multimodal:latest --build-arg https_proxy=$https_proxy --build-arg http_proxy=$http_proxy -f comps/embeddings/multimodal_clip/docker/Dockerfile .
26+
```
27+
28+
### 1.2 Run Docker with Docker Compose
29+
30+
```bash
31+
cd comps/embeddings/multimodal_clip/docker
32+
docker compose -f docker_compose_embedding.yaml up -d
33+
```
34+
35+
## 🚀2. Consume Embedding Service
36+
37+
### 2.1 Check Service Status
38+
39+
```bash
40+
curl http://localhost:6000/v1/health_check\
41+
-X GET \
42+
-H 'Content-Type: application/json'
43+
```
44+
45+
### 2.2 Consume Embedding Service
46+
47+
```bash
48+
curl http://localhost:6000/v1/embeddings \
49+
-X POST -d '{"text":"Sample text"}' \
50+
-H 'Content-Type: application/json'
51+
52+
```
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# Copyright (C) 2024 Intel Corporation
2+
# SPDX-License-Identifier: Apache-2.0
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# Copyright (C) 2024 Intel Corporation
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
FROM langchain/langchain:latest
5+
6+
ARG ARCH="cpu"
7+
8+
RUN apt-get update -y && apt-get install -y --no-install-recommends --fix-missing \
9+
libgl1-mesa-glx \
10+
libjemalloc-dev \
11+
vim
12+
13+
RUN useradd -m -s /bin/bash user && \
14+
mkdir -p /home/user && \
15+
chown -R user /home/user/
16+
17+
USER user
18+
19+
COPY comps /home/user/comps
20+
21+
RUN pip install --no-cache-dir --upgrade pip && \
22+
if [ ${ARCH} = "cpu" ]; then pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu; fi && \
23+
pip install --no-cache-dir -r /home/user/comps/embeddings/multimodal_clip/requirements.txt
24+
25+
ENV PYTHONPATH=$PYTHONPATH:/home/user
26+
27+
WORKDIR /home/user/comps/embeddings/multimodal_clip
28+
29+
ENTRYPOINT ["python", "embedding_multimodal.py"]
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# Copyright (C) 2024 Intel Corporation
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
version: "3.8"
5+
6+
services:
7+
embedding:
8+
image: opea/embedding-multimodal:latest
9+
container_name: embedding-multimodal-server
10+
ports:
11+
- "6000:6000"
12+
ipc: host
13+
environment:
14+
no_proxy: ${no_proxy}
15+
http_proxy: ${http_proxy}
16+
https_proxy: ${https_proxy}
17+
LANGCHAIN_API_KEY: ${LANGCHAIN_API_KEY}
18+
restart: unless-stopped
19+
20+
networks:
21+
default:
22+
driver: bridge
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
# Copyright (C) 2024 Intel Corporation
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
import datetime
5+
import os
6+
import time
7+
from typing import Union
8+
9+
from dateparser.search import search_dates
10+
from embeddings_clip import vCLIP
11+
12+
from comps import (
13+
EmbedDoc,
14+
ServiceType,
15+
TextDoc,
16+
opea_microservices,
17+
register_microservice,
18+
register_statistics,
19+
statistics_dict,
20+
)
21+
22+
23+
def filtler_dates(prompt):
24+
25+
base_date = datetime.datetime.today()
26+
today_date = base_date.date()
27+
dates_found = search_dates(prompt, settings={"PREFER_DATES_FROM": "past", "RELATIVE_BASE": base_date})
28+
29+
if dates_found is not None:
30+
for date_tuple in dates_found:
31+
date_string, parsed_date = date_tuple
32+
date_out = str(parsed_date.date())
33+
time_out = str(parsed_date.time())
34+
hours, minutes, seconds = map(float, time_out.split(":"))
35+
year, month, day_out = map(int, date_out.split("-"))
36+
37+
rounded_seconds = min(round(parsed_date.second + 0.5), 59)
38+
parsed_date = parsed_date.replace(second=rounded_seconds, microsecond=0)
39+
40+
iso_date_time = parsed_date.isoformat()
41+
iso_date_time = str(iso_date_time)
42+
43+
if date_string == "today":
44+
constraints = {"date": ["==", date_out]}
45+
elif date_out != str(today_date) and time_out == "00:00:00": ## exact day (example last friday)
46+
constraints = {"date": ["==", date_out]}
47+
elif (
48+
date_out == str(today_date) and time_out == "00:00:00"
49+
): ## when search_date interprates words as dates output is todays date + time 00:00:00
50+
constraints = {}
51+
else: ## Interval of time:last 48 hours, last 2 days,..
52+
constraints = {"date_time": [">=", {"_date": iso_date_time}]}
53+
return constraints
54+
55+
else:
56+
return {}
57+
58+
59+
@register_microservice(
60+
name="opea_service@embedding_multimodal",
61+
service_type=ServiceType.EMBEDDING,
62+
endpoint="/v1/embeddings",
63+
host="0.0.0.0",
64+
port=6000,
65+
input_datatype=TextDoc,
66+
output_datatype=EmbedDoc,
67+
)
68+
@register_statistics(names=["opea_service@embedding_multimodal"])
69+
def embedding(input: TextDoc) -> EmbedDoc:
70+
start = time.time()
71+
72+
if isinstance(input, TextDoc):
73+
# Handle text input
74+
embed_vector = embeddings.embed_query(input.text).tolist()[0]
75+
res = EmbedDoc(text=input.text, embedding=embed_vector, constraints=filtler_dates(input.text))
76+
77+
else:
78+
raise ValueError("Invalid input type")
79+
80+
statistics_dict["opea_service@embedding_multimodal"].append_latency(time.time() - start, None)
81+
return res
82+
83+
84+
if __name__ == "__main__":
85+
embeddings = vCLIP({"model_name": "openai/clip-vit-base-patch32", "num_frm": 4})
86+
opea_microservices["opea_service@embedding_multimodal"].start()
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# Copyright (C) 2024 Intel Corporation
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
import torch
5+
import torch.nn as nn
6+
from einops import rearrange
7+
from transformers import AutoProcessor, AutoTokenizer, CLIPModel
8+
9+
model_name = "openai/clip-vit-base-patch32"
10+
11+
clip = CLIPModel.from_pretrained(model_name)
12+
processor = AutoProcessor.from_pretrained(model_name)
13+
tokenizer = AutoTokenizer.from_pretrained(model_name)
14+
15+
16+
class vCLIP(nn.Module):
17+
def __init__(self, cfg):
18+
super().__init__()
19+
20+
self.num_frm = cfg["num_frm"]
21+
self.model_name = cfg["model_name"]
22+
23+
def embed_query(self, texts):
24+
"""Input is list of texts."""
25+
text_inputs = tokenizer(texts, padding=True, return_tensors="pt")
26+
text_features = clip.get_text_features(**text_inputs)
27+
return text_features
28+
29+
def get_embedding_length(self):
30+
return len(self.embed_query("sample_text"))
31+
32+
def get_image_embeddings(self, images):
33+
"""Input is list of images."""
34+
image_inputs = processor(images=images, return_tensors="pt")
35+
image_features = clip.get_image_features(**image_inputs)
36+
return image_features
37+
38+
def get_video_embeddings(self, frames_batch):
39+
"""Input is list of list of frames in video."""
40+
self.batch_size = len(frames_batch)
41+
vid_embs = []
42+
for frames in frames_batch:
43+
frame_embeddings = self.get_image_embeddings(frames)
44+
frame_embeddings = rearrange(frame_embeddings, "(b n) d -> b n d", b=len(frames_batch))
45+
# Normalize, mean aggregate and return normalized video_embeddings
46+
frame_embeddings = frame_embeddings / frame_embeddings.norm(dim=-1, keepdim=True)
47+
video_embeddings = frame_embeddings.mean(dim=1)
48+
video_embeddings = video_embeddings / video_embeddings.norm(dim=-1, keepdim=True)
49+
vid_embs.append(video_embeddings)
50+
return torch.cat(vid_embs, dim=0)
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
dateparser
2+
docarray[full]
3+
einops
4+
fastapi
5+
huggingface_hub
6+
langchain
7+
open_clip_torch
8+
opentelemetry-api
9+
opentelemetry-exporter-otlp
10+
opentelemetry-sdk
11+
prometheus-fastapi-instrumentator
12+
sentence_transformers
13+
shortuuid
14+
uvicorn

comps/finetuning/README.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,10 +92,12 @@ Assuming a training file `alpaca_data.json` is uploaded, it can be downloaded in
9292

9393
```bash
9494
# upload a training file
95+
9596
curl http://${your_ip}:8015/v1/finetune/upload_training_files -X POST -H "Content-Type: multipart/form-data" -F "files=@./alpaca_data.json"
9697

9798
# create a finetuning job
9899
curl http://${your_ip}:8015/v1/fine_tuning/jobs \
100+
99101
-X POST \
100102
-H "Content-Type: application/json" \
101103
-d '{
@@ -104,18 +106,22 @@ curl http://${your_ip}:8015/v1/fine_tuning/jobs \
104106
}'
105107

106108
# list finetuning jobs
109+
107110
curl http://${your_ip}:8015/v1/fine_tuning/jobs -X GET
108111

109112
# retrieve one finetuning job
110113
curl http://localhost:8015/v1/fine_tuning/jobs/retrieve -X POST -H "Content-Type: application/json" -d '{
114+
111115
"fine_tuning_job_id": ${fine_tuning_job_id}}'
112116

113117
# cancel one finetuning job
114118

119+
115120
curl http://localhost:8015/v1/fine_tuning/jobs/cancel -X POST -H "Content-Type: application/json" -d '{
116121
"fine_tuning_job_id": ${fine_tuning_job_id}}'
117122

118123
# list checkpoints of a finetuning job
119124
curl http://${your_ip}:8015/v1/finetune/list_checkpoints -X POST -H "Content-Type: application/json" -d '{"fine_tuning_job_id": ${fine_tuning_job_id}}'
120125

126+
121127
```

0 commit comments

Comments
 (0)