Skip to content

Commit f257fd2

Browse files
authored
Add an API example using server.cpp similar to OAI. (#2009)
* add api_like_OAI.py * add evaluated token count to server * add /v1/ endpoints binding
1 parent 7ee76e4 commit f257fd2

File tree

3 files changed

+244
-5
lines changed

3 files changed

+244
-5
lines changed

examples/server/README.md

+16
Original file line numberDiff line numberDiff line change
@@ -190,3 +190,19 @@ Run with bash:
190190
```sh
191191
bash chat.sh
192192
```
193+
194+
### API like OAI
195+
196+
API example using Python Flask: [api_like_OAI.py](api_like_OAI.py)
197+
This example must be used with server.cpp
198+
199+
```sh
200+
python api_like_OAI.py
201+
```
202+
203+
After running the API server, you can use it in Python by setting the API base URL.
204+
```python
205+
openai.api_base = "http://<Your api-server IP>:port"
206+
```
207+
208+
Then you can utilize llama.cpp as an OpenAI's **chat.completion** or **text_completion** API

examples/server/api_like_OAI.py

+219
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,219 @@
1+
import argparse
2+
from flask import Flask, jsonify, request, Response
3+
import urllib.parse
4+
import requests
5+
import time
6+
import json
7+
8+
9+
app = Flask(__name__)
10+
11+
parser = argparse.ArgumentParser(description="An example of using server.cpp with a similar API to OAI. It must be used together with server.cpp.")
12+
parser.add_argument("--chat-prompt", type=str, help="the top prompt in chat completions(default: 'A chat between a curious user and an artificial intelligence assistant. The assistant follows the given rules no matter what.\\n')", default='A chat between a curious user and an artificial intelligence assistant. The assistant follows the given rules no matter what.\\n')
13+
parser.add_argument("--user-name", type=str, help="USER name in chat completions(default: '\\nUSER: ')", default="\\nUSER: ")
14+
parser.add_argument("--ai-name", type=str, help="ASSISTANT name in chat completions(default: '\\nASSISTANT: ')", default="\\nASSISTANT: ")
15+
parser.add_argument("--system-name", type=str, help="SYSTEM name in chat completions(default: '\\nASSISTANT's RULE: ')", default="\\nASSISTANT's RULE: ")
16+
parser.add_argument("--stop", type=str, help="the end of response in chat completions(default: '</s>')", default="</s>")
17+
parser.add_argument("--llama-api", type=str, help="Set the address of server.cpp in llama.cpp(default: http://127.0.0.1:8080)", default='http://127.0.0.1:8080')
18+
parser.add_argument("--api-key", type=str, help="Set the api key to allow only few user(default: NULL)", default="")
19+
parser.add_argument("--host", type=str, help="Set the ip address to listen.(default: 127.0.0.1)", default='127.0.0.1')
20+
parser.add_argument("--port", type=int, help="Set the port to listen.(default: 8081)", default=8081)
21+
22+
args = parser.parse_args()
23+
24+
def is_present(json, key):
25+
try:
26+
buf = json[key]
27+
except KeyError:
28+
return False
29+
return True
30+
31+
32+
33+
#convert chat to prompt
34+
def convert_chat(messages):
35+
prompt = "" + args.chat_prompt.replace("\\n", "\n")
36+
37+
system_n = args.system_name.replace("\\n", "\n")
38+
user_n = args.user_name.replace("\\n", "\n")
39+
ai_n = args.ai_name.replace("\\n", "\n")
40+
stop = args.stop.replace("\\n", "\n")
41+
42+
43+
for line in messages:
44+
if (line["role"] == "system"):
45+
prompt += f"{system_n}{line['content']}"
46+
if (line["role"] == "user"):
47+
prompt += f"{user_n}{line['content']}"
48+
if (line["role"] == "assistant"):
49+
prompt += f"{ai_n}{line['content']}{stop}"
50+
prompt += ai_n.rstrip()
51+
52+
return prompt
53+
54+
def make_postData(body, chat=False, stream=False):
55+
postData = {}
56+
if (chat):
57+
postData["prompt"] = convert_chat(body["messages"])
58+
else:
59+
postData["prompt"] = body["prompt"]
60+
if(is_present(body, "temperature")): postData["temperature"] = body["temperature"]
61+
if(is_present(body, "top_k")): postData["top_k"] = body["top_k"]
62+
if(is_present(body, "top_p")): postData["top_p"] = body["top_p"]
63+
if(is_present(body, "max_tokens")): postData["n_predict"] = body["max_tokens"]
64+
if(is_present(body, "presence_penalty")): postData["presence_penalty"] = body["presence_penalty"]
65+
if(is_present(body, "frequency_penalty")): postData["frequency_penalty"] = body["frequency_penalty"]
66+
if(is_present(body, "repeat_penalty")): postData["repeat_penalty"] = body["repeat_penalty"]
67+
if(is_present(body, "mirostat")): postData["mirostat"] = body["mirostat"]
68+
if(is_present(body, "mirostat_tau")): postData["mirostat_tau"] = body["mirostat_tau"]
69+
if(is_present(body, "mirostat_eta")): postData["mirostat_eta"] = body["mirostat_eta"]
70+
if(is_present(body, "seed")): postData["seed"] = body["seed"]
71+
if(is_present(body, "logit_bias")): postData["logit_bias"] = [[int(token), body["logit_bias"][token]] for token in body["logit_bias"].keys()]
72+
if (args.stop != ""):
73+
postData["stop"] = [args.stop]
74+
else:
75+
postData["stop"] = []
76+
if(is_present(body, "stop")): postData["stop"] += body["stop"]
77+
postData["n_keep"] = -1
78+
postData["stream"] = stream
79+
80+
return postData
81+
82+
def make_resData(data, chat=False, promptToken=[]):
83+
resData = {
84+
"id": "chatcmpl" if (chat) else "cmpl",
85+
"object": "chat.completion" if (chat) else "text_completion",
86+
"created": int(time.time()),
87+
"truncated": data["truncated"],
88+
"model": "LLaMA_CPP",
89+
"usage": {
90+
"prompt_tokens": data["tokens_evaluated"],
91+
"completion_tokens": data["tokens_predicted"],
92+
"total_tokens": data["tokens_evaluated"] + data["tokens_predicted"]
93+
}
94+
}
95+
if (len(promptToken) != 0):
96+
resData["promptToken"] = promptToken
97+
if (chat):
98+
#only one choice is supported
99+
resData["choices"] = [{
100+
"index": 0,
101+
"message": {
102+
"role": "assistant",
103+
"content": data["content"],
104+
},
105+
"finish_reason": "stop" if (data["stopped_eos"] or data["stopped_word"]) else "length"
106+
}]
107+
else:
108+
#only one choice is supported
109+
resData["choices"] = [{
110+
"text": data["content"],
111+
"index": 0,
112+
"logprobs": None,
113+
"finish_reason": "stop" if (data["stopped_eos"] or data["stopped_word"]) else "length"
114+
}]
115+
return resData
116+
117+
def make_resData_stream(data, chat=False, time_now = 0, start=False):
118+
resData = {
119+
"id": "chatcmpl" if (chat) else "cmpl",
120+
"object": "chat.completion.chunk" if (chat) else "text_completion.chunk",
121+
"created": time_now,
122+
"model": "LLaMA_CPP",
123+
"choices": [
124+
{
125+
"finish_reason": None,
126+
"index": 0
127+
}
128+
]
129+
}
130+
if (chat):
131+
if (start):
132+
resData["choices"][0]["delta"] = {
133+
"role": "assistant"
134+
}
135+
else:
136+
resData["choices"][0]["delta"] = {
137+
"content": data["content"]
138+
}
139+
if (data["stop"]):
140+
resData["choices"][0]["finish_reason"] = "stop" if (data["stopped_eos"] or data["stopped_word"]) else "length"
141+
else:
142+
resData["choices"][0]["text"] = data["content"]
143+
if (data["stop"]):
144+
resData["choices"][0]["finish_reason"] = "stop" if (data["stopped_eos"] or data["stopped_word"]) else "length"
145+
146+
return resData
147+
148+
149+
@app.route('/chat/completions', methods=['POST'])
150+
@app.route('/v1/chat/completions', methods=['POST'])
151+
def chat_completions():
152+
if (args.api_key != "" and request.headers["Authorization"].split()[1] != args.api_key):
153+
return Response(status=403)
154+
body = request.get_json()
155+
stream = False
156+
tokenize = False
157+
if(is_present(body, "stream")): stream = body["stream"]
158+
if(is_present(body, "tokenize")): tokenize = body["tokenize"]
159+
postData = make_postData(body, chat=True, stream=stream)
160+
161+
promptToken = []
162+
if (tokenize):
163+
tokenData = requests.request("POST", urllib.parse.urljoin(args.llama_api, "/tokenize"), data=json.dumps({"content": postData["prompt"]})).json()
164+
promptToken = tokenData["tokens"]
165+
166+
if (not stream):
167+
data = requests.request("POST", urllib.parse.urljoin(args.llama_api, "/completion"), data=json.dumps(postData))
168+
print(data.json())
169+
resData = make_resData(data.json(), chat=True, promptToken=promptToken)
170+
return jsonify(resData)
171+
else:
172+
def generate():
173+
data = requests.request("POST", urllib.parse.urljoin(args.llama_api, "/completion"), data=json.dumps(postData), stream=True)
174+
time_now = int(time.time())
175+
resData = make_resData_stream({}, chat=True, time_now=time_now, start=True)
176+
yield 'data: {}\n'.format(json.dumps(resData))
177+
for line in data.iter_lines():
178+
if line:
179+
decoded_line = line.decode('utf-8')
180+
resData = make_resData_stream(json.loads(decoded_line[6:]), chat=True, time_now=time_now)
181+
yield 'data: {}\n'.format(json.dumps(resData))
182+
return Response(generate(), mimetype='text/event-stream')
183+
184+
185+
@app.route('/completions', methods=['POST'])
186+
@app.route('/v1/completions', methods=['POST'])
187+
def completion():
188+
if (args.api_key != "" and request.headers["Authorization"].split()[1] != args.api_key):
189+
return Response(status=403)
190+
body = request.get_json()
191+
stream = False
192+
tokenize = False
193+
if(is_present(body, "stream")): stream = body["stream"]
194+
if(is_present(body, "tokenize")): tokenize = body["tokenize"]
195+
postData = make_postData(body, chat=False, stream=stream)
196+
197+
promptToken = []
198+
if (tokenize):
199+
tokenData = requests.request("POST", urllib.parse.urljoin(args.llama_api, "/tokenize"), data=json.dumps({"content": postData["prompt"]})).json()
200+
promptToken = tokenData["tokens"]
201+
202+
if (not stream):
203+
data = requests.request("POST", urllib.parse.urljoin(args.llama_api, "/completion"), data=json.dumps(postData))
204+
print(data.json())
205+
resData = make_resData(data.json(), chat=False, promptToken=promptToken)
206+
return jsonify(resData)
207+
else:
208+
def generate():
209+
data = requests.request("POST", urllib.parse.urljoin(args.llama_api, "/completion"), data=json.dumps(postData), stream=True)
210+
time_now = int(time.time())
211+
for line in data.iter_lines():
212+
if line:
213+
decoded_line = line.decode('utf-8')
214+
resData = make_resData_stream(json.loads(decoded_line[6:]), chat=False, time_now=time_now)
215+
yield 'data: {}\n'.format(json.dumps(resData))
216+
return Response(generate(), mimetype='text/event-stream')
217+
218+
if __name__ == '__main__':
219+
app.run(args.host, port=args.port)

examples/server/server.cpp

+9-5
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,7 @@ struct llama_server_context {
158158
std::string generated_text;
159159
std::vector<completion_token_output> generated_token_probs;
160160

161+
size_t num_prompt_tokens = 0;
161162
size_t num_tokens_predicted = 0;
162163
size_t n_past = 0;
163164
size_t n_remain = 0;
@@ -195,6 +196,7 @@ struct llama_server_context {
195196

196197
void rewind() {
197198
params.antiprompt.clear();
199+
num_prompt_tokens = 0;
198200
num_tokens_predicted = 0;
199201
generated_text = "";
200202
generated_text.reserve(params.n_ctx);
@@ -226,17 +228,18 @@ struct llama_server_context {
226228
void loadPrompt() {
227229
params.prompt.insert(0, 1, ' '); // always add a first space
228230
std::vector<llama_token> prompt_tokens = ::llama_tokenize(ctx, params.prompt, true);
231+
num_prompt_tokens = prompt_tokens.size();
229232

230233
if (params.n_keep < 0) {
231-
params.n_keep = (int)prompt_tokens.size();
234+
params.n_keep = (int)num_prompt_tokens;
232235
}
233236
params.n_keep = std::min(params.n_ctx - 4, params.n_keep);
234237

235238
// if input prompt is too big, truncate like normal
236-
if (prompt_tokens.size() >= (size_t)params.n_ctx) {
239+
if (num_prompt_tokens>= (size_t)params.n_ctx) {
237240
const int n_left = (params.n_ctx - params.n_keep) / 2;
238241
std::vector<llama_token> new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + params.n_keep);
239-
const int erased_blocks = (prompt_tokens.size() - params.n_keep - n_left - 1) / n_left;
242+
const int erased_blocks = (num_prompt_tokens - params.n_keep - n_left - 1) / n_left;
240243
new_tokens.insert(new_tokens.end(), prompt_tokens.begin() + params.n_keep + erased_blocks * n_left, prompt_tokens.end());
241244
std::copy(prompt_tokens.end() - params.n_ctx, prompt_tokens.end(), last_n_tokens.begin());
242245

@@ -250,15 +253,15 @@ struct llama_server_context {
250253
truncated = true;
251254
prompt_tokens = new_tokens;
252255
} else {
253-
const size_t ps = prompt_tokens.size();
256+
const size_t ps = num_prompt_tokens;
254257
std::fill(last_n_tokens.begin(), last_n_tokens.end() - ps, 0);
255258
std::copy(prompt_tokens.begin(), prompt_tokens.end(), last_n_tokens.end() - ps);
256259
}
257260

258261
// compare the evaluated prompt with the new prompt
259262
n_past = common_part(embd, prompt_tokens);
260263
embd = prompt_tokens;
261-
if (n_past == prompt_tokens.size()) {
264+
if (n_past == num_prompt_tokens) {
262265
// we have to evaluate at least 1 token to generate logits.
263266
n_past--;
264267
}
@@ -763,6 +766,7 @@ static json format_final_response(llama_server_context & llama, const std::strin
763766
{ "stop", true },
764767
{ "model", llama.params.model_alias },
765768
{ "tokens_predicted", llama.num_tokens_predicted },
769+
{ "tokens_evaluated", llama.num_prompt_tokens },
766770
{ "generation_settings", format_generation_settings(llama) },
767771
{ "prompt", llama.params.prompt },
768772
{ "truncated", llama.truncated },

0 commit comments

Comments
 (0)