Skip to content

Commit ec67b8c

Browse files
jhen0409mglambda
authored andcommitted
swift : fix llama-vocab api usage (ggml-org#11645)
* swiftui : fix vocab api usage * batched.swift : fix vocab api usage
1 parent c579d6c commit ec67b8c

File tree

2 files changed

+17
-10
lines changed

2 files changed

+17
-10
lines changed

examples/batched.swift/Sources/main.swift

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,11 @@ defer {
3131
llama_model_free(model)
3232
}
3333

34+
guard let vocab = llama_model_get_vocab(model) else {
35+
print("Failed to get vocab")
36+
exit(1)
37+
}
38+
3439
var tokens = tokenize(text: prompt, add_bos: true)
3540

3641
let n_kv_req = UInt32(tokens.count) + UInt32((n_len - Int(tokens.count)) * n_parallel)
@@ -41,7 +46,7 @@ context_params.n_batch = UInt32(max(n_len, n_parallel))
4146
context_params.n_threads = 8
4247
context_params.n_threads_batch = 8
4348

44-
let context = llama_new_context_with_model(model, context_params)
49+
let context = llama_init_from_model(model, context_params)
4550
guard context != nil else {
4651
print("Failed to initialize context")
4752
exit(1)
@@ -141,7 +146,7 @@ while n_cur <= n_len {
141146
let new_token_id = llama_sampler_sample(smpl, context, i_batch[i])
142147

143148
// is it an end of stream? -> mark the stream as finished
144-
if llama_vocab_is_eog(model, new_token_id) || n_cur == n_len {
149+
if llama_vocab_is_eog(vocab, new_token_id) || n_cur == n_len {
145150
i_batch[i] = -1
146151
// print("")
147152
if n_parallel > 1 {
@@ -207,7 +212,7 @@ private func tokenize(text: String, add_bos: Bool) -> [llama_token] {
207212
let utf8Count = text.utf8.count
208213
let n_tokens = utf8Count + (add_bos ? 1 : 0)
209214
let tokens = UnsafeMutablePointer<llama_token>.allocate(capacity: n_tokens)
210-
let tokenCount = llama_tokenize(model, text, Int32(utf8Count), tokens, Int32(n_tokens), add_bos, /*special tokens*/ false)
215+
let tokenCount = llama_tokenize(vocab, text, Int32(utf8Count), tokens, Int32(n_tokens), add_bos, /*special tokens*/ false)
211216
var swiftTokens: [llama_token] = []
212217
for i in 0 ..< tokenCount {
213218
swiftTokens.append(tokens[Int(i)])
@@ -218,12 +223,12 @@ private func tokenize(text: String, add_bos: Bool) -> [llama_token] {
218223

219224
private func token_to_piece(token: llama_token, buffer: inout [CChar]) -> String? {
220225
var result = [CChar](repeating: 0, count: 8)
221-
let nTokens = llama_token_to_piece(model, token, &result, Int32(result.count), 0, false)
226+
let nTokens = llama_token_to_piece(vocab, token, &result, Int32(result.count), 0, false)
222227
if nTokens < 0 {
223228
let actualTokensCount = -Int(nTokens)
224229
result = .init(repeating: 0, count: actualTokensCount)
225230
let check = llama_token_to_piece(
226-
model,
231+
vocab,
227232
token,
228233
&result,
229234
Int32(result.count),

examples/llama.swiftui/llama.cpp.swift/LibLlama.swift

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ func llama_batch_add(_ batch: inout llama_batch, _ id: llama_token, _ pos: llama
2424
actor LlamaContext {
2525
private var model: OpaquePointer
2626
private var context: OpaquePointer
27+
private var vocab: OpaquePointer
2728
private var sampling: UnsafeMutablePointer<llama_sampler>
2829
private var batch: llama_batch
2930
private var tokens_list: [llama_token]
@@ -47,6 +48,7 @@ actor LlamaContext {
4748
self.sampling = llama_sampler_chain_init(sparams)
4849
llama_sampler_chain_add(self.sampling, llama_sampler_init_temp(0.4))
4950
llama_sampler_chain_add(self.sampling, llama_sampler_init_dist(1234))
51+
vocab = llama_model_get_vocab(model)
5052
}
5153

5254
deinit {
@@ -79,7 +81,7 @@ actor LlamaContext {
7981
ctx_params.n_threads = Int32(n_threads)
8082
ctx_params.n_threads_batch = Int32(n_threads)
8183

82-
let context = llama_new_context_with_model(model, ctx_params)
84+
let context = llama_init_from_model(model, ctx_params)
8385
guard let context else {
8486
print("Could not load context!")
8587
throw LlamaError.couldNotInitializeContext
@@ -151,7 +153,7 @@ actor LlamaContext {
151153

152154
new_token_id = llama_sampler_sample(sampling, context, batch.n_tokens - 1)
153155

154-
if llama_vocab_is_eog(model, new_token_id) || n_cur == n_len {
156+
if llama_vocab_is_eog(vocab, new_token_id) || n_cur == n_len {
155157
print("\n")
156158
is_done = true
157159
let new_token_str = String(cString: temporary_invalid_cchars + [0])
@@ -297,7 +299,7 @@ actor LlamaContext {
297299
let utf8Count = text.utf8.count
298300
let n_tokens = utf8Count + (add_bos ? 1 : 0) + 1
299301
let tokens = UnsafeMutablePointer<llama_token>.allocate(capacity: n_tokens)
300-
let tokenCount = llama_tokenize(model, text, Int32(utf8Count), tokens, Int32(n_tokens), add_bos, false)
302+
let tokenCount = llama_tokenize(vocab, text, Int32(utf8Count), tokens, Int32(n_tokens), add_bos, false)
301303

302304
var swiftTokens: [llama_token] = []
303305
for i in 0..<tokenCount {
@@ -316,15 +318,15 @@ actor LlamaContext {
316318
defer {
317319
result.deallocate()
318320
}
319-
let nTokens = llama_token_to_piece(model, token, result, 8, 0, false)
321+
let nTokens = llama_token_to_piece(vocab, token, result, 8, 0, false)
320322

321323
if nTokens < 0 {
322324
let newResult = UnsafeMutablePointer<Int8>.allocate(capacity: Int(-nTokens))
323325
newResult.initialize(repeating: Int8(0), count: Int(-nTokens))
324326
defer {
325327
newResult.deallocate()
326328
}
327-
let nNewTokens = llama_token_to_piece(model, token, newResult, -nTokens, 0, false)
329+
let nNewTokens = llama_token_to_piece(vocab, token, newResult, -nTokens, 0, false)
328330
let bufferPointer = UnsafeBufferPointer(start: newResult, count: Int(nNewTokens))
329331
return Array(bufferPointer)
330332
} else {

0 commit comments

Comments
 (0)