Skip to content

Commit 1dc7a5c

Browse files
committed
Define embedding_4bit ops
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent 9dab3e8 commit 1dc7a5c

File tree

1 file changed

+267
-1
lines changed

1 file changed

+267
-1
lines changed

exir/passes/_quant_patterns_and_replacements.py

Lines changed: 267 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import copy
8-
from typing import Callable, List, Tuple
8+
from typing import Callable, List, Optional, Tuple
99

1010
import torch
1111
from executorch.exir.dialects._ops import bind_pattern_to_op, ops as exir_ops
@@ -15,6 +15,7 @@
1515
)
1616
from torch import fx
1717
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib
18+
from torch.library import impl, impl_abstract
1819

1920

2021
__all__ = [
@@ -31,9 +32,274 @@
3132

3233
quantized_decomposed_lib.define(
3334
"embedding_byte.dtype(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, "
35+
"int weight_quant_min, int weight_quant_max, Tensor indices, ScalarType? dtype=None) -> Tensor",
36+
)
37+
38+
quantized_decomposed_lib.define(
39+
"embedding_byte.out(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, "
40+
"int weight_quant_min, int weight_quant_max, Tensor indices, *, Tensor(a!) out) -> Tensor(a!)",
41+
)
42+
43+
quantized_decomposed_lib.define(
44+
"embedding_byte.dtype_out(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, "
45+
"int weight_quant_min, int weight_quant_max, Tensor indices, ScalarType? dtype=None, *, Tensor(a!) out) -> Tensor(a!)",
46+
)
47+
48+
49+
def embedding_weight_checks(weight, weight_scales, weight_zero_points):
50+
assert weight.dtype in [
51+
torch.int8,
52+
torch.uint8,
53+
], f"Expecting weights to be of dtype in [torch.int8, torch.uint8], but got {weight.dtype}"
54+
assert (
55+
weight.dim() == 2
56+
), f"Expecting weight tensor to have dim()==2, but found {weight.dim()}"
57+
58+
assert weight_scales.dtype in [
59+
torch.float16,
60+
torch.float32,
61+
], f"Expecting weight_scales to be of dtype in [torch.float16, torch.float32], but got {weight_scales.dtype}"
62+
assert (
63+
weight_scales.dim() == 1 or weight_scales.dim() == 2
64+
), f"Expecting weight_scales tensor to have rank 1 or 2, but found {weight_scales.dim()}"
65+
assert weight_scales.size(0) == weight.size(
66+
0
67+
), f"Expecting weight and scale tensor to have same number of rows, but found {weight.size()} and {weight_scales.size()}"
68+
69+
assert (
70+
weight_zero_points is None or weight_zero_points.dtype == weight_scales.dtype
71+
), "Expecting weight_zero_points to be None or have same dtype as weight_scales"
72+
assert (
73+
weight_zero_points is None or weight_zero_points.dim() == 1
74+
), f"Expecting weight_zero_points tensor to be None or have dim()==1, but found {weight_zero_points.dim()}"
75+
assert weight_zero_points is None or weight_zero_points.size(0) == weight.size(
76+
0
77+
), f"Expecting weight_zero_points tensor to be None or have same number of rows as weights, but found {weight.size()} and {weight_zero_points.size()}"
78+
79+
80+
@impl(quantized_decomposed_lib, "embedding_byte", "CompositeExplicitAutograd")
81+
def embedding_byte(
82+
weight: torch.Tensor,
83+
weight_scales: torch.Tensor,
84+
weight_zero_points: Optional[torch.Tensor],
85+
weight_quant_min: int,
86+
weight_quant_max: int,
87+
indices: torch.Tensor,
88+
) -> torch.Tensor:
89+
embedding_weight_checks(weight, weight_scales, weight_zero_points)
90+
group_size = weight.size(1) // (
91+
weight_scales.size(1) if weight_scales.dim() == 2 else 1
92+
)
93+
weight = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(
94+
weight,
95+
weight_scales,
96+
weight_zero_points,
97+
weight_quant_min,
98+
weight_quant_max,
99+
weight.dtype,
100+
group_size,
101+
weight_scales.dtype,
102+
)
103+
return torch.ops.aten.embedding.default(weight, indices)
104+
105+
106+
@impl_abstract("quantized_decomposed::embedding_byte.out")
107+
def embedding_byte_out_meta(
108+
weight: torch.Tensor,
109+
weight_scales: torch.Tensor,
110+
weight_zero_points: Optional[torch.Tensor],
111+
weight_quant_min: int,
112+
weight_quant_max: int,
113+
indices: torch.Tensor,
114+
out: torch.Tensor,
115+
) -> torch.Tensor:
116+
return embedding_byte(
117+
weight,
118+
weight_scales,
119+
weight_zero_points,
120+
weight_quant_min,
121+
weight_quant_max,
122+
indices,
123+
)
124+
125+
126+
@impl(quantized_decomposed_lib, "embedding_byte.dtype", "CompositeExplicitAutograd")
127+
def embedding_byte_dtype(
128+
weight: torch.Tensor,
129+
weight_scales: torch.Tensor,
130+
weight_zero_points: Optional[torch.Tensor],
131+
weight_quant_min: int,
132+
weight_quant_max: int,
133+
indices: torch.Tensor,
134+
dtype: Optional[torch.dtype],
135+
) -> torch.Tensor:
136+
embedding_weight_checks(weight, weight_scales, weight_zero_points)
137+
group_size = weight.size(1) // (
138+
weight_scales.size(1) if weight_scales.dim() == 2 else 1
139+
)
140+
weight = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(
141+
weight,
142+
weight_scales,
143+
weight_zero_points,
144+
weight_quant_min,
145+
weight_quant_max,
146+
weight.dtype,
147+
group_size,
148+
dtype,
149+
)
150+
return torch.ops.aten.embedding.default(weight, indices)
151+
152+
153+
@impl_abstract("quantized_decomposed::embedding_byte.dtype_out")
154+
def embedding_byte_dtype_out_meta(
155+
weight: torch.Tensor,
156+
weight_scales: torch.Tensor,
157+
weight_zero_points: Optional[torch.Tensor],
158+
weight_quant_min: int,
159+
weight_quant_max: int,
160+
indices: torch.Tensor,
161+
dtype: Optional[torch.dtype],
162+
out: torch.Tensor,
163+
) -> torch.Tensor:
164+
return embedding_byte_dtype(
165+
weight,
166+
weight_scales,
167+
weight_zero_points,
168+
weight_quant_min,
169+
weight_quant_max,
170+
indices,
171+
dtype,
172+
)
173+
174+
175+
quantized_decomposed_lib.define(
176+
"embedding_4bit(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, "
177+
"int weight_quant_min, int weight_quant_max, Tensor indices) -> Tensor",
178+
)
179+
180+
quantized_decomposed_lib.define(
181+
"embedding_4bit.dtype(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, "
34182
"int weight_quant_min, int weight_quant_max, Tensor indices, *, ScalarType? dtype=None) -> Tensor",
35183
)
36184

