Skip to content

Commit 51d2be6

Browse files
sryapfacebook-github-bot
authored andcommitted
Allow reusing input data in TBE benchmark (#3594)
Summary: X-link: facebookresearch/FBGEMM#674 Pull Request resolved: #3594 Add `--num-requests` in TBE's `device` benchmark to allow for input batches reuse. By default, `--num-requests` is set to -1. In this case, the benchmark will generate `iters` batches. If it is set, the benchmark will generate `num_requests` batches. If this value is smaller than `iters`, input batches will be reused (i.e., iter `i` uses batch `i % num_requests`). Reviewed By: gajjanag Differential Revision: D68340968 fbshipit-source-id: fdae703ec499f3ba2656cba3a3b4967c684058f5
1 parent 5852de7 commit 51d2be6

File tree

2 files changed

+21
-7
lines changed

2 files changed

+21
-7
lines changed

fbgemm_gpu/bench/bench_utils.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,7 @@ def benchmark_requests(
182182
callback_after_warmup: Optional[Callable[[], None]] = None,
183183
periodic_logs: bool = False,
184184
warmup_ms: Optional[int] = None,
185+
iters: int = -1,
185186
) -> float:
186187
times = []
187188
# Run at least one warmup iteration to avoid the long cudaLaunchKernel time
@@ -209,17 +210,20 @@ def benchmark_requests(
209210
if callback_after_warmup is not None:
210211
callback_after_warmup()
211212

212-
num_iters = len(requests)
213+
num_reqs = len(requests)
214+
iters = num_reqs if iters == -1 else iters
213215

214216
if torch.cuda.is_available():
215217
torch.cuda.synchronize()
216-
start_events = [torch.cuda.Event(enable_timing=True) for _ in range(num_iters)]
217-
end_events = [torch.cuda.Event(enable_timing=True) for _ in range(num_iters)]
218+
start_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)]
219+
end_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)]
218220
else:
219221
start_events = []
220222
end_events = []
221223

222-
for it, req in enumerate(requests):
224+
for it in range(iters):
225+
req = requests[it % num_reqs]
226+
223227
indices, offsets, weights = req.unpack_3()
224228
if bwd_only:
225229
# Run forward before profiling if does backward only
@@ -259,15 +263,15 @@ def benchmark_requests(
259263
]
260264

261265
if periodic_logs:
262-
for it in range(100, num_iters + 1, 100):
266+
for it in range(100, iters + 1, 100):
263267
times_ = times[0:it]
264268
avg_time = sum(times_) / len(times_) * 1.0e6
265269
last_100_avg = sum(times_[-100:]) / 100 * 1.0e6
266270
logging.info(
267271
f"Iteration [{it}/{len(requests)}]: Last 100: {last_100_avg:.2f} us, Running avg: {avg_time:.2f} us"
268272
)
269273

270-
avg_time = sum(times) / len(requests)
274+
avg_time = sum(times) / iters
271275
median_time = statistics.median(times)
272276
return median_time if check_median else avg_time
273277

fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,12 @@ def cli() -> None:
161161
"--ssd-prefix", type=str, default="/tmp/ssd_benchmark", help="SSD directory prefix"
162162
)
163163
@click.option("--cache-load-factor", default=0.2)
164+
@click.option(
165+
"--num-requests",
166+
default=-1,
167+
help="Number of input batches to generate. If the value is smaller than "
168+
"iters, the benchmark will reuse the input batches",
169+
)
164170
def device( # noqa C901
165171
alpha: float,
166172
bag_size: int,
@@ -191,8 +197,10 @@ def device( # noqa C901
191197
ssd: bool,
192198
ssd_prefix: str,
193199
cache_load_factor: float,
200+
num_requests: int,
194201
) -> None:
195202
assert not ssd or not dense, "--ssd cannot be used together with --dense"
203+
num_requests = iters if num_requests == -1 else num_requests
196204
np.random.seed(42)
197205
torch.manual_seed(42)
198206
B = batch_size
@@ -341,7 +349,7 @@ def device( # noqa C901
341349
f"Accessed weights per batch: {B * sum(Ds) * L * param_size_multiplier / 1.0e9: .2f} GB"
342350
)
343351
requests = generate_requests(
344-
iters,
352+
num_requests,
345353
B,
346354
T,
347355
L,
@@ -375,6 +383,7 @@ def context_factory(on_trace_ready: Callable[[profile], None]):
375383
),
376384
flush_gpu_cache_size_mb=flush_gpu_cache_size_mb,
377385
num_warmups=warmup_runs,
386+
iters=iters,
378387
)
379388

380389
logging.info(
@@ -409,6 +418,7 @@ def context_factory(on_trace_ready: Callable[[profile], None]):
409418
bwd_only=True,
410419
grad=grad_output,
411420
num_warmups=warmup_runs,
421+
iters=iters,
412422
)
413423

414424
logging.info(

0 commit comments

Comments
 (0)