Skip to content

Commit 0da7b40

Browse files
Michael Gschwindlarryliu0820
Michael Gschwind
authored andcommitted
4b embedding quantizer (#3081)
Summary: 4b embedding quantizer Reviewed By: larryliu0820 Differential Revision: D56229021
1 parent 3d7dcd5 commit 0da7b40

File tree

1 file changed

+92
-38
lines changed

1 file changed

+92
-38
lines changed

examples/models/llama2/quantize.py

Lines changed: 92 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,10 @@ def dynamically_quantize_per_channel(
122122
return quant, scales, zero_points
123123

124124

125+
#########################################################################
126+
### QuantHandler API definition ###
127+
128+
125129
class QuantHandler:
126130
def __init__(self, mod):
127131
self.mod = mod
@@ -132,8 +136,15 @@ def create_quantized_state_dict(self) -> Dict: # "StateDict"
132136
def convert_for_runtime(self) -> nn.Module:
133137
pass
134138

139+
def quantized_model(self) -> nn.Module:
140+
model_updated_state_dict = self.create_quantized_state_dict()
141+
self.convert_for_runtime()
142+
self.mod.load_state_dict(model_updated_state_dict)
143+
return self.mod
135144

136-
##### Weight-only int8 per-channel quantized code ######
145+
146+
#########################################################################
147+
### Weight-only int8 per-channel quantized code ###
137148

138149

139150
def replace_linear_weight_only_int8_per_channel(module, node_type):
@@ -151,16 +162,17 @@ def replace_linear_weight_only_int8_per_channel(module, node_type):
151162
setattr(
152163
module,
153164
name,
154-
WeightOnlyInt8Linear(child.in_features, child.out_features),
165+
WeightOnlyInt8Linear("cpu", child.in_features, child.out_features),
155166
)
156167
else:
157168
replace_linear_weight_only_int8_per_channel(child, node_type)
158169

159170

160-
class WeightOnlyInt8QuantHandler:
171+
class WeightOnlyInt8QuantHandler(QuantHandler):
161172
def __init__(
162173
self,
163174
mod,
175+
device="cpu",
164176
*,
165177
node_type: str = "*",
166178
bitwidth: Optional[int] = None,
@@ -200,7 +212,7 @@ def create_quantized_state_dict(self) -> Dict:
200212
)
201213
):
202214
print(
203-
f"quantize {self.node_type} {fqn, mod} with groupsize {self.group_size}, bitwidth {self.bitwidth}"
215+
f"quantize {self.node_type} {fqn, mod} with group_size {self.group_size}, bitwidth {self.bitwidth}"
204216
)
205217

206218
# print(f"initial weight shape {mod.weight.shape}")
@@ -217,7 +229,7 @@ def create_quantized_state_dict(self) -> Dict:
217229
)
218230

219231
cur_state_dict[f"{fqn}.weight"] = weight
220-
# squeeze makes groupsize=rowsize unidimensional
232+
# squeeze makes group_size=rowsize unidimensional
221233
cur_state_dict[f"{fqn}.scales"] = scales.squeeze(dim=-1)
222234

223235
return cur_state_dict
@@ -241,10 +253,10 @@ class WeightOnlyInt8Linear(torch.nn.Module):
241253

242254
def __init__(
243255
self,
256+
device,
244257
in_features: int,
245258
out_features: int,
246259
bias: bool = True,
247-
device=None,
248260
dtype=None,
249261
) -> None:
250262
super().__init__()
@@ -260,11 +272,12 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
260272
# return F.linear(input, self.weight.to(dtype=input.dtype)) * se...
261273

262274

263-
##### embedding table quantization ######
275+
#########################################################################
276+
##### embedding table quantization ######
264277

265278

266279
def replace_embedding_weight_only_grouped_int8_per_channel(
267-
module, bitwidth: int = 8, group_size: Optional[int] = None
280+
module, device, bitwidth: int = 8, group_size: Optional[int] = None, packed=False
268281
):
269282
for name, child in module.named_children():
270283
# print(f"name: {name}")
@@ -275,25 +288,41 @@ def replace_embedding_weight_only_grouped_int8_per_channel(
275288
module,
276289
name,
277290
QuantizedGroupEmbedding(
291+
device=device,
278292
vocab_size=child.weight.shape[0],
279293
embedding_dim=child.weight.shape[1],
280294
group_size=group_size,
295+
packed=packed,
281296
),
282297
)
283298
else:
284299
replace_embedding_weight_only_grouped_int8_per_channel(
285-
child, bitwidth, group_size
300+
child, device, bitwidth, group_size, packed
286301
)
287302

288303

289-
class EmbeddingOnlyInt8QuantHandler:
290-
def __init__(self, mod, *, bitwidth: int = 8, group_size: Optional[int] = None):
304+
class EmbeddingOnlyInt8QuantHandler(QuantHandler):
305+
def __init__(
306+
self,
307+
mod,
308+
device="cpu",
309+
*,
310+
bitwidth: int = 8,
311+
group_size: Optional[int] = None,
312+
packed=False,
313+
):
314+
if isinstance(packed, str):
315+
packed = packed == "True"
291316
self.mod = mod
317+
self.device = device
292318
self.group_size = group_size
293319
self.bitwidth = bitwidth
320+
self.packed = packed
321+
if (bitwidth != 4) and packed:
322+
raise RuntimeError("pack only works with bitsize 4")
294323

295324
@torch.no_grad()
296-
def create_quantized_state_dict(self) -> Dict:
325+
def create_quantized_state_dict(self, packed=False) -> Dict:
297326
cur_state_dict = self.mod.state_dict()
298327

