From f249f0e10fdfc1dba0d9f1d82bc46352c2a705e7 Mon Sep 17 00:00:00 2001 From: Chenyu Zhang Date: Wed, 18 Jun 2025 09:49:21 -0700 Subject: [PATCH] kvzch use new operator in model publish (#3108) Summary: Pull Request resolved: https://github.com/pytorch/torchrec/pull/3108 Publish change to enable KVEmbeddingInference when use_virtual_table is set to true Differential Revision: D75321284 --- torchrec/distributed/embedding_types.py | 6 ++ torchrec/distributed/embeddingbag.py | 1 + .../distributed/quant_embedding_kernel.py | 66 ++++++++++--------- torchrec/quant/embedding_modules.py | 26 +++++--- 4 files changed, 59 insertions(+), 40 deletions(-) diff --git a/torchrec/distributed/embedding_types.py b/torchrec/distributed/embedding_types.py index 20f0a4c88..a46a75191 100644 --- a/torchrec/distributed/embedding_types.py +++ b/torchrec/distributed/embedding_types.py @@ -301,6 +301,12 @@ def embedding_shard_metadata(self) -> List[Optional[ShardMetadata]]: embedding_shard_metadata.append(table.local_metadata) return embedding_shard_metadata + def is_using_virtual_table(self) -> bool: + return self.compute_kernel in [ + EmbeddingComputeKernel.SSD_VIRTUAL_TABLE, + EmbeddingComputeKernel.DRAM_VIRTUAL_TABLE, + ] + F = TypeVar("F", bound=Multistreamable) T = TypeVar("T") diff --git a/torchrec/distributed/embeddingbag.py b/torchrec/distributed/embeddingbag.py index 2beaf3aef..024a2a6c4 100644 --- a/torchrec/distributed/embeddingbag.py +++ b/torchrec/distributed/embeddingbag.py @@ -292,6 +292,7 @@ def create_sharding_infos_by_sharding_device_group( getattr(config, "num_embeddings_post_pruning", None) # TODO: Need to check if attribute exists for BC ), + use_virtual_table=config.use_virtual_table, ), param_sharding=parameter_sharding, param=param, diff --git a/torchrec/distributed/quant_embedding_kernel.py b/torchrec/distributed/quant_embedding_kernel.py index 4e0dc31f3..b6fa8a354 100644 --- a/torchrec/distributed/quant_embedding_kernel.py +++ b/torchrec/distributed/quant_embedding_kernel.py @@ -20,6 +20,7 @@ PoolingMode, rounded_row_size_in_bytes, ) +from fbgemm_gpu.tbe.cache.kv_embedding_ops_inference import KVEmbeddingInference from torchrec.distributed.batched_embedding_kernel import ( BaseBatchedEmbedding, BaseBatchedEmbeddingBag, @@ -284,6 +285,8 @@ def __init__( if self.lengths_to_tbe: tbe_clazz = IntNBitTableBatchedEmbeddingBagsCodegenWithLength + elif config.is_using_virtual_table(): + tbe_clazz = KVEmbeddingInference else: tbe_clazz = IntNBitTableBatchedEmbeddingBagsCodegen @@ -465,37 +468,40 @@ def __init__( ) # 16 for CUDA, 1 for others like CPU and MTIA. self._tbe_row_alignment: int = 16 if self._runtime_device.type == "cuda" else 1 - self._emb_module: IntNBitTableBatchedEmbeddingBagsCodegen = ( - IntNBitTableBatchedEmbeddingBagsCodegen( - embedding_specs=[ + embedding_clazz = ( + KVEmbeddingInference + if config.is_using_virtual_table() + else IntNBitTableBatchedEmbeddingBagsCodegen + ) + self._emb_module: IntNBitTableBatchedEmbeddingBagsCodegen = embedding_clazz( + embedding_specs=[ + ( + table.name, + local_rows, ( - table.name, - local_rows, - ( - local_cols - if self._quant_state_dict_split_scale_bias - else table.embedding_dim - ), - data_type_to_sparse_type(table.data_type), - location, - ) - for local_rows, local_cols, table, location in zip( - self._local_rows, - self._local_cols, - config.embedding_tables, - managed, - ) - ], - device=device, - pooling_mode=PoolingMode.NONE, - feature_table_map=self._feature_table_map, - row_alignment=self._tbe_row_alignment, - uvm_host_mapped=True, # Use cudaHostAlloc for UVM CACHING to fix imbalance numa memory issue - feature_names_per_table=[ - table.feature_names for table in config.embedding_tables - ], - **(tbe_fused_params(fused_params) or {}), - ) + local_cols + if self._quant_state_dict_split_scale_bias + else table.embedding_dim + ), + data_type_to_sparse_type(table.data_type), + location, + ) + for local_rows, local_cols, table, location in zip( + self._local_rows, + self._local_cols, + config.embedding_tables, + managed, + ) + ], + device=device, + pooling_mode=PoolingMode.NONE, + feature_table_map=self._feature_table_map, + row_alignment=self._tbe_row_alignment, + uvm_host_mapped=True, # Use cudaHostAlloc for UVM CACHING to fix imbalance numa memory issue + feature_names_per_table=[ + table.feature_names for table in config.embedding_tables + ], + **(tbe_fused_params(fused_params) or {}), ) if device is not None: self._emb_module.initialize_weights() diff --git a/torchrec/quant/embedding_modules.py b/torchrec/quant/embedding_modules.py index bcd428a4e..3e979b34d 100644 --- a/torchrec/quant/embedding_modules.py +++ b/torchrec/quant/embedding_modules.py @@ -30,6 +30,7 @@ IntNBitTableBatchedEmbeddingBagsCodegen, PoolingMode, ) +from fbgemm_gpu.tbe.cache.kv_embedding_ops_inference import KVEmbeddingInference from torch import Tensor from torchrec.distributed.utils import none_throws from torchrec.modules.embedding_configs import ( @@ -357,7 +358,7 @@ def __init__( self._is_weighted = is_weighted self._embedding_bag_configs: List[EmbeddingBagConfig] = tables self._key_to_tables: Dict[ - Tuple[PoolingType, DataType, bool], List[EmbeddingBagConfig] + Tuple[PoolingType, bool], List[EmbeddingBagConfig] ] = defaultdict(list) self._feature_names: List[str] = [] self._feature_splits: List[int] = [] @@ -383,15 +384,13 @@ def __init__( key = (table.pooling, table.use_virtual_table) else: key = (table.pooling, False) - # pyre-ignore self._key_to_tables[key].append(table) location = ( EmbeddingLocation.HOST if device.type == "cpu" else EmbeddingLocation.DEVICE ) - for key, emb_configs in self._key_to_tables.items(): - pooling = key[0] + for (pooling, use_virtual_table), emb_configs in self._key_to_tables.items(): embedding_specs = [] weight_lists: Optional[ List[Tuple[torch.Tensor, Optional[torch.Tensor]]] @@ -420,7 +419,12 @@ def __init__( ) feature_table_map.extend([idx] * table.num_features()) - emb_module = IntNBitTableBatchedEmbeddingBagsCodegen( + embedding_clazz = ( + KVEmbeddingInference + if use_virtual_table + else IntNBitTableBatchedEmbeddingBagsCodegen + ) + emb_module = embedding_clazz( embedding_specs=embedding_specs, pooling_mode=pooling_type_to_pooling_mode(pooling), weight_lists=weight_lists, @@ -790,8 +794,7 @@ def __init__( # noqa C901 key = (table.data_type, False) self._key_to_tables[key].append(table) self._feature_splits: List[int] = [] - for key, emb_configs in self._key_to_tables.items(): - data_type = key[0] + for (data_type, use_virtual_table), emb_configs in self._key_to_tables.items(): embedding_specs = [] weight_lists: Optional[ List[Tuple[torch.Tensor, Optional[torch.Tensor]]] @@ -816,10 +819,13 @@ def __init__( # noqa C901 table_name_to_quantized_weights[table.name] ) feature_table_map.extend([idx] * table.num_features()) - # move to here to make sure feature_names order is consistent with the embedding groups self._feature_names.extend(table.feature_names) - - emb_module = IntNBitTableBatchedEmbeddingBagsCodegen( + embedding_clazz = ( + KVEmbeddingInference + if use_virtual_table + else IntNBitTableBatchedEmbeddingBagsCodegen + ) + emb_module = embedding_clazz( embedding_specs=embedding_specs, pooling_mode=PoolingMode.NONE, weight_lists=weight_lists,