Skip to content

Commit e3ec7ac

Browse files
authored
Make API and server compatible with OpenAI API (#1034)
1 parent c7f56f2 commit e3ec7ac

File tree

3 files changed

+57
-99
lines changed

3 files changed

+57
-99
lines changed

api/api.py

+9-5
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,9 @@ class CompletionRequest:
125125
parallel_tool_calls: Optional[bool] = None # unimplemented - Assistant features
126126
user: Optional[str] = None # unimplemented
127127

128+
def __post_init__(self):
129+
self.stream = bool(self.stream)
130+
128131

129132
@dataclass
130133
class CompletionChoice:
@@ -204,7 +207,7 @@ class CompletionResponseChunk:
204207
choices: List[CompletionChoiceChunk]
205208
created: int
206209
model: str
207-
system_fingerprint: str
210+
system_fingerprint: Optional[str] = None
208211
service_tier: Optional[str] = None
209212
object: str = "chat.completion.chunk"
210213
usage: Optional[UsageStats] = None
@@ -311,7 +314,7 @@ def callback(x, *, done_generating=False):
311314
sequential_prefill=generator_args.sequential_prefill,
312315
start_pos=start_pos,
313316
max_seq_length=self.max_seq_length,
314-
seed=int(completion_request.seed),
317+
seed=int(completion_request.seed or 0),
315318
):
316319
if y is None:
317320
continue
@@ -333,9 +336,10 @@ def callback(x, *, done_generating=False):
333336
choice_chunk = CompletionChoiceChunk(
334337
delta=chunk_delta,
335338
index=idx,
339+
finish_reason=None,
336340
)
337341
chunk_response = CompletionResponseChunk(
338-
id=str(id),
342+
id="chatcmpl-" + str(id),
339343
choices=[choice_chunk],
340344
created=int(time.time()),
341345
model=completion_request.model,
@@ -351,7 +355,7 @@ def callback(x, *, done_generating=False):
351355
)
352356

