Skip to content

Commit 816cbd2

Browse files
PawanOsmanPawan Osmanmrwyattii
authored
Adding OpenAI Compatible RESTful API (#317)
Co-authored-by: Pawan Osman <[email protected]> Co-authored-by: Michael Wyatt <[email protected]>
1 parent 237a2c9 commit 816cbd2

File tree

5 files changed

+974
-0
lines changed

5 files changed

+974
-0
lines changed
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
{{ (messages|selectattr('role', 'equalto', 'system')|list|last).content|trim if (messages|selectattr('role', 'equalto', 'system')|list) else '' }}
2+
3+
{% for message in messages %}
4+
{% if message['role'] == 'user' %}
5+
### Instruction:
6+
{{ message['content']|trim -}}
7+
{% if not loop.last %}
8+
9+
10+
{% endif %}
11+
{% elif message['role'] == 'assistant' %}
12+
### Response:
13+
{{ message['content']|trim -}}
14+
{% if not loop.last %}
15+
16+
17+
{% endif %}
18+
{% elif message['role'] == 'user_context' %}
19+
### Input:
20+
{{ message['content']|trim -}}
21+
{% if not loop.last %}
22+
23+
24+
{% endif %}
25+
{% endif %}
26+
{% endfor %}
27+
{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}
28+
### Response:
29+
{% endif %}

mii/entrypoints/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
# DeepSpeed Team

mii/entrypoints/api_server.py

Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
# DeepSpeed Team
5+
# Standard library imports
6+
import json
7+
import grpc
8+
import argparse
9+
10+
# Third-party imports
11+
import uvicorn
12+
import mii
13+
from fastapi import FastAPI
14+
from fastapi.middleware.cors import CORSMiddleware
15+
from fastapi.responses import JSONResponse, Response
16+
from mii.grpc_related.proto.modelresponse_pb2_grpc import ModelResponseStub
17+
from mii.grpc_related.proto import modelresponse_pb2
18+
from mii.utils import kwarg_dict_to_proto
19+
20+
# Local module imports
21+
from .data_models import CompletionRequest
22+
23+
app = FastAPI()
24+
load_balancer = "localhost:50050"
25+
26+
27+
@app.post("/generate")
28+
async def generate(request: CompletionRequest) -> Response:
29+
# TODO: Add support for multiple stop tokens, as for now only one is supported
30+
# Check if stop token is a list
31+
if request.stop is not None and isinstance(request.stop, list):
32+
request.stop = request.stop[0]
33+
34+
# Set defaults
35+
if request.max_tokens is None:
36+
request.max_tokens = 128
37+
38+
if request.stream is None:
39+
request.stream = False
40+
41+
if request.prompt is None:
42+
return JSONResponse({"error": "Prompt is required."}, status_code=400)
43+
44+
if isinstance(request.prompt, str):
45+
request.prompt = [request.prompt]
46+
47+
# Set up the generation arguments
48+
generate_args = {"ignore_eos": False, "do_sample": True, "return_full_text": False}
49+
50+
# Set optional generation arguments
51+
if request.max_length is not None:
52+
generate_args["max_length"] = request.max_length
53+
54+
if request.min_tokens is not None:
55+
generate_args["min_new_tokens"] = request.min_tokens
56+
57+
if request.max_tokens is not None:
58+
generate_args["max_new_tokens"] = request.max_tokens
59+
60+
if request.top_p is not None:
61+
generate_args["top_p"] = request.top_p
62+
63+
if request.top_k is not None:
64+
generate_args["top_k"] = request.top_k
65+
66+
if request.temperature is not None:
67+
generate_args["temperature"] = request.temperature
68+
69+
if request.stop is not None:
70+
generate_args["stop"] = request.stop
71+
72+
if request.stream:
73+
generate_args["stream"] = True
74+
75+
channel = grpc.aio.insecure_channel(load_balancer)
76+
stub = ModelResponseStub(channel)
77+
requestData = modelresponse_pb2.MultiStringRequest(
78+
request=request.prompt,
79+
query_kwargs=kwarg_dict_to_proto(generate_args),
80+
)
81+
82+
# Streaming case
83+
if request.stream:
84+
return JSONResponse({"error": "Streaming is not yet supported."},
85+
status_code=400)
86+
# async def StreamResults() -> AsyncGenerator[bytes, None]:
87+
# # Send an empty chunk to start the stream and prevent timeout
88+
# yield ""
89+
# async for response_chunk in stub.GeneratorReplyStream(requestData):
90+
# # Send the response chunk
91+
# responses = [obj.response for obj in response_chunk.response]
92+
# dataOut = {"text": responses}
93+
# yield f"data: {json.dumps(dataOut)}\n\n"
94+
# yield f"data: [DONE]\n\n"
95+
# return StreamingResponse(StreamResults(), media_type="text/event-stream")
96+
97+
# Non-streaming case
98+
responseData = await stub.GeneratorReply(requestData)
99+
responses = [obj.response for obj in responseData.response]
100+
result = {"text": responses}
101+
return JSONResponse(result)
102+
103+
104+
@app.get("/health")
105+
async def health() -> Response:
106+
"""Health check."""
107+
return JSONResponse({"status": "ok"}, status_code=200)
108+
109+
110+
if __name__ == "__main__":
111+
parser = argparse.ArgumentParser(
112+
"DeepSpeed-MII Simple Text Generation RESRful API Server")
113+
parser.add_argument(
114+
"--model",
115+
type=str,
116+
default="mistralai/Mistral-7B-Instruct-v0.1",
117+
help=
118+
"model name or path to model directory (defaults to mistralai/Mistral-7B-Instruct-v0.1)"
119+
)
120+
parser.add_argument(
121+
'--deployment-name',
122+
type=str,
123+
default="deepspeed-mii",
124+
help=
125+
'A unique identifying string for the persistent model (defaults to f"deepspeed-mii")'
126+
)
127+
parser.add_argument("--load-balancer",
128+
type=str,
129+
default=None,
130+
help="load balancer address (defaults to None)")
131+
parser.add_argument("--max-length",
132+
type=int,
133+
default=32768,
134+
help="maximum token length (defaults to 32768)")
135+
parser.add_argument("--host",
136+
type=str,
137+
default="0.0.0.0",
138+
help="host address (defaults to 0.0.0.0)")
139+
parser.add_argument("--port", type=int, default=8000, help="port (defaults to 8000)")
140+
parser.add_argument(
141+
"--allow-credentials",
142+
action="store_true",\
143+
help="allow credentials"
144+
)
145+
parser.add_argument("--allowed-origins",
146+
type=json.loads,
147+
default=["*"],
148+
help="allowed origins")
149+
parser.add_argument("--allowed-methods",
150+
type=json.loads,
151+
default=["*"],
152+
help="allowed methods")
153+
parser.add_argument("--allowed-headers",
154+
type=json.loads,
155+
default=["*"],
156+
help="allowed headers")
157+
parser.add_argument(
158+
'--max_length',
159+
type=int,
160+
default=None,
161+
help=
162+
'Sets the default maximum token length for the prompt + response (defaults to maximum sequence length in model config)'
163+
)
164+
parser.add_argument('--tensor-parallel',
165+
type=int,
166+
default=1,
167+
help='Number of GPUs to split the model across (defaults to 1)')
168+
parser.add_argument('--replica-num',
169+
type=int,
170+
default=1,
171+
help='The number of model replicas to stand up (defaults to 1)')
172+
173+
args = parser.parse_args()
174+
175+
# Add CORS middleware
176+
app.add_middleware(
177+
CORSMiddleware,
178+
allow_origins=args.allowed_origins,
179+
allow_credentials=args.allow_credentials,
180+
allow_methods=args.allowed_methods,
181+
allow_headers=args.allowed_headers,
182+
)
183+
184+
# Check if a load balancer is specified else start the DeepSpeed-MII instance
185+
if args.load_balancer is not None:
186+
# Set the load balancer
187+
load_balancer = args.load_balancer
188+
else:
189+
# Initialize the DeepSpeed-MII instance
190+
mii.serve(args.model,
191+
deployment_name=args.deployment_name,
192+
tensor_parallel=args.tensor_parallel,
193+
replica_num=args.replica_num,
194+
max_length=args.max_length)
195+
196+
# Start the server
197+
uvicorn.run(app,
198+
host=args.host,
199+
port=args.port,
200+
log_level="info",
201+
timeout_keep_alive=300)

0 commit comments

Comments
 (0)