diff --git a/api/api.py b/api/api.py index e46c6a33e..20b1077f6 100644 --- a/api/api.py +++ b/api/api.py @@ -125,6 +125,9 @@ class CompletionRequest: parallel_tool_calls: Optional[bool] = None # unimplemented - Assistant features user: Optional[str] = None # unimplemented + def __post_init__(self): + self.stream = bool(self.stream) + @dataclass class CompletionChoice: @@ -204,7 +207,7 @@ class CompletionResponseChunk: choices: List[CompletionChoiceChunk] created: int model: str - system_fingerprint: str + system_fingerprint: Optional[str] = None service_tier: Optional[str] = None object: str = "chat.completion.chunk" usage: Optional[UsageStats] = None @@ -311,7 +314,7 @@ def callback(x, *, done_generating=False): sequential_prefill=generator_args.sequential_prefill, start_pos=start_pos, max_seq_length=self.max_seq_length, - seed=int(completion_request.seed), + seed=int(completion_request.seed or 0), ): if y is None: continue @@ -333,9 +336,10 @@ def callback(x, *, done_generating=False): choice_chunk = CompletionChoiceChunk( delta=chunk_delta, index=idx, + finish_reason=None, ) chunk_response = CompletionResponseChunk( - id=str(id), + id="chatcmpl-" + str(id), choices=[choice_chunk], created=int(time.time()), model=completion_request.model, @@ -351,7 +355,7 @@ def callback(x, *, done_generating=False): ) yield CompletionResponseChunk( - id=str(id), + id="chatcmpl-" + str(id), choices=[end_chunk], created=int(time.time()), model=completion_request.model, @@ -367,7 +371,7 @@ def sync_completion(self, request: CompletionRequest): message = AssistantMessage(content=output) return CompletionResponse( - id=str(uuid.uuid4()), + id="chatcmpl-" + str(uuid.uuid4()), choices=[ CompletionChoice( finish_reason="stop", diff --git a/browser/browser.py b/browser/browser.py index 0074cf392..e702c3539 100644 --- a/browser/browser.py +++ b/browser/browser.py @@ -1,91 +1,40 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. - -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import time - import streamlit as st -from api.api import CompletionRequest, OpenAiApiGenerator - -from build.builder import BuilderArgs, TokenizerArgs - -from generate import GeneratorArgs - - -def main(args): - builder_args = BuilderArgs.from_args(args) - speculative_builder_args = BuilderArgs.from_speculative_args(args) - tokenizer_args = TokenizerArgs.from_args(args) - generator_args = GeneratorArgs.from_args(args) - generator_args.chat_mode = False - - @st.cache_resource - def initialize_generator() -> OpenAiApiGenerator: - return OpenAiApiGenerator( - builder_args, - speculative_builder_args, - tokenizer_args, - generator_args, - args.profile, - args.quantize, - args.draft_quantize, - ) - - gen = initialize_generator() - - st.title("torchchat") - - # Initialize chat history - if "messages" not in st.session_state: - st.session_state.messages = [] - - # Display chat messages from history on app rerun - for message in st.session_state.messages: - with st.chat_message(message["role"]): - st.markdown(message["content"]) - - # Accept user input - if prompt := st.chat_input("What is up?"): - # Add user message to chat history - st.session_state.messages.append({"role": "user", "content": prompt}) - # Display user message in chat message container - with st.chat_message("user"): - st.markdown(prompt) - - # Display assistant response in chat message container - with st.chat_message("assistant"), st.status( - "Generating... ", expanded=True - ) as status: - - req = CompletionRequest( - model=gen.builder_args.checkpoint_path, - prompt=prompt, - temperature=generator_args.temperature, - messages=[], - ) - - def unwrap(completion_generator): - start = time.time() - tokcount = 0 - for chunk_response in completion_generator: - content = chunk_response.choices[0].delta.content - if not gen.is_llama3_model or content not in set( - gen.tokenizer.special_tokens.keys() - ): - yield content - if content == gen.tokenizer.eos_id(): - yield "." - tokcount += 1 - status.update( - label="Done, averaged {:.2f} tokens/second".format( - tokcount / (time.time() - start) - ), - state="complete", - ) - - response = st.write_stream(unwrap(gen.completion(req))) - - # Add assistant response to chat history - st.session_state.messages.append({"role": "assistant", "content": response}) +from openai import OpenAI + +with st.sidebar: + openai_api_key = st.text_input( + "OpenAI API Key", key="chatbot_api_key", type="password" + ) + "[Get an OpenAI API key](https://platform.openai.com/account/api-keys)" + "[View the source code](https://github.com/streamlit/llm-examples/blob/main/Chatbot.py)" + "[![Open in GitHub Codespaces](https://github.com/codespaces/badge.svg)](https://codespaces.new/streamlit/llm-examples?quickstart=1)" + +st.title("💬 Chatbot") + +if "messages" not in st.session_state: + st.session_state["messages"] = [ + { + "role": "system", + "content": "You're an assistant. Be brief, no yapping. Use as few words as possible to respond to the users' questions.", + }, + {"role": "assistant", "content": "How can I help you?"}, + ] + +for msg in st.session_state.messages: + st.chat_message(msg["role"]).write(msg["content"]) + +if prompt := st.chat_input(): + client = OpenAI( + # This is the default and can be omitted + base_url="http://127.0.0.1:5000/v1", + api_key="YOURMOTHER", + ) + + st.session_state.messages.append({"role": "user", "content": prompt}) + st.chat_message("user").write(prompt) + response = client.chat.completions.create( + model="stories15m", messages=st.session_state.messages, max_tokens=64 + ) + msg = response.choices[0].message.content + st.session_state.messages.append({"role": "assistant", "content": msg}) + st.chat_message("assistant").write(msg) diff --git a/server.py b/server.py index 074df6646..f1dbbcdc9 100644 --- a/server.py +++ b/server.py @@ -41,7 +41,7 @@ def _del_none(d: Union[Dict, List]) -> Union[Dict, List]: return [_del_none(v) for v in d if v] return d - @app.route(f"/{OPENAI_API_VERSION}/chat", methods=["POST"]) + @app.route(f"/{OPENAI_API_VERSION}/chat/completions", methods=["POST"]) def chat_endpoint(): """ Endpoint for the Chat API. This endpoint is used to generate a response to a user prompt. @@ -63,7 +63,7 @@ def chat_endpoint(): data = request.get_json() req = CompletionRequest(**data) - if data.get("stream") == "true": + if req.stream: def chunk_processor(chunked_completion_generator): """Inline function for postprocessing CompletionResponseChunk objects. @@ -74,14 +74,19 @@ def chunk_processor(chunked_completion_generator): if (next_tok := chunk.choices[0].delta.content) is None: next_tok = "" print(next_tok, end="", flush=True) - yield json.dumps(_del_none(asdict(chunk))) + yield f"data:{json.dumps(_del_none(asdict(chunk)))}\n\n" + # wasda = json.dumps(asdict(chunk)) + # print(wasda) + # yield wasda - return Response( + resp = Response( chunk_processor(gen.chunked_completion(req)), mimetype="text/event-stream", ) + return resp else: response = gen.sync_completion(req) + print(response.choices[0].message.content) return json.dumps(_del_none(asdict(response)))