|
| 1 | +import argparse |
| 2 | +from flask import Flask, jsonify, request, Response |
| 3 | +import urllib.parse |
| 4 | +import requests |
| 5 | +import time |
| 6 | +import json |
| 7 | + |
| 8 | + |
| 9 | +app = Flask(__name__) |
| 10 | + |
| 11 | +parser = argparse.ArgumentParser(description="An example of using server.cpp with a similar API to OAI. It must be used together with server.cpp.") |
| 12 | +parser.add_argument("--chat-prompt", type=str, help="the top prompt in chat completions(default: 'A chat between a curious user and an artificial intelligence assistant. The assistant follows the given rules no matter what.\\n')", default='A chat between a curious user and an artificial intelligence assistant. The assistant follows the given rules no matter what.\\n') |
| 13 | +parser.add_argument("--user-name", type=str, help="USER name in chat completions(default: '\\nUSER: ')", default="\\nUSER: ") |
| 14 | +parser.add_argument("--ai-name", type=str, help="ASSISTANT name in chat completions(default: '\\nASSISTANT: ')", default="\\nASSISTANT: ") |
| 15 | +parser.add_argument("--system-name", type=str, help="SYSTEM name in chat completions(default: '\\nASSISTANT's RULE: ')", default="\\nASSISTANT's RULE: ") |
| 16 | +parser.add_argument("--stop", type=str, help="the end of response in chat completions(default: '</s>')", default="</s>") |
| 17 | +parser.add_argument("--llama-api", type=str, help="Set the address of server.cpp in llama.cpp(default: http://127.0.0.1:8080)", default='http://127.0.0.1:8080') |
| 18 | +parser.add_argument("--api-key", type=str, help="Set the api key to allow only few user(default: NULL)", default="") |
| 19 | +parser.add_argument("--host", type=str, help="Set the ip address to listen.(default: 127.0.0.1)", default='127.0.0.1') |
| 20 | +parser.add_argument("--port", type=int, help="Set the port to listen.(default: 8081)", default=8081) |
| 21 | + |
| 22 | +args = parser.parse_args() |
| 23 | + |
| 24 | +def is_present(json, key): |
| 25 | + try: |
| 26 | + buf = json[key] |
| 27 | + except KeyError: |
| 28 | + return False |
| 29 | + return True |
| 30 | + |
| 31 | + |
| 32 | + |
| 33 | +#convert chat to prompt |
| 34 | +def convert_chat(messages): |
| 35 | + prompt = "" + args.chat_prompt.replace("\\n", "\n") |
| 36 | + |
| 37 | + system_n = args.system_name.replace("\\n", "\n") |
| 38 | + user_n = args.user_name.replace("\\n", "\n") |
| 39 | + ai_n = args.ai_name.replace("\\n", "\n") |
| 40 | + stop = args.stop.replace("\\n", "\n") |
| 41 | + |
| 42 | + |
| 43 | + for line in messages: |
| 44 | + if (line["role"] == "system"): |
| 45 | + prompt += f"{system_n}{line['content']}" |
| 46 | + if (line["role"] == "user"): |
| 47 | + prompt += f"{user_n}{line['content']}" |
| 48 | + if (line["role"] == "assistant"): |
| 49 | + prompt += f"{ai_n}{line['content']}{stop}" |
| 50 | + prompt += ai_n.rstrip() |
| 51 | + |
| 52 | + return prompt |
| 53 | + |
| 54 | +def make_postData(body, chat=False, stream=False): |
| 55 | + postData = {} |
| 56 | + if (chat): |
| 57 | + postData["prompt"] = convert_chat(body["messages"]) |
| 58 | + else: |
| 59 | + postData["prompt"] = body["prompt"] |
| 60 | + if(is_present(body, "temperature")): postData["temperature"] = body["temperature"] |
| 61 | + if(is_present(body, "top_k")): postData["top_k"] = body["top_k"] |
| 62 | + if(is_present(body, "top_p")): postData["top_p"] = body["top_p"] |
| 63 | + if(is_present(body, "max_tokens")): postData["n_predict"] = body["max_tokens"] |
| 64 | + if(is_present(body, "presence_penalty")): postData["presence_penalty"] = body["presence_penalty"] |
| 65 | + if(is_present(body, "frequency_penalty")): postData["frequency_penalty"] = body["frequency_penalty"] |
| 66 | + if(is_present(body, "repeat_penalty")): postData["repeat_penalty"] = body["repeat_penalty"] |
| 67 | + if(is_present(body, "mirostat")): postData["mirostat"] = body["mirostat"] |
| 68 | + if(is_present(body, "mirostat_tau")): postData["mirostat_tau"] = body["mirostat_tau"] |
| 69 | + if(is_present(body, "mirostat_eta")): postData["mirostat_eta"] = body["mirostat_eta"] |
| 70 | + if(is_present(body, "seed")): postData["seed"] = body["seed"] |
| 71 | + if(is_present(body, "logit_bias")): postData["logit_bias"] = [[int(token), body["logit_bias"][token]] for token in body["logit_bias"].keys()] |
| 72 | + if (args.stop != ""): |
| 73 | + postData["stop"] = [args.stop] |
| 74 | + else: |
| 75 | + postData["stop"] = [] |
| 76 | + if(is_present(body, "stop")): postData["stop"] += body["stop"] |
| 77 | + postData["n_keep"] = -1 |
| 78 | + postData["stream"] = stream |
| 79 | + |
| 80 | + return postData |
| 81 | + |
| 82 | +def make_resData(data, chat=False, promptToken=[]): |
| 83 | + resData = { |
| 84 | + "id": "chatcmpl" if (chat) else "cmpl", |
| 85 | + "object": "chat.completion" if (chat) else "text_completion", |
| 86 | + "created": int(time.time()), |
| 87 | + "truncated": data["truncated"], |
| 88 | + "model": "LLaMA_CPP", |
| 89 | + "usage": { |
| 90 | + "prompt_tokens": data["tokens_evaluated"], |
| 91 | + "completion_tokens": data["tokens_predicted"], |
| 92 | + "total_tokens": data["tokens_evaluated"] + data["tokens_predicted"] |
| 93 | + } |
| 94 | + } |
| 95 | + if (len(promptToken) != 0): |
| 96 | + resData["promptToken"] = promptToken |
| 97 | + if (chat): |
| 98 | + #only one choice is supported |
| 99 | + resData["choices"] = [{ |
| 100 | + "index": 0, |
| 101 | + "message": { |
| 102 | + "role": "assistant", |
| 103 | + "content": data["content"], |
| 104 | + }, |
| 105 | + "finish_reason": "stop" if (data["stopped_eos"] or data["stopped_word"]) else "length" |
| 106 | + }] |
| 107 | + else: |
| 108 | + #only one choice is supported |
| 109 | + resData["choices"] = [{ |
| 110 | + "text": data["content"], |
| 111 | + "index": 0, |
| 112 | + "logprobs": None, |
| 113 | + "finish_reason": "stop" if (data["stopped_eos"] or data["stopped_word"]) else "length" |
| 114 | + }] |
| 115 | + return resData |
| 116 | + |
| 117 | +def make_resData_stream(data, chat=False, time_now = 0, start=False): |
| 118 | + resData = { |
| 119 | + "id": "chatcmpl" if (chat) else "cmpl", |
| 120 | + "object": "chat.completion.chunk" if (chat) else "text_completion.chunk", |
| 121 | + "created": time_now, |
| 122 | + "model": "LLaMA_CPP", |
| 123 | + "choices": [ |
| 124 | + { |
| 125 | + "finish_reason": None, |
| 126 | + "index": 0 |
| 127 | + } |
| 128 | + ] |
| 129 | + } |
| 130 | + if (chat): |
| 131 | + if (start): |
| 132 | + resData["choices"][0]["delta"] = { |
| 133 | + "role": "assistant" |
| 134 | + } |
| 135 | + else: |
| 136 | + resData["choices"][0]["delta"] = { |
| 137 | + "content": data["content"] |
| 138 | + } |
| 139 | + if (data["stop"]): |
| 140 | + resData["choices"][0]["finish_reason"] = "stop" if (data["stopped_eos"] or data["stopped_word"]) else "length" |
| 141 | + else: |
| 142 | + resData["choices"][0]["text"] = data["content"] |
| 143 | + if (data["stop"]): |
| 144 | + resData["choices"][0]["finish_reason"] = "stop" if (data["stopped_eos"] or data["stopped_word"]) else "length" |
| 145 | + |
| 146 | + return resData |
| 147 | + |
| 148 | + |
| 149 | +@app.route('/chat/completions', methods=['POST']) |
| 150 | +@app.route('/v1/chat/completions', methods=['POST']) |
| 151 | +def chat_completions(): |
| 152 | + if (args.api_key != "" and request.headers["Authorization"].split()[1] != args.api_key): |
| 153 | + return Response(status=403) |
| 154 | + body = request.get_json() |
| 155 | + stream = False |
| 156 | + tokenize = False |
| 157 | + if(is_present(body, "stream")): stream = body["stream"] |
| 158 | + if(is_present(body, "tokenize")): tokenize = body["tokenize"] |
| 159 | + postData = make_postData(body, chat=True, stream=stream) |
| 160 | + |
| 161 | + promptToken = [] |
| 162 | + if (tokenize): |
| 163 | + tokenData = requests.request("POST", urllib.parse.urljoin(args.llama_api, "/tokenize"), data=json.dumps({"content": postData["prompt"]})).json() |
| 164 | + promptToken = tokenData["tokens"] |
| 165 | + |
| 166 | + if (not stream): |
| 167 | + data = requests.request("POST", urllib.parse.urljoin(args.llama_api, "/completion"), data=json.dumps(postData)) |
| 168 | + print(data.json()) |
| 169 | + resData = make_resData(data.json(), chat=True, promptToken=promptToken) |
| 170 | + return jsonify(resData) |
| 171 | + else: |
| 172 | + def generate(): |
| 173 | + data = requests.request("POST", urllib.parse.urljoin(args.llama_api, "/completion"), data=json.dumps(postData), stream=True) |
| 174 | + time_now = int(time.time()) |
| 175 | + resData = make_resData_stream({}, chat=True, time_now=time_now, start=True) |
| 176 | + yield 'data: {}\n'.format(json.dumps(resData)) |
| 177 | + for line in data.iter_lines(): |
| 178 | + if line: |
| 179 | + decoded_line = line.decode('utf-8') |
| 180 | + resData = make_resData_stream(json.loads(decoded_line[6:]), chat=True, time_now=time_now) |
| 181 | + yield 'data: {}\n'.format(json.dumps(resData)) |
| 182 | + return Response(generate(), mimetype='text/event-stream') |
| 183 | + |
| 184 | + |
| 185 | +@app.route('/completions', methods=['POST']) |
| 186 | +@app.route('/v1/completions', methods=['POST']) |
| 187 | +def completion(): |
| 188 | + if (args.api_key != "" and request.headers["Authorization"].split()[1] != args.api_key): |
| 189 | + return Response(status=403) |
| 190 | + body = request.get_json() |
| 191 | + stream = False |
| 192 | + tokenize = False |
| 193 | + if(is_present(body, "stream")): stream = body["stream"] |
| 194 | + if(is_present(body, "tokenize")): tokenize = body["tokenize"] |
| 195 | + postData = make_postData(body, chat=False, stream=stream) |
| 196 | + |
| 197 | + promptToken = [] |
| 198 | + if (tokenize): |
| 199 | + tokenData = requests.request("POST", urllib.parse.urljoin(args.llama_api, "/tokenize"), data=json.dumps({"content": postData["prompt"]})).json() |
| 200 | + promptToken = tokenData["tokens"] |
| 201 | + |
| 202 | + if (not stream): |
| 203 | + data = requests.request("POST", urllib.parse.urljoin(args.llama_api, "/completion"), data=json.dumps(postData)) |
| 204 | + print(data.json()) |
| 205 | + resData = make_resData(data.json(), chat=False, promptToken=promptToken) |
| 206 | + return jsonify(resData) |
| 207 | + else: |
| 208 | + def generate(): |
| 209 | + data = requests.request("POST", urllib.parse.urljoin(args.llama_api, "/completion"), data=json.dumps(postData), stream=True) |
| 210 | + time_now = int(time.time()) |
| 211 | + for line in data.iter_lines(): |
| 212 | + if line: |
| 213 | + decoded_line = line.decode('utf-8') |
| 214 | + resData = make_resData_stream(json.loads(decoded_line[6:]), chat=False, time_now=time_now) |
| 215 | + yield 'data: {}\n'.format(json.dumps(resData)) |
| 216 | + return Response(generate(), mimetype='text/event-stream') |
| 217 | + |
| 218 | +if __name__ == '__main__': |
| 219 | + app.run(args.host, port=args.port) |
0 commit comments