Skip to content

Commit 0f2c2b1

Browse files
vllm langchain: Add Document Retriever Support (#687)
* vllm langchain: Add Document Retriever Support Include SearchedDoc in /v1/chat/completions endpoint to accept document data retreived from retriever service to parse into LLM for answer generation. Signed-off-by: Yeoh, Hoong Tee <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * vllm: Update README documentation Signed-off-by: Yeoh, Hoong Tee <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Yeoh, Hoong Tee <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 574fecf commit 0f2c2b1

File tree

3 files changed

+185
-30
lines changed

3 files changed

+185
-30
lines changed

comps/llms/text-generation/vllm/langchain/README.md

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ curl http://${your_ip}:8008/v1/completions \
165165

166166
## 🚀3. Set up LLM microservice
167167

168-
Then we warp the VLLM service into LLM microcervice.
168+
Then we warp the VLLM service into LLM microservice.
169169

170170
### Build docker
171171

@@ -179,11 +179,48 @@ bash build_docker_microservice.sh
179179
bash launch_microservice.sh
180180
```
181181

182-
### Query the microservice
182+
### Consume the microservice
183+
184+
#### Check microservice status
183185

184186
```bash
187+
curl http://${your_ip}:9000/v1/health_check\
188+
-X GET \
189+
-H 'Content-Type: application/json'
190+
191+
# Output
192+
# {"Service Title":"opea_service@llm_vllm/MicroService","Service Description":"OPEA Microservice Infrastructure"}
193+
```
194+
195+
#### Consume vLLM Service
196+
197+
User can set the following model parameters according to needs:
198+
199+
- max_new_tokens: Total output token
200+
- streaming(true/false): return text response in streaming mode or non-streaming mode
201+
202+
```bash
203+
# 1. Non-streaming mode
185204
curl http://${your_ip}:9000/v1/chat/completions \
186205
-X POST \
187206
-d '{"query":"What is Deep Learning?","max_new_tokens":17,"top_p":0.95,"temperature":0.01,"streaming":false}' \
188207
-H 'Content-Type: application/json'
208+
209+
# 2. Streaming mode
210+
curl http://${your_ip}:9000/v1/chat/completions \
211+
-X POST \
212+
-d '{"query":"What is Deep Learning?","max_new_tokens":17,"top_k":10,"top_p":0.95,"typical_p":0.95,"temperature":0.01,"repetition_penalty":1.03,"streaming":true}' \
213+
-H 'Content-Type: application/json'
214+
215+
# 3. Custom chat template with streaming mode
216+
curl http://${your_ip}:9000/v1/chat/completions \
217+
-X POST \
218+
-d '{"query":"What is Deep Learning?","max_new_tokens":17,"top_k":10,"top_p":0.95,"typical_p":0.95,"temperature":0.01,"repetition_penalty":1.03,"streaming":true, "chat_template":"### You are a helpful, respectful and honest assistant to help the user with questions.\n### Context: {context}\n### Question: {question}\n### Answer:"}' \
219+
-H 'Content-Type: application/json'
220+
221+
4. # Chat with SearchedDoc (Retrieval context)
222+
curl http://${your_ip}:9000/v1/chat/completions \
223+
-X POST \
224+
-d '{"initial_query":"What is Deep Learning?","retrieved_docs":[{"text":"Deep Learning is a ..."},{"text":"Deep Learning is b ..."}]}' \
225+
-H 'Content-Type: application/json'
189226
```

comps/llms/text-generation/vllm/langchain/llm.py

Lines changed: 117 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -2,23 +2,31 @@
22
# SPDX-License-Identifier: Apache-2.0
33

44
import os
5+
from typing import Union
56

67
from fastapi.responses import StreamingResponse
78
from langchain_community.llms import VLLMOpenAI
9+
from langchain_core.prompts import PromptTemplate
10+
from template import ChatTemplate
811

912
from comps import (
1013
CustomLogger,
1114
GeneratedDoc,
1215
LLMParamsDoc,
16+
SearchedDoc,
1317
ServiceType,
1418
opea_microservices,
1519
opea_telemetry,
1620
register_microservice,
1721
)
22+
from comps.cores.proto.api_protocol import ChatCompletionRequest
1823

1924
logger = CustomLogger("llm_vllm")
2025
logflag = os.getenv("LOGFLAG", False)
2126

27+
llm_endpoint = os.getenv("vLLM_ENDPOINT", "http://localhost:8008")
28+
model_name = os.getenv("LLM_MODEL", "meta-llama/Meta-Llama-3-8B-Instruct")
29+
2230

2331
@opea_telemetry
2432
def post_process_text(text: str):
@@ -39,39 +47,120 @@ def post_process_text(text: str):
3947
host="0.0.0.0",
4048
port=9000,
4149
)
42-
def llm_generate(input: LLMParamsDoc):
50+
def llm_generate(input: Union[LLMParamsDoc, ChatCompletionRequest, SearchedDoc]):
4351
if logflag:
4452
logger.info(input)
45-
llm_endpoint = os.getenv("vLLM_ENDPOINT", "http://localhost:8008")
46-
model_name = os.getenv("LLM_MODEL", "meta-llama/Meta-Llama-3-8B-Instruct")
47-
llm = VLLMOpenAI(
48-
openai_api_key="EMPTY",
49-
openai_api_base=llm_endpoint + "/v1",
50-
max_tokens=input.max_new_tokens,
51-
model_name=model_name,
52-
top_p=input.top_p,
53-
temperature=input.temperature,
54-
streaming=input.streaming,
55-
)
56-
57-
if input.streaming:
58-
59-
def stream_generator():
60-
chat_response = ""
61-
for text in llm.stream(input.query):
62-
chat_response += text
63-
chunk_repr = repr(text.encode("utf-8"))
64-
yield f"data: {chunk_repr}\n\n"
53+
54+
prompt_template = None
55+
56+
if not isinstance(input, SearchedDoc) and input.chat_template:
57+
prompt_template = PromptTemplate.from_template(input.chat_template)
58+
input_variables = prompt_template.input_variables
59+
60+
if isinstance(input, SearchedDoc):
61+
if logflag:
62+
logger.info("[ SearchedDoc ] input from retriever microservice")
63+
64+
prompt = input.initial_query
65+
66+
if input.retrieved_docs:
67+
docs = [doc.text for doc in input.retrieved_docs]
6568
if logflag:
66-
logger.info(f"[llm - chat_stream] stream response: {chat_response}")
67-
yield "data: [DONE]\n\n"
69+
logger.info(f"[ SearchedDoc ] combined retrieved docs: {docs}")
70+
71+
prompt = ChatTemplate.generate_rag_prompt(input.initial_query, docs)
72+
73+
# use default llm parameter for inference
74+
new_input = LLMParamsDoc(query=prompt)
6875

69-
return StreamingResponse(stream_generator(), media_type="text/event-stream")
70-
else:
71-
response = llm.invoke(input.query)
7276
if logflag:
73-
logger.info(response)
74-
return GeneratedDoc(text=response, prompt=input.query)
77+
logger.info(f"[ SearchedDoc ] final input: {new_input}")
78+
79+
llm = VLLMOpenAI(
80+
openai_api_key="EMPTY",
81+
openai_api_base=llm_endpoint + "/v1",
82+
max_tokens=new_input.max_new_tokens,
83+
model_name=model_name,
84+
top_p=new_input.top_p,
85+
temperature=new_input.temperature,
86+
streaming=new_input.streaming,
87+
)
88+
89+
if new_input.streaming:
90+
91+
def stream_generator():
92+
chat_response = ""
93+
for text in llm.stream(new_input.query):
94+
chat_response += text
95+
chunk_repr = repr(text.encode("utf-8"))
96+
if logflag:
97+
logger.info(f"[ SearchedDoc ] chunk: {chunk_repr}")
98+
yield f"data: {chunk_repr}\n\n"
99+
if logflag:
100+
logger.info(f"[ SearchedDoc ] stream response: {chat_response}")
101+
yield "data: [DONE]\n\n"
102+
103+
return StreamingResponse(stream_generator(), media_type="text/event-stream")
104+
105+
else:
106+
response = llm.invoke(new_input.query)
107+
if logflag:
108+
logger.info(response)
109+
110+
return GeneratedDoc(text=response, prompt=new_input.query)
111+
112+
elif isinstance(input, LLMParamsDoc):
113+
if logflag:
114+
logger.info("[ LLMParamsDoc ] input from rerank microservice")
115+
116+
prompt = input.query
117+
118+
if prompt_template:
119+
if sorted(input_variables) == ["context", "question"]:
120+
prompt = prompt_template.format(question=input.query, context="\n".join(input.documents))
121+
elif input_variables == ["question"]:
122+
prompt = prompt_template.format(question=input.query)
123+
else:
124+
logger.info(
125+
f"[ LLMParamsDoc ] {prompt_template} not used, we only support 2 input variables ['question', 'context']"
126+
)
127+
else:
128+
if input.documents:
129+
# use rag default template
130+
prompt = ChatTemplate.generate_rag_prompt(input.query, input.documents)
131+
132+
llm = VLLMOpenAI(
133+
openai_api_key="EMPTY",
134+
openai_api_base=llm_endpoint + "/v1",
135+
max_tokens=input.max_new_tokens,
136+
model_name=model_name,
137+
top_p=input.top_p,
138+
temperature=input.temperature,
139+
streaming=input.streaming,
140+
)
141+
142+
if input.streaming:
143+
144+
def stream_generator():
145+
chat_response = ""
146+
for text in llm.stream(input.query):
147+
chat_response += text
148+
chunk_repr = repr(text.encode("utf-8"))
149+
if logflag:
150+
logger.info(f"[ LLMParamsDoc ] chunk: {chunk_repr}")
151+
yield f"data: {chunk_repr}\n\n"
152+
if logflag:
153+
logger.info(f"[ LLMParamsDoc ] stream response: {chat_response}")
154+
yield "data: [DONE]\n\n"
155+
156+
return StreamingResponse(stream_generator(), media_type="text/event-stream")
157+
158+
else:
159+
response = llm.invoke(input.query)
160+
if logflag:
161+
logger.info(response)
162+
163+
return GeneratedDoc(text=response, prompt=input.query)
75164

76165

77166
if __name__ == "__main__":
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# Copyright (C) 2024 Intel Corporation
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
import re
5+
6+
7+
class ChatTemplate:
8+
@staticmethod
9+
def generate_rag_prompt(question, documents):
10+
context_str = "\n".join(documents)
11+
if context_str and len(re.findall("[\u4E00-\u9FFF]", context_str)) / len(context_str) >= 0.3:
12+
# chinese context
13+
template = """
14+
### 你将扮演一个乐于助人、尊重他人并诚实的助手,你的目标是帮助用户解答问题。有效地利用来自本地知识库的搜索结果。确保你的回答中只包含相关信息。如果你不确定问题的答案,请避免分享不准确的信息。
15+
### 搜索结果:{context}
16+
### 问题:{question}
17+
### 回答:
18+
"""
19+
else:
20+
template = """
21+
### You are a helpful, respectful and honest assistant to help the user with questions. \
22+
Please refer to the search results obtained from the local knowledge base. \
23+
But be careful to not incorporate the information that you think is not relevant to the question. \
24+
If you don't know the answer to a question, please don't share false information. \n
25+
### Search results: {context} \n
26+
### Question: {question} \n
27+
### Answer:
28+
"""
29+
return template.format(context=context_str, question=question)

0 commit comments

Comments
 (0)