From acf26acd4b7811a76c29ce1af21d275b422ca048 Mon Sep 17 00:00:00 2001 From: Yong Hoon Shin Date: Tue, 14 May 2024 09:19:38 -0700 Subject: [PATCH] Improve CPU benchmark for KJT (#2001) Summary: Add benchmark for KJT methods: - `permute` - `to_dict` - `split` - `__getitem__` - `dist_init` Reviewed By: gnahzg Differential Revision: D57314675 --- .../distributed/benchmark/benchmark_utils.py | 14 +- torchrec/distributed/test_utils/test_model.py | 61 ++-- .../tests/keyed_jagged_tensor_benchmark.py | 325 ++++++++++++------ 3 files changed, 274 insertions(+), 126 deletions(-) diff --git a/torchrec/distributed/benchmark/benchmark_utils.py b/torchrec/distributed/benchmark/benchmark_utils.py index 209ea11a9..513cf4f9e 100644 --- a/torchrec/distributed/benchmark/benchmark_utils.py +++ b/torchrec/distributed/benchmark/benchmark_utils.py @@ -106,14 +106,20 @@ class BenchmarkResult: max_mem_allocated: List[int] # megabytes rank: int = -1 - def runtime_percentile(self, percentile: int = 50) -> torch.Tensor: + def runtime_percentile( + self, percentile: int = 50, interpolation: str = "nearest" + ) -> torch.Tensor: return torch.quantile( - self.elapsed_time, percentile / 100.0, interpolation="nearest" + self.elapsed_time, + percentile / 100.0, + interpolation=interpolation, ) - def max_mem_percentile(self, percentile: int = 50) -> torch.Tensor: + def max_mem_percentile( + self, percentile: int = 50, interpolation: str = "nearest" + ) -> torch.Tensor: max_mem = torch.tensor(self.max_mem_allocated, dtype=torch.float) - return torch.quantile(max_mem, percentile / 100.0, interpolation="nearest") + return torch.quantile(max_mem, percentile / 100.0, interpolation=interpolation) class ECWrapper(torch.nn.Module): diff --git a/torchrec/distributed/test_utils/test_model.py b/torchrec/distributed/test_utils/test_model.py index c5ba0147f..846cb343b 100644 --- a/torchrec/distributed/test_utils/test_model.py +++ b/torchrec/distributed/test_utils/test_model.py @@ -69,6 +69,8 @@ def generate( long_indices: bool = True, tables_pooling: Optional[List[int]] = None, weighted_tables_pooling: Optional[List[int]] = None, + randomize_indices: bool = True, + device: Optional[torch.device] = None, ) -> Tuple["ModelInput", List["ModelInput"]]: """ Returns a global (single-rank training) batch @@ -132,15 +134,16 @@ def _validate_pooling_factor( idlist_pooling_factor[idx], idlist_pooling_factor[idx] / 10, [batch_size * world_size], + device=device, ), - torch.tensor(1.0), + torch.tensor(1.0, device=device), ).int() else: lengths_ = torch.abs( - torch.randn(batch_size * world_size) + pooling_avg + torch.randn(batch_size * world_size, device=device) + pooling_avg, ).int() if variable_batch_size: - lengths = torch.zeros(batch_size * world_size).int() + lengths = torch.zeros(batch_size * world_size, device=device).int() for r in range(world_size): lengths[r * batch_size : r * batch_size + batch_size_by_rank[r]] = ( lengths_[ @@ -150,12 +153,20 @@ def _validate_pooling_factor( else: lengths = lengths_ num_indices = cast(int, torch.sum(lengths).item()) - indices = torch.randint( - 0, - ind_range, - (num_indices,), - dtype=torch.long if long_indices else torch.int32, - ) + if randomize_indices: + indices = torch.randint( + 0, + ind_range, + (num_indices,), + dtype=torch.long if long_indices else torch.int32, + device=device, + ) + else: + indices = torch.zeros( + (num_indices), + dtype=torch.long if long_indices else torch.int32, + device=device, + ) global_idlist_lengths.append(lengths) global_idlist_indices.append(indices) global_idlist_kjt = KeyedJaggedTensor( @@ -167,7 +178,7 @@ def _validate_pooling_factor( for idx in range(len(idscore_ind_ranges)): ind_range = idscore_ind_ranges[idx] lengths_ = torch.abs( - torch.randn(batch_size * world_size) + torch.randn(batch_size * world_size, device=device) + ( idscore_pooling_factor[idx] if idscore_pooling_factor @@ -175,7 +186,7 @@ def _validate_pooling_factor( ) ).int() if variable_batch_size: - lengths = torch.zeros(batch_size * world_size).int() + lengths = torch.zeros(batch_size * world_size, device=device).int() for r in range(world_size): lengths[r * batch_size : r * batch_size + batch_size_by_rank[r]] = ( lengths_[ @@ -185,13 +196,21 @@ def _validate_pooling_factor( else: lengths = lengths_ num_indices = cast(int, torch.sum(lengths).item()) - indices = torch.randint( - 0, - ind_range, - (num_indices,), - dtype=torch.long if long_indices else torch.int32, - ) - weights = torch.rand((num_indices,)) + if randomize_indices: + indices = torch.randint( + 0, + ind_range, + (num_indices,), + dtype=torch.long if long_indices else torch.int32, + device=device, + ) + else: + indices = torch.zeros( + (num_indices), + dtype=torch.long if long_indices else torch.int32, + device=device, + ) + weights = torch.rand((num_indices,), device=device) global_idscore_lengths.append(lengths) global_idscore_indices.append(indices) global_idscore_weights.append(weights) @@ -206,8 +225,10 @@ def _validate_pooling_factor( else None ) - global_float = torch.rand((batch_size * world_size, num_float_features)) - global_label = torch.rand(batch_size * world_size) + global_float = torch.rand( + (batch_size * world_size, num_float_features), device=device + ) + global_label = torch.rand(batch_size * world_size, device=device) # Split global batch into local batches. local_inputs = [] diff --git a/torchrec/sparse/tests/keyed_jagged_tensor_benchmark.py b/torchrec/sparse/tests/keyed_jagged_tensor_benchmark.py index cfdf17390..24324395e 100644 --- a/torchrec/sparse/tests/keyed_jagged_tensor_benchmark.py +++ b/torchrec/sparse/tests/keyed_jagged_tensor_benchmark.py @@ -6,149 +6,270 @@ # LICENSE file in the root directory of this source tree. # pyre-strict -import copy -import multiprocessing +import random import time -from typing import Callable, Dict, List +import timeit +from typing import Any, Callable, Dict, List, Optional, Tuple import click +import torch +from torchrec.distributed.benchmark.benchmark_utils import BenchmarkResult +from torchrec.distributed.dist_data import _get_recat from torchrec.distributed.test_utils.test_model import ModelInput from torchrec.modules.embedding_configs import EmbeddingBagConfig -from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor +from torchrec.sparse.jagged_tensor import KeyedJaggedTensor def generate_kjt( - table_configs: List[EmbeddingBagConfig], - sparse_features_per_kjt: int, + tables: List[EmbeddingBagConfig], + batch_size: int, + mean_pooling_factor: int, + device: torch.device, ) -> KeyedJaggedTensor: - raw_value = ModelInput.generate( - batch_size=1, - world_size=1, + global_input = ModelInput.generate( + batch_size=batch_size, + world_size=1, # 1 for cpu num_float_features=0, - tables=table_configs, + tables=tables, weighted_tables=[], + # mean pooling factor per feature + tables_pooling=[mean_pooling_factor] * len(tables), + # returns KJTs with values all set to 0 + # we don't care about KJT values for benchmark, and this saves time + randomize_indices=False, + device=device, + )[0] + return global_input.idlist_features + + +def build_kjts( + tables: List[EmbeddingBagConfig], + batch_size: int, + mean_pooling_factor: int, + device: torch.device, +) -> KeyedJaggedTensor: + start = time.perf_counter() + print("Starting to build KJTs") + + kjt = generate_kjt( + tables, + batch_size, + mean_pooling_factor, + device, ) - return raw_value[0].idlist_features + end = time.perf_counter() + time_taken_s = end - start + print(f"Took {time_taken_s * 1000:.1f}ms to build KJT\n") + return kjt -def prepare_benchmark( - sparse_features_per_kjt: int, - test_size: int, - num_embeddings: int = 1000, - embedding_dim: int = 50, - in_parallel: bool = False, -) -> List[KeyedJaggedTensor]: - tables: List[EmbeddingBagConfig] = [ - EmbeddingBagConfig( - num_embeddings=num_embeddings, - embedding_dim=embedding_dim, - name=f"table_{i}", - feature_names=[f"feature_{i}"], + +def wrapped_func( + kjt: KeyedJaggedTensor, + test_func: Callable[[KeyedJaggedTensor], object], + fn_kwargs: Dict[str, Any], +) -> Callable[..., object]: + def fn() -> object: + return test_func(kjt, **fn_kwargs) + + return fn + + +def benchmark_kjt( + method_name: str, + kjt: KeyedJaggedTensor, + num_repeat: int, + num_warmup: int, + num_features: int, + batch_size: int, + mean_pooling_factor: int, + fn_kwargs: Dict[str, Any], + is_static_method: bool, +) -> None: + test_name = method_name + + # pyre-ignore + def test_func(kjt: KeyedJaggedTensor, **kwargs): + return getattr(KeyedJaggedTensor if is_static_method else kjt, method_name)( + **kwargs ) - for i in range(sparse_features_per_kjt) - ] - kjt_lists: List[KeyedJaggedTensor] = [] + for _ in range(num_warmup): + test_func(kjt, **fn_kwargs) - if in_parallel: - # TODO Make this parallel version performance efficient - with multiprocessing.Pool() as pool: - kjt_lists: List[KeyedJaggedTensor] = pool.starmap( - generate_kjt, - [(tables, sparse_features_per_kjt) for _ in range(test_size)], - ) - else: - for _ in range(test_size): - kjt_lists.append(generate_kjt(tables, sparse_features_per_kjt)) - return list(kjt_lists) + times = [] + for _ in range(num_repeat): + time_elapsed = timeit.timeit(wrapped_func(kjt, test_func, fn_kwargs), number=1) + # remove length_per_key and offset_per_key cache for fairer comparison + kjt.unsync() + times.append(time_elapsed) + result = BenchmarkResult( + short_name=test_name, + elapsed_time=torch.tensor(times), + max_mem_allocated=[0], + ) -def kjt_to_dict(kjt: KeyedJaggedTensor) -> Dict[str, JaggedTensor]: - return kjt.to_dict() + print( + f" {test_name : <{35}} | B: {batch_size : <{8}} | F: {num_features : <{8}} | Mean Pooling Factor: {mean_pooling_factor : <{8}} | Runtime (P50): {result.runtime_percentile(50, interpolation='linear'):5f} ms | Runtime (P90): {result.runtime_percentile(90, interpolation='linear'):5f} ms" + ) -def benchmark( - input_data: List[KeyedJaggedTensor], - test_name: str, - warmup_size: int, - test_size: int, - test_func: Callable[..., object], -) -> None: - start = time.perf_counter() - for i in range(warmup_size): - test_func(input_data[i]) - end = time.perf_counter() - print(f"warmup time for {test_name} {(end-start)*1000/warmup_size:.1f}ms") - start = time.perf_counter() - for i in range(warmup_size, warmup_size + test_size): - test_func(input_data[i]) - end = time.perf_counter() - print(f"benmark avarge time {test_name} {(end-start)*1000/test_size:.1f}ms") +def get_k_splits(n: int, k: int) -> List[int]: + split_size, _ = divmod(n, k) + splits = [split_size] * (k - 1) + [n - split_size * (k - 1)] + return splits + + +def gen_dist_split_input( + tables: List[EmbeddingBagConfig], + batch_size: int, + num_workers: int, + num_features: int, + mean_pooling_factor: int, + device: torch.device, +) -> Tuple[torch.Tensor, torch.Tensor, List[int], Optional[torch.Tensor]]: + batch_size_per_rank = get_k_splits(n=batch_size, k=num_workers) + kjts = [ + generate_kjt(tables, batch_size_rank, mean_pooling_factor, device) + for batch_size_rank in batch_size_per_rank + ] + kjt_lengths = torch.cat([kjt.lengths() for kjt in kjts]) + kjt_values = torch.cat([kjt.values() for kjt in kjts]) + recat = _get_recat( + local_split=num_features, + num_splits=num_workers, + device=device, + batch_size_per_rank=batch_size_per_rank, + use_tensor_compute=False, + ) + + return (kjt_lengths, kjt_values, batch_size_per_rank, recat) def bench( - feature_per_kjt: int, - test_size: int, - warmup_size: int, - test_jitscripted: bool = True, - parallel: bool = False, + num_repeat: int, + num_warmup: int, + num_features: int, + batch_size: int, + mean_pooling_factor: int, + num_workers: int, ) -> None: - generated_input_data = prepare_benchmark( - feature_per_kjt, test_size + warmup_size, in_parallel=parallel - ) - assert len(generated_input_data) == test_size + warmup_size - - test_sets = [generated_input_data] - test_names = ["eager"] - if test_jitscripted: - test_sets.append(copy.deepcopy(generated_input_data)) - test_names.append("jitscripted") - - benchmark( - input_data=test_sets[0], - test_name=test_names[0], - warmup_size=warmup_size, - test_size=test_size, - test_func=lambda x: x.to_dict(), + # TODO: support CUDA benchmark + device: torch.device = torch.device("cpu") + + tables: List[EmbeddingBagConfig] = [ + EmbeddingBagConfig( + num_embeddings=20, # determines indices range + embedding_dim=10, # doesn't matter for benchmark + name=f"table_{i}", + feature_names=[f"feature_{i}"], + ) + for i in range(num_features) + ] + + kjt = build_kjts( + tables, + batch_size, + mean_pooling_factor, + device, ) - benchmark( - input_data=test_sets[1], - test_name=test_names[1], - warmup_size=warmup_size, - test_size=test_size, - test_func=lambda x: kjt_to_dict(x), + + splits = get_k_splits(n=num_features, k=8) + permute_indices = random.sample(range(num_features), k=num_features) + key = f"feature_{random.randint(0, num_features - 1)}" + + kjt_lengths, kjt_values, strides_per_rank, recat = gen_dist_split_input( + tables, batch_size, num_workers, num_features, mean_pooling_factor, device ) + benchmarked_methods: List[Tuple[str, Dict[str, Any], bool]] = [ + ("permute", {"indices": permute_indices}, False), + ("to_dict", {}, False), + ("split", {"segments": splits}, False), + ("__getitem__", {"key": key}, False), + ("dist_splits", {"key_splits": splits}, False), + ( + "dist_init", + { + "keys": kjt.keys(), + "tensors": [ + # lengths from each rank, should add up to num_features x batch_size in total + kjt_lengths, + # values from each rank + kjt_values, + ], + "variable_stride_per_key": False, + "num_workers": num_workers, + "recat": recat, + "stride_per_rank": strides_per_rank, + }, + True, # is static method + ), + ] + + for method_name, fn_kwargs, is_static_method in benchmarked_methods: + benchmark_kjt( + method_name=method_name, + kjt=kjt, + num_repeat=num_repeat, + num_warmup=num_warmup, + num_features=num_features, + batch_size=batch_size, + mean_pooling_factor=mean_pooling_factor, + fn_kwargs=fn_kwargs, + is_static_method=is_static_method, + ) + @click.command() @click.option( - "--feature_per_kjt", - default=50, - help="Total number of sparse features per KJT. Loosely corresponds to lengths of KJT values.", + "--num-repeat", + default=30, + help="Number of times method under test is run", ) @click.option( - "--test_size", - default=100, - help="Total number of KJT tested in the benchmark.", + "--num-warmup", + default=10, + help="Number of times method under test is run for warmup", ) @click.option( - "--warmup_size", - default=10, - help="Total warmup number of KJT tested before the formal benchmark.", + "--num-features", + default=1280, + help="Total number of sparse features per KJT", ) @click.option( - "--parallel", - "-p", - help="Generate input data in parallel.", - required=False, - type=bool, - is_flag=False, + "--batch-size", + default=4096, + help="Batch size per KJT (assumes non-VBE)", +) +@click.option( + "--mean-pooling-factor", + default=100, + help="Avg pooling factor for KJT", +) +@click.option( + "--num-workers", + default=4, + help="World size to simulate for dist_init", ) def main( - feature_per_kjt: int, test_size: int, warmup_size: int, parallel: bool = False + num_repeat: int, + num_warmup: int, + num_features: int, + batch_size: int, + mean_pooling_factor: int, + num_workers: int, ) -> None: - bench(feature_per_kjt, test_size, warmup_size) + bench( + num_repeat, + num_warmup, + num_features, + batch_size, + mean_pooling_factor, + num_workers, + ) if __name__ == "__main__":