Skip to content

Commit 912bacc

Browse files
author
Judd
committed
add support of Grok
1 parent 4fdccf3 commit 912bacc

File tree

6 files changed

+596
-12
lines changed

6 files changed

+596
-12
lines changed

README.md

+5
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,11 @@ pure C++ implementation based on [@ggerganov](https://github.com/ggerganov)'s [g
8989
* Cohere (`CohereForCausalLM`)
9090
* [x] [C4AI Command-R](https://huggingface.co/CohereForAI/c4ai-command-r-v01)
9191

92+
* Grok-1
93+
* [x] [Base](https://huggingface.co/xai-org/grok-1)
94+
95+
About [Grok-1](./docs/grok.md).
96+
9297
* Text Embedding (`XLMRobertaModel`)
9398
* [x] [BCE-Embedding](https://huggingface.co/maidalun1020/bce-embedding-base_v1)
9499

convert.py

+256-2
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,14 @@
88
import struct
99
import sys
1010
import io
11+
import pickle
1112
from pathlib import Path
1213
from enum import Enum
1314
from pathlib import Path
1415
from typing import IO, Any, Iterable, List, Optional, Tuple
1516
import numpy as np
1617
import math
18+
from attr import dataclass
1719

1820
import torch
1921
from torch import nn
@@ -92,6 +94,8 @@ class ModelType(Enum):
9294

9395
CohereCommand = 0x1400
9496

97+
Grok1 = 0x1500
98+
9599
BCE_Embedding = 0x10000100
96100
BCE_ReRanker = 0x10000101
97101

@@ -205,7 +209,7 @@ def load_all_model_files(model_files) -> Dict:
205209
r[k] = v
206210
yield r
207211

208-
def dump_state_dict(f, weight_names, model_files, ggml_type, config, state_dict_pp):
212+
def dump_state_dict(f, weight_names, model_files, ggml_type, config, state_dict_pp, loader_fun = None):
209213
tensor_info = []
210214
converted_names = []
211215

@@ -214,7 +218,10 @@ def dump_state_dict(f, weight_names, model_files, ggml_type, config, state_dict_
214218
state_dict_cache = {}
215219
remaining: List = weight_names.copy()
216220

217-
for state_dict in load_all_model_files(model_files):
221+
if loader_fun is None:
222+
loader_fun = load_all_model_files
223+
224+
for state_dict in loader_fun(model_files):
218225
this_round = {}
219226
state_dict = state_dict_pp(config, state_dict)
220227

@@ -2240,6 +2247,247 @@ def get_weight_names(config):
22402247
r = LlamaConverter.get_weight_names(config)
22412248
return r[:-1]
22422249

2250+
@dataclass
2251+
class QuantizedWeight8bit:
2252+
def __init__(self):
2253+
import jax
2254+
import jax.numpy as jnp
2255+
import jnp.array
2256+
2257+
self.weight: jnp.array
2258+
self.scales: jnp.array
2259+
2260+
@property
2261+
def shape(self):
2262+
return self.weight.shape
2263+
2264+
class Grok1Converter(BaseConverter):
2265+
MODEL_TYPE = ModelType.Grok1
2266+
tensor_map = []
2267+
file_to_name = {}
2268+
experts = []
2269+
2270+
@classmethod
2271+
def state_dict_pp(cls, config, state_dict):
2272+
new_dict = {}
2273+
2274+
for name in state_dict:
2275+
tensor: torch.Tensor = state_dict[name]
2276+
if name.endswith('embed_tokens.weight'):
2277+
new_dict['model.embed_tokens.weight'] = tensor * config.embedding_multiplier_scale
2278+
elif 'multi_head_attention' in name:
2279+
old_name = name.replace('multi_head_attention', 'self_attn')
2280+
if name.endswith('k_proj.weight'):
2281+
new_dict[old_name] = permute(tensor, config.num_key_value_heads)
2282+
elif name.endswith('q_proj.weight'):
2283+
new_dict[old_name] = permute(tensor, config.num_attention_heads)
2284+
else:
2285+
new_dict[old_name] = tensor
2286+
elif 'experts' in name:
2287+
new_dict[name] = tensor
2288+
else:
2289+
old_name = ''
2290+
mapping = {
2291+
'language_model.norm.weight': 'model.norm.weight',
2292+
'rms_norm.weight': 'rms_norm.weight',
2293+
'rms_norm_1.weight': 'rms_norm_1.weight',
2294+
'rms_norm_2.weight': 'rms_norm_2.weight',
2295+
'rms_norm_3.weight': 'rms_norm_3.weight',
2296+
'router.weight': 'router.weight',
2297+
}
2298+
2299+
for k in mapping.keys():
2300+
if name.endswith(k):
2301+
old_name = name.replace(k, mapping[k])
2302+
break
2303+
2304+
if old_name == '':
2305+
raise Exception(f'unhandled tensor {name}')
2306+
2307+
new_dict[old_name] = tensor
2308+
2309+
return new_dict
2310+
2311+
@staticmethod
2312+
def dump_config(f, config, ggml_type):
2313+
assert config.hidden_act == 'gelu', "hidden_act == 'gelu'"
2314+
2315+
config.hidden_act = 'silu'
2316+
LlamaConverter.dump_config(f, config, ggml_type)
2317+
config_values = [
2318+
config.num_key_value_heads,
2319+
config.num_experts,
2320+
config.num_selected_experts,
2321+
]
2322+
f.write(struct.pack("i" * len(config_values), *config_values))
2323+
f.write(struct.pack("<f", config.rope_theta))
2324+
f.write(struct.pack("<f", config.output_multiplier_scale))
2325+
2326+
@staticmethod
2327+
def get_weight_names(config):
2328+
weight_names = ["model.embed_tokens.weight"]
2329+
for i in range(config.num_hidden_layers):
2330+
for j in range(config.num_experts):
2331+
weight_names += [
2332+
f"model.layers.{i}.experts.{j}.w1.weight",
2333+
f"model.layers.{i}.experts.{j}.w2.weight",
2334+
f"model.layers.{i}.experts.{j}.w3.weight",
2335+
]
2336+
2337+
weight_names += [
2338+
f"model.layers.{i}.self_attn.k_proj.weight",
2339+
f"model.layers.{i}.self_attn.o_proj.weight",
2340+
f"model.layers.{i}.self_attn.q_proj.weight",
2341+
f"model.layers.{i}.self_attn.v_proj.weight",
2342+
f"model.layers.{i}.rms_norm.weight",
2343+
f"model.layers.{i}.rms_norm_1.weight",
2344+
f"model.layers.{i}.rms_norm_2.weight",
2345+
f"model.layers.{i}.rms_norm_3.weight",
2346+
f"model.layers.{i}.router.weight",
2347+
]
2348+
2349+
weight_names += [
2350+
"model.norm.weight",
2351+
]
2352+
2353+
return weight_names
2354+
2355+
@staticmethod
2356+
def load_tensor_file(tensor_name, fn) -> Any:
2357+
tensor_dict = {}
2358+
new_dict = {}
2359+
2360+
# copied from https://gist.github.com/chu-tianxiang/ec310e15d56949fd0f351cb5f65ee7a1
2361+
def convert_weight(v):
2362+
dtype = torch.float32
2363+
if hasattr(v , 'scales'):
2364+
weight = torch.from_numpy(np.asarray(v.weight).astype(np.float32)).to(dtype)
2365+
scale =torch.from_numpy(np.asarray(v.scales).astype(np.float32)).to(dtype)
2366+
# row parallel layers have sharded scale
2367+
if len(scale.shape) >= 2 and scale.shape[-2] != 1:
2368+
scale = scale[..., None, :]
2369+
weight = weight.view(*weight.shape[:-2], 8, -1, weight.shape[-1])
2370+
weight = (weight * scale).view(*weight.shape[:-3], -1, weight.shape[-1])
2371+
else:
2372+
weight = weight * scale
2373+
else:
2374+
weight = torch.from_numpy(np.asarray(v).astype(np.float32)).to(dtype)
2375+
2376+
# Transpose linear matrix
2377+
if len(weight.shape) >= 2 and 'embed_tokens.weight' not in tensor_name:
2378+
weight = weight.transpose(-1, -2).contiguous()
2379+
2380+
if tensor_name.endswith('router.weight'):
2381+
new_dict[tensor_name] = weight[Grok1Converter.experts]
2382+
elif 'experts' not in tensor_name:
2383+
new_dict[tensor_name] = weight
2384+
else:
2385+
# split moe
2386+
for i in range(len(Grok1Converter.experts)):
2387+
new_key_i = tensor_name.replace('experts', f'experts.{i}')
2388+
new_dict[new_key_i] = weight[Grok1Converter.experts[i]]
2389+
2390+
with open(fn, 'rb') as f:
2391+
r = pickle.load(f)
2392+
tensor_dict[tensor_name] = r
2393+
2394+
convert_weight(r)
2395+
2396+
return new_dict
2397+
2398+
@staticmethod
2399+
def load_tensor_files(tensor_files) -> Dict:
2400+
for (t, f) in tensor_files:
2401+
print(f)
2402+
yield Grok1Converter.load_tensor_file(t, f)
2403+
2404+
@classmethod
2405+
def convert(cls, config, model_files_path, vocab: Any, ggml_type, save_path):
2406+
2407+
Grok1Converter.experts = config.experts
2408+
2409+
map = ['language_model.embed_tokens.weight',
2410+
'language_model.norm.weight']
2411+
2412+
# caution: alphabet order must not be changed!
2413+
for i in range(config.num_hidden_layers):
2414+
map += [
2415+
f"model.layers.{i}.experts.w1.weight",
2416+
f"model.layers.{i}.experts.w2.weight",
2417+
f"model.layers.{i}.experts.w3.weight",
2418+
f"model.layers.{i}.multi_head_attention.k_proj.weight",
2419+
f"model.layers.{i}.multi_head_attention.o_proj.weight",
2420+
f"model.layers.{i}.multi_head_attention.q_proj.weight",
2421+
f"model.layers.{i}.multi_head_attention.v_proj.weight",
2422+
f"model.layers.{i}.rms_norm.weight",
2423+
f"model.layers.{i}.rms_norm_1.weight",
2424+
f"model.layers.{i}.rms_norm_2.weight",
2425+
f"model.layers.{i}.rms_norm_3.weight",
2426+
f"model.layers.{i}.router.weight",
2427+
]
2428+
2429+
order = list(range(len(map)))
2430+
order.sort(key=lambda i: map[i])
2431+
2432+
for i in range(len(map)):
2433+
idx = order.index(i)
2434+
fn = model_files_path + f'/tensor{idx:05}_000'
2435+
info = (map[i], fn)
2436+
Grok1Converter.tensor_map.append(info)
2437+
Grok1Converter.file_to_name[fn] = map[i]
2438+
2439+
# convert all weights to fp16
2440+
with open(save_path, "wb") as f:
2441+
f.write(b"ggml") # magic
2442+
f.write(struct.pack("ii", cls.MODEL_TYPE.value, cls.FILE_VERSION))
2443+
Grok1Converter.dump_config(f, config, ggml_type)
2444+
vocab.write_vocab(f)
2445+
2446+
weight_names = Grok1Converter.get_weight_names(config)
2447+
dump_state_dict(f, weight_names, Grok1Converter.tensor_map, ggml_type, config, Grok1Converter.state_dict_pp, loader_fun=Grok1Converter.load_tensor_files)
2448+
2449+
print(f"{Grok1Converter.MODEL_TYPE.name} GGML model saved to {save_path}")
2450+
2451+
def convert_grok_1_base(args, vocab, ggml_type):
2452+
def ffn_size(emb_size, widening_factor):
2453+
_ffn_size = int(widening_factor * emb_size) * 2 // 3
2454+
_ffn_size = _ffn_size + (8 - _ffn_size) % 8 # ensure it's a multiple of 8
2455+
return _ffn_size
2456+
2457+
grok1_config = {
2458+
'vocab_size': 128 * 1024,
2459+
'hidden_act': 'gelu',
2460+
'pad_token_id': 0,
2461+
'eos_token_id': 2,
2462+
'max_position_embeddings': 8192,
2463+
'output_multiplier_scale': 0.5773502691896257,
2464+
'embedding_multiplier_scale': 78.38367176906169,
2465+
'hidden_size': 48 * 128,
2466+
'intermediate_size': -1,
2467+
'num_attention_heads': 48,
2468+
'num_key_value_heads': 8,
2469+
'num_hidden_layers': 64,
2470+
'num_selected_experts': 2,
2471+
'rope_theta': 10000,
2472+
'attn_output_multiplier': 0.08838834764831845,
2473+
}
2474+
2475+
grok1_config['intermediate_size'] = ffn_size(grok1_config['hidden_size'], 8)
2476+
2477+
grok1_config['experts'] = list(range(8))
2478+
if args.experts != '':
2479+
grok1_config['experts'] = [int(x, 0) for x in args.experts.split(',')]
2480+
2481+
grok1_config['num_experts'] = len(grok1_config['experts'])
2482+
2483+
if grok1_config['num_experts'] < 2:
2484+
raise Exception(f"at least 2 experts")
2485+
2486+
print(f"experts to export: {grok1_config['experts']}")
2487+
2488+
Grok1Converter.convert(AttributeDict(grok1_config), args.model_name_or_path, vocab, ggml_type, args.save_path)
2489+
return
2490+
22432491
def load_vocab(path: Path) -> Any:
22442492

22452493
def load_spm(p: Path) -> Any:
@@ -2329,11 +2577,17 @@ def main():
23292577
parser.add_argument("-o", "--save_path", type=Path)
23302578
parser.add_argument("-t", "--type", type=str, default="q8_0", choices=["f32", "f16", "q8_0", "q4_0", "q4_1"])
23312579
parser.add_argument("--vocab_dir", type=str, default='')
2580+
parser.add_argument("--experts", type=str, default='')
23322581
args = parser.parse_args()
23332582

23342583
ggml_type = GGMLType[args.type.upper()]
23352584

23362585
vocab = load_vocab(Path(args.model_name_or_path) if args.vocab_dir == '' else Path(args.vocab_dir))
2586+
2587+
if args.arch.lower() == 'grok-1-base':
2588+
convert_grok_1_base(args, vocab, ggml_type)
2589+
return
2590+
23372591
model_files = load_some_model(Path(args.model_name_or_path))
23382592

23392593
#if args.lora_model_name_or_path is not None:

docs/grok.md

+50
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
2+
# About Grok-1
3+
4+
Disclaimer: I am not sure if the implementation is correct or not, because I don't have enough compute resource to run the full model.
5+
6+
## Convert the model
7+
8+
To convert the base model, `jax` is needed:
9+
10+
```sh
11+
pip install jax[cpu]
12+
13+
```
14+
15+
Download the [base model](https://huggingface.co/xai-org/grok-1) and [repository](https://github.com/xai-org/grok-1).
16+
17+
Use `convert.py` to convert it, for example quantized to Q4_0:
18+
19+
```sh
20+
python convert.py -i /path/to/model/ckpt-0 --vocab_dir /path/to/repository -o grok.bin -a Grok-1-Base -t q4_0
21+
```
22+
23+
**Bonus**: Use `--experts` to export a subset of experts, such as `--experts 0,1,2,3` for the first 4 experts.
24+
The converted model will have less parameters but performance will degrade significantly.
25+
At least 2 experts are required. Remember that `NUM_EXPERTS` in `grok.cpp` should be the actual number of experts.
26+
27+
## Test
28+
29+
Below is a test run with the first 4 experts:
30+
31+
```sh
32+
./bin/main -m ../grok-1-4_q4_0.bin -i --temp 0 --max_length 1024
33+
34+
________ __ __ __ __ ___
35+
/ ____/ /_ ____ _/ /_/ / / / / |/ /_________ ____
36+
/ / / __ \/ __ `/ __/ / / / / /|_/ // ___/ __ \/ __ \
37+
/ /___/ / / / /_/ / /_/ /___/ /___/ / / // /__/ /_/ / /_/ /
38+
\____/_/ /_/\__,_/\__/_____/_____/_/ /_(_)___/ .___/ .___/
39+
You are served by Grok-1, /_/ /_/
40+
with 161064425472 (83.8B effect.) parameters.
41+
42+
You > what is your name?
43+
A.I. >
44+
45+
what is your age?
46+
47+
what is your weight?
48+
49+
...
50+
```

0 commit comments

Comments
 (0)