353357
yield CompletionResponseChunk(
354-
id=str(id),
358+
id="chatcmpl-" + str(id),
355359
choices=[end_chunk],
356360
created=int(time.time()),
357361
model=completion_request.model,
@@ -367,7 +371,7 @@ def sync_completion(self, request: CompletionRequest):
367371

368372
message = AssistantMessage(content=output)
369373
return CompletionResponse(
370-
id=str(uuid.uuid4()),
374+
id="chatcmpl-" + str(uuid.uuid4()),
371375
choices=[
372376
CompletionChoice(
373377
finish_reason="stop",

browser/browser.py

+39-90
Original file line numberDiff line numberDiff line change
@@ -1,91 +1,40 @@
1-
# Copyright (c) Meta Platforms, Inc. and affiliates.
2-
# All rights reserved.
3-
4-
# This source code is licensed under the license found in the
5-
# LICENSE file in the root directory of this source tree.
6-
7-
import time
8-
91
import streamlit as st
10-
from api.api import CompletionRequest, OpenAiApiGenerator
11-
12-
from build.builder import BuilderArgs, TokenizerArgs
13-
14-
from generate import GeneratorArgs
15-
16-
17-
def main(args):
18-
builder_args = BuilderArgs.from_args(args)
19-
speculative_builder_args = BuilderArgs.from_speculative_args(args)
20-
tokenizer_args = TokenizerArgs.from_args(args)
21-
generator_args = GeneratorArgs.from_args(args)
22-
generator_args.chat_mode = False
23-
24-
@st.cache_resource
25-
def initialize_generator() -> OpenAiApiGenerator:
26-
return OpenAiApiGenerator(
27-
builder_args,
28-
speculative_builder_args,
29-
tokenizer_args,
30-
generator_args,
31-
args.profile,
32-
args.quantize,
33-
args.draft_quantize,
34-
)
35-
36-
gen = initialize_generator()
37-
38-
st.title("torchchat")
39-
40-
# Initialize chat history
41-
if "messages" not in st.session_state:
42-
st.session_state.messages = []
43-
44-
# Display chat messages from history on app rerun
45-
for message in st.session_state.messages:
46-
with st.chat_message(message["role"]):
47-
st.markdown(message["content"])
48-
49-
# Accept user input
50-
if prompt := st.chat_input("What is up?"):
51-
# Add user message to chat history
52-
st.session_state.messages.append({"role": "user", "content": prompt})
53-
# Display user message in chat message container
54-
with st.chat_message("user"):
55-
st.markdown(prompt)
56-
57-
# Display assistant response in chat message container
58-
with st.chat_message("assistant"), st.status(
59-
"Generating... ", expanded=True
60-
) as status:
61-
62-
req = CompletionRequest(
63-
model=gen.builder_args.checkpoint_path,
64-
prompt=prompt,
65-
temperature=generator_args.temperature,
66-
messages=[],
67-
)
68-
69-
def unwrap(completion_generator):
70-
start = time.time()
71-
tokcount = 0
72-
for chunk_response in completion_generator:
73-
content = chunk_response.choices[0].delta.content
74-
if not gen.is_llama3_model or content not in set(
75-
gen.tokenizer.special_tokens.keys()
76-
):
77-
yield content
78-
if content == gen.tokenizer.eos_id():
79-
yield "."
80-
tokcount += 1
81-
status.update(
82-
label="Done, averaged {:.2f} tokens/second".format(
83-
tokcount / (time.time() - start)
84-
),
85-
state="complete",
86-
)
87-
88-
response = st.write_stream(unwrap(gen.completion(req)))
89-
90-
# Add assistant response to chat history
91-
st.session_state.messages.append({"role": "assistant", "content": response})
2+
from openai import OpenAI
3+
4+
with st.sidebar:
5+
openai_api_key = st.text_input(
6+
"OpenAI API Key", key="chatbot_api_key", type="password"
7+
)
8+
"[Get an OpenAI API key](https://platform.openai.com/account/api-keys)"
9+
"[View the source code](https://github.com/streamlit/llm-examples/blob/main/Chatbot.py)"
10+
"[![Open in GitHub Codespaces](https://github.com/codespaces/badge.svg)](https://codespaces.new/streamlit/llm-examples?quickstart=1)"
11+
12+
st.title("💬 Chatbot")
13+
14+
if "messages" not in st.session_state:
15+
st.session_state["messages"] = [
16+
{
17+
"role": "system",
18+
"content": "You're an assistant. Be brief, no yapping. Use as few words as possible to respond to the users' questions.",
19+
},
20+
{"role": "assistant", "content": "How can I help you?"},
21+
]
22+
23+
for msg in st.session_state.messages:
24+
st.chat_message(msg["role"]).write(msg["content"])
25+
26+
if prompt := st.chat_input():
27+
client = OpenAI(
28+
# This is the default and can be omitted
29+
base_url="http://127.0.0.1:5000/v1",
30+
api_key="YOURMOTHER",
31+
)
32+
33+
st.session_state.messages.append({"role": "user", "content": prompt})
34+
st.chat_message("user").write(prompt)
35+
response = client.chat.completions.create(
36+
model="stories15m", messages=st.session_state.messages, max_tokens=64
37+
)
38+
msg = response.choices[0].message.content
39+
st.session_state.messages.append({"role": "assistant", "content": msg})
40+
st.chat_message("assistant").write(msg)

server.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def _del_none(d: Union[Dict, List]) -> Union[Dict, List]:
4141
return [_del_none(v) for v in d if v]
4242
return d
4343

44-
@app.route(f"/{OPENAI_API_VERSION}/chat", methods=["POST"])
44+
@app.route(f"/{OPENAI_API_VERSION}/chat/completions", methods=["POST"])
4545
def chat_endpoint():
4646
"""
4747
Endpoint for the Chat API. This endpoint is used to generate a response to a user prompt.
@@ -63,7 +63,7 @@ def chat_endpoint():
6363
data = request.get_json()
6464
req = CompletionRequest(**data)
6565

66-
if data.get("stream") == "true":
66+
if req.stream:
6767

6868
def chunk_processor(chunked_completion_generator):
6969
"""Inline function for postprocessing CompletionResponseChunk objects.
@@ -74,14 +74,19 @@ def chunk_processor(chunked_completion_generator):
7474
if (next_tok := chunk.choices[0].delta.content) is None:
7575
next_tok = ""
7676
print(next_tok, end="", flush=True)
77-
yield json.dumps(_del_none(asdict(chunk)))
77+
yield f"data:{json.dumps(_del_none(asdict(chunk)))}\n\n"
78+
# wasda = json.dumps(asdict(chunk))
79+
# print(wasda)
80+
# yield wasda
7881

79-
return Response(
82+
resp = Response(
8083
chunk_processor(gen.chunked_completion(req)),
8184
mimetype="text/event-stream",
8285
)
86+
return resp
8387
else:
8488
response = gen.sync_completion(req)
89+
print(response.choices[0].message.content)
8590

8691
return json.dumps(_del_none(asdict(response)))
8792

0 commit comments

Comments
 (0)