32
32
)
33
33
from fbgemm_gpu .tbe .bench import (
34
34
benchmark_requests ,
35
+ EmbeddingOpsCommonConfigLoader ,
35
36
TBEBenchmarkingConfigLoader ,
36
37
TBEDataConfigLoader ,
37
38
)
@@ -50,50 +51,39 @@ def cli() -> None:
50
51
51
52
52
53
@cli .command ()
53
- @click .option ("--weights-precision" , type = SparseType , default = SparseType .FP32 )
54
- @click .option ("--cache-precision" , type = SparseType , default = None )
55
- @click .option ("--stoc" , is_flag = True , default = False )
56
- @click .option (
57
- "--managed" ,
58
- default = "device" ,
59
- type = click .Choice (["device" , "managed" , "managed_caching" ], case_sensitive = False ),
60
- )
61
54
@click .option (
62
55
"--emb-op-type" ,
63
56
default = "split" ,
64
57
type = click .Choice (["split" , "dense" , "ssd" ], case_sensitive = False ),
58
+ help = "The type of the embedding op to benchmark" ,
59
+ )
60
+ @click .option (
61
+ "--row-wise/--no-row-wise" ,
62
+ default = True ,
63
+ help = "Whether to use row-wise adagrad optimzier or not" ,
65
64
)
66
- @click .option ("--row-wise/--no-row-wise" , default = True )
67
- @click .option ("--pooling" , type = str , default = "sum" )
68
- @click .option ("--weighted-num-requires-grad" , type = int , default = None )
69
- @click .option ("--bounds-check-mode" , type = int , default = BoundsCheckMode .NONE .value )
70
- @click .option ("--output-dtype" , type = SparseType , default = SparseType .FP32 )
71
65
@click .option (
72
- "--uvm-host-mapped " ,
73
- is_flag = True ,
74
- default = False ,
75
- help = "Use host mapped UVM buffers in SSD-TBE (malloc+cudaHostRegister) " ,
66
+ "--weighted-num-requires-grad " ,
67
+ type = int ,
68
+ default = None ,
69
+ help = "The number of weighted tables that require gradient " ,
76
70
)
77
71
@click .option (
78
- "--ssd-prefix" , type = str , default = "/tmp/ssd_benchmark" , help = "SSD directory prefix"
72
+ "--ssd-prefix" ,
73
+ type = str ,
74
+ default = "/tmp/ssd_benchmark" ,
75
+ help = "SSD directory prefix" ,
79
76
)
80
77
@click .option ("--cache-load-factor" , default = 0.2 )
81
78
@TBEBenchmarkingConfigLoader .options
82
79
@TBEDataConfigLoader .options
80
+ @EmbeddingOpsCommonConfigLoader .options
83
81
@click .pass_context
84
82
def device ( # noqa C901
85
83
context : click .Context ,
86
84
emb_op_type : click .Choice ,
87
- weights_precision : SparseType ,
88
- cache_precision : Optional [SparseType ],
89
- stoc : bool ,
90
- managed : click .Choice ,
91
85
row_wise : bool ,
92
- pooling : str ,
93
86
weighted_num_requires_grad : Optional [int ],
94
- bounds_check_mode : int ,
95
- output_dtype : SparseType ,
96
- uvm_host_mapped : bool ,
97
87
cache_load_factor : float ,
98
88
# SSD params
99
89
ssd_prefix : str ,
@@ -110,6 +100,9 @@ def device( # noqa C901
110
100
# Load TBE data configuration from cli arguments
111
101
tbeconfig = TBEDataConfigLoader .load (context )
112
102
103
+ # Load common embedding op configuration from cli arguments
104
+ embconfig = EmbeddingOpsCommonConfigLoader .load (context )
105
+
113
106
# Generate feature_requires_grad
114
107
feature_requires_grad = (
115
108
tbeconfig .generate_feature_requires_grad (weighted_num_requires_grad )
@@ -123,22 +116,8 @@ def device( # noqa C901
123
116
# Determine the optimizer
124
117
optimizer = OptimType .EXACT_ROWWISE_ADAGRAD if row_wise else OptimType .EXACT_ADAGRAD
125
118
126
- # Determine the embedding location
127
- embedding_location = str_to_embedding_location (str (managed ))
128
- if embedding_location is EmbeddingLocation .DEVICE and not torch .cuda .is_available ():
129
- embedding_location = EmbeddingLocation .HOST
130
-
131
- # Determine the pooling mode
132
- pooling_mode = str_to_pooling_mode (pooling )
133
-
134
119
# Construct the common split arguments for the embedding op
135
- common_split_args : Dict [str , Any ] = {
136
- "weights_precision" : weights_precision ,
137
- "stochastic_rounding" : stoc ,
138
- "output_dtype" : output_dtype ,
139
- "pooling_mode" : pooling_mode ,
140
- "bounds_check_mode" : BoundsCheckMode (bounds_check_mode ),
141
- "uvm_host_mapped" : uvm_host_mapped ,
120
+ common_split_args : Dict [str , Any ] = embconfig .split_args () | {
142
121
"optimizer" : optimizer ,
143
122
"learning_rate" : 0.1 ,
144
123
"eps" : 0.1 ,
@@ -154,7 +133,7 @@ def device( # noqa C901
154
133
)
155
134
for d in Ds
156
135
],
157
- pooling_mode = pooling_mode ,
136
+ pooling_mode = embconfig . pooling_mode ,
158
137
use_cpu = not torch .cuda .is_available (),
159
138
)
160
139
elif emb_op_type == "ssd" :
@@ -177,7 +156,7 @@ def device( # noqa C901
177
156
(
178
157
tbeconfig .E ,
179
158
d ,
180
- embedding_location ,
159
+ embconfig . embedding_location ,
181
160
(
182
161
ComputeDevice .CUDA
183
162
if torch .cuda .is_available ()
@@ -187,25 +166,27 @@ def device( # noqa C901
187
166
for d in Ds
188
167
],
189
168
cache_precision = (
190
- weights_precision if cache_precision is None else cache_precision
169
+ embconfig .weights_dtype
170
+ if embconfig .cache_dtype is None
171
+ else embconfig .cache_dtype
191
172
),
192
173
cache_algorithm = CacheAlgorithm .LRU ,
193
174
cache_load_factor = cache_load_factor ,
194
175
** common_split_args ,
195
176
)
196
177
embedding_op = embedding_op .to (get_device ())
197
178
198
- if weights_precision == SparseType .INT8 :
179
+ if embconfig . weights_dtype == SparseType .INT8 :
199
180
# pyre-fixme[29]: `Union[(self: DenseTableBatchedEmbeddingBagsCodegen,
200
181
# min_val: float, max_val: float) -> None, (self:
201
182
# SplitTableBatchedEmbeddingBagsCodegen, min_val: float, max_val: float) ->
202
183
# None, Tensor, Module]` is not a function.
203
184
embedding_op .init_embedding_weights_uniform (- 0.0003 , 0.0003 )
204
185
205
186
nparams = sum (d * tbeconfig .E for d in Ds )
206
- param_size_multiplier = weights_precision .bit_rate () / 8.0
207
- output_size_multiplier = output_dtype .bit_rate () / 8.0
208
- if pooling_mode .do_pooling ():
187
+ param_size_multiplier = embconfig . weights_dtype .bit_rate () / 8.0
188
+ output_size_multiplier = embconfig . output_dtype .bit_rate () / 8.0
189
+ if embconfig . pooling_mode .do_pooling ():
209
190
read_write_bytes = (
210
191
output_size_multiplier * tbeconfig .batch_params .B * sum (Ds )
211
192
+ param_size_multiplier
@@ -225,7 +206,7 @@ def device( # noqa C901
225
206
* tbeconfig .pooling_params .L
226
207
)
227
208
228
- logging .info (f"Managed option: { managed } " )
209
+ logging .info (f"Managed option: { embconfig . embedding_location } " )
229
210
logging .info (
230
211
f"Embedding parameters: { nparams / 1.0e9 : .2f} GParam, "
231
212
f"{ nparams * param_size_multiplier / 1.0e9 : .2f} GB"
@@ -274,11 +255,11 @@ def _context_factory(on_trace_ready: Callable[[profile], None]):
274
255
f"T: { time_per_iter * 1.0e6 :.0f} us"
275
256
)
276
257
277
- if output_dtype == SparseType .INT8 :
258
+ if embconfig . output_dtype == SparseType .INT8 :
278
259
# backward bench not representative
279
260
return
280
261
281
- if pooling_mode .do_pooling ():
262
+ if embconfig . pooling_mode .do_pooling ():
282
263
grad_output = torch .randn (tbeconfig .batch_params .B , sum (Ds )).to (get_device ())
283
264
else :
284
265
grad_output = torch .randn (
0 commit comments