Skip to content

Commit 276e5fa

Browse files
author
mo374z
committed
fix token count
1 parent 70ae296 commit 276e5fa

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

promptolution/tasks/base_task.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ def _compute_costs(
216216
per_prompt_outputs: List[np.ndarray] = []
217217

218218
for prompt in prompts:
219-
prompt_tokens = token_counter(prompt.construct_prompt())
219+
prompt_token_count = token_counter(prompt.construct_prompt())
220220
seq_token_counts: List[float] = []
221221
input_token_counts = []
222222
for x, y in zip(xs, ys):
@@ -228,9 +228,9 @@ def _compute_costs(
228228
continue
229229
seq_text = self.seq_cache[cache_key]
230230
seq_token_counts.append(token_counter(seq_text))
231-
input_token_counts.append(token_counter(prompt.construct_prompt() + " " + x))
231+
input_token_counts.append(token_counter(x))
232232

233-
prompt_input_tokens = prompt_tokens + np.array(input_token_counts, dtype=float)
233+
prompt_input_tokens = np.array(input_token_counts, dtype=float) + prompt_token_count
234234
output_token_counts = np.array(seq_token_counts, dtype=float) - np.array(input_token_counts, dtype=float)
235235

236236
per_prompt_inputs.append(np.asarray(prompt_input_tokens, dtype=float))

0 commit comments

Comments
 (0)