|
| 1 | +# Copyright (C) 2024 Intel Corporation |
| 2 | +# SPDX-License-Identifier: Apache-2.0 |
| 3 | + |
| 4 | +import os |
| 5 | +import time |
| 6 | +from typing import Union |
| 7 | + |
| 8 | +import requests |
| 9 | +from fastapi import HTTPException |
| 10 | +from fastapi.responses import StreamingResponse |
| 11 | +from langchain_core.prompts import PromptTemplate |
| 12 | +from openai import OpenAI |
| 13 | + |
| 14 | +from comps import ( |
| 15 | + CustomLogger, |
| 16 | + LVMDoc, |
| 17 | + LVMSearchedMultimodalDoc, |
| 18 | + MetadataTextDoc, |
| 19 | + OpeaComponent, |
| 20 | + OpeaComponentRegistry, |
| 21 | + ServiceType, |
| 22 | + TextDoc, |
| 23 | + statistics_dict, |
| 24 | +) |
| 25 | + |
| 26 | +logger = CustomLogger("opea_vllm") |
| 27 | +logflag = os.getenv("LOGFLAG", False) |
| 28 | + |
| 29 | +# The maximum number of images that should be sent to the LVM |
| 30 | +# max_images = int(os.getenv("MAX_IMAGES", 1)) |
| 31 | +LLM_MODEL_ID = os.getenv("LLM_MODEL_ID", "llava-hf/llava-1.5-7b-hf") |
| 32 | + |
| 33 | + |
| 34 | +class ChatTemplate: |
| 35 | + |
| 36 | + @staticmethod |
| 37 | + def generate_multimodal_rag_on_videos_prompt(question: str, context: str, has_image: bool = False): |
| 38 | + |
| 39 | + if has_image: |
| 40 | + template = """The transcript associated with the image is '{context}'. {question}""" |
| 41 | + else: |
| 42 | + template = ( |
| 43 | + """Refer to the following results obtained from the local knowledge base: '{context}'. {question}""" |
| 44 | + ) |
| 45 | + |
| 46 | + return template.format(context=context, question=question) |
| 47 | + |
| 48 | + |
| 49 | +@OpeaComponentRegistry.register("OPEA_VLLM_LVM") |
| 50 | +class OpeaVllmLvm(OpeaComponent): |
| 51 | + """A specialized vLLM LVM component derived from OpeaComponent for vLLM LVM services.""" |
| 52 | + |
| 53 | + def __init__(self, name: str, description: str, config: dict = None): |
| 54 | + super().__init__(name, ServiceType.LVM.name.lower(), description, config) |
| 55 | + self.base_url = os.getenv("LVM_ENDPOINT", "http://localhost:8399") |
| 56 | + # https://github.com/huggingface/huggingface_hub/blob/v0.29.1/src/huggingface_hub/inference/_providers/hf_inference.py#L87 |
| 57 | + # latest AsyncInferenceClient has model hardcoded issues to "tgi" |
| 58 | + # so we use OpenAI client |
| 59 | + self.lvm_client = OpenAI(api_key="EMPTY", base_url=f"{self.base_url}/v1") |
| 60 | + health_status = self.check_health() |
| 61 | + # if logflag: |
| 62 | + # logger.info(f"MAX_IMAGES: {max_images}") |
| 63 | + if not health_status: |
| 64 | + logger.error("OpeaVllmLvm health check failed.") |
| 65 | + |
| 66 | + async def invoke( |
| 67 | + self, |
| 68 | + request: Union[LVMDoc, LVMSearchedMultimodalDoc], |
| 69 | + ) -> Union[TextDoc, MetadataTextDoc]: |
| 70 | + """Involve the LVM service to generate answer for the provided input.""" |
| 71 | + if logflag: |
| 72 | + logger.info(request) |
| 73 | + if isinstance(request, LVMSearchedMultimodalDoc): |
| 74 | + # TODO may bugs here |
| 75 | + if logflag: |
| 76 | + logger.info("[LVMSearchedMultimodalDoc ] input from retriever microservice") |
| 77 | + retrieved_metadatas = request.metadata |
| 78 | + if retrieved_metadatas is None or len(retrieved_metadatas) == 0: |
| 79 | + raise HTTPException(status_code=500, detail="There is no video segments retrieved given the query!") |
| 80 | + |
| 81 | + img_b64_str = retrieved_metadatas[0]["b64_img_str"] |
| 82 | + has_image = img_b64_str != "" |
| 83 | + initial_query = request.initial_query |
| 84 | + context = retrieved_metadatas[0]["transcript_for_inference"] |
| 85 | + prompt = initial_query |
| 86 | + if request.chat_template is None: |
| 87 | + prompt = ChatTemplate.generate_multimodal_rag_on_videos_prompt(initial_query, context, has_image) |
| 88 | + else: |
| 89 | + prompt_template = PromptTemplate.from_template(request.chat_template) |
| 90 | + input_variables = prompt_template.input_variables |
| 91 | + if sorted(input_variables) == ["context", "question"]: |
| 92 | + prompt = prompt_template.format(question=initial_query, context=context) |
| 93 | + else: |
| 94 | + logger.info( |
| 95 | + f"[ LVMSearchedMultimodalDoc ] {prompt_template} not used, we only support 2 input variables ['question', 'context']" |
| 96 | + ) |
| 97 | + max_new_tokens = request.max_new_tokens |
| 98 | + stream = request.stream |
| 99 | + repetition_penalty = request.repetition_penalty |
| 100 | + temperature = request.temperature |
| 101 | + top_k = request.top_k |
| 102 | + top_p = request.top_p |
| 103 | + if logflag: |
| 104 | + logger.info( |
| 105 | + f"prompt generated for [LVMSearchedMultimodalDoc ] input from retriever microservice: {prompt}" |
| 106 | + ) |
| 107 | + |
| 108 | + else: |
| 109 | + # TODO align legacy LVMDoc with chat completions parameters for vLLM |
| 110 | + img_b64_str = request.image |
| 111 | + prompt = request.prompt |
| 112 | + max_new_tokens = request.max_new_tokens |
| 113 | + stream = request.stream |
| 114 | + # repetition_penalty = request.repetition_penalty |
| 115 | + temperature = request.temperature |
| 116 | + # top_k = request.top_k |
| 117 | + top_p = request.top_p |
| 118 | + |
| 119 | + if not img_b64_str: |
| 120 | + # If img_b64_str was an empty string, which means we have just have a text prompt. |
| 121 | + # Work around an issue where LLaVA-NeXT is not providing good responses when prompted without an image. |
| 122 | + # Provide an image and then instruct the model to ignore the image. The base64 string below is the encoded png: |
| 123 | + # https://raw.githubusercontent.com/opea-project/GenAIExamples/refs/tags/v1.0/AudioQnA/ui/svelte/src/lib/assets/icons/png/audio1.png |
| 124 | + img_b64_str = "iVBORw0KGgoAAAANSUhEUgAAADUAAAAlCAYAAADiMKHrAAAACXBIWXMAAAsTAAALEwEAmpwYAAAAAXNSR0IArs4c6QAAAARnQU1BAACxjwv8YQUAAAKPSURBVHgB7Zl/btowFMefnUTqf+MAHYMTjN4gvcGOABpM+8E0doLSE4xpsE3rKuAG3KC5Ad0J6MYOkP07YnvvhR9y0lVzupTIVT5SwDjB9fd97WfsMkCef1rUXM8dY9HHK4hWUevzi/oVWAqnF8fzLmAtiPA3Aq0lFsVA1fRKxlgNLIbDPaQUZQuu6YO98aIipHOiFGtIqaYfn1UnUCDds6WPyeANlTFbv9WztbFTK+HNUVAPiz7nbPzq7HsPCoKWIBREGfsJXZit5xT07X0jp6iRdIbEHOnjyyD97OvzH00lVS2K5OS2ax11cBXxJgYxlEIE6XZclzdTX6n8XjkkcEIfbj2nMO0/SNd1vy4vsCNjYPyEovfyy88GZIQCSKOCMf6ORgStoboLJuSWKDYCfK2q4jjrMZ+GOh7Pib/gek5DHxVUJtcgA7mJ4kwZRbN7viQXFzQn0Nl52gXG4Fo7DKAYp0yI3VHQ16oaWV0wYa+iGE8nG+wAdx5DzpS/KGyhFGULpShbKEXZQinqLlBK/IKc2asoh4sZvoXJWhlAzuxV1KBVD3HrfYTFAK8ZHgu0hu36DHLG+Izinw250WUkXHJht02QUnxLP7fZxR7f1I6S7Ir2GgmYvIQM5OYUuYBdainATq2ZjTqPBlnbGXYeBrg9Od18DKmc1U0jpw4OIIwEJFxQSl2b4MN2lf74fw8nFNbHt/5N9xWKTZvJ2S6YZk6RC3j2cKpVhSIShZ0mea6caCOCAjyNHd5gPPxGncMBTvI6hunYdaJ6kf8VoSCP2odxX6RkR6NOtanfj13EswKVqEQrPzzFL1lK+YvCFraiEqs8TrwQLGYraqpX4kr/Hixml+63Z+CoM9DTo438AUmP+KyMWT+tAAAAAElFTkSuQmCC" |
| 125 | + |
| 126 | + if stream: |
| 127 | + t_start = time.time() |
| 128 | + |
| 129 | + def stream_generator(time_start): |
| 130 | + first_token_latency = None |
| 131 | + chat_response = "" |
| 132 | + |
| 133 | + # https://docs.vllm.ai/en/v0.5.1/getting_started/examples/openai_vision_api_client.html |
| 134 | + # vLLM chat completions api |
| 135 | + # TODO align legacy LVMDoc with chat completions parameters for vLLM |
| 136 | + # Now we simply keep the intersection of them |
| 137 | + # TODO check vLLM multi-image inputs https://platform.openai.com/docs/guides/vision#multiple-image-inputs |
| 138 | + text_generation = self.lvm_client.chat.completions.create( |
| 139 | + model=LLM_MODEL_ID, |
| 140 | + messages=[ |
| 141 | + { |
| 142 | + "role": "user", |
| 143 | + "content": [ |
| 144 | + {"type": "text", "text": prompt}, |
| 145 | + {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{img_b64_str}"}}, |
| 146 | + ], |
| 147 | + } |
| 148 | + ], |
| 149 | + max_tokens=max_new_tokens, |
| 150 | + temperature=temperature, |
| 151 | + top_p=top_p, |
| 152 | + stream=True, |
| 153 | + ) |
| 154 | + |
| 155 | + for output in text_generation: |
| 156 | + if first_token_latency is None: |
| 157 | + first_token_latency = time.time() - time_start |
| 158 | + text = output.choices[0].delta.content |
| 159 | + chat_response += text |
| 160 | + chunk_repr = repr(text.encode("utf-8")) |
| 161 | + if logflag: |
| 162 | + logger.info(f"[llm - chat_stream] chunk:{chunk_repr}") |
| 163 | + yield f"data: {chunk_repr}\n\n" |
| 164 | + if logflag: |
| 165 | + logger.info(f"[llm - chat_stream] stream response: {chat_response}") |
| 166 | + statistics_dict["opea_service@lvm"].append_latency(time.time() - time_start, first_token_latency) |
| 167 | + yield "data: [DONE]\n\n" |
| 168 | + |
| 169 | + return StreamingResponse(stream_generator(t_start), media_type="text/event-stream") |
| 170 | + else: |
| 171 | + # https://docs.vllm.ai/en/v0.5.1/getting_started/examples/openai_vision_api_client.html |
| 172 | + # vLLM chat completions api |
| 173 | + # TODO align legacy LVMDoc with chat completions parameters for vLLM |
| 174 | + # Now we simply keep the intersection of them |
| 175 | + # TODO check vLLM multi-image inputs https://platform.openai.com/docs/guides/vision#multiple-image-inputs |
| 176 | + generated_output = self.lvm_client.chat.completions.create( |
| 177 | + model=LLM_MODEL_ID, |
| 178 | + messages=[ |
| 179 | + { |
| 180 | + "role": "user", |
| 181 | + "content": [ |
| 182 | + {"type": "text", "text": prompt}, |
| 183 | + {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{img_b64_str}"}}, |
| 184 | + ], |
| 185 | + } |
| 186 | + ], |
| 187 | + max_tokens=max_new_tokens, |
| 188 | + temperature=temperature, |
| 189 | + top_p=top_p, |
| 190 | + ) |
| 191 | + generated_str = generated_output.choices[0].message.content |
| 192 | + |
| 193 | + if logflag: |
| 194 | + logger.info(generated_str) |
| 195 | + if isinstance(request, LVMSearchedMultimodalDoc): |
| 196 | + # TODO Check bugs here |
| 197 | + retrieved_metadata = request.metadata[0] |
| 198 | + return_metadata = {} # this metadata will be used to construct proof for generated text |
| 199 | + return_metadata["video_id"] = retrieved_metadata["video_id"] |
| 200 | + return_metadata["source_video"] = retrieved_metadata["source_video"] |
| 201 | + return_metadata["time_of_frame_ms"] = retrieved_metadata["time_of_frame_ms"] |
| 202 | + return_metadata["transcript_for_inference"] = retrieved_metadata["transcript_for_inference"] |
| 203 | + return MetadataTextDoc(text=generated_str, metadata=return_metadata) |
| 204 | + else: |
| 205 | + return TextDoc(text=generated_str) |
| 206 | + |
| 207 | + def check_health(self) -> bool: |
| 208 | + """Checks the health of the embedding service. |
| 209 | +
|
| 210 | + Returns: |
| 211 | + bool: True if the service is reachable and healthy, False otherwise. |
| 212 | + """ |
| 213 | + try: |
| 214 | + response = requests.get(f"{self.base_url}/health") |
| 215 | + if response.status_code == 200: |
| 216 | + return True |
| 217 | + else: |
| 218 | + return False |
| 219 | + except Exception as e: |
| 220 | + # Handle connection errors, timeouts, etc. |
| 221 | + logger.error(f"Health check failed: {e}") |
| 222 | + return False |
0 commit comments