diff --git a/torchrec/distributed/batched_embedding_kernel.py b/torchrec/distributed/batched_embedding_kernel.py index cdf3e26fa..55e83e00e 100644 --- a/torchrec/distributed/batched_embedding_kernel.py +++ b/torchrec/distributed/batched_embedding_kernel.py @@ -65,6 +65,7 @@ compute_kernel_to_embedding_location, DTensorMetadata, GroupedEmbeddingConfig, + ShardedEmbeddingTable, ) from torchrec.distributed.shards_wrapper import LocalShardsWrapper from torchrec.distributed.types import ( @@ -216,6 +217,42 @@ def _populate_zero_collision_tbe_params( ) +def _get_sharded_local_buckets_for_zero_collision( + embedding_tables: List[ShardedEmbeddingTable], + pg: Optional[dist.ProcessGroup] = None, +) -> List[Tuple[int, int, int]]: + """ + utils to get bucket offset start, bucket offset end, bucket size based on embedding sharding spec + """ + sharded_local_buckets: List[Tuple[int, int, int]] = [] + world_size = dist.get_world_size(pg) + local_rank = dist.get_rank(pg) + + for table in embedding_tables: + total_num_buckets = none_throws(table.total_num_buckets) + assert ( + total_num_buckets % world_size == 0 + ), f"total_num_buckets={total_num_buckets} must be divisible by world_size={world_size}" + assert ( + table.total_num_buckets + and table.num_embeddings % table.total_num_buckets == 0 + ), f"Table size '{table.num_embeddings}' must be divisible by num_buckets '{table.total_num_buckets}'" + bucket_offset_start = total_num_buckets // world_size * local_rank + bucket_offset_end = min( + total_num_buckets, total_num_buckets // world_size * (local_rank + 1) + ) + bucket_size = ( + table.num_embeddings + total_num_buckets - 1 + ) // total_num_buckets + sharded_local_buckets.append( + (bucket_offset_start, bucket_offset_end, bucket_size) + ) + logger.info( + f"bucket_offset: {bucket_offset_start}:{bucket_offset_end}, bucket_size: {bucket_size} for table {table.name}" + ) + return sharded_local_buckets + + class KeyValueEmbeddingFusedOptimizer(FusedOptimizer): def __init__( self, @@ -1076,6 +1113,11 @@ def __init__( assert ( len({table.embedding_dim for table in config.embedding_tables}) == 1 ), "Currently we expect all tables in SSD TBE to have the same embedding dimension." + for table in config.embedding_tables: + assert table.local_cols % 4 == 0, ( + f"table {table.name} has local_cols={table.local_cols} " + "not divisible by 4. " + ) ssd_tbe_params = _populate_ssd_tbe_params(config) compute_kernel = config.embedding_tables[0].compute_kernel @@ -1263,9 +1305,18 @@ def __init__( assert ( len({table.embedding_dim for table in config.embedding_tables}) == 1 ), "Currently we expect all tables in SSD TBE to have the same embedding dimension." + for table in config.embedding_tables: + assert table.local_cols % 4 == 0, ( + f"table {table.name} has local_cols={table.local_cols} " + "not divisible by 4. " + ) ssd_tbe_params = _populate_ssd_tbe_params(config) - self._bucket_spec: List[Tuple[int, int, int]] = self.get_sharded_local_buckets() + self._bucket_spec: List[Tuple[int, int, int]] = ( + _get_sharded_local_buckets_for_zero_collision( + self._config.embedding_tables, self._pg + ) + ) _populate_zero_collision_tbe_params(ssd_tbe_params, self._bucket_spec) compute_kernel = config.embedding_tables[0].compute_kernel embedding_location = compute_kernel_to_embedding_location(compute_kernel) @@ -1334,38 +1385,6 @@ def fused_optimizer(self) -> FusedOptimizer: """ return self._optim - def get_sharded_local_buckets(self) -> List[Tuple[int, int, int]]: - """ - utils to get bucket offset start, bucket offset end, bucket size based on embedding sharding spec - """ - sharded_local_buckets: List[Tuple[int, int, int]] = [] - world_size = dist.get_world_size(self._pg) - local_rank = dist.get_rank(self._pg) - - for table in self._config.embedding_tables: - total_num_buckets = none_throws(table.total_num_buckets) - assert ( - total_num_buckets % world_size == 0 - ), f"total_num_buckets={total_num_buckets} must be divisible by world_size={world_size}" - assert ( - table.total_num_buckets - and table.num_embeddings % table.total_num_buckets == 0 - ), f"Table size '{table.num_embeddings}' must be divisible by num_buckets '{table.total_num_buckets}'" - bucket_offset_start = total_num_buckets // world_size * local_rank - bucket_offset_end = min( - total_num_buckets, total_num_buckets // world_size * (local_rank + 1) - ) - bucket_size = ( - table.num_embeddings + total_num_buckets - 1 - ) // total_num_buckets - sharded_local_buckets.append( - (bucket_offset_start, bucket_offset_end, bucket_size) - ) - logger.info( - f"bucket_offset: {bucket_offset_start}:{bucket_offset_end}, bucket_size: {bucket_size} for table {table.name}" - ) - return sharded_local_buckets - def state_dict( self, destination: Optional[Dict[str, Any]] = None, @@ -1553,10 +1572,7 @@ def forward(self, features: KeyedJaggedTensor) -> torch.Tensor: self._split_weights_res = None self._optim.set_sharded_embedding_weight_ids(sharded_embedding_weight_ids=None) - return self.emb_module( - indices=features.values().long(), - offsets=features.offsets().long(), - ) + return super().forward(features) class BatchedFusedEmbedding(BaseBatchedEmbedding[torch.Tensor], FusedOptimizerModule): @@ -1901,6 +1917,11 @@ def __init__( assert ( len({table.embedding_dim for table in config.embedding_tables}) == 1 ), "Currently we expect all tables in SSD TBE to have the same embedding dimension." + for table in config.embedding_tables: + assert table.local_cols % 4 == 0, ( + f"table {table.name} has local_cols={table.local_cols} " + "not divisible by 4. " + ) ssd_tbe_params = _populate_ssd_tbe_params(config) compute_kernel = config.embedding_tables[0].compute_kernel @@ -2070,6 +2091,296 @@ def split_embedding_weights(self, no_snapshot: bool = True) -> Tuple[ return self.emb_module.split_embedding_weights(no_snapshot) +class ZeroCollisionKeyValueEmbeddingBag( + BaseBatchedEmbeddingBag[torch.Tensor], FusedOptimizerModule +): + def __init__( + self, + config: GroupedEmbeddingConfig, + pg: Optional[dist.ProcessGroup] = None, + device: Optional[torch.device] = None, + sharding_type: Optional[ShardingType] = None, + backend_type: BackendType = BackendType.SSD, + ) -> None: + super().__init__(config, pg, device, sharding_type) + + assert ( + len(config.embedding_tables) > 0 + ), "Expected to see at least one table in SSD TBE, but found 0." + assert ( + len({table.embedding_dim for table in config.embedding_tables}) == 1 + ), "Currently we expect all tables in SSD TBE to have the same embedding dimension." + + for table in config.embedding_tables: + assert table.local_cols % 4 == 0, ( + f"table {table.name} has local_cols={table.local_cols} " + "not divisible by 4. " + ) + + ssd_tbe_params = _populate_ssd_tbe_params(config) + self._bucket_spec: List[Tuple[int, int, int]] = ( + _get_sharded_local_buckets_for_zero_collision( + self._config.embedding_tables, self._pg + ) + ) + _populate_zero_collision_tbe_params(ssd_tbe_params, self._bucket_spec) + compute_kernel = config.embedding_tables[0].compute_kernel + embedding_location = compute_kernel_to_embedding_location(compute_kernel) + + # every split_embeding_weights call is expensive, since it iterates over all the elements in the backend kv db + # use split weights result cache so that multiple calls in the same train iteration will only trigger once + self._split_weights_res: Optional[ + Tuple[ + List[ShardedTensor], + List[ShardedTensor], + List[ShardedTensor], + ] + ] = None + + self._emb_module: SSDTableBatchedEmbeddingBags = SSDTableBatchedEmbeddingBags( + embedding_specs=list(zip(self._num_embeddings, self._local_cols)), + feature_table_map=self._feature_table_map, + ssd_cache_location=embedding_location, + pooling_mode=self._pooling, + backend_type=backend_type, + **ssd_tbe_params, + ).to(device) + + logger.info( + f"tbe_unique_id:{self._emb_module.tbe_unique_id} => table name to count dict:{self.table_name_to_count}" + ) + self._table_name_to_weight_count_per_rank: Dict[str, List[int]] = {} + self._init_sharded_split_embedding_weights() # this will populate self._split_weights_res + self._optim: ZeroCollisionKeyValueEmbeddingFusedOptimizer = ( + ZeroCollisionKeyValueEmbeddingFusedOptimizer( + config, + self._emb_module, + # pyre-ignore[16] + sharded_embedding_weights_by_table=self._split_weights_res[0], + table_name_to_weight_count_per_rank=self._table_name_to_weight_count_per_rank, + sharded_embedding_weight_ids=self._split_weights_res[1], + pg=pg, + ) + ) + self._param_per_table: Dict[str, nn.Parameter] = dict( + _gen_named_parameters_by_table_ssd_pmt( + emb_module=self._emb_module, + table_name_to_count=self.table_name_to_count.copy(), + config=self._config, + pg=pg, + ) + ) + self.init_parameters() + + def init_parameters(self) -> None: + """ + An advantage of KV TBE is that we don't need to init weights. Hence skipping. + """ + pass + + @property + def emb_module( + self, + ) -> SSDTableBatchedEmbeddingBags: + return self._emb_module + + @property + def fused_optimizer(self) -> FusedOptimizer: + """ + SSD Embedding fuses backward with backward. + """ + return self._optim + + def state_dict( + self, + destination: Optional[Dict[str, Any]] = None, + prefix: str = "", + keep_vars: bool = False, + no_snapshot: bool = True, + ) -> Dict[str, Any]: + """ + Args: + no_snapshot (bool): the tensors in the returned dict are + PartiallyMaterializedTensors. this argument controls wether the + PartiallyMaterializedTensor owns a RocksDB snapshot handle. True means the + PartiallyMaterializedTensor doesn't have a RocksDB snapshot handle. False means the + PartiallyMaterializedTensor has a RocksDB snapshot handle + """ + # in the case no_snapshot=False, a flush is required. we rely on the flush operation in + # ShardedEmbeddingBagCollection._pre_state_dict_hook() + + emb_tables, _, _ = self.split_embedding_weights(no_snapshot=no_snapshot) + emb_table_config_copy = copy.deepcopy(self._config.embedding_tables) + for emb_table in emb_table_config_copy: + emb_table.local_metadata.placement._device = torch.device("cpu") + ret = get_state_dict( + emb_table_config_copy, + emb_tables, + self._pg, + destination, + prefix, + ) + return ret + + def named_parameters( + self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True + ) -> Iterator[Tuple[str, nn.Parameter]]: + """ + Only allowed ways to get state_dict. + """ + for name, tensor in self.named_split_embedding_weights( + prefix, recurse, remove_duplicate + ): + # hack before we support optimizer on sharded parameter level + # can delete after PEA deprecation + # pyre-ignore [6] + param = nn.Parameter(tensor) + # pyre-ignore + param._in_backward_optimizers = [EmptyFusedOptimizer()] + yield name, param + + # pyre-ignore [15] + def named_split_embedding_weights( + self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True + ) -> Iterator[Tuple[str, Union[PartiallyMaterializedTensor, torch.Tensor]]]: + assert ( + remove_duplicate + ), "remove_duplicate=False not supported in BaseBatchedEmbedding.named_split_embedding_weights" + for config, tensor in zip( + self._config.embedding_tables, + self.split_embedding_weights()[0], + ): + key = append_prefix(prefix, f"{config.name}.weight") + yield key, tensor + + # initialize sharded _split_weights_res if it's None + # this method is used to generate sharded embedding weights once for all following state_dict + # calls in checkpointing and publishing. + # When training is resumed, the cached value will be reset to None and the value needs to be + # rebuilt for next checkpointing and publishing, as the weight id, weight embedding will be updated + # during training in backend k/v store. + def _init_sharded_split_embedding_weights( + self, prefix: str = "", force_regenerate: bool = False + ) -> None: + if not force_regenerate and self._split_weights_res is not None: + return + + pmt_list, weight_ids_list, bucket_cnt_list = self.split_embedding_weights( + no_snapshot=False, + ) + emb_table_config_copy = copy.deepcopy(self._config.embedding_tables) + for emb_table in emb_table_config_copy: + none_throws( + none_throws( + emb_table.local_metadata, + f"local_metadata is None for emb_table: {emb_table.name}", + ).placement, + f"placement is None for local_metadata of emb table: {emb_table.name}", + )._device = torch.device("cpu") + + pmt_sharded_t_list = create_virtual_sharded_tensors( + emb_table_config_copy, + pmt_list, + self._pg, + prefix, + self._table_name_to_weight_count_per_rank, + ) + weight_id_sharded_t_list = create_virtual_sharded_tensors( + emb_table_config_copy, + weight_ids_list, # pyre-ignore [6] + self._pg, + prefix, + self._table_name_to_weight_count_per_rank, + ) + bucket_cnt_sharded_t_list = create_virtual_sharded_tensors( + emb_table_config_copy, + bucket_cnt_list, # pyre-ignore [6] + self._pg, + prefix, + self._table_name_to_weight_count_per_rank, + use_param_size_as_rows=True, + ) + # pyre-ignore + assert len(pmt_list) == len(weight_ids_list) == len(bucket_cnt_list) + assert ( + len(pmt_sharded_t_list) + == len(weight_id_sharded_t_list) + == len(bucket_cnt_sharded_t_list) + ) + self._split_weights_res = ( + pmt_sharded_t_list, + weight_id_sharded_t_list, + bucket_cnt_sharded_t_list, + ) + + def get_named_split_embedding_weights_snapshot(self, prefix: str = "") -> Iterator[ + Tuple[ + str, + Union[ShardedTensor, PartiallyMaterializedTensor], + Optional[ShardedTensor], + Optional[ShardedTensor], + ] + ]: + """ + Return an iterator over embedding tables, for each table yielding + table name, + PMT for embedding table with a valid RocksDB snapshot to support tensor IO + optional ShardedTensor for weight_id + optional ShardedTensor for bucket_cnt + """ + self._init_sharded_split_embedding_weights() + # pyre-ignore[16] + self._optim.set_sharded_embedding_weight_ids(self._split_weights_res[1]) + + pmt_sharded_t_list = self._split_weights_res[0] + weight_id_sharded_t_list = self._split_weights_res[1] + bucket_cnt_sharded_t_list = self._split_weights_res[2] + for table_idx, pmt_sharded_t in enumerate(pmt_sharded_t_list): + table_config = self._config.embedding_tables[table_idx] + key = append_prefix(prefix, f"{table_config.name}") + + yield key, pmt_sharded_t, weight_id_sharded_t_list[ + table_idx + ], bucket_cnt_sharded_t_list[table_idx] + + def flush(self) -> None: + """ + Flush the embeddings in cache back to SSD. Should be pretty expensive. + """ + self.emb_module.flush() + + def purge(self) -> None: + """ + Reset the cache space. This is needed when we load state dict. + """ + # TODO: move the following to SSD TBE. + self.emb_module.lxu_cache_weights.zero_() + self.emb_module.lxu_cache_state.fill_(-1) + + def create_rocksdb_hard_link_snapshot(self) -> None: + """ + Create a RocksDB checkpoint. This is needed before we call state_dict() for publish. + """ + self.emb_module.create_rocksdb_hard_link_snapshot() + + # pyre-ignore [15] + def split_embedding_weights( + self, no_snapshot: bool = True, should_flush: bool = False + ) -> Tuple[ + Union[List[PartiallyMaterializedTensor], List[torch.Tensor]], + Optional[List[torch.Tensor]], + Optional[List[torch.Tensor]], + ]: + return self.emb_module.split_embedding_weights(no_snapshot, should_flush) + + def forward(self, features: KeyedJaggedTensor) -> torch.Tensor: + # reset split weights during training + self._split_weights_res = None + self._optim.set_sharded_embedding_weight_ids(sharded_embedding_weight_ids=None) + + return super().forward(features) + + class BatchedFusedEmbeddingBag( BaseBatchedEmbeddingBag[torch.Tensor], FusedOptimizerModule ): diff --git a/torchrec/distributed/embedding_lookup.py b/torchrec/distributed/embedding_lookup.py index fce86dc19..122b7be8e 100644 --- a/torchrec/distributed/embedding_lookup.py +++ b/torchrec/distributed/embedding_lookup.py @@ -40,6 +40,7 @@ KeyValueEmbedding, KeyValueEmbeddingBag, ZeroCollisionKeyValueEmbedding, + ZeroCollisionKeyValueEmbeddingBag, ) from torchrec.distributed.comm_ops import get_gradient_division from torchrec.distributed.composable.table_batched_embedding_slice import ( @@ -512,6 +513,24 @@ def _create_embedding_kernel( device=device, sharding_type=sharding_type, ) + elif config.compute_kernel == EmbeddingComputeKernel.SSD_VIRTUAL_TABLE: + # for ssd kv + return ZeroCollisionKeyValueEmbeddingBag( + config=config, + pg=pg, + device=device, + sharding_type=sharding_type, + backend_type=BackendType.SSD, + ) + elif config.compute_kernel == EmbeddingComputeKernel.DRAM_VIRTUAL_TABLE: + # for dram kv + return ZeroCollisionKeyValueEmbeddingBag( + config=config, + pg=pg, + device=device, + sharding_type=sharding_type, + backend_type=BackendType.DRAM, + ) else: raise ValueError(f"Compute kernel not supported {config.compute_kernel}") @@ -525,6 +544,8 @@ def _need_prefetch(config: GroupedEmbeddingConfig) -> bool: if ( table.compute_kernel == EmbeddingComputeKernel.FUSED_UVM_CACHING or table.compute_kernel == EmbeddingComputeKernel.KEY_VALUE + or table.compute_kernel == EmbeddingComputeKernel.SSD_VIRTUAL_TABLE + or table.compute_kernel == EmbeddingComputeKernel.DRAM_VIRTUAL_TABLE ): return True return False @@ -720,7 +741,9 @@ def get_named_split_embedding_weights_snapshot( RocksDB snapshot to support windowed access. """ for emb_module in self._emb_modules: - if isinstance(emb_module, KeyValueEmbeddingBag): + if isinstance(emb_module, KeyValueEmbeddingBag) or isinstance( + emb_module, ZeroCollisionKeyValueEmbeddingBag + ): yield from emb_module.get_named_split_embedding_weights_snapshot() def flush(self) -> None: diff --git a/torchrec/distributed/embeddingbag.py b/torchrec/distributed/embeddingbag.py index 5fecc7272..cdfe7b496 100644 --- a/torchrec/distributed/embeddingbag.py +++ b/torchrec/distributed/embeddingbag.py @@ -35,6 +35,7 @@ from torch.nn.modules.module import _IncompatibleKeys from torch.nn.parallel import DistributedDataParallel from torchrec.distributed.comm import get_local_size +from torchrec.distributed.embedding_lookup import PartiallyMaterializedTensor from torchrec.distributed.embedding_sharding import ( EmbeddingSharding, EmbeddingShardingContext, @@ -292,6 +293,8 @@ def create_sharding_infos_by_sharding_device_group( getattr(config, "num_embeddings_post_pruning", None) # TODO: Need to check if attribute exists for BC ), + total_num_buckets=config.total_num_buckets, + use_virtual_table=config.use_virtual_table, ), param_sharding=parameter_sharding, param=param, @@ -558,6 +561,7 @@ def __init__( tbe_module.fused_optimizer.params = params optims.append(("", tbe_module.fused_optimizer)) self._optim: CombinedOptimizer = CombinedOptimizer(optims) + self._skip_missing_weight_key: List[str] = [] for i, (sharding, lookup) in enumerate( zip(self._embedding_shardings, self._lookups) @@ -687,6 +691,8 @@ def create_grouped_sharding_infos( getattr(config, "num_embeddings_post_pruning", None) # TODO: Need to check if attribute exists for BC ), + total_num_buckets=config.total_num_buckets, + use_virtual_table=config.use_virtual_table, ), param_sharding=parameter_sharding, param=param, @@ -788,6 +794,27 @@ def _pre_load_state_dict_hook( to transform from ShardedTensors/DTensors into tensors """ for table_name in self._model_parallel_name_to_local_shards.keys(): + if self._table_name_to_config[table_name].use_virtual_table: + # weight_id and bucket are generated at the runtime of state_dict instead of registered class + # so we need to erase them before passing into load_state_dict + weight_key = f"{prefix}embedding_bags.{table_name}.weight" + weight_id_key = f"{prefix}embedding_bags.{table_name}.weight_id" + bucket_key = f"{prefix}embedding_bags.{table_name}.bucket" + if weight_id_key in state_dict: + del state_dict[weight_id_key] + if bucket_key in state_dict: + del state_dict[bucket_key] + assert weight_key in state_dict + assert ( + len(self._model_parallel_name_to_local_shards[table_name]) == 1 + ), "currently only support 1 shard per rank" + + # for loading state_dict into virtual table, we skip the weights assignment + # if needed, for now this should be handled separately outside of load_state_dict call + self._skip_missing_weight_key.append(weight_key) + del state_dict[weight_key] + continue + key = f"{prefix}embedding_bags.{table_name}.weight" # gather model shards from both DTensor and ShardedTensor maps model_shards_sharded_tensor = self._model_parallel_name_to_local_shards[ @@ -947,6 +974,8 @@ def _initialize_torch_state(self, skip_registering: bool = False) -> None: # no shards_wrapper["global_stride"] = v.stride() shards_wrapper["placements"] = v.placements elif isinstance(v, ShardedTensor): + # for virtual table, we only populate the shardedTensor for Embedding Table during + # initial state_dict calls, skip weight id and bucket tensor self._model_parallel_name_to_local_shards[table_name].extend( v.local_shards() ) @@ -956,6 +985,10 @@ def _initialize_torch_state(self, skip_registering: bool = False) -> None: # no # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute # `named_parameters_by_table`. ) in lookup.named_parameters_by_table(): + # for virtual table, currently we don't expose id tensor and bucket tensor + # because they are not updated in real time, and they are created on the fly + # whenever state_dict is called + # reference: ƒbgs _gen_named_parameters_by_table_ssd_pmt self.embedding_bags[table_name].register_parameter("weight", tbe_slice) for table_name in self._model_parallel_name_to_local_shards.keys(): @@ -1061,7 +1094,9 @@ def extract_sharded_kvtensors( sharded_t, ) in module._model_parallel_name_to_sharded_tensor.items(): if _model_parallel_name_to_compute_kernel[table_name] in { - EmbeddingComputeKernel.KEY_VALUE.value + EmbeddingComputeKernel.KEY_VALUE.value, + EmbeddingComputeKernel.SSD_VIRTUAL_TABLE.value, + EmbeddingComputeKernel.DRAM_VIRTUAL_TABLE.value, }: ret[table_name] = sharded_t return ret @@ -1093,24 +1128,88 @@ def post_state_dict_hook( return sharded_kvtensors_copy = copy.deepcopy(sharded_kvtensors) + virtual_table_sharded_t_map: Optional[ + Dict[str, Tuple[ShardedTensor, ShardedTensor]] + ] = None for lookup, sharding in zip(module._lookups, module._embedding_shardings): if not isinstance(sharding, DpPooledEmbeddingSharding): for ( - key, - v, - _, - _, + table_name, + weights_t, + weight_ids_sharded_t, + id_cnt_per_bucket_sharded_t, ) in ( lookup.get_named_split_embedding_weights_snapshot() # pyre-ignore ): - assert key in sharded_kvtensors_copy - sharded_kvtensors_copy[key].local_shards()[0].tensor = v + assert table_name in sharded_kvtensors_copy + if self._table_name_to_config[table_name].use_virtual_table: + assert isinstance(weights_t, ShardedTensor) + if virtual_table_sharded_t_map is None: + virtual_table_sharded_t_map = {} + assert ( + weight_ids_sharded_t is not None + and id_cnt_per_bucket_sharded_t is not None + ) + # The logic here assumes there is only one shard per table on any particular rank + # if there are cases each rank has >1 shards, we need to update here accordingly + sharded_kvtensors_copy[table_name] = weights_t + virtual_table_sharded_t_map[table_name] = ( + weight_ids_sharded_t, + id_cnt_per_bucket_sharded_t, + ) + else: + assert isinstance(weights_t, PartiallyMaterializedTensor) + assert ( + weight_ids_sharded_t is None + and id_cnt_per_bucket_sharded_t is None + ) + # The logic here assumes there is only one shard per table on any particular rank + # if there are cases each rank has >1 shards, we need to update here accordingly + # pyre-ignore + sharded_kvtensors_copy[table_name].local_shards()[ + 0 + ].tensor = weights_t + + def update_destination( + table_name: str, + tensor_name: str, + destination: Dict[str, torch.Tensor], + value: torch.Tensor, + ) -> None: + destination_key = f"{prefix}embedding_bags.{table_name}.{tensor_name}" + destination[destination_key] = value + for ( table_name, sharded_kvtensor, ) in sharded_kvtensors_copy.items(): - destination_key = f"{prefix}embedding_bags.{table_name}.weight" - destination[destination_key] = sharded_kvtensor + update_destination(table_name, "weight", destination, sharded_kvtensor) + if ( + virtual_table_sharded_t_map + and table_name in virtual_table_sharded_t_map + ): + update_destination( + table_name, + "weight_id", + destination, + virtual_table_sharded_t_map[table_name][0], + ) + update_destination( + table_name, + "bucket", + destination, + virtual_table_sharded_t_map[table_name][1], + ) + + def _post_load_state_dict_hook( + module: "ShardedEmbeddingBagCollection", + incompatible_keys: _IncompatibleKeys, + ) -> None: + if incompatible_keys.missing_keys: + # has to remove the key inplace + for skip_key in module._skip_missing_weight_key: + if skip_key in incompatible_keys.missing_keys: + incompatible_keys.missing_keys.remove(skip_key) if not skip_registering: self.register_state_dict_pre_hook(self._pre_state_dict_hook) @@ -1118,6 +1217,8 @@ def post_state_dict_hook( self._register_load_state_dict_pre_hook( self._pre_load_state_dict_hook, with_module=True ) + self.register_load_state_dict_post_hook(_post_load_state_dict_hook) + self.reset_parameters() def reset_parameters(self) -> None: @@ -1128,6 +1229,8 @@ def reset_parameters(self) -> None: for table_config in self._embedding_bag_configs: if self.module_sharding_plan[table_config.name].compute_kernel in { EmbeddingComputeKernel.KEY_VALUE.value, + EmbeddingComputeKernel.SSD_VIRTUAL_TABLE.value, + EmbeddingComputeKernel.DRAM_VIRTUAL_TABLE.value, }: continue assert table_config.init_fn is not None diff --git a/torchrec/distributed/tests/test_model_parallel_nccl_ssd_single_gpu.py b/torchrec/distributed/tests/test_model_parallel_nccl_ssd_single_gpu.py index 6b9ee360b..b7314585a 100644 --- a/torchrec/distributed/tests/test_model_parallel_nccl_ssd_single_gpu.py +++ b/torchrec/distributed/tests/test_model_parallel_nccl_ssd_single_gpu.py @@ -20,6 +20,7 @@ KeyValueEmbedding, KeyValueEmbeddingBag, ZeroCollisionKeyValueEmbedding, + ZeroCollisionKeyValueEmbeddingBag, ) from torchrec.distributed.embedding_types import ( EmbeddingComputeKernel, @@ -853,6 +854,437 @@ def test_ssd_load_state_dict( self._compare_models(m1, m2, is_deterministic=is_deterministic) +class ZeroCollisionModelParallelTest(ModelParallelSingleRankBase): + def _create_tables(self) -> None: + num_features = 4 + self.tables += [ + EmbeddingBagConfig( + num_embeddings=(i + 1) * 1000, + embedding_dim=256, + name="table_" + str(i), + feature_names=["feature_" + str(i)], + total_num_buckets=10, + use_virtual_table=True, + ) + for i in range(num_features) + ] + + @staticmethod + def _copy_ssd_emb_modules( + m1: DistributedModelParallel, m2: DistributedModelParallel + ) -> None: + """ + Util function to copy and set the SSD TBE modules of two models. It + requires both DMP modules to have the same sharding plan. + """ + for lookup1, lookup2 in zip( + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `ebc`. + m1.module.sparse.ebc._lookups, + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `ebc`. + m2.module.sparse.ebc._lookups, + ): + for emb_module1, emb_module2 in zip( + lookup1._emb_modules, lookup2._emb_modules + ): + ssd_emb_modules = { + ZeroCollisionKeyValueEmbeddingBag, + ZeroCollisionKeyValueEmbedding, + } + if type(emb_module1) in ssd_emb_modules: + assert type(emb_module1) is type(emb_module2), ( + "Expect two emb_modules to be of the same type, either both " + "SSDEmbeddingBag or SSDEmbeddingBag." + ) + emb_module1.flush() + emb_module2.flush() + + emb1_kv = { + t: (sharded_t, sharded_w_id, bucket) + for t, sharded_t, sharded_w_id, bucket in emb_module1.get_named_split_embedding_weights_snapshot() + } + for ( + t, + sharded_t2, + _, + _, + ) in emb_module2.get_named_split_embedding_weights_snapshot(): + assert t in emb1_kv + sharded_t1 = emb1_kv[t][0] + sharded_w1_id = emb1_kv[t][1] + w1_id = sharded_w1_id.local_shards()[0].tensor + + pmt1 = sharded_t1.local_shards()[0].tensor + w1 = pmt1.get_weights_by_ids(w1_id) + + # write value into ssd for both emb module for later comparison + pmt2 = sharded_t2.local_shards()[0].tensor + pmt2.wrapped.set_weights_and_ids(w1, w1_id.view(-1)) + + # purge after loading. This is needed, since we pass a batch + # through dmp when instantiating them. + emb_module1.purge() + emb_module2.purge() + + @staticmethod + def _copy_fused_modules_into_ssd_emb_modules( + fused_m: DistributedModelParallel, ssd_m: DistributedModelParallel + ) -> None: + """ + Util function to copy from fused embedding module to SSD TBE for initialization. It + requires both DMP modules to have the same sharding plan. + """ + + for fused_lookup, ssd_lookup in zip( + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `ec`. + fused_m.module.sparse.ebc._lookups, + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `ec`. + ssd_m.module.sparse.ebc._lookups, + ): + for fused_emb_module, ssd_emb_module in zip( + fused_lookup._emb_modules, ssd_lookup._emb_modules + ): + ssd_emb_modules = { + ZeroCollisionKeyValueEmbeddingBag, + ZeroCollisionKeyValueEmbedding, + } + if type(ssd_emb_module) in ssd_emb_modules: + fused_state_dict = fused_emb_module.state_dict() + for ( + t, + sharded_t, + _, + _, + ) in ssd_emb_module.get_named_split_embedding_weights_snapshot(): + weight_key = f"{t}.weight" + fused_sharded_t = fused_state_dict[weight_key] + fused_weight = fused_sharded_t.local_shards()[0].tensor.to( + "cpu" + ) + + # write value into ssd for both emb module for later comparison + pmt = sharded_t.local_shards()[0].tensor + pmt.wrapped.set_range(0, 0, fused_weight.size(0), fused_weight) + + # purge after loading. This is needed, since we pass a batch + # through dmp when instantiating them. + fused_emb_module.purge() + ssd_emb_module.purge() + + @unittest.skipIf( + not torch.cuda.is_available(), + "Not enough GPUs, this test requires at least one GPU", + ) + # pyre-ignore[56] + @given( + sharder_type=st.sampled_from( + [ + SharderType.EMBEDDING_BAG_COLLECTION.value, + ] + ), + kernel_type=st.sampled_from( + [ + EmbeddingComputeKernel.SSD_VIRTUAL_TABLE.value, + ] + ), + sharding_type=st.sampled_from( + [ + # TODO: add other test cases when kv embedding support other sharding + ShardingType.ROW_WISE.value, + ] + ), + is_training=st.booleans(), + stochastic_rounding=st.booleans(), + dtype=st.sampled_from([DataType.FP32, DataType.FP16]), + ) + @settings(verbosity=Verbosity.verbose, max_examples=4, deadline=None) + def test_kv_zch_load_state_dict( + self, + sharder_type: str, + kernel_type: str, + sharding_type: str, + is_training: bool, + stochastic_rounding: bool, + dtype: DataType, + ) -> None: + """ + This test checks that if SSD TBE is deterministic. That is, if two SSD + TBEs start with the same state, they would produce the same output. + """ + self._set_table_weights_precision(dtype) + + fused_params = { + "learning_rate": 0.1, + "stochastic_rounding": stochastic_rounding, + } + is_deterministic = dtype == DataType.FP32 or not stochastic_rounding + constraints = { + table.name: ParameterConstraints( + sharding_types=[sharding_type], + compute_kernels=[kernel_type], + key_value_params=KeyValueParams(bulk_init_chunk_size=1024), + ) + for i, table in enumerate(self.tables) + } + sharders = [ + create_test_sharder( + sharder_type, + sharding_type, + kernel_type, + fused_params=fused_params, + ), + ] + + # pyre-ignore + models, batch = self._generate_dmps_and_batch(sharders, constraints=constraints) + m1, m2 = models + + # load state dict for dense modules + m2.load_state_dict(cast("OrderedDict[str, torch.Tensor]", m1.state_dict())) + self._copy_ssd_emb_modules(m1, m2) + + if is_training: + self._train_models(m1, m2, batch) + self._eval_models(m1, m2, batch, is_deterministic=is_deterministic) + self._compare_models( + m1, m2, is_deterministic=is_deterministic, use_virtual_table=True + ) + + @unittest.skipIf( + not torch.cuda.is_available(), + "Not enough GPUs, this test requires at least one GPU", + ) + # pyre-ignore[56] + @given( + sharder_type=st.sampled_from( + [ + SharderType.EMBEDDING_BAG_COLLECTION.value, + ] + ), + kernel_type=st.sampled_from( + [ + EmbeddingComputeKernel.SSD_VIRTUAL_TABLE.value, + ] + ), + sharding_type=st.sampled_from( + [ + # TODO: add other test cases when kv embedding support other sharding + ShardingType.ROW_WISE.value, + ] + ), + is_training=st.booleans(), + stochastic_rounding=st.booleans(), + dtype=st.sampled_from([DataType.FP32, DataType.FP16]), + ) + @settings(verbosity=Verbosity.verbose, max_examples=4, deadline=None) + def test_kv_zch_numerical_accuracy( + self, + sharder_type: str, + kernel_type: str, + sharding_type: str, + is_training: bool, + stochastic_rounding: bool, + dtype: DataType, + ) -> None: + """ + Make sure it produces same numbers as normal TBE. + """ + self._set_table_weights_precision(dtype) + + base_kernel_type = EmbeddingComputeKernel.FUSED.value + learning_rate = 0.1 + fused_params = { + "optimizer": EmbOptimType.EXACT_ROWWISE_ADAGRAD, + "learning_rate": learning_rate, + "stochastic_rounding": stochastic_rounding, + } + is_deterministic = dtype == DataType.FP32 or not stochastic_rounding + fused_sharders = [ + cast( + ModuleSharder[nn.Module], + create_test_sharder( + sharder_type, + sharding_type, + base_kernel_type, # base kernel type + fused_params=fused_params, + ), + ), + ] + ssd_sharders = [ + cast( + ModuleSharder[nn.Module], + create_test_sharder( + sharder_type, + sharding_type, + kernel_type, + fused_params=fused_params, + ), + ), + ] + ssd_constraints = { + table.name: ParameterConstraints( + sharding_types=[sharding_type], + compute_kernels=[kernel_type], + ) + for i, table in enumerate(self.tables) + } + + # for fused model, we need to change the table config to non-kvzch + ssd_tables = copy.deepcopy(self.tables) + for table in self.tables: + table.total_num_buckets = None + table.use_virtual_table = False + (fused_model, _), _ = self._generate_dmps_and_batch(fused_sharders) + self.tables = ssd_tables + (ssd_model, _), batch = self._generate_dmps_and_batch( + ssd_sharders, constraints=ssd_constraints + ) + + # load state dict for dense modules + ssd_model.load_state_dict( + cast("OrderedDict[str, torch.Tensor]", fused_model.state_dict()) + ) + + # for this to work, we expect the order of lookups to be the same + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `ebc`. + assert len(fused_model.module.sparse.ebc._lookups) == len( + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `ebc`. + ssd_model.module.sparse.ebc._lookups + ), "Expect same number of lookups" + + for fused_lookup, ssd_lookup in zip( + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `ebc`. + fused_model.module.sparse.ebc._lookups, + # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `ebc`. + ssd_model.module.sparse.ebc._lookups, + ): + assert len(fused_lookup._emb_modules) == len( + ssd_lookup._emb_modules + ), "Expect same number of emb modules" + + self._copy_fused_modules_into_ssd_emb_modules(fused_model, ssd_model) + + if is_training: + self._train_models(fused_model, ssd_model, batch) + self._eval_models( + fused_model, ssd_model, batch, is_deterministic=is_deterministic + ) + + @unittest.skipIf( + not torch.cuda.is_available(), + "Not enough GPUs, this test requires at least one GPU", + ) + # pyre-ignore[56] + @given( + sharder_type=st.sampled_from( + [ + SharderType.EMBEDDING_BAG_COLLECTION.value, + ] + ), + kernel_type=st.sampled_from( + [ + EmbeddingComputeKernel.SSD_VIRTUAL_TABLE.value, + ] + ), + sharding_type=st.sampled_from( + [ + # TODO: add other test cases when kv embedding support other sharding + ShardingType.ROW_WISE.value, + ] + ), + is_training=st.booleans(), + stochastic_rounding=st.booleans(), + dtype=st.sampled_from([DataType.FP32, DataType.FP16]), + ) + @settings(verbosity=Verbosity.verbose, max_examples=4, deadline=None) + def test_kv_zch_fused_optimizer( + self, + sharder_type: str, + kernel_type: str, + sharding_type: str, + is_training: bool, + stochastic_rounding: bool, + dtype: DataType, + ) -> None: + """ + Purpose of this test is to make sure it works with warm up policy. + """ + self._set_table_weights_precision(dtype) + + is_deterministic = dtype == DataType.FP32 or not stochastic_rounding + + constraints = { + table.name: ParameterConstraints( + sharding_types=[sharding_type], + compute_kernels=[kernel_type], + key_value_params=KeyValueParams(bulk_init_chunk_size=1024), + ) + for i, table in enumerate(self.tables) + } + + base_sharders = [ + create_test_sharder( + sharder_type, + sharding_type, + kernel_type, + fused_params={ + "learning_rate": 0.2, + "stochastic_rounding": stochastic_rounding, + }, + ), + ] + models, batch = self._generate_dmps_and_batch( + base_sharders, # pyre-ignore + constraints=constraints, + ) + base_model, _ = models + + test_sharders = [ + create_test_sharder( + sharder_type, + sharding_type, + kernel_type, + fused_params={ + "learning_rate": 0.1, + "stochastic_rounding": stochastic_rounding, + }, + ), + ] + models, _ = self._generate_dmps_and_batch( + test_sharders, # pyre-ignore + constraints=constraints, + ) + test_model, _ = models + + # load state dict for dense modules + test_model.load_state_dict( + cast("OrderedDict[str, torch.Tensor]", base_model.state_dict()) + ) + self._copy_ssd_emb_modules(base_model, test_model) + + self._eval_models( + base_model, test_model, batch, is_deterministic=is_deterministic + ) + + # change learning rate for test_model + fused_opt = test_model.fused_optimizer + # pyre-ignore + fused_opt.param_groups[0]["lr"] = 0.2 + fused_opt.zero_grad() + + if is_training: + self._train_models(base_model, test_model, batch) + self._eval_models( + base_model, test_model, batch, is_deterministic=is_deterministic + ) + self._compare_models(base_model, test_model, is_deterministic=is_deterministic) + + # TODO: uncomment this when we have multiple kernels in rw support(unblock input dist) + # def test_ssd_mixed_kernels + + # TODO: uncomment this when we support different sharding types, e.g. tw, tw_rw together with rw + # def test_ssd_mixed_sharding_types + + class ZeroCollisionSequenceModelParallelStateDictTest(ModelParallelSingleRankBase): def setUp(self, backend: str = "nccl") -> None: self.shared_features = [] @@ -1196,15 +1628,6 @@ def test_kv_zch_numerical_accuracy( fused_model, ssd_model, batch, is_deterministic=is_deterministic ) - # TODO: uncomment this when we have optimizer plumb through - # def test_ssd_fused_optimizer( - - # TODO: uncomment this when we have multiple kernels in rw support(unblock input dist) - # def test_ssd_mixed_kernels - - # TODO: uncomment this when we support different sharding types, e.g. tw, tw_rw together with rw - # def test_ssd_mixed_sharding_types - # TODO: remove after development is done def main() -> None: