Skip to content

Commit 959e027

Browse files
dstaay-fbfacebook-github-bot
authored andcommitted
Fix bug on VBE+CPU
Summary: Internal users reported a bug working with VBE + CPU. Identified regression was introduced by stray edit in D55695198. Simple 1-line fix, but added unit test to cover this edge case for both CPU + GPU setups. Differential Revision: D60430765
1 parent 2771a90 commit 959e027

File tree

4 files changed

+114
-2
lines changed

4 files changed

+114
-2
lines changed

torchrec/modules/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ def _permute_tensor_by_segments(
184184
segment_sizes,
185185
tensor,
186186
weights,
187-
tensor.numel(),
187+
output_size,
188188
)
189189
return permuted_tensor, permuted_weights
190190

torchrec/sparse/jagged_tensor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -453,7 +453,7 @@ def _permute_tensor_by_segments(
453453
segment_sizes,
454454
tensor,
455455
weights,
456-
tensor.numel(),
456+
output_size,
457457
)
458458
return permuted_tensor, permuted_weights
459459

torchrec/sparse/tests/test_jagged_tensor.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1400,6 +1400,59 @@ def test_permute_vb(self) -> None:
14001400
)
14011401
self.assertEqual(permuted_jag_tensor.weights_or_none(), None)
14021402

1403+
def test_permute_vb_duplicate(self) -> None:
1404+
values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0])
1405+
lengths = torch.IntTensor([1, 0, 1, 3, 0, 1, 0, 2, 0])
1406+
keys = ["index_0", "index_1", "index_2"]
1407+
stride_per_key_per_rank = [[2], [4], [3]]
1408+
1409+
jag_tensor = KeyedJaggedTensor.from_lengths_sync(
1410+
values=values,
1411+
keys=keys,
1412+
lengths=lengths,
1413+
stride_per_key_per_rank=stride_per_key_per_rank,
1414+
)
1415+
1416+
indices = [1, 1, 0, 0, 2, 2]
1417+
permuted_jag_tensor = jag_tensor.permute(indices)
1418+
1419+
self.assertEqual(
1420+
permuted_jag_tensor.keys(),
1421+
["index_1", "index_1", "index_0", "index_0", "index_2", "index_2"],
1422+
)
1423+
self.assertTrue(
1424+
torch.equal(
1425+
permuted_jag_tensor.values(),
1426+
torch.Tensor(
1427+
[
1428+
2.0,
1429+
3.0,
1430+
4.0,
1431+
5.0,
1432+
6.0,
1433+
2.0,
1434+
3.0,
1435+
4.0,
1436+
5.0,
1437+
6.0,
1438+
1.0,
1439+
1.0,
1440+
7.0,
1441+
8.0,
1442+
7.0,
1443+
8.0,
1444+
]
1445+
),
1446+
)
1447+
)
1448+
self.assertTrue(
1449+
torch.equal(
1450+
permuted_jag_tensor.lengths(),
1451+
torch.IntTensor([1, 3, 0, 1, 1, 3, 0, 1, 1, 0, 1, 0, 0, 2, 0, 0, 2, 0]),
1452+
)
1453+
)
1454+
self.assertEqual(permuted_jag_tensor.weights_or_none(), None)
1455+
14031456
def test_permute_duplicates(self) -> None:
14041457
values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0])
14051458
lengths = torch.IntTensor([0, 2, 0, 1, 1, 1, 0, 3, 0])

torchrec/sparse/tests/test_jagged_tensor_gpu.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,65 @@ def test_permute_vb(self) -> None:
187187
)
188188
self.assertEqual(permuted_jag_tensor.weights_or_none(), None)
189189

190+
@unittest.skipIf(
191+
torch.cuda.device_count() <= 0,
192+
"Not enough GPUs, this test requires at least one GPUs",
193+
)
194+
def test_permute_vb_duplicate(self) -> None:
195+
values = torch.tensor(
196+
[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], device=self.device
197+
)
198+
lengths = torch.tensor([1, 0, 1, 3, 0, 1, 0, 2, 0], device=self.device)
199+
keys = ["index_0", "index_1", "index_2"]
200+
stride_per_key_per_rank = [[2], [4], [3]]
201+
202+
jag_tensor = KeyedJaggedTensor.from_lengths_sync(
203+
values=values,
204+
keys=keys,
205+
lengths=lengths,
206+
stride_per_key_per_rank=stride_per_key_per_rank,
207+
)
208+
209+
indices = [1, 1, 0, 0, 2, 2]
210+
permuted_jag_tensor = jag_tensor.permute(indices)
211+
212+
self.assertEqual(
213+
permuted_jag_tensor.keys(),
214+
["index_1", "index_1", "index_0", "index_0", "index_2", "index_2"],
215+
)
216+
self.assertTrue(
217+
torch.equal(
218+
permuted_jag_tensor.values().cpu(),
219+
torch.Tensor(
220+
[
221+
2.0,
222+
3.0,
223+
4.0,
224+
5.0,
225+
6.0,
226+
2.0,
227+
3.0,
228+
4.0,
229+
5.0,
230+
6.0,
231+
1.0,
232+
1.0,
233+
7.0,
234+
8.0,
235+
7.0,
236+
8.0,
237+
]
238+
),
239+
)
240+
)
241+
self.assertTrue(
242+
torch.equal(
243+
permuted_jag_tensor.lengths().cpu(),
244+
torch.IntTensor([1, 3, 0, 1, 1, 3, 0, 1, 1, 0, 1, 0, 0, 2, 0, 0, 2, 0]),
245+
)
246+
)
247+
self.assertEqual(permuted_jag_tensor.weights_or_none(), None)
248+
190249
# pyre-ignore
191250
@unittest.skipIf(
192251
torch.cuda.device_count() <= 0,

0 commit comments

Comments
 (0)