Skip to content

Commit 501dfa7

Browse files
caogaofacebook-github-bot
authored andcommitted
Add INT8 and INT4 support to P2P benchmark. (#918)
Summary: Pull Request resolved: #918 Add two options: 1) default, which transfers quantized tensor directly; 2) --include_quantization, which start with FP16 tensor, quantize, transfer, and finally dequantize to FP16 tensor. Also, add a option to sweep through data types and shapes. Caveat: INT4 dequantization is not numerically correct, but adding as a proxy for performance measurement. Reviewed By: brad-mengchi Differential Revision: D31098854 fbshipit-source-id: 7e4e4ca7f81f537c0fe37c91a36e46b862c28cdd
1 parent e385d02 commit 501dfa7

File tree

1 file changed

+245
-34
lines changed

1 file changed

+245
-34
lines changed

fbgemm_gpu/bench/merge_embeddings_benchmark.py

Lines changed: 245 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77

88
# pyre-unsafe
99

10+
import logging
11+
import signal
12+
1013
import click
1114
import numpy as np
1215
import tabulate
@@ -20,29 +23,87 @@
2023
torch.ops.load_library(
2124
"//deeplearning/fbgemm/fbgemm_gpu:merge_pooled_embeddings_cpu"
2225
)
26+
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops")
27+
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu")
2328

2429

25-
@click.command()
26-
@click.option("--all-to-one-only", is_flag=True, default=False)
27-
@click.option("--num-ads", default=1024, type=int)
28-
@click.option("--embedding-dimension", default=300, type=int)
29-
@click.option("--ads-tables", default=400, type=int)
30-
@click.option("--iters", default=10, type=int)
31-
@click.option("--p2p_bw", is_flag=True, default=False)
32-
@click.option("--dst-device", default=0, type=int)
33-
def main(
34-
all_to_one_only, num_ads, embedding_dimension, ads_tables, iters, p2p_bw, dst_device
35-
) -> None:
30+
def _get_random_tensor(
31+
num_ads: int,
32+
embedding_dimension: int,
33+
ads_tables: int,
34+
data_type: str,
35+
gpu_idx: int,
36+
include_quantization: bool,
37+
):
38+
if data_type == "FP16" or include_quantization:
39+
result_tensor = torch.randn(
40+
num_ads,
41+
embedding_dimension * ads_tables,
42+
dtype=torch.float16,
43+
device=torch.device(f"cuda:{gpu_idx}"),
44+
)
45+
elif data_type == "INT8":
46+
assert (
47+
embedding_dimension % 2
48+
) == 0, "needs to align to 2 bytes (half type size) for INT8"
49+
result_tensor = torch.randint(
50+
0,
51+
255,
52+
# 2 FP16 numbers for scale and bias, total of 4 bytes overhead
53+
size=(num_ads, (embedding_dimension + 4) * ads_tables),
54+
dtype=torch.uint8,
55+
device=torch.device(f"cuda:{gpu_idx}"),
56+
)
57+
elif data_type == "INT4":
58+
assert (
59+
embedding_dimension % 4
60+
) == 0, "needs to align to 2 bytes (half type size) for INT4"
61+
result_tensor = torch.randint(
62+
0,
63+
255,
64+
# Using torch.uint8 for int4 storage
65+
size=(num_ads, (embedding_dimension // 2 + 4) * ads_tables),
66+
dtype=torch.uint8,
67+
device=torch.device(f"cuda:{gpu_idx}"),
68+
)
69+
else:
70+
raise ValueError
71+
72+
return result_tensor
73+
74+
75+
def benchmark(
76+
all_to_one_only,
77+
num_ads,
78+
embedding_dimension,
79+
ads_tables,
80+
iters=10,
81+
p2p_bw=False,
82+
dst_device=0,
83+
data_type="FP16",
84+
include_quantization=False,
85+
) -> str:
3686
torch.cuda.set_device(dst_device)
3787
num_gpus = torch.cuda.device_count()
38-
ad_ds = [embedding_dimension * ads_tables for _ in range(num_gpus)]
3988
batch_indices = torch.zeros(num_ads).long().cuda()
4089
pooled_ad_embeddings = [
41-
torch.randn(
42-
num_ads, ad_d, dtype=torch.float16, device=torch.device(f"cuda:{i}")
90+
_get_random_tensor(
91+
num_ads,
92+
embedding_dimension,
93+
ads_tables,
94+
data_type,
95+
gpu_idx,
96+
include_quantization,
4397
)
44-
for i, ad_d in enumerate(ad_ds)
98+
for gpu_idx in range(num_gpus)
4599
]
100+
# Using torch.int8 for int4 storage
101+
bytes_per_element = 2 if (data_type == "FP16" or include_quantization) else 1
102+
total_elements = num_ads * embedding_dimension * ads_tables * num_gpus
103+
104+
logging.debug(
105+
f"B: {num_ads}, D: {embedding_dimension}, T: {ads_tables}, Data Type: {data_type}, Num GPUs: {num_gpus}, Destination GPU: {dst_device}"
106+
)
46107

47108
def benchmark_torch_function(iters: int, f, *args) -> float:
48109
f(*args)
@@ -68,7 +129,9 @@ def benchmark_torch_function(iters: int, f, *args) -> float:
68129
if i != j
69130
else pooled_ad_embeddings[i].clone(),
70131
)
71-
p2p_copy_bw[i, j] = pooled_ad_embeddings[i].numel() * 2 / t / 1.0e9
132+
p2p_copy_bw[i, j] = (
133+
pooled_ad_embeddings[i].numel() * bytes_per_element / t / 1.0e9
134+
)
72135
table = tabulate.tabulate(
73136
p2p_copy_bw,
74137
headers=[f"GPU {i}" for i in range(num_gpus)],
@@ -80,32 +143,180 @@ def benchmark_torch_function(iters: int, f, *args) -> float:
80143
streams = [torch.cuda.Stream(device=i) for i in range(num_gpus)]
81144
import contextlib
82145

83-
with contextlib.ExitStack() as stack:
84-
for stream in streams:
85-
stack.enter_context(torch.cuda.stream(stream))
146+
def pool_func_with_quantization(
147+
pooled_ad_embeddings,
148+
batch_indices,
149+
include_quantization,
150+
data_type,
151+
):
152+
if include_quantization:
153+
assert data_type == "INT8" or data_type == "INT4"
154+
quantized = [
155+
torch.ops.fbgemm.FloatToFused8BitRowwiseQuantized(t.float())
156+
if data_type == "INT8"
157+
else torch.ops.fbgemm.FloatToFusedNBitRowwiseQuantizedSBHalf(
158+
t.float(), 4
159+
)
160+
for t in pooled_ad_embeddings
161+
]
162+
pooled_quantized_result = torch.ops.fbgemm.merge_pooled_embeddings(
163+
quantized, batch_indices.size(0), batch_indices.device
164+
)
165+
PooledEmbeddingDequantizeDataTypeFP16 = 1
86166

87-
merged = torch.ops.fbgemm.merge_pooled_embeddings(
88-
pooled_ad_embeddings, batch_indices.size(0), batch_indices.device
89-
)
167+
if data_type == "INT8":
168+
offset = torch.cumsum(
169+
torch.tensor(
170+
[0] + [quantized[0].shape[1] for _ in range(len(quantized))],
171+
device=batch_indices.device,
172+
),
173+
dim=0,
174+
).to(torch.int)
175+
return torch.ops.fbgemm.Fused8BitRowwiseQuantizedToFloatMixedDim(
176+
pooled_quantized_result,
177+
offset,
178+
PooledEmbeddingDequantizeDataTypeFP16,
179+
)
180+
else:
181+
# TODO: the result here is wrong. Once MixedDim version for FusedNBit quantization is done, switch to that.
182+
# Since their performance is similar, keep using FusedNBitRowwiseQuantizedSBHalfToFloat for now.
183+
return torch.ops.fbgemm.FusedNBitRowwiseQuantizedSBHalfToFloat(
184+
pooled_quantized_result, 4
185+
).half()
90186

91187
if all_to_one_only:
92-
t = benchmark_torch_function(
93-
iters,
94-
lambda: torch.ops.fbgemm.all_to_one_device(
95-
pooled_ad_embeddings, batch_indices.device
96-
),
188+
return torch.ops.fbgemm.all_to_one_device(
189+
pooled_ad_embeddings, batch_indices.device
97190
)
98191
else:
99-
t = benchmark_torch_function(
100-
iters,
101-
lambda: torch.ops.fbgemm.merge_pooled_embeddings(
102-
pooled_ad_embeddings, batch_indices.size(0), batch_indices.device
103-
),
192+
return torch.ops.fbgemm.merge_pooled_embeddings(
193+
pooled_ad_embeddings, batch_indices.size(0), batch_indices.device
104194
)
105195

106-
print(
107-
f"Merge, B: {num_ads}, D: {embedding_dimension}, T: {ads_tables}, Num GPUs: {num_gpus}, Destination GPU: {dst_device} Output Size: {merged.numel() * 2 / 1.0e6:.2f}MB, BW: {merged.numel() * 2 / t / 1.0e9:.2f}GB/s, t: {t * 1.0e3:.2f}ms"
196+
with contextlib.ExitStack() as stack:
197+
for stream in streams:
198+
stack.enter_context(torch.cuda.stream(stream))
199+
200+
merged = pool_func_with_quantization(
201+
pooled_ad_embeddings, batch_indices, include_quantization, data_type
202+
)
203+
t = benchmark_torch_function(
204+
iters,
205+
lambda: pool_func_with_quantization(
206+
pooled_ad_embeddings, batch_indices, include_quantization, data_type
207+
),
208+
)
209+
210+
logging.debug(
211+
f"Merge, B: {num_ads}, D: {embedding_dimension}, T: {ads_tables}, Data Type: {data_type}, Num GPUs: {num_gpus}, Destination GPU: {dst_device}, "
212+
f"Number of elements: {total_elements / 1.0e6:.0f} Million, Billion elements per sec: {total_elements / t / 1.0e9:.1f}, "
213+
f"Output Size: {merged.numel() * bytes_per_element / 1.0e6:.0f}MB, BW: {merged.numel() * bytes_per_element / t / 1.0e9:.1f}GB/s, "
214+
f"t: {t * 1.0e3:.2f}ms"
215+
)
216+
# return result in CSV format
217+
return (
218+
f"{num_ads}, {embedding_dimension}, {ads_tables}, {data_type}, {num_gpus}, {dst_device}, "
219+
f"{total_elements / 1.0e6:.0f}, {total_elements / t / 1.0e9:.1f}, "
220+
f"{merged.numel() * bytes_per_element / 1.0e6:.0f}, {merged.numel() * bytes_per_element / t / 1.0e9:.1f}, "
221+
f"{t * 1.0e3:.2f}"
222+
)
223+
224+
225+
@click.command()
226+
@click.option("--all-to-one-only", is_flag=True, default=False)
227+
@click.option("--num_ads", default=1024, type=int)
228+
@click.option("--embedding_dimension", default=300, type=int)
229+
@click.option("--ads_tables", default=100, type=int)
230+
@click.option("--iters", default=10, type=int)
231+
@click.option("--p2p_bw", is_flag=True, default=False)
232+
@click.option("--dst_device", default=0, type=int)
233+
@click.option(
234+
"--data_type",
235+
type=click.Choice(["FP16", "INT8", "INT4"]),
236+
default="FP16",
237+
)
238+
# For INT8/INT4 data type, whether to start with FP16 and include quantization overhead
239+
@click.option("--include_quantization", is_flag=True, default=False)
240+
@click.option("--sweep", is_flag=True, default=False)
241+
def main(
242+
all_to_one_only,
243+
num_ads,
244+
embedding_dimension,
245+
ads_tables,
246+
iters,
247+
p2p_bw,
248+
dst_device,
249+
data_type,
250+
include_quantization,
251+
sweep,
252+
) -> None:
253+
assert sweep or not (
254+
include_quantization and data_type == "FP16"
255+
), "no quantization is needed for FP16"
256+
257+
csv_header = (
258+
"num_ads, embedding_dimension, ads_tables, data_type, num_gpus,"
259+
"dst_device, number of elements (Million), throughput (billion elements per sec), "
260+
"output size (MB), BW (GB/s), t (ms)"
261+
)
262+
if sweep:
263+
264+
def handler(signum, frame):
265+
logging.error("timeout")
266+
raise TimeoutError()
267+
268+
results = []
269+
num_gpu = torch.cuda.device_count()
270+
for num_ads in [128, 256, 512, 1024, 2048]:
271+
# Scale num_ads so all GPUs have sweep through the same number of total elements
272+
num_ads *= 8 // num_gpu
273+
for embedding_dimension in [16, 64, 104, 300]:
274+
for ads_tables in [25, 50, 100, 400, 800]:
275+
data_type_list = (
276+
["INT8", "INT4"]
277+
if include_quantization
278+
else ["FP16", "INT8", "INT4"]
279+
)
280+
for data_type in data_type_list:
281+
if num_ads * embedding_dimension * ads_tables > 1228800000:
282+
continue # Skip tests that are too large
283+
signal.signal(signal.SIGTERM, handler)
284+
signal.alarm(600)
285+
try:
286+
result = benchmark(
287+
all_to_one_only,
288+
num_ads,
289+
embedding_dimension,
290+
ads_tables,
291+
iters,
292+
p2p_bw,
293+
dst_device,
294+
data_type,
295+
include_quantization,
296+
)
297+
results.append(result)
298+
except (TimeoutError, RuntimeError) as err:
299+
logging.error(f"timed out or failed: {err}")
300+
logging.error(
301+
f"B: {num_ads}, D: {embedding_dimension}, T: {ads_tables}, Data Type: {data_type}, Num GPU: {num_gpu}"
302+
)
303+
print(csv_header)
304+
print(*results, sep="\n")
305+
return
306+
307+
result = benchmark(
308+
all_to_one_only,
309+
num_ads,
310+
embedding_dimension,
311+
ads_tables,
312+
iters,
313+
p2p_bw,
314+
dst_device,
315+
data_type,
316+
include_quantization,
108317
)
318+
print(csv_header)
319+
print(result)
109320

110321

111322
if __name__ == "__main__":

0 commit comments

Comments
 (0)