299328
if self.bitwidth == 4:
@@ -306,18 +335,14 @@ def create_quantized_state_dict(self) -> Dict:
306335
raise ValueError(f"Unsupported bitwidth {self.bitwidth}")
307336

308337
for fqn, mod in self.mod.named_modules():
309-
if (
310-
isinstance(mod, nn.Embedding)
311-
or isinstance(mod, fsEmbedding)
312-
or isinstance(mod, fsStandardEmbedding)
313-
):
338+
if isinstance(mod, nn.Embedding):
314339
# print("****")
315340
# print(f"Embedding identified: {fqn, mod}")
316341
# print(f"weights size: {mod.weight.size()}")
317342
# print(f"quantize {fqn}...")
318343

319344
print(
320-
f"quantize {fqn, mod} with groupsize {self.group_size}, bitwidth {self.bitwidth}"
345+
f"quantize {fqn, mod} with group_size {self.group_size}, bitwidth {self.bitwidth}"
321346
)
322347
weight, scales, _ = dynamically_quantize_per_channel(
323348
mod.weight.float(),
@@ -328,21 +353,36 @@ def create_quantized_state_dict(self) -> Dict:
328353
scales_dtype=mod.weight.dtype,
329354
)
330355

356+
if packed:
357+
if weight.shape[-1] % 2 != 0:
358+
raise RuntimeError("automatic padding not implemented yet")
359+
360+
weight_range_shifted = weight.add(8).view(torch.uint8)
361+
weight_view = weight_range_shifted.view(
362+
weight.shape[0], weight.shape[1] // 2, 2
363+
)
364+
weight_even = weight_view[:, :, 0] * 16 # left shift 4
365+
weight_odd = weight_view[:, :, 1]
366+
weight_packed = weight_even + weight_odd
367+
weight = weight_packed
368+
369+
weight = weight.to(device=self.device)
370+
scales = scales.to(device=self.device)
331371
# Update state dict
332372
cur_state_dict[f"{fqn}.weight"] = weight
333-
# squeeze makes groupsize=rowsize unidimensional
373+
# squeeze makes group_size=rowsize unidimensional
334374
cur_state_dict[f"{fqn}.scales"] = scales.squeeze(dim=-1)
335375

336376
return cur_state_dict
337377

338378
def convert_for_runtime(self) -> nn.Module:
339379
replace_embedding_weight_only_grouped_int8_per_channel(
340-
self.mod, self.bitwidth, self.group_size
380+
self.mod, self.device, self.bitwidth, self.group_size, self.packed
341381
)
342382
return self.mod
343383

344384
def quantized_model(self) -> nn.Module:
345-
model_updated_state_dict = self.create_quantized_state_dict()
385+
model_updated_state_dict = self.create_quantized_state_dict(self.packed)
346386
self.convert_for_runtime()
347387
self.mod.load_state_dict(model_updated_state_dict)
348388
return self.mod
@@ -351,39 +391,53 @@ def quantized_model(self) -> nn.Module:
351391
class QuantizedGroupEmbedding(torch.nn.Module):
352392
def __init__(
353393
self,
394+
device,
354395
vocab_size: int,
355396
embedding_dim: int,
356397
group_size: Optional[int] = None,
357-
device=None,
358398
dtype=torch.half,
399+
packed=False,
359400
) -> None:
360401
super().__init__()
361-
if group_size is None:
402+
if group_size is None or group_size == 0:
362403
group_size = embedding_dim
363404
self.group_size = group_size
364405
self.dtype = dtype
365-
self.register_buffer(
366-
"weight", torch.empty((vocab_size, embedding_dim), dtype=torch.int8)
367-
)
406+
self.packed = packed
407+
if not packed:
408+
self.register_buffer(
409+
"weight",
410+
torch.empty(
411+
(vocab_size, embedding_dim), dtype=torch.int8, device=device
412+
),
413+
)
414+
else: # packed
415+
self.register_buffer(
416+
"weight",
417+
torch.empty(
418+
(vocab_size, embedding_dim // 2), dtype=torch.uint8, device=device
419+
),
420+
)
368421
groups_per_row = (embedding_dim + group_size - 1) // group_size
369422
if groups_per_row > 1:
370423
self.register_buffer(
371-
"scales", torch.ones((vocab_size, groups_per_row), dtype=torch.float16)
424+
"scales",
425+
torch.ones(
426+
(vocab_size, groups_per_row), dtype=torch.float16, device=device
427+
),
372428
)
373429
else:
374430
self.register_buffer(
375-
"scales", torch.ones((vocab_size,), dtype=torch.float16)
431+
"scales", torch.ones((vocab_size,), dtype=torch.float16, device=device)
376432
)
377433

378434
@torch.no_grad()
379435
def forward(self, indices: torch.Tensor) -> torch.Tensor:
380-
return torch.ops.llama_quantized.DEPRECATED_DO_NOT_USE_embedding_byte.dtype(
381-
self.weight, self.scales, None, 0, 0, indices, dtype=self.dtype
382-
)
383-
384-
385-
# result_weights = self.weight.index_select(0, indices.view(-1))
386-
# result_scales = self.scales.index_select(0, indices.view(-1))
387-
#
388-
# r = result_weights.to(dtype=result_scales.dtype) * result_scales
389-
# return r.view(indices.size() + (-1,))
436+
if not self.packed: # 8bit
437+
return torch.ops.llama_quantized.DEPRECATED_DO_NOT_USE_embedding_byte.dtype(
438+
self.weight, self.scales, None, 0, 0, indices, dtype=self.dtype
439+
)
440+
else: # 4bit packed
441+
return torch.ops.llama_quantized.embedding_4bit.dtype(
442+
self.weight, self.scales, None, 0, 0, indices, dtype=self.dtype
443+
)

0 commit comments

Comments
 (0)