Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions torchrec/distributed/test_utils/infer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@
quant_prep_enable_quant_state_dict_split_scale_bias_for_types,
quant_prep_enable_register_tbes,
)
from torchrec.sparse.jagged_tensor import is_non_strict_exporting


@dataclass
Expand Down Expand Up @@ -126,6 +127,7 @@ def forward(
self,
values: torch.Tensor,
lengths: torch.Tensor,
weights: Optional[torch.Tensor] = None,
# pyre-ignore
*args,
# pyre-ignore
Expand All @@ -135,12 +137,47 @@ def forward(
keys=self._kjt_keys,
values=values,
lengths=lengths,
weights=weights,
)
output = self._module_kjt_input(kjt, *args, **kwargs)
# TODO(ivankobzarev): Support of None leaves in dynamo/export (e.g. KJT offsets)
return [leaf for leaf in pytree.tree_leaves(output) if leaf is not None]


class KJTInputExportDynamicShapeWrapper(torch.nn.Module):
def __init__(
self,
kjt_input_wrapper: KJTInputExportWrapper,
) -> None:
super().__init__()
self.kjt_input_wrapper = kjt_input_wrapper

# pyre-ignore
def forward(
self,
values: torch.Tensor,
lengths: torch.Tensor,
weights: Optional[torch.Tensor] = None,
# pyre-ignore
*args,
# pyre-ignore
**kwargs,
):
# Generate unbacked symints to represent sizes
# for values and weights, constrain them reasonably
values_size = values[0].item()
torch._check_is_size(values_size)
torch._check(values_size >= lengths.shape[0])
values = torch.ones(values_size).to(values.device)
if weights is not None:
weights_size = weights.int()[0].item()
torch._check_is_size(weights_size)
torch._check(weights_size >= lengths.shape[0])
weights = torch.ones(weights_size).to(weights.device)

return self.kjt_input_wrapper(values, lengths, weights, *args, **kwargs)


def prep_inputs(
model_info: TestModelInfo,
world_size: int,
Expand Down
67 changes: 42 additions & 25 deletions torchrec/distributed/tests/test_pt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,13 @@

import sys
import unittest
from typing import List, Tuple
from typing import Any, List, Tuple

import torch
from torchrec.distributed.test_utils.infer_utils import TestQuantFPEBCSharder
from torchrec.distributed.test_utils.infer_utils import (
KJTInputExportDynamicShapeWrapper,
TestQuantFPEBCSharder,
)

try:
# pyre-ignore
Expand Down Expand Up @@ -43,11 +46,13 @@
def make_kjt(values: List[int], lengths: List[int]) -> KeyedJaggedTensor:
values_tensor = torch.tensor(values, dtype=torch.int32)
lengths_tensor = torch.tensor(lengths, dtype=torch.int32)
weights_tensor = torch.randn(len(values), dtype=torch.float32)
torch._check(torch.sum(lengths_tensor).item() == values_tensor.size(0))
kjt = KeyedJaggedTensor(
keys=[f"key{i}" for i in range(len(lengths))],
values=values_tensor,
lengths=lengths_tensor,
weights=weights_tensor,
)
return kjt

Expand Down Expand Up @@ -125,21 +130,21 @@ class TestPt2(unittest.TestCase):
def _test_kjt_input_module(
self,
kjt_input_module: torch.nn.Module,
kjt_keys: List[str],
# pyre-ignore
inputs,
kjt: KeyedJaggedTensor,
inputs: Tuple[Any],
test_dynamo: bool = True,
test_aot_inductor: bool = True,
test_pt2_ir_export: bool = False,
) -> None:
with dynamo_skipfiles_allow("torchrec"):
EM: torch.nn.Module = KJTInputExportWrapper(kjt_input_module, kjt_keys)
eager_output = EM(*inputs)
EM: torch.nn.Module = KJTInputExportWrapper(kjt_input_module, kjt.keys())
em_inputs = (kjt.values(), kjt.lengths(), kjt.weights_or_none(), *inputs)
eager_output = EM(*em_inputs)
if test_dynamo:
x = torch._dynamo.export(EM, same_signature=True)(*inputs)
x = torch._dynamo.export(EM, same_signature=True)(*em_inputs)

export_gm = x.graph_module
export_gm_output = export_gm(*inputs)
export_gm_output = export_gm(*em_inputs)

assert_close(eager_output, export_gm_output)

Expand All @@ -152,12 +157,23 @@ def _test_kjt_input_module(
device = "cuda"
# pyre-ignore
aot_inductor_module = AOTIRunnerUtil.load(device, so_path)
aot_actual_output = aot_inductor_module(*inputs)
aot_actual_output = aot_inductor_module(*em_inputs)
assert_close(eager_output, aot_actual_output)

if test_pt2_ir_export:
pt2_ir = torch.export.export(EM, inputs, {}, strict=False)
pt2_ir_output = pt2_ir.module()(*inputs)
symint_wrapper = KJTInputExportDynamicShapeWrapper(EM)

# KJTInputExportDynamicShapeWrapper represents sizes of values/weights
# from first element of values/weights respectively (simulate symint)
# Need to set as size in order to run a proper forward
em_inputs[0][0] = kjt.values().size(0)
em_inputs[2][0] = kjt.weights().size(0)
eager_output = symint_wrapper(*em_inputs)
pt2_ir = torch.export.export(
symint_wrapper, em_inputs, {}, strict=False
)

pt2_ir_output = pt2_ir.module()(*em_inputs)
assert_close(eager_output, pt2_ir_output)

def test_kjt_split(self) -> None:
Expand All @@ -166,11 +182,11 @@ def forward(self, kjt: KeyedJaggedTensor):
return kjt.split([1, 2, 1])

kjt: KeyedJaggedTensor = make_kjt([2, 3, 4, 5, 6], [1, 2, 1, 1])
segments: List[int] = [1, 2, 1]

self._test_kjt_input_module(
M(),
kjt.keys(),
(kjt._values, kjt._lengths),
kjt,
(),
test_aot_inductor=False,
test_dynamo=False,
test_pt2_ir_export=True,
Expand All @@ -185,8 +201,8 @@ def forward(self, kjt: KeyedJaggedTensor, indices: List[int]):
indices: List[int] = [1, 0, 3, 2]
self._test_kjt_input_module(
M(),
kjt.keys(),
(kjt._values, kjt._lengths, indices),
kjt,
(indices,),
test_aot_inductor=False,
test_pt2_ir_export=True,
)
Expand All @@ -200,8 +216,8 @@ def forward(self, kjt: KeyedJaggedTensor):

self._test_kjt_input_module(
M(),
kjt.keys(),
(kjt._values, kjt._lengths),
kjt,
(),
test_aot_inductor=False,
test_pt2_ir_export=True,
)
Expand All @@ -215,8 +231,8 @@ def forward(self, kjt: KeyedJaggedTensor):

self._test_kjt_input_module(
M(),
kjt.keys(),
(kjt._values, kjt._lengths),
kjt,
(),
test_aot_inductor=False,
test_pt2_ir_export=True,
)
Expand All @@ -230,12 +246,13 @@ def forward(self, kjt: KeyedJaggedTensor):

return out0, out1

# First element represents symint for values and weights shape
kjt: KeyedJaggedTensor = make_kjt([2, 3, 4, 5, 6], [1, 2, 1, 1])

self._test_kjt_input_module(
M(),
kjt.keys(),
(kjt._values, kjt._lengths),
kjt,
(),
test_dynamo=False,
test_aot_inductor=False,
test_pt2_ir_export=True,
Expand Down Expand Up @@ -367,8 +384,8 @@ def test_maybe_compute_kjt_to_jt_dict(self) -> None:
kjt: KeyedJaggedTensor = make_kjt([2, 3, 4, 5, 6], [1, 2, 1, 1])
self._test_kjt_input_module(
ComputeKJTToJTDict(),
kjt.keys(),
(kjt._values, kjt._lengths),
kjt,
(),
# TODO: turn on AOT Inductor test once the support is ready
test_aot_inductor=False,
)
9 changes: 9 additions & 0 deletions torchrec/sparse/jagged_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1865,6 +1865,11 @@ def split(self, segments: List[int]) -> List["KeyedJaggedTensor"]:
_lengths = torch.narrow(
self.lengths(), 0, lengths_start, lengths_sz
)

if self.weights_or_none() is not None:
torch._check(start_offset + sz <= self.weights().size(0))
torch._check(start_offset <= self.weights().size(0))

split_list.append(
KeyedJaggedTensor(
keys=keys,
Expand Down Expand Up @@ -2063,6 +2068,10 @@ def __getitem__(self, key: str) -> JaggedTensor:
torch._check(start_offset <= self.values().size(0))
torch._check(sz <= self.values().size(0))

if self.weights_or_none() is not None:
torch._check(start_offset <= self.weights().size(0))
torch._check(sz <= self.weights().size(0))

return JaggedTensor(
values=torch.narrow(
self.values(),
Expand Down