Skip to content

Commit 3a31295

Browse files
XinyaoWakevinintelpre-commit-ci[bot]lvliang-intel
authored
Align parameters for "max_token, repetition_penalty,presence_penalty,frequency_penalty" (#608)
* align max_tokens Signed-off-by: Xinyao Wang <[email protected]> * aligin repetition_penalty Signed-off-by: Xinyao Wang <[email protected]> * fix bug Signed-off-by: Xinyao Wang <[email protected]> * align penalty parameters Signed-off-by: Xinyao Wang <[email protected]> * fix bug Signed-off-by: Xinyao Wang <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix bug Signed-off-by: Xinyao Wang <[email protected]> * fix bug Signed-off-by: Xinyao Wang <[email protected]> * fix bug Signed-off-by: Xinyao Wang <[email protected]> * fix bug Signed-off-by: Xinyao Wang <[email protected]> * align max_tokens Signed-off-by: Xinyao Wang <[email protected]> * fix bug Signed-off-by: Xinyao Wang <[email protected]> * fix bug Signed-off-by: Xinyao Wang <[email protected]> * debug Signed-off-by: Xinyao Wang <[email protected]> * debug Signed-off-by: Xinyao Wang <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix langchain version bug Signed-off-by: Xinyao Wang <[email protected]> * fix langchain version bug Signed-off-by: Xinyao Wang <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Xinyao Wang <[email protected]> Co-authored-by: kevinintel <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: lvliang-intel <[email protected]>
1 parent 00227b8 commit 3a31295

30 files changed

+107
-57
lines changed

comps/cores/mega/gateway.py

Lines changed: 33 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -160,11 +160,13 @@ async def handle_request(self, request: Request):
160160
chat_request = ChatCompletionRequest.parse_obj(data)
161161
prompt = self._handle_message(chat_request.messages)
162162
parameters = LLMParams(
163-
max_new_tokens=chat_request.max_tokens if chat_request.max_tokens else 1024,
163+
max_tokens=chat_request.max_tokens if chat_request.max_tokens else 1024,
164164
top_k=chat_request.top_k if chat_request.top_k else 10,
165165
top_p=chat_request.top_p if chat_request.top_p else 0.95,
166166
temperature=chat_request.temperature if chat_request.temperature else 0.01,
167-
repetition_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 1.03,
167+
frequency_penalty=chat_request.frequency_penalty if chat_request.frequency_penalty else 0.0,
168+
presence_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 0.0,
169+
repetition_penalty=chat_request.repetition_penalty if chat_request.repetition_penalty else 1.03,
168170
streaming=stream_opt,
169171
chat_template=chat_request.chat_template if chat_request.chat_template else None,
170172
)
@@ -214,11 +216,13 @@ async def handle_request(self, request: Request):
214216
chat_request = ChatCompletionRequest.parse_obj(data)
215217
prompt = self._handle_message(chat_request.messages)
216218
parameters = LLMParams(
217-
max_new_tokens=chat_request.max_tokens if chat_request.max_tokens else 1024,
219+
max_tokens=chat_request.max_tokens if chat_request.max_tokens else 1024,
218220
top_k=chat_request.top_k if chat_request.top_k else 10,
219221
top_p=chat_request.top_p if chat_request.top_p else 0.95,
220222
temperature=chat_request.temperature if chat_request.temperature else 0.01,
221-
repetition_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 1.03,
223+
frequency_penalty=chat_request.frequency_penalty if chat_request.frequency_penalty else 0.0,
224+
presence_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 0.0,
225+
repetition_penalty=chat_request.repetition_penalty if chat_request.repetition_penalty else 1.03,
222226
streaming=stream_opt,
223227
)
224228
result_dict, runtime_graph = await self.megaservice.schedule(
@@ -350,11 +354,13 @@ async def handle_request(self, request: Request):
350354
chat_request = ChatCompletionRequest.parse_obj(data)
351355
prompt = self._handle_message(chat_request.messages)
352356
parameters = LLMParams(
353-
max_new_tokens=chat_request.max_tokens if chat_request.max_tokens else 1024,
357+
max_tokens=chat_request.max_tokens if chat_request.max_tokens else 1024,
354358
top_k=chat_request.top_k if chat_request.top_k else 10,
355359
top_p=chat_request.top_p if chat_request.top_p else 0.95,
356360
temperature=chat_request.temperature if chat_request.temperature else 0.01,
357-
repetition_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 1.03,
361+
frequency_penalty=chat_request.frequency_penalty if chat_request.frequency_penalty else 0.0,
362+
presence_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 0.0,
363+
repetition_penalty=chat_request.repetition_penalty if chat_request.repetition_penalty else 1.03,
358364
streaming=stream_opt,
359365
)
360366
result_dict, runtime_graph = await self.megaservice.schedule(
@@ -399,11 +405,13 @@ async def handle_request(self, request: Request):
399405
chat_request = AudioChatCompletionRequest.parse_obj(data)
400406
parameters = LLMParams(
401407
# relatively lower max_tokens for audio conversation
402-
max_new_tokens=chat_request.max_tokens if chat_request.max_tokens else 128,
408+
max_tokens=chat_request.max_tokens if chat_request.max_tokens else 128,
403409
top_k=chat_request.top_k if chat_request.top_k else 10,
404410
top_p=chat_request.top_p if chat_request.top_p else 0.95,
405411
temperature=chat_request.temperature if chat_request.temperature else 0.01,
406-
repetition_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 1.03,
412+
frequency_penalty=chat_request.frequency_penalty if chat_request.frequency_penalty else 0.0,
413+
presence_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 0.0,
414+
repetition_penalty=chat_request.repetition_penalty if chat_request.repetition_penalty else 1.03,
407415
streaming=False, # TODO add streaming LLM output as input to TTS
408416
)
409417
result_dict, runtime_graph = await self.megaservice.schedule(
@@ -428,11 +436,13 @@ async def handle_request(self, request: Request):
428436
chat_request = ChatCompletionRequest.parse_obj(data)
429437
prompt = self._handle_message(chat_request.messages)
430438
parameters = LLMParams(
431-
max_new_tokens=chat_request.max_tokens if chat_request.max_tokens else 1024,
439+
max_tokens=chat_request.max_tokens if chat_request.max_tokens else 1024,
432440
top_k=chat_request.top_k if chat_request.top_k else 10,
433441
top_p=chat_request.top_p if chat_request.top_p else 0.95,
434442
temperature=chat_request.temperature if chat_request.temperature else 0.01,
435-
repetition_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 1.03,
443+
frequency_penalty=chat_request.frequency_penalty if chat_request.frequency_penalty else 0.0,
444+
presence_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 0.0,
445+
repetition_penalty=chat_request.repetition_penalty if chat_request.repetition_penalty else 1.03,
436446
streaming=stream_opt,
437447
)
438448
result_dict, runtime_graph = await self.megaservice.schedule(
@@ -472,11 +482,13 @@ async def handle_request(self, request: Request):
472482
chat_request = ChatCompletionRequest.parse_obj(data)
473483
prompt = self._handle_message(chat_request.messages)
474484
parameters = LLMParams(
475-
max_new_tokens=chat_request.max_tokens if chat_request.max_tokens else 1024,
485+
max_tokens=chat_request.max_tokens if chat_request.max_tokens else 1024,
476486
top_k=chat_request.top_k if chat_request.top_k else 10,
477487
top_p=chat_request.top_p if chat_request.top_p else 0.95,
478488
temperature=chat_request.temperature if chat_request.temperature else 0.01,
479-
repetition_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 1.03,
489+
frequency_penalty=chat_request.frequency_penalty if chat_request.frequency_penalty else 0.0,
490+
presence_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 0.0,
491+
repetition_penalty=chat_request.repetition_penalty if chat_request.repetition_penalty else 1.03,
480492
streaming=stream_opt,
481493
)
482494
result_dict, runtime_graph = await self.megaservice.schedule(
@@ -520,7 +532,9 @@ async def handle_request(self, request: Request):
520532
top_k=chat_request.top_k if chat_request.top_k else 10,
521533
top_p=chat_request.top_p if chat_request.top_p else 0.95,
522534
temperature=chat_request.temperature if chat_request.temperature else 0.01,
523-
repetition_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 1.03,
535+
frequency_penalty=chat_request.frequency_penalty if chat_request.frequency_penalty else 0.0,
536+
presence_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 0.0,
537+
repetition_penalty=chat_request.repetition_penalty if chat_request.repetition_penalty else 1.03,
524538
streaming=stream_opt,
525539
)
526540
result_dict, runtime_graph = await self.megaservice.schedule(
@@ -569,7 +583,9 @@ async def handle_request(self, request: Request):
569583
top_k=chat_request.top_k if chat_request.top_k else 10,
570584
top_p=chat_request.top_p if chat_request.top_p else 0.95,
571585
temperature=chat_request.temperature if chat_request.temperature else 0.01,
572-
repetition_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 1.03,
586+
frequency_penalty=chat_request.frequency_penalty if chat_request.frequency_penalty else 0.0,
587+
presence_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 0.0,
588+
repetition_penalty=chat_request.repetition_penalty if chat_request.repetition_penalty else 1.03,
573589
streaming=stream_opt,
574590
)
575591
result_dict, runtime_graph = await self.megaservice.schedule(
@@ -758,7 +774,9 @@ async def handle_request(self, request: Request):
758774
top_k=chat_request.top_k if chat_request.top_k else 10,
759775
top_p=chat_request.top_p if chat_request.top_p else 0.95,
760776
temperature=chat_request.temperature if chat_request.temperature else 0.01,
761-
repetition_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 1.03,
777+
frequency_penalty=chat_request.frequency_penalty if chat_request.frequency_penalty else 0.0,
778+
presence_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 0.0,
779+
repetition_penalty=chat_request.repetition_penalty if chat_request.repetition_penalty else 1.03,
762780
streaming=stream_opt,
763781
chat_template=chat_request.chat_template if chat_request.chat_template else None,
764782
)

comps/cores/proto/api_protocol.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -285,8 +285,9 @@ class AudioChatCompletionRequest(BaseModel):
285285
max_tokens: Optional[int] = 1024
286286
stop: Optional[Union[str, List[str]]] = None
287287
stream: Optional[bool] = False
288-
presence_penalty: Optional[float] = 1.03
288+
presence_penalty: Optional[float] = 0.0
289289
frequency_penalty: Optional[float] = 0.0
290+
repetition_penalty: Optional[float] = 1.03
290291
user: Optional[str] = None
291292

292293

@@ -345,6 +346,7 @@ class CompletionRequest(BaseModel):
345346
echo: Optional[bool] = False
346347
presence_penalty: Optional[float] = 0.0
347348
frequency_penalty: Optional[float] = 0.0
349+
repetition_penalty: Optional[float] = 1.03
348350
user: Optional[str] = None
349351
use_beam_search: Optional[bool] = False
350352
best_of: Optional[int] = None

comps/cores/proto/docarray.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,11 +145,14 @@ class RerankedDoc(BaseDoc):
145145
class LLMParamsDoc(BaseDoc):
146146
model: Optional[str] = None # for openai and ollama
147147
query: str
148+
max_tokens: int = 1024
148149
max_new_tokens: int = 1024
149150
top_k: int = 10
150151
top_p: float = 0.95
151152
typical_p: float = 0.95
152153
temperature: float = 0.01
154+
frequency_penalty: float = 0.0
155+
presence_penalty: float = 0.0
153156
repetition_penalty: float = 1.03
154157
streaming: bool = True
155158

@@ -179,11 +182,14 @@ def chat_template_must_contain_variables(cls, v):
179182

180183

181184
class LLMParams(BaseDoc):
185+
max_tokens: int = 1024
182186
max_new_tokens: int = 1024
183187
top_k: int = 10
184188
top_p: float = 0.95
185189
typical_p: float = 0.95
186190
temperature: float = 0.01
191+
frequency_penalty: float = 0.0
192+
presence_penalty: float = 0.0
187193
repetition_penalty: float = 1.03
188194
streaming: bool = True
189195

comps/llms/faq-generation/tgi/langchain/llm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def llm_generate(input: LLMParamsDoc):
4040
llm_endpoint = os.getenv("TGI_LLM_ENDPOINT", "http://localhost:8080")
4141
llm = HuggingFaceEndpoint(
4242
endpoint_url=llm_endpoint,
43-
max_new_tokens=input.max_new_tokens,
43+
max_new_tokens=input.max_tokens,
4444
top_k=input.top_k,
4545
top_p=input.top_p,
4646
typical_p=input.typical_p,

comps/llms/faq-generation/tgi/langchain/requirements.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,10 @@ docarray[full]
22
fastapi
33
huggingface_hub
44
langchain
5+
langchain-huggingface
6+
langchain-openai
57
langchain_community
8+
langchainhub
69
opentelemetry-api
710
opentelemetry-exporter-otlp
811
opentelemetry-sdk

comps/llms/summarization/tgi/langchain/llm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def llm_generate(input: LLMParamsDoc):
3939
llm_endpoint = os.getenv("TGI_LLM_ENDPOINT", "http://localhost:8080")
4040
llm = HuggingFaceEndpoint(
4141
endpoint_url=llm_endpoint,
42-
max_new_tokens=input.max_new_tokens,
42+
max_new_tokens=input.max_tokens,
4343
top_k=input.top_k,
4444
top_p=input.top_p,
4545
typical_p=input.typical_p,

comps/llms/text-generation/README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -374,7 +374,7 @@ curl http://${your_ip}:8008/v1/chat/completions \
374374

375375
### 3.3 Consume LLM Service
376376

377-
You can set the following model parameters according to your actual needs, such as `max_new_tokens`, `streaming`.
377+
You can set the following model parameters according to your actual needs, such as `max_tokens`, `streaming`.
378378

379379
The `streaming` parameter determines the format of the data returned by the API. It will return text string with `streaming=false`, return text streaming flow with `streaming=true`.
380380

@@ -385,7 +385,7 @@ curl http://${your_ip}:9000/v1/chat/completions \
385385
-H 'Content-Type: application/json' \
386386
-d '{
387387
"query":"What is Deep Learning?",
388-
"max_new_tokens":17,
388+
"max_tokens":17,
389389
"top_k":10,
390390
"top_p":0.95,
391391
"typical_p":0.95,
@@ -401,7 +401,7 @@ curl http://${your_ip}:9000/v1/chat/completions \
401401
-H 'Content-Type: application/json' \
402402
-d '{
403403
"query":"What is Deep Learning?",
404-
"max_new_tokens":17,
404+
"max_tokens":17,
405405
"top_k":10,
406406
"top_p":0.95,
407407
"typical_p":0.95,

comps/llms/text-generation/ollama/langchain/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,5 +70,5 @@ docker run --network host -e http_proxy=$http_proxy -e https_proxy=$https_proxy
7070
## Consume the Ollama Microservice
7171

7272
```bash
73-
curl http://127.0.0.1:9000/v1/chat/completions -X POST -d '{"model": "llama3", "query":"What is Deep Learning?","max_new_tokens":32,"top_k":10,"top_p":0.95,"typical_p":0.95,"temperature":0.01,"repetition_penalty":1.03,"streaming":true}' -H 'Content-Type: application/json'
73+
curl http://127.0.0.1:9000/v1/chat/completions -X POST -d '{"model": "llama3", "query":"What is Deep Learning?","max_tokens":32,"top_k":10,"top_p":0.95,"typical_p":0.95,"temperature":0.01,"repetition_penalty":1.03,"streaming":true}' -H 'Content-Type: application/json'
7474
```

comps/llms/text-generation/ollama/langchain/llm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def llm_generate(input: LLMParamsDoc):
2525
ollama = Ollama(
2626
base_url=ollama_endpoint,
2727
model=input.model if input.model else model_name,
28-
num_predict=input.max_new_tokens,
28+
num_predict=input.max_tokens,
2929
top_k=input.top_k,
3030
top_p=input.top_p,
3131
temperature=input.temperature,

comps/llms/text-generation/predictionguard/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ curl -X POST http://localhost:9000/v1/chat/completions \
2929
-d '{
3030
"model": "Hermes-2-Pro-Llama-3-8B",
3131
"query": "Tell me a joke.",
32-
"max_new_tokens": 100,
32+
"max_tokens": 100,
3333
"temperature": 0.7,
3434
"top_p": 0.9,
3535
"top_k": 50,
@@ -45,7 +45,7 @@ curl -N -X POST http://localhost:9000/v1/chat/completions \
4545
-d '{
4646
"model": "Hermes-2-Pro-Llama-3-8B",
4747
"query": "Tell me a joke.",
48-
"max_new_tokens": 100,
48+
"max_tokens": 100,
4949
"temperature": 0.7,
5050
"top_p": 0.9,
5151
"top_k": 50,

0 commit comments

Comments
 (0)