Skip to content

Commit a556812

Browse files
q10facebook-github-bot
authored andcommitted
Cleanups for the EEG-based TBE benchmark CLI, pt 2
Summary: X-link: facebookresearch/FBGEMM#890 - Cleanups for the EEG-based TBE benchmark CLI, pt 2 Reviewed By: jiawenliu64 Differential Revision: D70426271
1 parent 05d089a commit a556812

File tree

7 files changed

+207
-60
lines changed

7 files changed

+207
-60
lines changed

fbgemm_gpu/bench/tbe/tbe_training_benchmark.py

Lines changed: 32 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
)
3333
from fbgemm_gpu.tbe.bench import (
3434
benchmark_requests,
35+
EmbeddingOpsCommonConfigLoader,
3536
TBEBenchmarkingConfigLoader,
3637
TBEDataConfigLoader,
3738
)
@@ -50,50 +51,39 @@ def cli() -> None:
5051

5152

5253
@cli.command()
53-
@click.option("--weights-precision", type=SparseType, default=SparseType.FP32)
54-
@click.option("--cache-precision", type=SparseType, default=None)
55-
@click.option("--stoc", is_flag=True, default=False)
56-
@click.option(
57-
"--managed",
58-
default="device",
59-
type=click.Choice(["device", "managed", "managed_caching"], case_sensitive=False),
60-
)
6154
@click.option(
6255
"--emb-op-type",
6356
default="split",
6457
type=click.Choice(["split", "dense", "ssd"], case_sensitive=False),
58+
help="The type of the embedding op to benchmark",
59+
)
60+
@click.option(
61+
"--row-wise/--no-row-wise",
62+
default=True,
63+
help="Whether to use row-wise adagrad optimzier or not",
6564
)
66-
@click.option("--row-wise/--no-row-wise", default=True)
67-
@click.option("--pooling", type=str, default="sum")
68-
@click.option("--weighted-num-requires-grad", type=int, default=None)
69-
@click.option("--bounds-check-mode", type=int, default=BoundsCheckMode.NONE.value)
70-
@click.option("--output-dtype", type=SparseType, default=SparseType.FP32)
7165
@click.option(
72-
"--uvm-host-mapped",
73-
is_flag=True,
74-
default=False,
75-
help="Use host mapped UVM buffers in SSD-TBE (malloc+cudaHostRegister)",
66+
"--weighted-num-requires-grad",
67+
type=int,
68+
default=None,
69+
help="The number of weighted tables that require gradient",
7670
)
7771
@click.option(
78-
"--ssd-prefix", type=str, default="/tmp/ssd_benchmark", help="SSD directory prefix"
72+
"--ssd-prefix",
73+
type=str,
74+
default="/tmp/ssd_benchmark",
75+
help="SSD directory prefix",
7976
)
8077
@click.option("--cache-load-factor", default=0.2)
8178
@TBEBenchmarkingConfigLoader.options
8279
@TBEDataConfigLoader.options
80+
@EmbeddingOpsCommonConfigLoader.options
8381
@click.pass_context
8482
def device( # noqa C901
8583
context: click.Context,
8684
emb_op_type: click.Choice,
87-
weights_precision: SparseType,
88-
cache_precision: Optional[SparseType],
89-
stoc: bool,
90-
managed: click.Choice,
9185
row_wise: bool,
92-
pooling: str,
9386
weighted_num_requires_grad: Optional[int],
94-
bounds_check_mode: int,
95-
output_dtype: SparseType,
96-
uvm_host_mapped: bool,
9787
cache_load_factor: float,
9888
# SSD params
9989
ssd_prefix: str,
@@ -110,6 +100,9 @@ def device( # noqa C901
110100
# Load TBE data configuration from cli arguments
111101
tbeconfig = TBEDataConfigLoader.load(context)
112102

103+
# Load common embedding op configuration from cli arguments
104+
embconfig = EmbeddingOpsCommonConfigLoader.load(context)
105+
113106
# Generate feature_requires_grad
114107
feature_requires_grad = (
115108
tbeconfig.generate_feature_requires_grad(weighted_num_requires_grad)
@@ -123,22 +116,8 @@ def device( # noqa C901
123116
# Determine the optimizer
124117
optimizer = OptimType.EXACT_ROWWISE_ADAGRAD if row_wise else OptimType.EXACT_ADAGRAD
125118

126-
# Determine the embedding location
127-
embedding_location = str_to_embedding_location(str(managed))
128-
if embedding_location is EmbeddingLocation.DEVICE and not torch.cuda.is_available():
129-
embedding_location = EmbeddingLocation.HOST
130-
131-
# Determine the pooling mode
132-
pooling_mode = str_to_pooling_mode(pooling)
133-
134119
# Construct the common split arguments for the embedding op
135-
common_split_args: Dict[str, Any] = {
136-
"weights_precision": weights_precision,
137-
"stochastic_rounding": stoc,
138-
"output_dtype": output_dtype,
139-
"pooling_mode": pooling_mode,
140-
"bounds_check_mode": BoundsCheckMode(bounds_check_mode),
141-
"uvm_host_mapped": uvm_host_mapped,
120+
common_split_args: Dict[str, Any] = embconfig.split_args() | {
142121
"optimizer": optimizer,
143122
"learning_rate": 0.1,
144123
"eps": 0.1,
@@ -154,7 +133,7 @@ def device( # noqa C901
154133
)
155134
for d in Ds
156135
],
157-
pooling_mode=pooling_mode,
136+
pooling_mode=embconfig.pooling_mode,
158137
use_cpu=not torch.cuda.is_available(),
159138
)
160139
elif emb_op_type == "ssd":
@@ -177,7 +156,7 @@ def device( # noqa C901
177156
(
178157
tbeconfig.E,
179158
d,
180-
embedding_location,
159+
embconfig.embedding_location,
181160
(
182161
ComputeDevice.CUDA
183162
if torch.cuda.is_available()
@@ -187,25 +166,27 @@ def device( # noqa C901
187166
for d in Ds
188167
],
189168
cache_precision=(
190-
weights_precision if cache_precision is None else cache_precision
169+
embconfig.weights_dtype
170+
if embconfig.cache_dtype is None
171+
else embconfig.cache_dtype
191172
),
192173
cache_algorithm=CacheAlgorithm.LRU,
193174
cache_load_factor=cache_load_factor,
194175
**common_split_args,
195176
)
196177
embedding_op = embedding_op.to(get_device())
197178

198-
if weights_precision == SparseType.INT8:
179+
if embconfig.weights_dtype == SparseType.INT8:
199180
# pyre-fixme[29]: `Union[(self: DenseTableBatchedEmbeddingBagsCodegen,
200181
# min_val: float, max_val: float) -> None, (self:
201182
# SplitTableBatchedEmbeddingBagsCodegen, min_val: float, max_val: float) ->
202183
# None, Tensor, Module]` is not a function.
203184
embedding_op.init_embedding_weights_uniform(-0.0003, 0.0003)
204185

205186
nparams = sum(d * tbeconfig.E for d in Ds)
206-
param_size_multiplier = weights_precision.bit_rate() / 8.0
207-
output_size_multiplier = output_dtype.bit_rate() / 8.0
208-
if pooling_mode.do_pooling():
187+
param_size_multiplier = embconfig.weights_dtype.bit_rate() / 8.0
188+
output_size_multiplier = embconfig.output_dtype.bit_rate() / 8.0
189+
if embconfig.pooling_mode.do_pooling():
209190
read_write_bytes = (
210191
output_size_multiplier * tbeconfig.batch_params.B * sum(Ds)
211192
+ param_size_multiplier
@@ -225,7 +206,7 @@ def device( # noqa C901
225206
* tbeconfig.pooling_params.L
226207
)
227208

228-
logging.info(f"Managed option: {managed}")
209+
logging.info(f"Managed option: {embconfig.embedding_location}")
229210
logging.info(
230211
f"Embedding parameters: {nparams / 1.0e9: .2f} GParam, "
231212
f"{nparams * param_size_multiplier / 1.0e9: .2f} GB"
@@ -274,11 +255,11 @@ def _context_factory(on_trace_ready: Callable[[profile], None]):
274255
f"T: {time_per_iter * 1.0e6:.0f}us"
275256
)
276257

277-
if output_dtype == SparseType.INT8:
258+
if embconfig.output_dtype == SparseType.INT8:
278259
# backward bench not representative
279260
return
280261

281-
if pooling_mode.do_pooling():
262+
if embconfig.pooling_mode.do_pooling():
282263
grad_output = torch.randn(tbeconfig.batch_params.B, sum(Ds)).to(get_device())
283264
else:
284265
grad_output = torch.randn(

fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_common.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,21 @@ class EmbeddingLocation(enum.IntEnum):
3333
HOST = 3
3434
MTIA = 4
3535

36+
@classmethod
37+
# pyre-ignore[3]
38+
def from_str(cls, key: str):
39+
lookup = {
40+
"device": EmbeddingLocation.DEVICE,
41+
"managed": EmbeddingLocation.MANAGED,
42+
"managed_caching": EmbeddingLocation.MANAGED_CACHING,
43+
"host": EmbeddingLocation.HOST,
44+
"mtia": EmbeddingLocation.MTIA,
45+
}
46+
if key in lookup:
47+
return lookup[key]
48+
else:
49+
raise ValueError(f"Cannot parse value into EmbeddingLocation: {key}")
50+
3651

3752
def str_to_embedding_location(key: str) -> EmbeddingLocation:
3853
lookup = {

fbgemm_gpu/fbgemm_gpu/tbe/bench/__init__.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,17 @@
1919
benchmark_requests_refer,
2020
benchmark_vbe,
2121
)
22-
from .config import TBEDataConfig # noqa F401
23-
from .config_loader import TBEDataConfigLoader # noqa F401
24-
from .config_param_models import BatchParams, IndicesParams, PoolingParams # noqa F401
22+
from .embedding_ops_common_config import EmbeddingOpsCommonConfigLoader # noqa F401
2523
from .eval_compression import ( # noqa F401
2624
benchmark_eval_compression,
2725
EvalCompressionBenchmarkOutput,
2826
)
2927
from .reporter import BenchmarkReporter # noqa F401
28+
from .tbe_data_config import TBEDataConfig # noqa F401
29+
from .tbe_data_config_loader import TBEDataConfigLoader # noqa F401
30+
from .tbe_data_config_param_models import ( # noqa F401
31+
BatchParams,
32+
IndicesParams,
33+
PoolingParams,
34+
)
3035
from .utils import fill_random_scale_bias # noqa F401
Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
# pyre-strict
9+
10+
import dataclasses
11+
from typing import Any, Dict, Optional
12+
13+
import click
14+
import torch
15+
from fbgemm_gpu.split_embedding_configs import SparseType
16+
from fbgemm_gpu.split_table_batched_embeddings_ops_common import (
17+
BoundsCheckMode,
18+
EmbeddingLocation,
19+
PoolingMode,
20+
str_to_embedding_location,
21+
str_to_pooling_mode,
22+
)
23+
24+
25+
@dataclasses.dataclass(frozen=True)
26+
class EmbeddingOpsCommonConfig:
27+
weights_dtype: SparseType
28+
cache_dtype: Optional[SparseType]
29+
output_dtype: SparseType
30+
stochastic_rounding: bool
31+
pooling_mode: PoolingMode
32+
uvm_host_mapped: bool
33+
embedding_location: EmbeddingLocation
34+
bounds_check_mode: BoundsCheckMode
35+
36+
# pyre-ignore [3]
37+
def validate(self):
38+
return self
39+
40+
def split_args(self) -> Dict[str, Any]:
41+
return {
42+
"weights_precision": self.weights_dtype,
43+
"stochastic_rounding": self.stochastic_rounding,
44+
"output_dtype": self.output_dtype,
45+
"pooling_mode": self.pooling_mode,
46+
"bounds_check_mode": self.bounds_check_mode,
47+
"uvm_host_mapped": self.uvm_host_mapped,
48+
}
49+
50+
51+
class EmbeddingOpsCommonConfigLoader:
52+
@classmethod
53+
# pyre-ignore [2]
54+
def options(cls, func) -> click.Command:
55+
options = [
56+
click.option(
57+
"--emb-weights-dtype",
58+
type=SparseType,
59+
default=SparseType.FP32,
60+
help="Precision of the embedding weights",
61+
),
62+
click.option(
63+
"--emb-cache-dtype",
64+
type=SparseType,
65+
default=None,
66+
help="Precision of the embedding cache",
67+
),
68+
click.option(
69+
"--emb-output-dtype",
70+
type=SparseType,
71+
default=SparseType.FP32,
72+
help="Precision of the embedding output",
73+
),
74+
click.option(
75+
"--emb-stochastic-rounding",
76+
is_flag=True,
77+
default=False,
78+
help="Enable stochastic rounding when performing quantization",
79+
),
80+
click.option(
81+
"--emb-pooling-mode",
82+
type=click.Choice(["sum", "mean", "none"], case_sensitive=False),
83+
default="sum",
84+
help="Pooling operation to perform",
85+
),
86+
click.option(
87+
"--emb-uvm-host-mapped",
88+
is_flag=True,
89+
default=False,
90+
help="Use host-mapped UVM buffers",
91+
),
92+
click.option(
93+
"--emb-location",
94+
default="device",
95+
type=click.Choice(
96+
["device", "managed", "managed_caching"], case_sensitive=False
97+
),
98+
help="Memory location of the embeddings",
99+
),
100+
click.option(
101+
"--emb-bounds-check",
102+
type=int,
103+
default=BoundsCheckMode.WARNING.value,
104+
help="Bounds check mode."
105+
f"Available modes: FATAL={BoundsCheckMode.FATAL.value}, "
106+
f"WARNING={BoundsCheckMode.WARNING.value}, "
107+
f"IGNORE={BoundsCheckMode.IGNORE.value}, "
108+
f"NONE={BoundsCheckMode.NONE.value}",
109+
),
110+
]
111+
112+
for option in reversed(options):
113+
func = option(func)
114+
return func
115+
116+
@classmethod
117+
def load(cls, context: click.Context) -> EmbeddingOpsCommonConfig:
118+
params = context.params
119+
120+
weights_dtype = params["emb_weights_dtype"]
121+
cache_dtype = params["emb_cache_dtype"]
122+
output_dtype = params["emb_output_dtype"]
123+
stochastic_rounding = params["emb_stochastic_rounding"]
124+
pooling_mode = str_to_pooling_mode(str(params["emb_pooling_mode"]))
125+
uvm_host_mapped = params["emb_uvm_host_mapped"]
126+
bounds_check_mode = BoundsCheckMode(params["emb_bounds_check"])
127+
128+
embedding_location = str_to_embedding_location(str(params["emb_location"]))
129+
if (
130+
embedding_location is EmbeddingLocation.DEVICE
131+
and not torch.cuda.is_available()
132+
):
133+
embedding_location = EmbeddingLocation.HOST
134+
135+
return EmbeddingOpsCommonConfig(
136+
weights_dtype,
137+
cache_dtype,
138+
output_dtype,
139+
stochastic_rounding,
140+
pooling_mode,
141+
uvm_host_mapped,
142+
embedding_location,
143+
bounds_check_mode,
144+
).validate()

fbgemm_gpu/fbgemm_gpu/tbe/bench/config.py renamed to fbgemm_gpu/fbgemm_gpu/tbe/bench/tbe_data_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
TBERequest,
2424
)
2525

26-
from .config_param_models import BatchParams, IndicesParams, PoolingParams
26+
from .tbe_data_config_param_models import BatchParams, IndicesParams, PoolingParams
2727

2828
try:
2929
torch.ops.load_library(

0 commit comments

Comments
 (0)