Skip to content

Commit 99be1bd

Browse files
tileintelsiddhivelankar23sjagtap1803pre-commit-ci[bot]Spycsh
authored
Add Megaservice support for MMRAG - MultimodalRAGQnAWithVideos usecase (#626)
* updates Signed-off-by: Tiep Le <[email protected]> * cosmetic Signed-off-by: siddhivelankar23 <[email protected]> * update redis schema Signed-off-by: siddhivelankar23 <[email protected]> * update multimodal config and docker compose retriever Signed-off-by: siddhivelankar23 <[email protected]> * update requirements Signed-off-by: siddhivelankar23 <[email protected]> * update retriever redis Signed-off-by: siddhivelankar23 <[email protected]> * multimodal retriever implementation Signed-off-by: siddhivelankar23 <[email protected]> * test for multimodal retriever Signed-off-by: siddhivelankar23 <[email protected]> * include prompt preparation for multimodal rag on videos application Signed-off-by: sjagtap1803 <[email protected]> * fix template Signed-off-by: sjagtap1803 <[email protected]> * add test for llava for mm_rag_on_videos Signed-off-by: sjagtap1803 <[email protected]> * update test Signed-off-by: sjagtap1803 <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * first update on gateaway Signed-off-by: sjagtap1803 <[email protected]> * fix index not found Signed-off-by: sjagtap1803 <[email protected]> * add LVMSearchedMultimodalDoc Signed-off-by: sjagtap1803 <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * implement gateway for MultimodalRagQnAWithVideos Signed-off-by: siddhivelankar23 <[email protected]> * remove INDEX_SCHEMA Signed-off-by: siddhivelankar23 <[email protected]> * update MultimodalRAGQnAWithVideosGateway with 2 megaservices Signed-off-by: sjagtap1803 <[email protected]> * revise folder structure to comps/retrievers/langchain/redis_multimodal Signed-off-by: siddhivelankar23 <[email protected]> * update test Signed-off-by: siddhivelankar23 <[email protected]> * add unittest for multimodalrag_qna_with_videos_gateway Signed-off-by: siddhivelankar23 <[email protected]> * update test mmrag qna with videos Signed-off-by: Tiep Le <[email protected]> * change port of redis to resolve CI test Signed-off-by: siddhivelankar23 <[email protected]> * update test Signed-off-by: siddhivelankar23 <[email protected]> * update lvms test Signed-off-by: siddhivelankar23 <[email protected]> * update test Signed-off-by: siddhivelankar23 <[email protected]> * update test Signed-off-by: siddhivelankar23 <[email protected]> * update test for multimodal rag qna with videos gateway Signed-off-by: siddhivelankar23 <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add more test to increase coverage Signed-off-by: Tiep Le <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * cosmetic Signed-off-by: Tiep Le <[email protected]> * add more test Signed-off-by: Tiep Le <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update name of gateway Signed-off-by: Tiep Le <[email protected]> --------- Signed-off-by: Tiep Le <[email protected]> Signed-off-by: siddhivelankar23 <[email protected]> Signed-off-by: sjagtap1803 <[email protected]> Co-authored-by: siddhivelankar23 <[email protected]> Co-authored-by: sjagtap1803 <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Sihan Chen <[email protected]>
1 parent 2705e93 commit 99be1bd

File tree

4 files changed

+374
-0
lines changed

4 files changed

+374
-0
lines changed

comps/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
RetrievalToolGateway,
4747
FaqGenGateway,
4848
VisualQnAGateway,
49+
MultimodalRAGWithVideosGateway,
4950
)
5051

5152
# Telemetry

comps/cores/mega/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ class MegaServiceEndpoint(Enum):
4242
CODE_TRANS = "/v1/codetrans"
4343
DOC_SUMMARY = "/v1/docsum"
4444
SEARCH_QNA = "/v1/searchqna"
45+
MULTIMODAL_RAG_WITH_VIDEOS = "/v1/mmragvideoqna"
4546
TRANSLATION = "/v1/translation"
4647
RETRIEVALTOOL = "/v1/retrievaltool"
4748
FAQ_GEN = "/v1/faqgen"

