7
7
8
8
# pyre-unsafe
9
9
10
+ import logging
11
+ import signal
12
+
10
13
import click
11
14
import numpy as np
12
15
import tabulate
20
23
torch .ops .load_library (
21
24
"//deeplearning/fbgemm/fbgemm_gpu:merge_pooled_embeddings_cpu"
22
25
)
26
+ torch .ops .load_library ("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops" )
27
+ torch .ops .load_library ("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu" )
23
28
24
29
25
- @click .command ()
26
- @click .option ("--all-to-one-only" , is_flag = True , default = False )
27
- @click .option ("--num-ads" , default = 1024 , type = int )
28
- @click .option ("--embedding-dimension" , default = 300 , type = int )
29
- @click .option ("--ads-tables" , default = 400 , type = int )
30
- @click .option ("--iters" , default = 10 , type = int )
31
- @click .option ("--p2p_bw" , is_flag = True , default = False )
32
- @click .option ("--dst-device" , default = 0 , type = int )
33
- def main (
34
- all_to_one_only , num_ads , embedding_dimension , ads_tables , iters , p2p_bw , dst_device
35
- ) -> None :
30
+ def _get_random_tensor (
31
+ num_ads : int ,
32
+ embedding_dimension : int ,
33
+ ads_tables : int ,
34
+ data_type : str ,
35
+ gpu_idx : int ,
36
+ include_quantization : bool ,
37
+ ):
38
+ if data_type == "FP16" or include_quantization :
39
+ result_tensor = torch .randn (
40
+ num_ads ,
41
+ embedding_dimension * ads_tables ,
42
+ dtype = torch .float16 ,
43
+ device = torch .device (f"cuda:{ gpu_idx } " ),
44
+ )
45
+ elif data_type == "INT8" :
46
+ assert (
47
+ embedding_dimension % 2
48
+ ) == 0 , "needs to align to 2 bytes (half type size) for INT8"
49
+ result_tensor = torch .randint (
50
+ 0 ,
51
+ 255 ,
52
+ # 2 FP16 numbers for scale and bias, total of 4 bytes overhead
53
+ size = (num_ads , (embedding_dimension + 4 ) * ads_tables ),
54
+ dtype = torch .uint8 ,
55
+ device = torch .device (f"cuda:{ gpu_idx } " ),
56
+ )
57
+ elif data_type == "INT4" :
58
+ assert (
59
+ embedding_dimension % 4
60
+ ) == 0 , "needs to align to 2 bytes (half type size) for INT4"
61
+ result_tensor = torch .randint (
62
+ 0 ,
63
+ 255 ,
64
+ # Using torch.uint8 for int4 storage
65
+ size = (num_ads , (embedding_dimension // 2 + 4 ) * ads_tables ),
66
+ dtype = torch .uint8 ,
67
+ device = torch .device (f"cuda:{ gpu_idx } " ),
68
+ )
69
+ else :
70
+ raise ValueError
71
+
72
+ return result_tensor
73
+
74
+
75
+ def benchmark (
76
+ all_to_one_only ,
77
+ num_ads ,
78
+ embedding_dimension ,
79
+ ads_tables ,
80
+ iters = 10 ,
81
+ p2p_bw = False ,
82
+ dst_device = 0 ,
83
+ data_type = "FP16" ,
84
+ include_quantization = False ,
85
+ ) -> str :
36
86
torch .cuda .set_device (dst_device )
37
87
num_gpus = torch .cuda .device_count ()
38
- ad_ds = [embedding_dimension * ads_tables for _ in range (num_gpus )]
39
88
batch_indices = torch .zeros (num_ads ).long ().cuda ()
40
89
pooled_ad_embeddings = [
41
- torch .randn (
42
- num_ads , ad_d , dtype = torch .float16 , device = torch .device (f"cuda:{ i } " )
90
+ _get_random_tensor (
91
+ num_ads ,
92
+ embedding_dimension ,
93
+ ads_tables ,
94
+ data_type ,
95
+ gpu_idx ,
96
+ include_quantization ,
43
97
)
44
- for i , ad_d in enumerate ( ad_ds )
98
+ for gpu_idx in range ( num_gpus )
45
99
]
100
+ # Using torch.int8 for int4 storage
101
+ bytes_per_element = 2 if (data_type == "FP16" or include_quantization ) else 1
102
+ total_elements = num_ads * embedding_dimension * ads_tables * num_gpus
103
+
104
+ logging .debug (
105
+ f"B: { num_ads } , D: { embedding_dimension } , T: { ads_tables } , Data Type: { data_type } , Num GPUs: { num_gpus } , Destination GPU: { dst_device } "
106
+ )
46
107
47
108
def benchmark_torch_function (iters : int , f , * args ) -> float :
48
109
f (* args )
@@ -68,7 +129,9 @@ def benchmark_torch_function(iters: int, f, *args) -> float:
68
129
if i != j
69
130
else pooled_ad_embeddings [i ].clone (),
70
131
)
71
- p2p_copy_bw [i , j ] = pooled_ad_embeddings [i ].numel () * 2 / t / 1.0e9
132
+ p2p_copy_bw [i , j ] = (
133
+ pooled_ad_embeddings [i ].numel () * bytes_per_element / t / 1.0e9
134
+ )
72
135
table = tabulate .tabulate (
73
136
p2p_copy_bw ,
74
137
headers = [f"GPU { i } " for i in range (num_gpus )],
@@ -80,32 +143,180 @@ def benchmark_torch_function(iters: int, f, *args) -> float:
80
143
streams = [torch .cuda .Stream (device = i ) for i in range (num_gpus )]
81
144
import contextlib
82
145
83
- with contextlib .ExitStack () as stack :
84
- for stream in streams :
85
- stack .enter_context (torch .cuda .stream (stream ))
146
+ def pool_func_with_quantization (
147
+ pooled_ad_embeddings ,
148
+ batch_indices ,
149
+ include_quantization ,
150
+ data_type ,
151
+ ):
152
+ if include_quantization :
153
+ assert data_type == "INT8" or data_type == "INT4"
154
+ quantized = [
155
+ torch .ops .fbgemm .FloatToFused8BitRowwiseQuantized (t .float ())
156
+ if data_type == "INT8"
157
+ else torch .ops .fbgemm .FloatToFusedNBitRowwiseQuantizedSBHalf (
158
+ t .float (), 4
159
+ )
160
+ for t in pooled_ad_embeddings
161
+ ]
162
+ pooled_quantized_result = torch .ops .fbgemm .merge_pooled_embeddings (
163
+ quantized , batch_indices .size (0 ), batch_indices .device
164
+ )
165
+ PooledEmbeddingDequantizeDataTypeFP16 = 1
86
166
87
- merged = torch .ops .fbgemm .merge_pooled_embeddings (
88
- pooled_ad_embeddings , batch_indices .size (0 ), batch_indices .device
89
- )
167
+ if data_type == "INT8" :
168
+ offset = torch .cumsum (
169
+ torch .tensor (
170
+ [0 ] + [quantized [0 ].shape [1 ] for _ in range (len (quantized ))],
171
+ device = batch_indices .device ,
172
+ ),
173
+ dim = 0 ,
174
+ ).to (torch .int )
175
+ return torch .ops .fbgemm .Fused8BitRowwiseQuantizedToFloatMixedDim (
176
+ pooled_quantized_result ,
177
+ offset ,
178
+ PooledEmbeddingDequantizeDataTypeFP16 ,
179
+ )
180
+ else :
181
+ # TODO: the result here is wrong. Once MixedDim version for FusedNBit quantization is done, switch to that.
182
+ # Since their performance is similar, keep using FusedNBitRowwiseQuantizedSBHalfToFloat for now.
183
+ return torch .ops .fbgemm .FusedNBitRowwiseQuantizedSBHalfToFloat (
184
+ pooled_quantized_result , 4
185
+ ).half ()
90
186
91
187
if all_to_one_only :
92
- t = benchmark_torch_function (
93
- iters ,
94
- lambda : torch .ops .fbgemm .all_to_one_device (
95
- pooled_ad_embeddings , batch_indices .device
96
- ),
188
+ return torch .ops .fbgemm .all_to_one_device (
189
+ pooled_ad_embeddings , batch_indices .device
97
190
)
98
191
else :
99
- t = benchmark_torch_function (
100
- iters ,
101
- lambda : torch .ops .fbgemm .merge_pooled_embeddings (
102
- pooled_ad_embeddings , batch_indices .size (0 ), batch_indices .device
103
- ),
192
+ return torch .ops .fbgemm .merge_pooled_embeddings (
193
+ pooled_ad_embeddings , batch_indices .size (0 ), batch_indices .device
104
194
)
105
195
106
- print (
107
- f"Merge, B: { num_ads } , D: { embedding_dimension } , T: { ads_tables } , Num GPUs: { num_gpus } , Destination GPU: { dst_device } Output Size: { merged .numel () * 2 / 1.0e6 :.2f} MB, BW: { merged .numel () * 2 / t / 1.0e9 :.2f} GB/s, t: { t * 1.0e3 :.2f} ms"
196
+ with contextlib .ExitStack () as stack :
197
+ for stream in streams :
198
+ stack .enter_context (torch .cuda .stream (stream ))
199
+
200
+ merged = pool_func_with_quantization (
201
+ pooled_ad_embeddings , batch_indices , include_quantization , data_type
202
+ )
203
+ t = benchmark_torch_function (
204
+ iters ,
205
+ lambda : pool_func_with_quantization (
206
+ pooled_ad_embeddings , batch_indices , include_quantization , data_type
207
+ ),
208
+ )
209
+
210
+ logging .debug (
211
+ f"Merge, B: { num_ads } , D: { embedding_dimension } , T: { ads_tables } , Data Type: { data_type } , Num GPUs: { num_gpus } , Destination GPU: { dst_device } , "
212
+ f"Number of elements: { total_elements / 1.0e6 :.0f} Million, Billion elements per sec: { total_elements / t / 1.0e9 :.1f} , "
213
+ f"Output Size: { merged .numel () * bytes_per_element / 1.0e6 :.0f} MB, BW: { merged .numel () * bytes_per_element / t / 1.0e9 :.1f} GB/s, "
214
+ f"t: { t * 1.0e3 :.2f} ms"
215
+ )
216
+ # return result in CSV format
217
+ return (
218
+ f"{ num_ads } , { embedding_dimension } , { ads_tables } , { data_type } , { num_gpus } , { dst_device } , "
219
+ f"{ total_elements / 1.0e6 :.0f} , { total_elements / t / 1.0e9 :.1f} , "
220
+ f"{ merged .numel () * bytes_per_element / 1.0e6 :.0f} , { merged .numel () * bytes_per_element / t / 1.0e9 :.1f} , "
221
+ f"{ t * 1.0e3 :.2f} "
222
+ )
223
+
224
+
225
+ @click .command ()
226
+ @click .option ("--all-to-one-only" , is_flag = True , default = False )
227
+ @click .option ("--num_ads" , default = 1024 , type = int )
228
+ @click .option ("--embedding_dimension" , default = 300 , type = int )
229
+ @click .option ("--ads_tables" , default = 100 , type = int )
230
+ @click .option ("--iters" , default = 10 , type = int )
231
+ @click .option ("--p2p_bw" , is_flag = True , default = False )
232
+ @click .option ("--dst_device" , default = 0 , type = int )
233
+ @click .option (
234
+ "--data_type" ,
235
+ type = click .Choice (["FP16" , "INT8" , "INT4" ]),
236
+ default = "FP16" ,
237
+ )
238
+ # For INT8/INT4 data type, whether to start with FP16 and include quantization overhead
239
+ @click .option ("--include_quantization" , is_flag = True , default = False )
240
+ @click .option ("--sweep" , is_flag = True , default = False )
241
+ def main (
242
+ all_to_one_only ,
243
+ num_ads ,
244
+ embedding_dimension ,
245
+ ads_tables ,
246
+ iters ,
247
+ p2p_bw ,
248
+ dst_device ,
249
+ data_type ,
250
+ include_quantization ,
251
+ sweep ,
252
+ ) -> None :
253
+ assert sweep or not (
254
+ include_quantization and data_type == "FP16"
255
+ ), "no quantization is needed for FP16"
256
+
257
+ csv_header = (
258
+ "num_ads, embedding_dimension, ads_tables, data_type, num_gpus,"
259
+ "dst_device, number of elements (Million), throughput (billion elements per sec), "
260
+ "output size (MB), BW (GB/s), t (ms)"
261
+ )
262
+ if sweep :
263
+
264
+ def handler (signum , frame ):
265
+ logging .error ("timeout" )
266
+ raise TimeoutError ()
267
+
268
+ results = []
269
+ num_gpu = torch .cuda .device_count ()
270
+ for num_ads in [128 , 256 , 512 , 1024 , 2048 ]:
271
+ # Scale num_ads so all GPUs have sweep through the same number of total elements
272
+ num_ads *= 8 // num_gpu
273
+ for embedding_dimension in [16 , 64 , 104 , 300 ]:
274
+ for ads_tables in [25 , 50 , 100 , 400 , 800 ]:
275
+ data_type_list = (
276
+ ["INT8" , "INT4" ]
277
+ if include_quantization
278
+ else ["FP16" , "INT8" , "INT4" ]
279
+ )
280
+ for data_type in data_type_list :
281
+ if num_ads * embedding_dimension * ads_tables > 1228800000 :
282
+ continue # Skip tests that are too large
283
+ signal .signal (signal .SIGTERM , handler )
284
+ signal .alarm (600 )
285
+ try :
286
+ result = benchmark (
287
+ all_to_one_only ,
288
+ num_ads ,
289
+ embedding_dimension ,
290
+ ads_tables ,
291
+ iters ,
292
+ p2p_bw ,
293
+ dst_device ,
294
+ data_type ,
295
+ include_quantization ,
296
+ )
297
+ results .append (result )
298
+ except (TimeoutError , RuntimeError ) as err :
299
+ logging .error (f"timed out or failed: { err } " )
300
+ logging .error (
301
+ f"B: { num_ads } , D: { embedding_dimension } , T: { ads_tables } , Data Type: { data_type } , Num GPU: { num_gpu } "
302
+ )
303
+ print (csv_header )
304
+ print (* results , sep = "\n " )
305
+ return
306
+
307
+ result = benchmark (
308
+ all_to_one_only ,
309
+ num_ads ,
310
+ embedding_dimension ,
311
+ ads_tables ,
312
+ iters ,
313
+ p2p_bw ,
314
+ dst_device ,
315
+ data_type ,
316
+ include_quantization ,
108
317
)
318
+ print (csv_header )
319
+ print (result )
109
320
110
321
111
322
if __name__ == "__main__" :
0 commit comments