8
8
import struct
9
9
import sys
10
10
import io
11
+ import pickle
11
12
from pathlib import Path
12
13
from enum import Enum
13
14
from pathlib import Path
14
15
from typing import IO , Any , Iterable , List , Optional , Tuple
15
16
import numpy as np
16
17
import math
18
+ from attr import dataclass
17
19
18
20
import torch
19
21
from torch import nn
@@ -92,6 +94,8 @@ class ModelType(Enum):
92
94
93
95
CohereCommand = 0x1400
94
96
97
+ Grok1 = 0x1500
98
+
95
99
BCE_Embedding = 0x10000100
96
100
BCE_ReRanker = 0x10000101
97
101
@@ -205,7 +209,7 @@ def load_all_model_files(model_files) -> Dict:
205
209
r [k ] = v
206
210
yield r
207
211
208
- def dump_state_dict (f , weight_names , model_files , ggml_type , config , state_dict_pp ):
212
+ def dump_state_dict (f , weight_names , model_files , ggml_type , config , state_dict_pp , loader_fun = None ):
209
213
tensor_info = []
210
214
converted_names = []
211
215
@@ -214,7 +218,10 @@ def dump_state_dict(f, weight_names, model_files, ggml_type, config, state_dict_
214
218
state_dict_cache = {}
215
219
remaining : List = weight_names .copy ()
216
220
217
- for state_dict in load_all_model_files (model_files ):
221
+ if loader_fun is None :
222
+ loader_fun = load_all_model_files
223
+
224
+ for state_dict in loader_fun (model_files ):
218
225
this_round = {}
219
226
state_dict = state_dict_pp (config , state_dict )
220
227
@@ -2240,6 +2247,247 @@ def get_weight_names(config):
2240
2247
r = LlamaConverter .get_weight_names (config )
2241
2248
return r [:- 1 ]
2242
2249
2250
+ @dataclass
2251
+ class QuantizedWeight8bit :
2252
+ def __init__ (self ):
2253
+ import jax
2254
+ import jax .numpy as jnp
2255
+ import jnp .array
2256
+
2257
+ self .weight : jnp .array
2258
+ self .scales : jnp .array
2259
+
2260
+ @property
2261
+ def shape (self ):
2262
+ return self .weight .shape
2263
+
2264
+ class Grok1Converter (BaseConverter ):
2265
+ MODEL_TYPE = ModelType .Grok1
2266
+ tensor_map = []
2267
+ file_to_name = {}
2268
+ experts = []
2269
+
2270
+ @classmethod
2271
+ def state_dict_pp (cls , config , state_dict ):
2272
+ new_dict = {}
2273
+
2274
+ for name in state_dict :
2275
+ tensor : torch .Tensor = state_dict [name ]
2276
+ if name .endswith ('embed_tokens.weight' ):
2277
+ new_dict ['model.embed_tokens.weight' ] = tensor * config .embedding_multiplier_scale
2278
+ elif 'multi_head_attention' in name :
2279
+ old_name = name .replace ('multi_head_attention' , 'self_attn' )
2280
+ if name .endswith ('k_proj.weight' ):
2281
+ new_dict [old_name ] = permute (tensor , config .num_key_value_heads )
2282
+ elif name .endswith ('q_proj.weight' ):
2283
+ new_dict [old_name ] = permute (tensor , config .num_attention_heads )
2284
+ else :
2285
+ new_dict [old_name ] = tensor
2286
+ elif 'experts' in name :
2287
+ new_dict [name ] = tensor
2288
+ else :
2289
+ old_name = ''
2290
+ mapping = {
2291
+ 'language_model.norm.weight' : 'model.norm.weight' ,
2292
+ 'rms_norm.weight' : 'rms_norm.weight' ,
2293
+ 'rms_norm_1.weight' : 'rms_norm_1.weight' ,
2294
+ 'rms_norm_2.weight' : 'rms_norm_2.weight' ,
2295
+ 'rms_norm_3.weight' : 'rms_norm_3.weight' ,
2296
+ 'router.weight' : 'router.weight' ,
2297
+ }
2298
+
2299
+ for k in mapping .keys ():
2300
+ if name .endswith (k ):
2301
+ old_name = name .replace (k , mapping [k ])
2302
+ break
2303
+
2304
+ if old_name == '' :
2305
+ raise Exception (f'unhandled tensor { name } ' )
2306
+
2307
+ new_dict [old_name ] = tensor
2308
+
2309
+ return new_dict
2310
+
2311
+ @staticmethod
2312
+ def dump_config (f , config , ggml_type ):
2313
+ assert config .hidden_act == 'gelu' , "hidden_act == 'gelu'"
2314
+
2315
+ config .hidden_act = 'silu'
2316
+ LlamaConverter .dump_config (f , config , ggml_type )
2317
+ config_values = [
2318
+ config .num_key_value_heads ,
2319
+ config .num_experts ,
2320
+ config .num_selected_experts ,
2321
+ ]
2322
+ f .write (struct .pack ("i" * len (config_values ), * config_values ))
2323
+ f .write (struct .pack ("<f" , config .rope_theta ))
2324
+ f .write (struct .pack ("<f" , config .output_multiplier_scale ))
2325
+
2326
+ @staticmethod
2327
+ def get_weight_names (config ):
2328
+ weight_names = ["model.embed_tokens.weight" ]
2329
+ for i in range (config .num_hidden_layers ):
2330
+ for j in range (config .num_experts ):
2331
+ weight_names += [
2332
+ f"model.layers.{ i } .experts.{ j } .w1.weight" ,
2333
+ f"model.layers.{ i } .experts.{ j } .w2.weight" ,
2334
+ f"model.layers.{ i } .experts.{ j } .w3.weight" ,
2335
+ ]
2336
+
2337
+ weight_names += [
2338
+ f"model.layers.{ i } .self_attn.k_proj.weight" ,
2339
+ f"model.layers.{ i } .self_attn.o_proj.weight" ,
2340
+ f"model.layers.{ i } .self_attn.q_proj.weight" ,
2341
+ f"model.layers.{ i } .self_attn.v_proj.weight" ,
2342
+ f"model.layers.{ i } .rms_norm.weight" ,
2343
+ f"model.layers.{ i } .rms_norm_1.weight" ,
2344
+ f"model.layers.{ i } .rms_norm_2.weight" ,
2345
+ f"model.layers.{ i } .rms_norm_3.weight" ,
2346
+ f"model.layers.{ i } .router.weight" ,
2347
+ ]
2348
+
2349
+ weight_names += [
2350
+ "model.norm.weight" ,
2351
+ ]
2352
+
2353
+ return weight_names
2354
+
2355
+ @staticmethod
2356
+ def load_tensor_file (tensor_name , fn ) -> Any :
2357
+ tensor_dict = {}
2358
+ new_dict = {}
2359
+
2360
+ # copied from https://gist.github.com/chu-tianxiang/ec310e15d56949fd0f351cb5f65ee7a1
2361
+ def convert_weight (v ):
2362
+ dtype = torch .float32
2363
+ if hasattr (v , 'scales' ):
2364
+ weight = torch .from_numpy (np .asarray (v .weight ).astype (np .float32 )).to (dtype )
2365
+ scale = torch .from_numpy (np .asarray (v .scales ).astype (np .float32 )).to (dtype )
2366
+ # row parallel layers have sharded scale
2367
+ if len (scale .shape ) >= 2 and scale .shape [- 2 ] != 1 :
2368
+ scale = scale [..., None , :]
2369
+ weight = weight .view (* weight .shape [:- 2 ], 8 , - 1 , weight .shape [- 1 ])
2370
+ weight = (weight * scale ).view (* weight .shape [:- 3 ], - 1 , weight .shape [- 1 ])
2371
+ else :
2372
+ weight = weight * scale
2373
+ else :
2374
+ weight = torch .from_numpy (np .asarray (v ).astype (np .float32 )).to (dtype )
2375
+
2376
+ # Transpose linear matrix
2377
+ if len (weight .shape ) >= 2 and 'embed_tokens.weight' not in tensor_name :
2378
+ weight = weight .transpose (- 1 , - 2 ).contiguous ()
2379
+
2380
+ if tensor_name .endswith ('router.weight' ):
2381
+ new_dict [tensor_name ] = weight [Grok1Converter .experts ]
2382
+ elif 'experts' not in tensor_name :
2383
+ new_dict [tensor_name ] = weight
2384
+ else :
2385
+ # split moe
2386
+ for i in range (len (Grok1Converter .experts )):
2387
+ new_key_i = tensor_name .replace ('experts' , f'experts.{ i } ' )
2388
+ new_dict [new_key_i ] = weight [Grok1Converter .experts [i ]]
2389
+
2390
+ with open (fn , 'rb' ) as f :
2391
+ r = pickle .load (f )
2392
+ tensor_dict [tensor_name ] = r
2393
+
2394
+ convert_weight (r )
2395
+
2396
+ return new_dict
2397
+
2398
+ @staticmethod
2399
+ def load_tensor_files (tensor_files ) -> Dict :
2400
+ for (t , f ) in tensor_files :
2401
+ print (f )
2402
+ yield Grok1Converter .load_tensor_file (t , f )
2403
+
2404
+ @classmethod
2405
+ def convert (cls , config , model_files_path , vocab : Any , ggml_type , save_path ):
2406
+
2407
+ Grok1Converter .experts = config .experts
2408
+
2409
+ map = ['language_model.embed_tokens.weight' ,
2410
+ 'language_model.norm.weight' ]
2411
+
2412
+ # caution: alphabet order must not be changed!
2413
+ for i in range (config .num_hidden_layers ):
2414
+ map += [
2415
+ f"model.layers.{ i } .experts.w1.weight" ,
2416
+ f"model.layers.{ i } .experts.w2.weight" ,
2417
+ f"model.layers.{ i } .experts.w3.weight" ,
2418
+ f"model.layers.{ i } .multi_head_attention.k_proj.weight" ,
2419
+ f"model.layers.{ i } .multi_head_attention.o_proj.weight" ,
2420
+ f"model.layers.{ i } .multi_head_attention.q_proj.weight" ,
2421
+ f"model.layers.{ i } .multi_head_attention.v_proj.weight" ,
2422
+ f"model.layers.{ i } .rms_norm.weight" ,
2423
+ f"model.layers.{ i } .rms_norm_1.weight" ,
2424
+ f"model.layers.{ i } .rms_norm_2.weight" ,
2425
+ f"model.layers.{ i } .rms_norm_3.weight" ,
2426
+ f"model.layers.{ i } .router.weight" ,
2427
+ ]
2428
+
2429
+ order = list (range (len (map )))
2430
+ order .sort (key = lambda i : map [i ])
2431
+
2432
+ for i in range (len (map )):
2433
+ idx = order .index (i )
2434
+ fn = model_files_path + f'/tensor{ idx :05} _000'
2435
+ info = (map [i ], fn )
2436
+ Grok1Converter .tensor_map .append (info )
2437
+ Grok1Converter .file_to_name [fn ] = map [i ]
2438
+
2439
+ # convert all weights to fp16
2440
+ with open (save_path , "wb" ) as f :
2441
+ f .write (b"ggml" ) # magic
2442
+ f .write (struct .pack ("ii" , cls .MODEL_TYPE .value , cls .FILE_VERSION ))
2443
+ Grok1Converter .dump_config (f , config , ggml_type )
2444
+ vocab .write_vocab (f )
2445
+
2446
+ weight_names = Grok1Converter .get_weight_names (config )
2447
+ dump_state_dict (f , weight_names , Grok1Converter .tensor_map , ggml_type , config , Grok1Converter .state_dict_pp , loader_fun = Grok1Converter .load_tensor_files )
2448
+
2449
+ print (f"{ Grok1Converter .MODEL_TYPE .name } GGML model saved to { save_path } " )
2450
+
2451
+ def convert_grok_1_base (args , vocab , ggml_type ):
2452
+ def ffn_size (emb_size , widening_factor ):
2453
+ _ffn_size = int (widening_factor * emb_size ) * 2 // 3
2454
+ _ffn_size = _ffn_size + (8 - _ffn_size ) % 8 # ensure it's a multiple of 8
2455
+ return _ffn_size
2456
+
2457
+ grok1_config = {
2458
+ 'vocab_size' : 128 * 1024 ,
2459
+ 'hidden_act' : 'gelu' ,
2460
+ 'pad_token_id' : 0 ,
2461
+ 'eos_token_id' : 2 ,
2462
+ 'max_position_embeddings' : 8192 ,
2463
+ 'output_multiplier_scale' : 0.5773502691896257 ,
2464
+ 'embedding_multiplier_scale' : 78.38367176906169 ,
2465
+ 'hidden_size' : 48 * 128 ,
2466
+ 'intermediate_size' : - 1 ,
2467
+ 'num_attention_heads' : 48 ,
2468
+ 'num_key_value_heads' : 8 ,
2469
+ 'num_hidden_layers' : 64 ,
2470
+ 'num_selected_experts' : 2 ,
2471
+ 'rope_theta' : 10000 ,
2472
+ 'attn_output_multiplier' : 0.08838834764831845 ,
2473
+ }
2474
+
2475
+ grok1_config ['intermediate_size' ] = ffn_size (grok1_config ['hidden_size' ], 8 )
2476
+
2477
+ grok1_config ['experts' ] = list (range (8 ))
2478
+ if args .experts != '' :
2479
+ grok1_config ['experts' ] = [int (x , 0 ) for x in args .experts .split (',' )]
2480
+
2481
+ grok1_config ['num_experts' ] = len (grok1_config ['experts' ])
2482
+
2483
+ if grok1_config ['num_experts' ] < 2 :
2484
+ raise Exception (f"at least 2 experts" )
2485
+
2486
+ print (f"experts to export: { grok1_config ['experts' ]} " )
2487
+
2488
+ Grok1Converter .convert (AttributeDict (grok1_config ), args .model_name_or_path , vocab , ggml_type , args .save_path )
2489
+ return
2490
+
2243
2491
def load_vocab (path : Path ) -> Any :
2244
2492
2245
2493
def load_spm (p : Path ) -> Any :
@@ -2329,11 +2577,17 @@ def main():
2329
2577
parser .add_argument ("-o" , "--save_path" , type = Path )
2330
2578
parser .add_argument ("-t" , "--type" , type = str , default = "q8_0" , choices = ["f32" , "f16" , "q8_0" , "q4_0" , "q4_1" ])
2331
2579
parser .add_argument ("--vocab_dir" , type = str , default = '' )
2580
+ parser .add_argument ("--experts" , type = str , default = '' )
2332
2581
args = parser .parse_args ()
2333
2582
2334
2583
ggml_type = GGMLType [args .type .upper ()]
2335
2584
2336
2585
vocab = load_vocab (Path (args .model_name_or_path ) if args .vocab_dir == '' else Path (args .vocab_dir ))
2586
+
2587
+ if args .arch .lower () == 'grok-1-base' :
2588
+ convert_grok_1_base (args , vocab , ggml_type )
2589
+ return
2590
+
2337
2591
model_files = load_some_model (Path (args .model_name_or_path ))
2338
2592
2339
2593
#if args.lora_model_name_or_path is not None:
0 commit comments