comps/cores/mega/gateway.py

Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ def _handle_message(self, messages):
108108
messages_dict[msg_role] = message["content"]
109109
else:
110110
raise ValueError(f"Unknown role: {msg_role}")
111+
111112
if system_prompt:
112113
prompt = system_prompt + "\n"
113114
for role, message in messages_dict.items():
@@ -582,3 +583,159 @@ def parser_input(data, TypeClass, key):
582583
response = result_dict[last_node]
583584
print("response is ", response)
584585
return response
586+
587+
588+
class MultimodalRAGWithVideosGateway(Gateway):
589+
def __init__(self, multimodal_rag_megaservice, lvm_megaservice, host="0.0.0.0", port=9999):
590+
self.lvm_megaservice = lvm_megaservice
591+
super().__init__(
592+
multimodal_rag_megaservice,
593+
host,
594+
port,
595+
str(MegaServiceEndpoint.MULTIMODAL_RAG_WITH_VIDEOS),
596+
ChatCompletionRequest,
597+
ChatCompletionResponse,
598+
)
599+
600+
# this overrides _handle_message method of Gateway
601+
def _handle_message(self, messages):
602+
images = []
603+
messages_dicts = []
604+
if isinstance(messages, str):
605+
prompt = messages
606+
else:
607+
messages_dict = {}
608+
system_prompt = ""
609+
prompt = ""
610+
for message in messages:
611+
msg_role = message["role"]
612+
messages_dict = {}
613+
if msg_role == "system":
614+
system_prompt = message["content"]
615+
elif msg_role == "user":
616+
if type(message["content"]) == list:
617+
text = ""
618+
text_list = [item["text"] for item in message["content"] if item["type"] == "text"]
619+
text += "\n".join(text_list)
620+
image_list = [
621+
item["image_url"]["url"] for item in message["content"] if item["type"] == "image_url"
622+
]
623+
if image_list:
624+
messages_dict[msg_role] = (text, image_list)
625+
else:
626+
messages_dict[msg_role] = text
627+
else:
628+
messages_dict[msg_role] = message["content"]
629+
messages_dicts.append(messages_dict)
630+
elif msg_role == "assistant":
631+
messages_dict[msg_role] = message["content"]
632+
messages_dicts.append(messages_dict)
633+
else:
634+
raise ValueError(f"Unknown role: {msg_role}")
635+
636+
if system_prompt:
637+
prompt = system_prompt + "\n"
638+
for messages_dict in messages_dicts:
639+
for i, (role, message) in enumerate(messages_dict.items()):
640+
if isinstance(message, tuple):
641+
text, image_list = message
642+
if i == 0:
643+
# do not add role for the very first message.
644+
# this will be added by llava_server
645+
if text:
646+
prompt += text + "\n"
647+
else:
648+
if text:
649+
prompt += role.upper() + ": " + text + "\n"
650+
else:
651+
prompt += role.upper() + ":"
652+
for img in image_list:
653+
# URL
654+
if img.startswith("http://") or img.startswith("https://"):
655+
response = requests.get(img)
656+
image = Image.open(BytesIO(response.content)).convert("RGBA")
657+
image_bytes = BytesIO()
658+
image.save(image_bytes, format="PNG")
659+
img_b64_str = base64.b64encode(image_bytes.getvalue()).decode()
660+
# Local Path
661+
elif os.path.exists(img):
662+
image = Image.open(img).convert("RGBA")
663+
image_bytes = BytesIO()
664+
image.save(image_bytes, format="PNG")
665+
img_b64_str = base64.b64encode(image_bytes.getvalue()).decode()
666+
# Bytes
667+
else:
668+
img_b64_str = img
669+
670+
images.append(img_b64_str)
671+
else:
672+
if i == 0:
673+
# do not add role for the very first message.
674+
# this will be added by llava_server
675+
if message:
676+
prompt += role.upper() + ": " + message + "\n"
677+
else:
678+
if message:
679+
prompt += role.upper() + ": " + message + "\n"
680+
else:
681+
prompt += role.upper() + ":"
682+
if images:
683+
return prompt, images
684+
else:
685+
return prompt
686+
687+
async def handle_request(self, request: Request):
688+
data = await request.json()
689+
stream_opt = bool(data.get("stream", False))
690+
if stream_opt:
691+
print("[ MultimodalRAGWithVideosGateway ] stream=True not used, this has not support streaming yet!")
692+
stream_opt = False
693+
chat_request = ChatCompletionRequest.model_validate(data)
694+
# Multimodal RAG QnA With Videos has not yet accepts image as input during QnA.
695+
prompt_and_image = self._handle_message(chat_request.messages)
696+
if isinstance(prompt_and_image, tuple):
697+
# print(f"This request include image, thus it is a follow-up query. Using lvm megaservice")
698+
prompt, images = prompt_and_image
699+
cur_megaservice = self.lvm_megaservice
700+
initial_inputs = {"prompt": prompt, "image": images[0]}
701+
else:
702+
# print(f"This is the first query, requiring multimodal retrieval. Using multimodal rag megaservice")
703+
prompt = prompt_and_image
704+
cur_megaservice = self.megaservice
705+
initial_inputs = {"text": prompt}
706+
707+
parameters = LLMParams(
708+
max_new_tokens=chat_request.max_tokens if chat_request.max_tokens else 1024,
709+
top_k=chat_request.top_k if chat_request.top_k else 10,
710+
top_p=chat_request.top_p if chat_request.top_p else 0.95,
711+
temperature=chat_request.temperature if chat_request.temperature else 0.01,
712+
repetition_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 1.03,
713+
streaming=stream_opt,
714+
chat_template=chat_request.chat_template if chat_request.chat_template else None,
715+
)
716+
result_dict, runtime_graph = await cur_megaservice.schedule(
717+
initial_inputs=initial_inputs, llm_parameters=parameters
718+
)
719+
for node, response in result_dict.items():
720+
# the last microservice in this megaservice is LVM.
721+
# checking if LVM returns StreamingResponse
722+
# Currently, LVM with LLAVA has not yet supported streaming.
723+
# @TODO: Will need to test this once LVM with LLAVA supports streaming
724+
if (
725+
isinstance(response, StreamingResponse)
726+
and node == runtime_graph.all_leaves()[-1]
727+
and self.megaservice.services[node].service_type == ServiceType.LVM
728+
):
729+
return response
730+
last_node = runtime_graph.all_leaves()[-1]
731+
response = result_dict[last_node]["text"]
732+
choices = []
733+
usage = UsageInfo()
734+
choices.append(
735+
ChatCompletionResponseChoice(
736+
index=0,
737+
message=ChatMessage(role="assistant", content=response),
738+
finish_reason="stop",
739+
)
740+
)
741+
return ChatCompletionResponse(model="multimodalragwithvideos", choices=choices, usage=usage)
Lines changed: 215 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,215 @@
1+
# Copyright (C) 2024 Intel Corporation
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
import json
5+
import unittest
6+
from typing import Union
7+
8+
import requests
9+
from fastapi import Request
10+
11+
from comps import (
12+
EmbedDoc,
13+
EmbedMultimodalDoc,
14+
LVMDoc,
15+
LVMSearchedMultimodalDoc,
16+
MultimodalDoc,
17+
MultimodalRAGWithVideosGateway,
18+
SearchedMultimodalDoc,
19+
ServiceOrchestrator,
20+
TextDoc,
21+
opea_microservices,
22+
register_microservice,
23+
)
24+
25+
26+
@register_microservice(name="mm_embedding", host="0.0.0.0", port=8083, endpoint="/v1/mm_embedding")
27+
async def mm_embedding_add(request: MultimodalDoc) -> EmbedDoc:
28+
req = request.model_dump_json()
29+
req_dict = json.loads(req)
30+
text = req_dict["text"]
31+
res = {}
32+
res["text"] = text
33+
res["embedding"] = [0.12, 0.45]
34+
return res
35+
36+
37+
@register_microservice(name="mm_retriever", host="0.0.0.0", port=8084, endpoint="/v1/mm_retriever")
38+
async def mm_retriever_add(request: EmbedMultimodalDoc) -> SearchedMultimodalDoc:
39+
req = request.model_dump_json()
40+
req_dict = json.loads(req)
41+
text = req_dict["text"]
42+
res = {}
43+
res["retrieved_docs"] = []
44+
res["initial_query"] = text
45+
res["top_n"] = 1
46+
res["metadata"] = [
47+
{
48+
"b64_img_str": "iVBORw0KGgoAAAANSUhEUgAAAAoAAAAKCAYAAACNMs+9AAAAFUlEQVR42mP8/5+hnoEIwDiqkL4KAcT9GO0U4BxoAAAAAElFTkSuQmCC",
49+
"transcript_for_inference": "yellow image",
50+
}
51+
]
52+
res["chat_template"] = "The caption of the image is: '{context}'. {question}"
53+
return res
54+
55+
56+
@register_microservice(name="lvm", host="0.0.0.0", port=8085, endpoint="/v1/lvm")
57+
async def lvm_add(request: Union[LVMDoc, LVMSearchedMultimodalDoc]) -> TextDoc:
58+
req = request.model_dump_json()
59+
req_dict = json.loads(req)
60+
if isinstance(request, LVMSearchedMultimodalDoc):
61+
print("request is the output of multimodal retriever")
62+
text = req_dict["initial_query"]
63+
text += "opea project!"
64+
65+
else:
66+
print("request is from user.")
67+
text = req_dict["prompt"]
68+
text = f"<image>\nUSER: {text}\nASSISTANT:"
69+
70+
res = {}
71+
res["text"] = text
72+
return res
73+
74+
75+
class TestServiceOrchestrator(unittest.IsolatedAsyncioTestCase):
76+
@classmethod
77+
def setUpClass(cls):
78+
cls.mm_embedding = opea_microservices["mm_embedding"]
79+
cls.mm_retriever = opea_microservices["mm_retriever"]
80+
cls.lvm = opea_microservices["lvm"]
81+
cls.mm_embedding.start()
82+
cls.mm_retriever.start()
83+
cls.lvm.start()
84+
85+
cls.service_builder = ServiceOrchestrator()
86+
87+
cls.service_builder.add(opea_microservices["mm_embedding"]).add(opea_microservices["mm_retriever"]).add(
88+
opea_microservices["lvm"]
89+
)
90+
cls.service_builder.flow_to(cls.mm_embedding, cls.mm_retriever)
91+
cls.service_builder.flow_to(cls.mm_retriever, cls.lvm)
92+
93+
cls.follow_up_query_service_builder = ServiceOrchestrator()
94+
cls.follow_up_query_service_builder.add(cls.lvm)
95+
96+
cls.gateway = MultimodalRAGWithVideosGateway(
97+
cls.service_builder, cls.follow_up_query_service_builder, port=9898
98+
)
99+
100+
@classmethod
101+
def tearDownClass(cls):
102+
cls.mm_embedding.stop()
103+
cls.mm_retriever.stop()
104+
cls.lvm.stop()
105+
cls.gateway.stop()
106+
107+
async def test_service_builder_schedule(self):
108+
result_dict, _ = await self.service_builder.schedule(initial_inputs={"text": "hello, "})
109+
self.assertEqual(result_dict[self.lvm.name]["text"], "hello, opea project!")
110+
111+
async def test_follow_up_query_service_builder_schedule(self):
112+
result_dict, _ = await self.follow_up_query_service_builder.schedule(
113+
initial_inputs={"prompt": "chao, ", "image": "some image"}
114+
)
115+
# print(result_dict)
116+
self.assertEqual(result_dict[self.lvm.name]["text"], "<image>\nUSER: chao, \nASSISTANT:")
117+
118+
def test_multimodal_rag_with_videos_gateway(self):
119+
json_data = {"messages": "hello, "}
120+
response = requests.post("http://0.0.0.0:9898/v1/mmragvideoqna", json=json_data)
121+
response = response.json()
122+
self.assertEqual(response["choices"][-1]["message"]["content"], "hello, opea project!")
123+
124+
def test_follow_up_mm_rag_with_videos_gateway(self):
125+
json_data = {
126+
"messages": [
127+
{
128+
"role": "user",
129+
"content": [
130+
{"type": "text", "text": "hello, "},
131+
{
132+
"type": "image_url",
133+
"image_url": {"url": "https://www.ilankelman.org/stopsigns/australia.jpg"},
134+
},
135+
],
136+
},
137+
{"role": "assistant", "content": "opea project! "},
138+
{"role": "user", "content": "chao, "},
139+
],
140+
"max_tokens": 300,
141+
}
142+
response = requests.post("http://0.0.0.0:9898/v1/mmragvideoqna", json=json_data)
143+
response = response.json()
144+
self.assertEqual(
145+
response["choices"][-1]["message"]["content"],
146+
"<image>\nUSER: hello, \nASSISTANT: opea project! \nUSER: chao, \n\nASSISTANT:",
147+
)
148+
149+
def test_handle_message(self):
150+
messages = [
151+
{
152+
"role": "user",
153+
"content": [
154+
{"type": "text", "text": "hello, "},
155+
{
156+
"type": "image_url",
157+
"image_url": {"url": "https://www.ilankelman.org/stopsigns/australia.jpg"},
158+
},
159+
],
160+
},
161+
{"role": "assistant", "content": "opea project! "},
162+
{"role": "user", "content": "chao, "},
163+
]
164+
prompt, images = self.gateway._handle_message(messages)
165+
self.assertEqual(prompt, "hello, \nASSISTANT: opea project! \nUSER: chao, \n")
166+
167+
def test_handle_message_with_system_prompt(self):
168+
messages = [
169+
{"role": "system", "content": "System Prompt"},
170+
{
171+
"role": "user",
172+
"content": [
173+
{"type": "text", "text": "hello, "},
174+
{
175+
"type": "image_url",
176+
"image_url": {"url": "https://www.ilankelman.org/stopsigns/australia.jpg"},
177+
},
178+
],
179+
},
180+
{"role": "assistant", "content": "opea project! "},
181+
{"role": "user", "content": "chao, "},
182+
]
183+
prompt, images = self.gateway._handle_message(messages)
184+
self.assertEqual(prompt, "System Prompt\nhello, \nASSISTANT: opea project! \nUSER: chao, \n")
185+
186+
async def test_handle_request(self):
187+
json_data = {
188+
"messages": [
189+
{
190+
"role": "user",
191+
"content": [
192+
{"type": "text", "text": "hello, "},
193+
{
194+
"type": "image_url",
195+
"image_url": {"url": "https://www.ilankelman.org/stopsigns/australia.jpg"},
196+
},
197+
],
198+
},
199+
{"role": "assistant", "content": "opea project! "},
200+
{"role": "user", "content": "chao, "},
201+
],
202+
"max_tokens": 300,
203+
}
204+
mock_request = Request(scope={"type": "http"})
205+
mock_request._json = json_data
206+
res = await self.gateway.handle_request(mock_request)
207+
res = json.loads(res.json())
208+
self.assertEqual(
209+
res["choices"][-1]["message"]["content"],
210+
"<image>\nUSER: hello, \nASSISTANT: opea project! \nUSER: chao, \n\nASSISTANT:",
211+
)
212+
213+
214+
if __name__ == "__main__":
215+
unittest.main()

0 commit comments

Comments
 (0)