2
2
#
3
3
# Usage:
4
4
#
5
- # python3 models/convert-h5-to-ggml.py
5
+ # python3 models/convert-h5-to-ggml.py
6
6
#
7
7
# This script is similar to "convert-pt-to-ggml.py"
8
8
#
@@ -40,15 +40,17 @@ def bytes_to_unicode():
40
40
cs = [chr (n ) for n in cs ]
41
41
return dict (zip (bs , cs ))
42
42
43
- if len (sys .argv ) < 3 :
44
- print ("Usage: python convert-hf-to-ggml.py model_name dir-output [use-f32]" )
43
+ if len (sys .argv ) < 4 :
44
+ print ("Usage: python convert-hf-to-ggml.py num_parts model_name dir-output [use-f32]" )
45
+ print (" num_parts: number of pytorch parts, use 0 if not a multipart model. example: 9" )
45
46
print (" model_name: name of the model to convert. Example: 'bigscience/bloomz-560m'" )
46
47
print (" dir-output: directory where the output file will be written" )
47
48
print (" use-f32: if present, use float32 instead of float16" )
48
49
sys .exit (1 )
49
50
50
- model_name = sys .argv [1 ]
51
- dir_out = sys .argv [2 ]
51
+ num_parts = int (sys .argv [1 ])
52
+ model_name = sys .argv [2 ]
53
+ dir_out = sys .argv [3 ]
52
54
53
55
# make sure the output directory exists
54
56
os .makedirs (dir_out , exist_ok = True )
@@ -60,19 +62,17 @@ def bytes_to_unicode():
60
62
# map from ftype to string
61
63
ftype_str = ["f32" , "f16" ]
62
64
ftype = 1
63
- if len (sys .argv ) > 3 :
65
+ if len (sys .argv ) > 4 :
64
66
ftype = 0
65
67
66
68
tokenizer = AutoTokenizer .from_pretrained (model_name )
67
69
config = AutoConfig .from_pretrained (model_name , trust_remote_code = True )
68
70
hparams = config .to_dict ()
69
- print ("Loading model: " , model_name )
70
- model = AutoModelForCausalLM .from_pretrained (model_name , config = config , torch_dtype = torch .float16 if ftype == 1 else torch .float32 , low_cpu_mem_usage = True , trust_remote_code = True )
71
- print ("Model loaded: " , model_name )
72
71
73
72
n_head = hparams ["n_head" ]
74
73
n_head_kv = hparams ["n_head_kv" ] if "n_head_kv" in hparams else 1
75
74
head_dim = hparams ["hidden_size" ] // n_head
75
+ print ("* Loading model from: " , model_name )
76
76
77
77
fname_out = dir_out + f"/ggml-model-{ model_name .split ('/' )[- 1 ]} -{ ftype_str [ftype ]} .bin"
78
78
fout = open (fname_out , "wb" )
@@ -93,51 +93,49 @@ def bytes_to_unicode():
93
93
text = bytearray ([byte_decoder [c ] for c in reverse_vocab [i ]])
94
94
fout .write (struct .pack ("i" , len (text )))
95
95
fout .write (text )
96
-
97
- list_vars = model .state_dict ()
98
- for name in list_vars .keys ():
99
- src = name
100
-
101
- # The original query_key_value tensor contains n_head_kv "kv groups",
102
- # each consisting of n_head/n_head_kv query weights followed by one key
103
- # and one value weight (shared by all query heads in the kv group).
104
- # This layout makes it a big pain to work with in GGML.
105
- # So we rearrange them here,, so that we have n_head query weights
106
- # followed by n_head_kv key weights followed by n_head_kv value weights,
107
- # in contiguous fashion.
108
-
109
- if "query_key_value" in src :
110
- qkv = list_vars [src ].view (
111
- n_head_kv , n_head // n_head_kv + 2 , head_dim , head_dim * n_head )
112
-
113
- q = qkv [:, :- 2 ].reshape (n_head * head_dim , head_dim * n_head )
114
- k = qkv [:, [- 2 ]].reshape (n_head_kv * head_dim , head_dim * n_head )
115
- v = qkv [:, [- 1 ]].reshape (n_head_kv * head_dim , head_dim * n_head )
116
-
117
- list_vars [src ] = torch .cat ((q ,k ,v )).reshape_as (list_vars [src ])
118
-
119
- data = list_vars [src ].squeeze ().numpy ()
120
- data = data .astype (np .float32 )
121
-
122
- n_dims = len (data .shape )
123
- print (name , n_dims , data .shape )
124
-
125
- # default type is fp32
126
- ftype_cur = 0
127
- if ftype == 1 and n_dims > 1 :
128
- print (" Converting to float16" )
129
- data = data .astype (np .float16 )
130
- ftype_cur = 1
131
-
132
- # header
133
- str = name .encode ('utf-8' )
134
- fout .write (struct .pack ("iii" , n_dims , len (str ), ftype_cur ))
135
- for i in range (n_dims ):
136
- fout .write (struct .pack ("i" , data .shape [n_dims - 1 - i ]))
137
- fout .write (str )
138
-
139
- # data
140
- data .tofile (fout )
96
+
97
+ if num_parts == 0 :
98
+ partnames = ('pytorch_model.bin' ,)
99
+ else :
100
+ partnames = (f'pytorch_model-{ n :05} -of-{ num_parts :05} .bin' for n in range (1 , num_parts + 1 ))
101
+ for partname in partnames :
102
+ filename = f'{ model_name } /{ partname } '
103
+ print (f'\n * Loading part: { partname } ' )
104
+ model = torch .load (filename , map_location = 'cpu' )
105
+ for name in model .keys ():
106
+ src = name
107
+ # The original query_key_value tensor contains n_head_kv "kv groups",
108
+ # each consisting of n_head/n_head_kv query weights followed by one key
109
+ # and one value weight (shared by all query heads in the kv group).
110
+ # This layout makes it a big pain to work with in GGML.
111
+ # So we rearrange them here,, so that we have n_head query weights
112
+ # followed by n_head_kv key weights followed by n_head_kv value weights,
113
+ # in contiguous fashion.
114
+
115
+ if "query_key_value" in src :
116
+ qkv = model [src ].view (
117
+ n_head_kv , n_head // n_head_kv + 2 , head_dim , head_dim * n_head )
118
+
119
+ q = qkv [:, :- 2 ].reshape (n_head * head_dim , head_dim * n_head )
120
+ k = qkv [:, [- 2 ]].reshape (n_head_kv * head_dim , head_dim * n_head )
121
+ v = qkv [:, [- 1 ]].reshape (n_head_kv * head_dim , head_dim * n_head )
122
+
123
+ model [src ] = torch .cat ((q ,k ,v )).reshape_as (model [src ])
124
+ data = model [src ].squeeze ()
125
+ n_dims = len (data .shape )
126
+ # default type is fp32
127
+ ftype_cur = 1 if ftype == 1 and n_dims > 1 else 0
128
+ data = data .to (dtype = torch .float16 if ftype_cur == 1 else torch .float32 ).numpy ()
129
+ print (f' |' , name , data .shape , '->' , data .dtype )
130
+ # header
131
+ str = name .encode ('utf-8' )
132
+ fout .write (struct .pack ("iii" , n_dims , len (str ), ftype_cur ))
133
+ for i in range (n_dims ):
134
+ fout .write (struct .pack ("i" , data .shape [n_dims - 1 - i ]))
135
+ fout .write (str )
136
+
137
+ # data
138
+ data .tofile (fout )
141
139
142
140
fout .close ()
143
141
0 commit comments