Skip to content

Commit cc8ac10

Browse files
authored
Merge pull request #1 from KerfuffleV2/feat-improve-falcon-convert-hf
Allow converting HF Falcon models with only one shard in memory at a time
2 parents fe13c37 + ac64e94 commit cc8ac10

File tree

1 file changed

+52
-54
lines changed

1 file changed

+52
-54
lines changed

examples/falcon/convert-hf-to-ggml.py

Lines changed: 52 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
#
33
# Usage:
44
#
5-
# python3 models/convert-h5-to-ggml.py
5+
# python3 models/convert-h5-to-ggml.py
66
#
77
# This script is similar to "convert-pt-to-ggml.py"
88
#
@@ -40,15 +40,17 @@ def bytes_to_unicode():
4040
cs = [chr(n) for n in cs]
4141
return dict(zip(bs, cs))
4242

43-
if len(sys.argv) < 3:
44-
print("Usage: python convert-hf-to-ggml.py model_name dir-output [use-f32]")
43+
if len(sys.argv) < 4:
44+
print("Usage: python convert-hf-to-ggml.py num_parts model_name dir-output [use-f32]")
45+
print(" num_parts: number of pytorch parts, use 0 if not a multipart model. example: 9")
4546
print(" model_name: name of the model to convert. Example: 'bigscience/bloomz-560m'")
4647
print(" dir-output: directory where the output file will be written")
4748
print(" use-f32: if present, use float32 instead of float16")
4849
sys.exit(1)
4950

50-
model_name = sys.argv[1]
51-
dir_out = sys.argv[2]
51+
num_parts = int(sys.argv[1])
52+
model_name = sys.argv[2]
53+
dir_out = sys.argv[3]
5254

5355
# make sure the output directory exists
5456
os.makedirs(dir_out, exist_ok=True)
@@ -60,19 +62,17 @@ def bytes_to_unicode():
6062
# map from ftype to string
6163
ftype_str = ["f32", "f16"]
6264
ftype = 1
63-
if len(sys.argv) > 3:
65+
if len(sys.argv) > 4:
6466
ftype = 0
6567

6668
tokenizer = AutoTokenizer.from_pretrained(model_name)
6769
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
6870
hparams = config.to_dict()
69-
print("Loading model: ", model_name)
70-
model = AutoModelForCausalLM.from_pretrained(model_name, config=config, torch_dtype=torch.float16 if ftype == 1 else torch.float32, low_cpu_mem_usage=True, trust_remote_code=True)
71-
print("Model loaded: ", model_name)
7271

7372
n_head = hparams["n_head"]
7473
n_head_kv = hparams["n_head_kv"] if "n_head_kv" in hparams else 1
7574
head_dim = hparams["hidden_size"] // n_head
75+
print("* Loading model from: ", model_name)
7676

7777
fname_out = dir_out + f"/ggml-model-{model_name.split('/')[-1]}-{ftype_str[ftype]}.bin"
7878
fout = open(fname_out, "wb")
@@ -93,51 +93,49 @@ def bytes_to_unicode():
9393
text = bytearray([byte_decoder[c] for c in reverse_vocab[i]])
9494
fout.write(struct.pack("i", len(text)))
9595
fout.write(text)
96-
97-
list_vars = model.state_dict()
98-
for name in list_vars.keys():
99-
src = name
100-
101-
# The original query_key_value tensor contains n_head_kv "kv groups",
102-
# each consisting of n_head/n_head_kv query weights followed by one key
103-
# and one value weight (shared by all query heads in the kv group).
104-
# This layout makes it a big pain to work with in GGML.
105-
# So we rearrange them here,, so that we have n_head query weights
106-
# followed by n_head_kv key weights followed by n_head_kv value weights,
107-
# in contiguous fashion.
108-
109-
if "query_key_value" in src:
110-
qkv = list_vars[src].view(
111-
n_head_kv, n_head // n_head_kv + 2, head_dim, head_dim * n_head)
112-
113-
q = qkv[:, :-2 ].reshape(n_head * head_dim, head_dim * n_head)
114-
k = qkv[:, [-2]].reshape(n_head_kv * head_dim, head_dim * n_head)
115-
v = qkv[:, [-1]].reshape(n_head_kv * head_dim, head_dim * n_head)
116-
117-
list_vars[src] = torch.cat((q,k,v)).reshape_as(list_vars[src])
118-
119-
data = list_vars[src].squeeze().numpy()
120-
data = data.astype(np.float32)
121-
122-
n_dims = len(data.shape)
123-
print(name, n_dims, data.shape)
124-
125-
# default type is fp32
126-
ftype_cur = 0
127-
if ftype == 1 and n_dims > 1:
128-
print(" Converting to float16")
129-
data = data.astype(np.float16)
130-
ftype_cur = 1
131-
132-
# header
133-
str = name.encode('utf-8')
134-
fout.write(struct.pack("iii", n_dims, len(str), ftype_cur))
135-
for i in range(n_dims):
136-
fout.write(struct.pack("i", data.shape[n_dims - 1 - i]))
137-
fout.write(str)
138-
139-
# data
140-
data.tofile(fout)
96+
97+
if num_parts == 0:
98+
partnames= ('pytorch_model.bin',)
99+
else:
100+
partnames = (f'pytorch_model-{n:05}-of-{num_parts:05}.bin' for n in range(1, num_parts + 1))
101+
for partname in partnames:
102+
filename = f'{model_name}/{partname}'
103+
print(f'\n* Loading part: {partname}')
104+
model = torch.load(filename, map_location = 'cpu')
105+
for name in model.keys():
106+
src = name
107+
# The original query_key_value tensor contains n_head_kv "kv groups",
108+
# each consisting of n_head/n_head_kv query weights followed by one key
109+
# and one value weight (shared by all query heads in the kv group).
110+
# This layout makes it a big pain to work with in GGML.
111+
# So we rearrange them here,, so that we have n_head query weights
112+
# followed by n_head_kv key weights followed by n_head_kv value weights,
113+
# in contiguous fashion.
114+
115+
if "query_key_value" in src:
116+
qkv = model[src].view(
117+
n_head_kv, n_head // n_head_kv + 2, head_dim, head_dim * n_head)
118+
119+
q = qkv[:, :-2 ].reshape(n_head * head_dim, head_dim * n_head)
120+
k = qkv[:, [-2]].reshape(n_head_kv * head_dim, head_dim * n_head)
121+
v = qkv[:, [-1]].reshape(n_head_kv * head_dim, head_dim * n_head)
122+
123+
model[src] = torch.cat((q,k,v)).reshape_as(model[src])
124+
data = model[src].squeeze()
125+
n_dims = len(data.shape)
126+
# default type is fp32
127+
ftype_cur = 1 if ftype == 1 and n_dims > 1 else 0
128+
data = data.to(dtype = torch.float16 if ftype_cur == 1 else torch.float32).numpy()
129+
print(f' |', name, data.shape, '->', data.dtype)
130+
# header
131+
str = name.encode('utf-8')
132+
fout.write(struct.pack("iii", n_dims, len(str), ftype_cur))
133+
for i in range(n_dims):
134+
fout.write(struct.pack("i", data.shape[n_dims - 1 - i]))
135+
fout.write(str)
136+
137+
# data
138+
data.tofile(fout)
141139

142140
fout.close()
143141

0 commit comments

Comments
 (0)