Skip to content

Commit f2e4aff

Browse files
abetlenDon Mahurin
authored and
Don Mahurin
committed
Fix llama_cpp and Llama type signatures. Closes ggml-org#221
1 parent 18eca89 commit f2e4aff

File tree

1 file changed

+17
-17
lines changed

1 file changed

+17
-17
lines changed

examples/llama_cpp.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ def llama_free(ctx: llama_context_p):
206206
# nthread - how many threads to use. If <=0, will use std::thread::hardware_concurrency(), else the number given
207207
def llama_model_quantize(
208208
fname_inp: bytes, fname_out: bytes, ftype: c_int, nthread: c_int
209-
) -> c_int:
209+
) -> int:
210210
return _lib.llama_model_quantize(fname_inp, fname_out, ftype, nthread)
211211

212212

@@ -225,7 +225,7 @@ def llama_apply_lora_from_file(
225225
path_lora: c_char_p,
226226
path_base_model: c_char_p,
227227
n_threads: c_int,
228-
) -> c_int:
228+
) -> int:
229229
return _lib.llama_apply_lora_from_file(ctx, path_lora, path_base_model, n_threads)
230230

231231

@@ -234,7 +234,7 @@ def llama_apply_lora_from_file(
234234

235235

236236
# Returns the number of tokens in the KV cache
237-
def llama_get_kv_cache_token_count(ctx: llama_context_p) -> c_int:
237+
def llama_get_kv_cache_token_count(ctx: llama_context_p) -> int:
238238
return _lib.llama_get_kv_cache_token_count(ctx)
239239

240240

@@ -253,7 +253,7 @@ def llama_set_rng_seed(ctx: llama_context_p, seed: c_int):
253253

254254
# Returns the maximum size in bytes of the state (rng, logits, embedding
255255
# and kv_cache) - will often be smaller after compacting tokens
256-
def llama_get_state_size(ctx: llama_context_p) -> c_size_t:
256+
def llama_get_state_size(ctx: llama_context_p) -> int:
257257
return _lib.llama_get_state_size(ctx)
258258

259259

@@ -293,7 +293,7 @@ def llama_load_session_file(
293293
tokens_out, # type: Array[llama_token]
294294
n_token_capacity: c_size_t,
295295
n_token_count_out, # type: _Pointer[c_size_t]
296-
) -> c_size_t:
296+
) -> int:
297297
return _lib.llama_load_session_file(
298298
ctx, path_session, tokens_out, n_token_capacity, n_token_count_out
299299
)
@@ -314,7 +314,7 @@ def llama_save_session_file(
314314
path_session: bytes,
315315
tokens, # type: Array[llama_token]
316316
n_token_count: c_size_t,
317-
) -> c_size_t:
317+
) -> int:
318318
return _lib.llama_save_session_file(ctx, path_session, tokens, n_token_count)
319319

320320

@@ -337,7 +337,7 @@ def llama_eval(
337337
n_tokens: c_int,
338338
n_past: c_int,
339339
n_threads: c_int,
340-
) -> c_int:
340+
) -> int:
341341
return _lib.llama_eval(ctx, tokens, n_tokens, n_past, n_threads)
342342

343343

@@ -364,23 +364,23 @@ def llama_tokenize(
364364
_lib.llama_tokenize.restype = c_int
365365

366366

367-
def llama_n_vocab(ctx: llama_context_p) -> c_int:
367+
def llama_n_vocab(ctx: llama_context_p) -> int:
368368
return _lib.llama_n_vocab(ctx)
369369

370370

371371
_lib.llama_n_vocab.argtypes = [llama_context_p]
372372
_lib.llama_n_vocab.restype = c_int
373373

374374

375-
def llama_n_ctx(ctx: llama_context_p) -> c_int:
375+
def llama_n_ctx(ctx: llama_context_p) -> int:
376376
return _lib.llama_n_ctx(ctx)
377377

378378

379379
_lib.llama_n_ctx.argtypes = [llama_context_p]
380380
_lib.llama_n_ctx.restype = c_int
381381

382382

383-
def llama_n_embd(ctx: llama_context_p) -> c_int:
383+
def llama_n_embd(ctx: llama_context_p) -> int:
384384
return _lib.llama_n_embd(ctx)
385385

386386

@@ -426,23 +426,23 @@ def llama_token_to_str(ctx: llama_context_p, token: llama_token) -> bytes:
426426
# Special tokens
427427

428428

429-
def llama_token_bos() -> llama_token:
429+
def llama_token_bos() -> int:
430430
return _lib.llama_token_bos()
431431

432432

433433
_lib.llama_token_bos.argtypes = []
434434
_lib.llama_token_bos.restype = llama_token
435435

436436

437-
def llama_token_eos() -> llama_token:
437+
def llama_token_eos() -> int:
438438
return _lib.llama_token_eos()
439439

440440

441441
_lib.llama_token_eos.argtypes = []
442442
_lib.llama_token_eos.restype = llama_token
443443

444444

445-
def llama_token_nl() -> llama_token:
445+
def llama_token_nl() -> int:
446446
return _lib.llama_token_nl()
447447

448448

@@ -625,7 +625,7 @@ def llama_sample_token_mirostat(
625625
eta: c_float,
626626
m: c_int,
627627
mu, # type: _Pointer[c_float]
628-
) -> llama_token:
628+
) -> int:
629629
return _lib.llama_sample_token_mirostat(ctx, candidates, tau, eta, m, mu)
630630

631631

@@ -651,7 +651,7 @@ def llama_sample_token_mirostat_v2(
651651
tau: c_float,
652652
eta: c_float,
653653
mu, # type: _Pointer[c_float]
654-
) -> llama_token:
654+
) -> int:
655655
return _lib.llama_sample_token_mirostat_v2(ctx, candidates, tau, eta, mu)
656656

657657

@@ -669,7 +669,7 @@ def llama_sample_token_mirostat_v2(
669669
def llama_sample_token_greedy(
670670
ctx: llama_context_p,
671671
candidates, # type: _Pointer[llama_token_data_array]
672-
) -> llama_token:
672+
) -> int:
673673
return _lib.llama_sample_token_greedy(ctx, candidates)
674674

675675

@@ -684,7 +684,7 @@ def llama_sample_token_greedy(
684684
def llama_sample_token(
685685
ctx: llama_context_p,
686686
candidates, # type: _Pointer[llama_token_data_array]
687-
) -> llama_token:
687+
) -> int:
688688
return _lib.llama_sample_token(ctx, candidates)
689689

690690

0 commit comments

Comments
 (0)