From d44756e8d9a6636d588a72aa587f4810852bf5be Mon Sep 17 00:00:00 2001
From: ensan-hcl <ensan.hcl.dev@gmail.com>
Date: Tue, 5 Dec 2023 00:31:30 +0900
Subject: [PATCH] fix concatenation method to avoid invalid UTF8 stringfication

---
 .../llama.cpp.swift/LibLlama.swift            | 37 +++++++++++++++----
 1 file changed, 30 insertions(+), 7 deletions(-)

diff --git a/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift b/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift
index f828106fbf6fb..3754f055163ea 100644
--- a/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift
+++ b/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift
@@ -11,6 +11,8 @@ actor LlamaContext {
     private var context: OpaquePointer
     private var batch: llama_batch
     private var tokens_list: [llama_token]
+    /// This variable is used to store temporarily invalid cchars
+    private var temporary_invalid_cchars: [CChar]
 
     var n_len: Int32 = 512
     var n_cur: Int32 = 0
@@ -21,6 +23,7 @@ actor LlamaContext {
         self.context = context
         self.tokens_list = []
         self.batch = llama_batch_init(512, 0, 1)
+        self.temporary_invalid_cchars = []
     }
 
     deinit {
@@ -61,6 +64,7 @@ actor LlamaContext {
         print("attempting to complete \"\(text)\"")
 
         tokens_list = tokenize(text: text, add_bos: true)
+        temporary_invalid_cchars = []
 
         let n_ctx = llama_n_ctx(context)
         let n_kv_req = tokens_list.count + (Int(n_len) - tokens_list.count)
@@ -72,7 +76,7 @@ actor LlamaContext {
         }
 
         for id in tokens_list {
-            print(token_to_piece(token: id))
+            print(String(cString: token_to_piece(token: id) + [0]))
         }
 
         // batch = llama_batch_init(512, 0) // done in init()
@@ -115,10 +119,25 @@ actor LlamaContext {
 
         if new_token_id == llama_token_eos(context) || n_cur == n_len {
             print("\n")
-            return ""
+            let new_token_str = String(cString: temporary_invalid_cchars + [0])
+            temporary_invalid_cchars.removeAll()
+            return new_token_str
         }
 
-        let new_token_str = token_to_piece(token: new_token_id)
+        let new_token_cchars = token_to_piece(token: new_token_id)
+        temporary_invalid_cchars.append(contentsOf: new_token_cchars)
+        let new_token_str: String
+        if let string = String(validatingUTF8: temporary_invalid_cchars + [0]) {
+            temporary_invalid_cchars.removeAll()
+            new_token_str = string
+        } else if (0 ..< temporary_invalid_cchars.count).contains(where: {$0 != 0 && String(validatingUTF8: Array(temporary_invalid_cchars.suffix($0)) + [0]) != nil}) {
+            // in this case, at least the suffix of the temporary_invalid_cchars can be interpreted as UTF8 string
+            let string = String(cString: temporary_invalid_cchars + [0])
+            temporary_invalid_cchars.removeAll()
+            new_token_str = string
+        } else {
+            new_token_str = ""
+        }
         print(new_token_str)
         // tokens_list.append(new_token_id)
 
@@ -144,6 +163,7 @@ actor LlamaContext {
 
     func clear() {
         tokens_list.removeAll()
+        temporary_invalid_cchars.removeAll()
     }
 
     private func tokenize(text: String, add_bos: Bool) -> [llama_token] {
@@ -162,7 +182,8 @@ actor LlamaContext {
         return swiftTokens
     }
 
-    private func token_to_piece(token: llama_token) -> String {
+    /// - note: The result does not contain null-terminator
+    private func token_to_piece(token: llama_token) -> [CChar] {
         let result = UnsafeMutablePointer<Int8>.allocate(capacity: 8)
         result.initialize(repeating: Int8(0), count: 8)
         defer {
@@ -176,10 +197,12 @@ actor LlamaContext {
             defer {
                 newResult.deallocate()
             }
-            _ = llama_token_to_piece(model, token, newResult, -nTokens)
-            return String(cString: newResult)
+            let nNewTokens = llama_token_to_piece(model, token, newResult, -nTokens)
+            let bufferPointer = UnsafeBufferPointer(start: newResult, count: Int(nNewTokens))
+            return Array(bufferPointer)
         } else {
-            return String(cString: result)
+            let bufferPointer = UnsafeBufferPointer(start: result, count: Int(nTokens))
+            return Array(bufferPointer)
         }
     }
 }