Skip to content

Commit 0aee072

Browse files
PaulZhang12facebook-github-bot
authored andcommitted
More efficient cumsum for offset_per_key calculation for torch.export (pytorch#1788)
Summary: Pull Request resolved: pytorch#1788 Differential Revision: D54135686
1 parent d9fbac9 commit 0aee072

File tree

2 files changed

+60
-2
lines changed

2 files changed

+60
-2
lines changed

torchrec/distributed/tests/test_pt2.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,21 @@ def forward(self, kjt: KeyedJaggedTensor):
178178
test_pt2_ir_export=True,
179179
)
180180

181+
def test_kjt_offset_per_key(self) -> None:
182+
class M(torch.nn.Module):
183+
def forward(self, kjt: KeyedJaggedTensor):
184+
return kjt.offset_per_key()
185+
186+
kjt: KeyedJaggedTensor = make_kjt([2, 3, 4, 5, 6], [1, 2, 1, 1])
187+
188+
self._test_kjt_input_module(
189+
M(),
190+
kjt.keys(),
191+
(kjt._values, kjt._lengths),
192+
test_aot_inductor=False,
193+
test_pt2_ir_export=True,
194+
)
195+
181196
# pyre-ignore
182197
@unittest.skipIf(
183198
torch.cuda.device_count() <= 1,

torchrec/sparse/jagged_tensor.py

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,10 @@ def _permute_tensor_by_segments(
244244
return permuted_tensor, permuted_weights
245245

246246

247+
def is_non_strict_exporting() -> bool:
248+
return not torch.compiler.is_dynamo_compiling() and torch.compiler.is_compiling()
249+
250+
247251
class JaggedTensorMeta(abc.ABCMeta, torch.fx._symbolic_trace.ProxyableClassMeta):
248252
pass
249253

@@ -822,9 +826,48 @@ def _maybe_compute_offset_per_key(
822826
offsets=offsets,
823827
values=values,
824828
)
825-
return _length_per_key, _cumsum(_length_per_key)
829+
830+
if is_non_strict_exporting():
831+
# only torch.export non-strict case
832+
return (
833+
_length_per_key,
834+
(
835+
torch.ops.fbgemm.asynchronous_complete_cumsum(
836+
torch._refs.tensor(
837+
_length_per_key,
838+
dtype=torch.int32,
839+
device=torch.device("cpu"),
840+
pin_memory=False,
841+
requires_grad=False,
842+
)
843+
).tolist()
844+
if len(_length_per_key) > 0
845+
else []
846+
),
847+
)
848+
else:
849+
return _length_per_key, _cumsum(_length_per_key)
826850
elif offset_per_key is None:
827-
return length_per_key, _cumsum(length_per_key)
851+
if is_non_strict_exporting():
852+
# only torch.export non-strict case
853+
return (
854+
length_per_key,
855+
(
856+
torch.ops.fbgemm.asynchronous_complete_cumsum(
857+
torch._refs.tensor(
858+
length_per_key,
859+
dtype=torch.int32,
860+
device=torch.device("cpu"),
861+
pin_memory=False,
862+
requires_grad=False,
863+
)
864+
).tolist()
865+
if len(length_per_key) > 0
866+
else []
867+
),
868+
)
869+
else:
870+
return length_per_key, _cumsum(length_per_key)
828871
else:
829872
return length_per_key, offset_per_key
830873

0 commit comments

Comments
 (0)