Skip to content

Commit 5f634ef

Browse files
committed
falcon7b example
1 parent 758471b commit 5f634ef

File tree

6 files changed

+1073
-0
lines changed

6 files changed

+1073
-0
lines changed

examples/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,3 +26,4 @@ add_subdirectory(dolly-v2)
2626
add_subdirectory(replit)
2727
add_subdirectory(mpt)
2828
add_subdirectory(starcoder)
29+
add_subdirectory(falcon)

examples/falcon/CMakeLists.txt

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
#
2+
# falcon
3+
4+
set(TEST_TARGET falcon)
5+
add_executable(${TEST_TARGET} main.cpp)
6+
target_link_libraries(${TEST_TARGET} PRIVATE ggml common common-ggml)
7+
8+
#
9+
# falcon-quantize
10+
11+
set(TEST_TARGET falcon-quantize)
12+
add_executable(${TEST_TARGET} quantize.cpp)
13+
target_link_libraries(${TEST_TARGET} PRIVATE ggml common common-ggml)

examples/falcon/README.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# falcon
2+
3+
Transformer architecture: falcon-7b
4+
5+
## Notes
6+
7+
- No guarantees for correctness
8+
- The tokenizer is currently hacked - probably works only for English
9+
- Non-parallel residual is not supported

examples/falcon/convert-hf-to-ggml.py

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
# Convert Hugging Face fine-tuned bloom-like models to ggml format
2+
#
3+
# Usage:
4+
#
5+
# python3 models/convert-h5-to-ggml.py
6+
#
7+
# This script is similar to "convert-pt-to-ggml.py"
8+
#
9+
10+
import io
11+
import os
12+
import sys
13+
import struct
14+
import json
15+
import code
16+
import torch
17+
import numpy as np
18+
19+
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
20+
21+
# ref: https://github.com/openai/gpt-2/blob/master/src/encoder.py
22+
def bytes_to_unicode():
23+
"""
24+
Returns list of utf-8 byte and a corresponding list of unicode strings.
25+
The reversible bpe codes work on unicode strings.
26+
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
27+
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
28+
This is a significant percentage of your normal, say, 32K bpe vocab.
29+
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
30+
And avoids mapping to whitespace/control characters the bpe code barfs on.
31+
"""
32+
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
33+
cs = bs[:]
34+
n = 0
35+
for b in range(2**8):
36+
if b not in bs:
37+
bs.append(b)
38+
cs.append(2**8+n)
39+
n += 1
40+
cs = [chr(n) for n in cs]
41+
return dict(zip(bs, cs))
42+
43+
if len(sys.argv) < 3:
44+
print("Usage: python convert-hf-to-ggml.py model_name dir-output [use-f32]")
45+
print(" model_name: name of the model to convert. Example: 'bigscience/bloomz-560m'")
46+
print(" dir-output: directory where the output file will be written")
47+
print(" use-f32: if present, use float32 instead of float16")
48+
sys.exit(1)
49+
50+
model_name = sys.argv[1]
51+
dir_out = sys.argv[2]
52+
53+
# make sure the output directory exists
54+
os.makedirs(dir_out, exist_ok=True)
55+
56+
# possible data types
57+
# ftype == 0 -> float32
58+
# ftype == 1 -> float16
59+
#
60+
# map from ftype to string
61+
ftype_str = ["f32", "f16"]
62+
ftype = 1
63+
if len(sys.argv) > 3:
64+
ftype = 0
65+
66+
tokenizer = AutoTokenizer.from_pretrained(model_name)
67+
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
68+
hparams = config.to_dict()
69+
print("Loading model: ", model_name)
70+
model = AutoModelForCausalLM.from_pretrained(model_name, config=config, torch_dtype=torch.float16 if ftype == 1 else torch.float32, low_cpu_mem_usage=True, trust_remote_code=True)
71+
print("Model loaded: ", model_name)
72+
73+
74+
fname_out = dir_out + f"/ggml-model-{model_name.split('/')[-1]}-{ftype_str[ftype]}.bin"
75+
fout = open(fname_out, "wb")
76+
fout.write(struct.pack("i", 0x67676d6c)) # magic: ggml in hex
77+
fout.write(struct.pack("i", hparams["vocab_size"]))
78+
fout.write(struct.pack("i", hparams["hidden_size"]))
79+
fout.write(struct.pack("i", hparams["n_head"]))
80+
fout.write(struct.pack("i", hparams["n_layer"]))
81+
fout.write(struct.pack("i", ftype))
82+
83+
# Is this correct?
84+
#
85+
# No. Multibyte characters that span multiple tokens like emoji 🤖 won't be
86+
# decoded properly.
87+
dot_token = tokenizer.encode(".")[0]
88+
for i in range(hparams["vocab_size"]):
89+
text = tokenizer.decode([i]).encode('utf-8')
90+
fout.write(struct.pack("i", len(text)))
91+
fout.write(text)
92+
93+
list_vars = model.state_dict()
94+
for name in list_vars.keys():
95+
src = name
96+
data = list_vars[src].squeeze().numpy()
97+
data = data.astype(np.float32)
98+
99+
n_dims = len(data.shape)
100+
print(name, n_dims, data.shape)
101+
102+
# default type is fp32
103+
ftype_cur = 0
104+
if ftype == 1 and n_dims > 1:
105+
print(" Converting to float16")
106+
data = data.astype(np.float16)
107+
ftype_cur = 1
108+
109+
# header
110+
str = name.encode('utf-8')
111+
fout.write(struct.pack("iii", n_dims, len(str), ftype_cur))
112+
for i in range(n_dims):
113+
fout.write(struct.pack("i", data.shape[n_dims - 1 - i]))
114+
fout.write(str)
115+
116+
# data
117+
data.tofile(fout)
118+
119+
fout.close()
120+
121+
print("Done. Output file: " + fname_out)
122+
print("")

0 commit comments

Comments
 (0)