diff --git a/torchrec/distributed/test_utils/infer_utils.py b/torchrec/distributed/test_utils/infer_utils.py index 534d88cb1..44cc95f23 100644 --- a/torchrec/distributed/test_utils/infer_utils.py +++ b/torchrec/distributed/test_utils/infer_utils.py @@ -91,6 +91,7 @@ quant_prep_enable_quant_state_dict_split_scale_bias_for_types, quant_prep_enable_register_tbes, ) +from torchrec.sparse.jagged_tensor import is_non_strict_exporting @dataclass @@ -126,6 +127,7 @@ def forward( self, values: torch.Tensor, lengths: torch.Tensor, + weights: Optional[torch.Tensor] = None, # pyre-ignore *args, # pyre-ignore @@ -135,12 +137,47 @@ def forward( keys=self._kjt_keys, values=values, lengths=lengths, + weights=weights, ) output = self._module_kjt_input(kjt, *args, **kwargs) # TODO(ivankobzarev): Support of None leaves in dynamo/export (e.g. KJT offsets) return [leaf for leaf in pytree.tree_leaves(output) if leaf is not None] +class KJTInputExportDynamicShapeWrapper(torch.nn.Module): + def __init__( + self, + kjt_input_wrapper: KJTInputExportWrapper, + ) -> None: + super().__init__() + self.kjt_input_wrapper = kjt_input_wrapper + + # pyre-ignore + def forward( + self, + values: torch.Tensor, + lengths: torch.Tensor, + weights: Optional[torch.Tensor] = None, + # pyre-ignore + *args, + # pyre-ignore + **kwargs, + ): + # Generate unbacked symints to represent sizes + # for values and weights, constrain them reasonably + values_size = values[0].item() + torch._check_is_size(values_size) + torch._check(values_size >= lengths.shape[0]) + values = torch.ones(values_size).to(values.device) + if weights is not None: + weights_size = weights.int()[0].item() + torch._check_is_size(weights_size) + torch._check(weights_size >= lengths.shape[0]) + weights = torch.ones(weights_size).to(weights.device) + + return self.kjt_input_wrapper(values, lengths, weights, *args, **kwargs) + + def prep_inputs( model_info: TestModelInfo, world_size: int, diff --git a/torchrec/distributed/tests/test_pt2.py b/torchrec/distributed/tests/test_pt2.py index 201bc7024..eae2d1454 100644 --- a/torchrec/distributed/tests/test_pt2.py +++ b/torchrec/distributed/tests/test_pt2.py @@ -9,10 +9,13 @@ import sys import unittest -from typing import List, Tuple +from typing import Any, List, Tuple import torch -from torchrec.distributed.test_utils.infer_utils import TestQuantFPEBCSharder +from torchrec.distributed.test_utils.infer_utils import ( + KJTInputExportDynamicShapeWrapper, + TestQuantFPEBCSharder, +) try: # pyre-ignore @@ -43,11 +46,13 @@ def make_kjt(values: List[int], lengths: List[int]) -> KeyedJaggedTensor: values_tensor = torch.tensor(values, dtype=torch.int32) lengths_tensor = torch.tensor(lengths, dtype=torch.int32) + weights_tensor = torch.randn(len(values), dtype=torch.float32) torch._check(torch.sum(lengths_tensor).item() == values_tensor.size(0)) kjt = KeyedJaggedTensor( keys=[f"key{i}" for i in range(len(lengths))], values=values_tensor, lengths=lengths_tensor, + weights=weights_tensor, ) return kjt @@ -125,21 +130,21 @@ class TestPt2(unittest.TestCase): def _test_kjt_input_module( self, kjt_input_module: torch.nn.Module, - kjt_keys: List[str], - # pyre-ignore - inputs, + kjt: KeyedJaggedTensor, + inputs: Tuple[Any], test_dynamo: bool = True, test_aot_inductor: bool = True, test_pt2_ir_export: bool = False, ) -> None: with dynamo_skipfiles_allow("torchrec"): - EM: torch.nn.Module = KJTInputExportWrapper(kjt_input_module, kjt_keys) - eager_output = EM(*inputs) + EM: torch.nn.Module = KJTInputExportWrapper(kjt_input_module, kjt.keys()) + em_inputs = (kjt.values(), kjt.lengths(), kjt.weights_or_none(), *inputs) + eager_output = EM(*em_inputs) if test_dynamo: - x = torch._dynamo.export(EM, same_signature=True)(*inputs) + x = torch._dynamo.export(EM, same_signature=True)(*em_inputs) export_gm = x.graph_module - export_gm_output = export_gm(*inputs) + export_gm_output = export_gm(*em_inputs) assert_close(eager_output, export_gm_output) @@ -152,12 +157,23 @@ def _test_kjt_input_module( device = "cuda" # pyre-ignore aot_inductor_module = AOTIRunnerUtil.load(device, so_path) - aot_actual_output = aot_inductor_module(*inputs) + aot_actual_output = aot_inductor_module(*em_inputs) assert_close(eager_output, aot_actual_output) if test_pt2_ir_export: - pt2_ir = torch.export.export(EM, inputs, {}, strict=False) - pt2_ir_output = pt2_ir.module()(*inputs) + symint_wrapper = KJTInputExportDynamicShapeWrapper(EM) + + # KJTInputExportDynamicShapeWrapper represents sizes of values/weights + # from first element of values/weights respectively (simulate symint) + # Need to set as size in order to run a proper forward + em_inputs[0][0] = kjt.values().size(0) + em_inputs[2][0] = kjt.weights().size(0) + eager_output = symint_wrapper(*em_inputs) + pt2_ir = torch.export.export( + symint_wrapper, em_inputs, {}, strict=False + ) + + pt2_ir_output = pt2_ir.module()(*em_inputs) assert_close(eager_output, pt2_ir_output) def test_kjt_split(self) -> None: @@ -166,11 +182,11 @@ def forward(self, kjt: KeyedJaggedTensor): return kjt.split([1, 2, 1]) kjt: KeyedJaggedTensor = make_kjt([2, 3, 4, 5, 6], [1, 2, 1, 1]) - segments: List[int] = [1, 2, 1] + self._test_kjt_input_module( M(), - kjt.keys(), - (kjt._values, kjt._lengths), + kjt, + (), test_aot_inductor=False, test_dynamo=False, test_pt2_ir_export=True, @@ -185,8 +201,8 @@ def forward(self, kjt: KeyedJaggedTensor, indices: List[int]): indices: List[int] = [1, 0, 3, 2] self._test_kjt_input_module( M(), - kjt.keys(), - (kjt._values, kjt._lengths, indices), + kjt, + (indices,), test_aot_inductor=False, test_pt2_ir_export=True, ) @@ -200,8 +216,8 @@ def forward(self, kjt: KeyedJaggedTensor): self._test_kjt_input_module( M(), - kjt.keys(), - (kjt._values, kjt._lengths), + kjt, + (), test_aot_inductor=False, test_pt2_ir_export=True, ) @@ -215,8 +231,8 @@ def forward(self, kjt: KeyedJaggedTensor): self._test_kjt_input_module( M(), - kjt.keys(), - (kjt._values, kjt._lengths), + kjt, + (), test_aot_inductor=False, test_pt2_ir_export=True, ) @@ -230,12 +246,13 @@ def forward(self, kjt: KeyedJaggedTensor): return out0, out1 + # First element represents symint for values and weights shape kjt: KeyedJaggedTensor = make_kjt([2, 3, 4, 5, 6], [1, 2, 1, 1]) self._test_kjt_input_module( M(), - kjt.keys(), - (kjt._values, kjt._lengths), + kjt, + (), test_dynamo=False, test_aot_inductor=False, test_pt2_ir_export=True, @@ -367,8 +384,8 @@ def test_maybe_compute_kjt_to_jt_dict(self) -> None: kjt: KeyedJaggedTensor = make_kjt([2, 3, 4, 5, 6], [1, 2, 1, 1]) self._test_kjt_input_module( ComputeKJTToJTDict(), - kjt.keys(), - (kjt._values, kjt._lengths), + kjt, + (), # TODO: turn on AOT Inductor test once the support is ready test_aot_inductor=False, ) diff --git a/torchrec/sparse/jagged_tensor.py b/torchrec/sparse/jagged_tensor.py index 7dab61b92..7ef22056d 100644 --- a/torchrec/sparse/jagged_tensor.py +++ b/torchrec/sparse/jagged_tensor.py @@ -1865,6 +1865,11 @@ def split(self, segments: List[int]) -> List["KeyedJaggedTensor"]: _lengths = torch.narrow( self.lengths(), 0, lengths_start, lengths_sz ) + + if self.weights_or_none() is not None: + torch._check(start_offset + sz <= self.weights().size(0)) + torch._check(start_offset <= self.weights().size(0)) + split_list.append( KeyedJaggedTensor( keys=keys, @@ -2063,6 +2068,10 @@ def __getitem__(self, key: str) -> JaggedTensor: torch._check(start_offset <= self.values().size(0)) torch._check(sz <= self.values().size(0)) + if self.weights_or_none() is not None: + torch._check(start_offset <= self.weights().size(0)) + torch._check(sz <= self.weights().size(0)) + return JaggedTensor( values=torch.narrow( self.values(),