Skip to content

Split of "[TorchRec][PT2] KJT custom op for 1d lengths input" (#2163) #2176

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 27 additions & 2 deletions torchrec/sparse/jagged_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1955,8 +1955,20 @@ def permute(
indices_tensor,
self.weights_or_none(),
)
elif is_torchdynamo_compiling():
(
permuted_lengths,
permuted_values,
permuted_weights,
) = torch.ops.fbgemm.permute_2D_sparse_data_input1D(
indices_tensor,
self.lengths(),
self.values(),
self.stride(),
self.weights_or_none(),
permuted_length_per_key_sum,
)
else:

(
permuted_lengths,
permuted_values,
Expand Down Expand Up @@ -2338,7 +2350,20 @@ def dist_init(
s == stride for s in stride_per_rank
)

if single_batch_per_rank:
if single_batch_per_rank and is_torchdynamo_compiling():
(
lengths,
values,
weights,
) = torch.ops.fbgemm.permute_2D_sparse_data_input1D(
torch.jit._unwrap_optional(recat),
lengths,
values,
stride,
weights,
values.numel(),
)
elif single_batch_per_rank:
(
lengths,
values,
Expand Down
131 changes: 130 additions & 1 deletion torchrec/sparse/tests/test_jagged_tensor_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,11 @@
import unittest

import torch
from torchrec.sparse.jagged_tensor import _regroup_keyed_tensors, KeyedTensor
from torchrec.sparse.jagged_tensor import (
_regroup_keyed_tensors,
KeyedJaggedTensor,
KeyedTensor,
)
from torchrec.sparse.tests.utils import build_groups, build_kts
from torchrec.test_utils import skip_if_asan_class

Expand Down Expand Up @@ -111,3 +115,128 @@ def test_regroup_backward(self) -> None:

torch.allclose(actual_kt_0_grad, expected_kt_0_grad)
torch.allclose(actual_kt_1_grad, expected_kt_1_grad)

# pyre-ignore
@unittest.skipIf(
torch.cuda.device_count() <= 0,
"Not enough GPUs, this test requires at least one GPUs",
)
def test_permute(self) -> None:
values = torch.tensor(
[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], device=self.device
)
lengths = torch.tensor([0, 2, 0, 1, 1, 1, 0, 3, 0], device=self.device)
keys = ["index_0", "index_1", "index_2"]

jag_tensor = KeyedJaggedTensor.from_lengths_sync(
values=values,
keys=keys,
lengths=lengths,
)
indices = [1, 0, 2]
permuted_jag_tensor = jag_tensor.permute(indices)

self.assertEqual(permuted_jag_tensor.keys(), ["index_1", "index_0", "index_2"])
self.assertEqual(
permuted_jag_tensor.offset_per_key(),
[0, 3, 5, 8],
)
self.assertEqual(
permuted_jag_tensor.values().tolist(),
[3.0, 4.0, 5.0, 1.0, 2.0, 6.0, 7.0, 8.0],
)
self.assertEqual(
permuted_jag_tensor.lengths().tolist(), [1, 1, 1, 0, 2, 0, 0, 3, 0]
)
self.assertEqual(permuted_jag_tensor.weights_or_none(), None)

# pyre-ignore
@unittest.skipIf(
torch.cuda.device_count() <= 0,
"Not enough GPUs, this test requires at least one GPUs",
)
def test_permute_vb(self) -> None:
values = torch.tensor(
[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], device=self.device
)
lengths = torch.tensor([1, 0, 1, 3, 0, 1, 0, 2, 0], device=self.device)
keys = ["index_0", "index_1", "index_2"]
stride_per_key_per_rank = [[2], [4], [3]]

jag_tensor = KeyedJaggedTensor.from_lengths_sync(
values=values,
keys=keys,
lengths=lengths,
stride_per_key_per_rank=stride_per_key_per_rank,
)

indices = [1, 0, 2]
permuted_jag_tensor = jag_tensor.permute(indices)

self.assertEqual(permuted_jag_tensor.keys(), ["index_1", "index_0", "index_2"])
self.assertEqual(
permuted_jag_tensor.offset_per_key(),
[0, 5, 6, 8],
)
self.assertEqual(
permuted_jag_tensor.values().tolist(),
[2.0, 3.0, 4.0, 5.0, 6.0, 1.0, 7.0, 8.0],
)
self.assertEqual(
permuted_jag_tensor.lengths().tolist(), [1, 3, 0, 1, 1, 0, 0, 2, 0]
)
self.assertEqual(permuted_jag_tensor.weights_or_none(), None)

# pyre-ignore
@unittest.skipIf(
torch.cuda.device_count() <= 0,
"Not enough GPUs, this test requires at least one GPUs",
)
def test_permute_duplicates(self) -> None:
values = torch.tensor(
[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], device=self.device
)
lengths = torch.tensor([0, 2, 0, 1, 1, 1, 0, 3, 0], device=self.device)
keys = ["index_0", "index_1", "index_2"]

jag_tensor = KeyedJaggedTensor.from_lengths_sync(
values=values,
keys=keys,
lengths=lengths,
)

indices = [1, 0, 2, 1, 1]
permuted_jag_tensor = jag_tensor.permute(indices)

self.assertEqual(
permuted_jag_tensor.keys(),
["index_1", "index_0", "index_2", "index_1", "index_1"],
)
self.assertEqual(
permuted_jag_tensor.offset_per_key(),
[0, 3, 5, 8, 11, 14],
)
self.assertEqual(
permuted_jag_tensor.values().tolist(),
[
3.0,
4.0,
5.0,
1.0,
2.0,
6.0,
7.0,
8.0,
3.0,
4.0,
5.0,
3.0,
4.0,
5.0,
],
)
self.assertEqual(
permuted_jag_tensor.lengths().tolist(),
[1, 1, 1, 0, 2, 0, 0, 3, 0, 1, 1, 1, 1, 1, 1],
)
self.assertEqual(permuted_jag_tensor.weights_or_none(), None)
Loading