185+
quantized_decomposed_lib.define(
186+
"embedding_4bit.out(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, "
187+
"int weight_quant_min, int weight_quant_max, Tensor indices, *, Tensor(a!) out) -> Tensor(a!)",
188+
)
189+
190+
quantized_decomposed_lib.define(
191+
"embedding_4bit.dtype_out(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, "
192+
"int weight_quant_min, int weight_quant_max, Tensor indices, ScalarType? dtype=None, *, Tensor(a!) out) -> Tensor(a!)",
193+
)
194+
195+
196+
@impl(quantized_decomposed_lib, "embedding_4bit", "CompositeExplicitAutograd")
197+
def embedding_4bit(
198+
weight: torch.Tensor,
199+
weight_scales: torch.Tensor,
200+
weight_zero_points: Optional[torch.Tensor],
201+
weight_quant_min: int,
202+
weight_quant_max: int,
203+
indices: torch.Tensor,
204+
) -> torch.Tensor:
205+
embedding_weight_checks(weight, weight_scales, weight_zero_points)
206+
group_size = (2 * weight.size(1)) // (
207+
weight_scales.size(1) if weight_scales.dim() == 2 else 1
208+
)
209+
weight_even = weight.div(16, rounding_mode="trunc")
210+
weight_odd = weight.remainder(16)
211+
weight_unpacked = torch.stack((weight_even, weight_odd), dim=-1)
212+
weight = weight_unpacked.view(weight.shape[0], -1)
213+
weight = weight.view(torch.int8).add(-8)
214+
215+
weight = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(
216+
weight,
217+
weight_scales,
218+
weight_zero_points,
219+
weight_quant_min,
220+
weight_quant_max,
221+
weight.dtype,
222+
group_size,
223+
weight_scales.dtype,
224+
)
225+
return torch.ops.aten.embedding.default(weight, indices)
226+
227+
228+
@impl_abstract("quantized_decomposed::embedding_4bit.out")
229+
def embedding_4bit_out_meta(
230+
weight: torch.Tensor,
231+
weight_scales: torch.Tensor,
232+
weight_zero_points: Optional[torch.Tensor],
233+
weight_quant_min: int,
234+
weight_quant_max: int,
235+
indices: torch.Tensor,
236+
out: torch.Tensor,
237+
) -> torch.Tensor:
238+
return embedding_4bit(
239+
weight,
240+
weight_scales,
241+
weight_zero_points,
242+
weight_quant_min,
243+
weight_quant_max,
244+
indices,
245+
)
246+
247+
248+
@impl(quantized_decomposed_lib, "embedding_4bit.dtype", "CompositeExplicitAutograd")
249+
def embedding_4bit_dtype(
250+
weight: torch.Tensor,
251+
weight_scales: torch.Tensor,
252+
weight_zero_points: Optional[torch.Tensor],
253+
weight_quant_min: int,
254+
weight_quant_max: int,
255+
indices: torch.Tensor,
256+
dtype: Optional[torch.dtype],
257+
) -> torch.Tensor:
258+
embedding_weight_checks(weight, weight_scales, weight_zero_points)
259+
group_size = (2 * weight.size(1)) // (
260+
weight_scales.size(1) if weight_scales.dim() == 2 else 1
261+
)
262+
weight_even = weight.div(16, rounding_mode="trunc")
263+
weight_odd = weight.remainder(16)
264+
weight_unpacked = torch.stack((weight_even, weight_odd), dim=-1)
265+
weight = weight_unpacked.view(weight.shape[0], -1)
266+
weight = weight.view(torch.int8).add(-8)
267+
268+
weight = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(
269+
weight,
270+
weight_scales,
271+
weight_zero_points,
272+
weight_quant_min,
273+
weight_quant_max,
274+
weight.dtype,
275+
group_size,
276+
dtype,
277+
)
278+
return torch.ops.aten.embedding.default(weight, indices)
279+
280+
281+
@impl_abstract("quantized_decomposed::embedding_4bit.dtype_out")
282+
def embedding_4bit_dtype_out_meta(
283+
weight: torch.Tensor,
284+
weight_scales: torch.Tensor,
285+
weight_zero_points: Optional[torch.Tensor],
286+
weight_quant_min: int,
287+
weight_quant_max: int,
288+
indices: torch.Tensor,
289+
dtype: Optional[torch.dtype],
290+
out: torch.Tensor,
291+
) -> torch.Tensor:
292+
return embedding_4bit_dtype(
293+
weight,
294+
weight_scales,
295+
weight_zero_points,
296+
weight_quant_min,
297+
weight_quant_max,
298+
indices,
299+
dtype,
300+
)
301+
302+
37303
quantized_decomposed_lib.define(
38304
"mixed_mm(Tensor input, Tensor weight, Tensor weight_scales, Tensor? weight_zero_points) -> Tensor",
39305
)

0 commit comments

Comments
 (0)