Skip to content

Commit 614a5d8

Browse files
TroyGardenfacebook-github-bot
authored andcommitted
implementation of fbgemm op - permute_multi_embedding (#2120)
Summary: X-link: pytorch/FBGEMM#2738 # context * current we have a working function `permute_pooled_embs_auto_grad` to do a full permute of KTs, including forward and backward * it has several limitations: a) it has to be a full permute, duplicates are not supported; b) in the main [use case](https://fburl.com/code/89od0rqm) there has to be a torch.concat on the input KTs, which is not very efficient; c) the function output a single KT which requires a split operation * there is some attempt to support duplicated outputs, but the backward doesn't work * this diff is trying to create a new kernel (named `permute_multi_embedding`) to support a multiple-KT to multiple-KT mapping operation with backward support # notes * this diff focuses on the implemenation and test of the operator * performance analysis and benchmark are in the next diff # operator example usage * used in python ``` # test inputs: 3 KTs with batch_size=2048 batch_size = 2048 keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]] lengths = [[96, 256], [512, 128, 768], [1024]] values = [ torch.randn(batch_size, sum(lens), device="cuda", requires_grad=True) for lens in lengths ] # target outputs: 4 KTs with re-arranged keys (features), duplicates are allowed groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]] # accessorial arguments to the op/kernel permutes, in_lengths, out_lengths = _multi_remap_to_groups( keys, lengths, groups ) # arguments outputs = torch.ops.fbgemm.permute_multi_embedding( values, permutes, in_lengths, out_lengths ) ``` * permutes ``` permutes = tensor( [ [0, 0, 0, 0, 3, 4], # f1 [1, 0, 0, 3, 5, 0], # f3 [0, 1, 3, 0, 4, 0], # f2 [1, 2, 5, 0, 6, 0], # f4 [0, 2, 0, 6, 3, -6], # f1 [2, 2, 0, 9, 8, 0], # f6 [0, 3, 0, 0, 3, -8], # f1 [1, 3, 11, 3, 7, 0], # f5 ] ) ``` # details 1. from the above example usage, we can clearly see that the operatior takes in the following: a) values: List[torch.Tensor], which represents the input KTs b) permutes: torch.Tensor, which contains the permute information, will be explained later c) output_lengths_list: List[int], the lengths of the output tensors (KTs), which is needed to allocate memory on device ahead d) in_lengths: torch.Tensor, lengths of input tensors, which is on device e) out_lengths: torch.Tensor, lengths of output tensors, which is on device 2. the operator returns a list of tensors, which represents the permuted KTs 3. `permute` is the most critical argument in this operator: a) 2-D tensor b) each row represents a key (feature) permute move c) a permute move = [input_tensor_id, output_tensor_id, input_start_idx, output_start_idx, feature_length, jump] d) jump is used in backward when a key (feature) from the input tensor is mapped to multiple places in the output tensors Differential Revision: D57055616
1 parent 66d713c commit 614a5d8

File tree

2 files changed

+257
-2
lines changed

2 files changed

+257
-2
lines changed

torchrec/sparse/jagged_tensor.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,12 @@
3636
torch.ops.load_library(
3737
"//deeplearning/fbgemm/fbgemm_gpu:permute_pooled_embedding_ops_cpu"
3838
)
39+
torch.ops.load_library(
40+
"//deeplearning/fbgemm/fbgemm_gpu:permute_multi_embedding_ops_cpu"
41+
)
42+
torch.ops.load_library(
43+
"//deeplearning/fbgemm/fbgemm_gpu:permute_multi_embedding_ops_gpu"
44+
)
3945
except OSError:
4046
pass
4147

@@ -240,6 +246,82 @@ def _remap_to_groups(
240246
return permute, inv_permute, offsets, inv_offsets, splits
241247

242248

249+
def _multi_remap_to_groups(
250+
keys: List[List[str]],
251+
key_lengths: List[List[int]],
252+
groups: List[List[str]],
253+
) -> Tuple[List[int], List[int], List[int]]:
254+
"""
255+
Given a list of keys and lengths per key for each group, return the permute 2D tensor, and 1D tensor lengths:
256+
[[input_tensor_idx, output_tensor_idx, input_start, output_start, length]], [length]
257+
"""
258+
# key => (tensor_idx, key_index)
259+
key_map: Dict[str, Tuple[int, int]] = {
260+
key: (tensor_idx, key_idx)
261+
for tensor_idx, tensor in enumerate(keys)
262+
for key_idx, key in enumerate(tensor)
263+
}
264+
265+
# [offsets per tensor]
266+
in_offsets: List[List[int]] = [[] for _ in key_lengths]
267+
for i, tensor in enumerate(key_lengths):
268+
in_offsets[i] = _cumsum(tensor)
269+
in_lengths: List[int] = [sum(lengths) for lengths in key_lengths]
270+
271+
# set total_permutes as the jump stop sign
272+
total_permutes: int = sum(len(tensor) for tensor in groups)
273+
out_lengths: List[int] = [0] * len(groups)
274+
275+
# [input_tensor_idx, output_tensor_idx, input_start, output_start, length, jump]
276+
permute_param = 6
277+
permutes: List[int] = [0] * (total_permutes * permute_param)
278+
279+
# record the last seen index, so that can make the jump from last_seen to current
280+
last_seen: Dict[str, int] = {}
281+
permute_idx = 0
282+
for output_tensor_idx, output_tenser in enumerate(groups):
283+
output_start = 0
284+
for output_key in output_tenser:
285+
input_tensor_idx, input_key_idx = key_map[output_key]
286+
input_start = in_offsets[input_tensor_idx][input_key_idx]
287+
length = key_lengths[input_tensor_idx][input_key_idx]
288+
289+
# add jump data
290+
if output_key not in last_seen:
291+
jump = 0 # don't need to jump yet
292+
# positive as a potential jump start
293+
last_seen[output_key] = permute_idx
294+
else:
295+
prev = last_seen[output_key]
296+
if prev >= 0: # positive ==> it's a jump start
297+
# jump to current idx, positive as the jump start
298+
permutes[prev * permute_param + 5] = permute_idx
299+
else: # it's already in a jump sequence, mark as negative
300+
permutes[-prev * permute_param + 5] = -permute_idx
301+
# mark last_seen negative since it's already in jump
302+
last_seen[output_key] = -permute_idx
303+
# it's a potential jump stop
304+
jump = -total_permutes
305+
306+
permutes[permute_idx * permute_param : permute_idx * permute_param + 6] = [
307+
input_tensor_idx,
308+
output_tensor_idx,
309+
input_start,
310+
output_start,
311+
length,
312+
jump,
313+
]
314+
permute_idx += 1
315+
output_start += length
316+
out_lengths[output_tensor_idx] = output_start
317+
318+
return (
319+
permutes,
320+
in_lengths,
321+
out_lengths,
322+
)
323+
324+
243325
def _values_string(values: torch.Tensor, start: int, end: int) -> str:
244326
size = values.size()
245327
if len(size) == 1:

torchrec/sparse/tests/test_jagged_tensor.py

Lines changed: 175 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from torch.testing import FileCheck
1717
from torchrec.fx import symbolic_trace
1818
from torchrec.sparse.jagged_tensor import (
19+
_multi_remap_to_groups,
1920
_regroup_keyed_tensors,
2021
ComputeJTDictToKJT,
2122
ComputeKJTToJTDict,
@@ -1374,6 +1375,180 @@ def test_permute_vb(self) -> None:
13741375
)
13751376
self.assertEqual(permuted_jag_tensor.weights_or_none(), None)
13761377

1378+
def test_multi_remap_to_group(self) -> None:
1379+
keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]]
1380+
lengths = [[3, 4], [5, 6, 7], [8]]
1381+
groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]]
1382+
permutes, in_lengths, out_lengths = _multi_remap_to_groups(
1383+
keys, lengths, groups
1384+
)
1385+
ref_permutes = [
1386+
[0, 0, 0, 0, 3, 4], # f1, jump to 4, as a start
1387+
[1, 0, 0, 3, 5, 0], # f3
1388+
[0, 1, 3, 0, 4, 0], # f2
1389+
[1, 2, 5, 0, 6, 0], # f4
1390+
[0, 2, 0, 6, 3, -6], # f1 jump to 6, as in a jump sequence
1391+
[2, 2, 0, 9, 8, 0], # f6
1392+
[0, 3, 0, 0, 3, -8], # f1 jump stop, as out of boundary
1393+
[1, 3, 11, 3, 7, 0], # f5
1394+
]
1395+
self.assertEqual(permutes, [i for p in ref_permutes for i in p])
1396+
self.assertEqual(in_lengths, [7, 18, 8])
1397+
self.assertEqual(out_lengths, [8, 4, 17, 10])
1398+
1399+
def test_multi_permute_forward_cpu(self) -> None:
1400+
batch_size = 5
1401+
keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]]
1402+
lengths = [[3, 4], [5, 6, 7], [8]]
1403+
groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]]
1404+
values = [
1405+
torch.randn(batch_size, sum(lens), device="cpu", requires_grad=True)
1406+
for lens in lengths
1407+
]
1408+
permutes, in_lengths, out_lengths = _multi_remap_to_groups(
1409+
keys, lengths, groups
1410+
)
1411+
refs = [[] for _ in groups]
1412+
for i in range(len(permutes) // 6):
1413+
in_idx, out_idx, in_start, _, length, _ = permutes[i * 6 : i * 6 + 6]
1414+
refs[out_idx].append(values[in_idx][:, in_start : (in_start + length)])
1415+
refs = [torch.cat(ref, dim=1) for ref in refs]
1416+
outputs = torch.ops.fbgemm.permute_multi_embedding(
1417+
values, permutes, in_lengths, out_lengths
1418+
)
1419+
for out, ref in zip(outputs, refs):
1420+
self.assertTrue(torch.allclose(out, ref))
1421+
1422+
def test_multi_permute_forward_meta(self) -> None:
1423+
batch_size = 5
1424+
keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]]
1425+
lengths = [[3, 4], [5, 6, 7], [8]]
1426+
groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]]
1427+
values = [
1428+
torch.randn(batch_size, sum(lens), device="meta", requires_grad=True)
1429+
for lens in lengths
1430+
]
1431+
permutes, in_lengths, out_lengths = _multi_remap_to_groups(
1432+
keys, lengths, groups
1433+
)
1434+
refs = [[] for _ in groups]
1435+
for i in range(len(permutes) // 6):
1436+
in_idx, out_idx, in_start, _, length, _ = permutes[i * 6 : i * 6 + 6]
1437+
refs[out_idx].append(values[in_idx][:, in_start : (in_start + length)])
1438+
refs = [torch.cat(ref, dim=1) for ref in refs]
1439+
outputs = torch.ops.fbgemm.permute_multi_embedding(
1440+
values, permutes, in_lengths, out_lengths
1441+
)
1442+
for out, ref in zip(outputs, refs):
1443+
self.assertEqual(out.shape, ref.shape)
1444+
1445+
# pyre-ignore[56]
1446+
@unittest.skipIf(
1447+
torch.cuda.device_count() <= 0,
1448+
"CUDA is not available",
1449+
)
1450+
def test_multi_permute_forward_gpu(self) -> None:
1451+
batch_size = 5
1452+
keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]]
1453+
lengths = [[3, 4], [5, 6, 7], [8]]
1454+
groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]]
1455+
values = [
1456+
torch.randn(batch_size, sum(lens), device="cuda", requires_grad=True)
1457+
for lens in lengths
1458+
]
1459+
permutes, in_lengths, out_lengths = _multi_remap_to_groups(
1460+
keys, lengths, groups
1461+
)
1462+
refs = [[] for _ in groups]
1463+
for i in range(len(permutes) // 6):
1464+
in_idx, out_idx, in_start, _, length, _ = permutes[i * 6 : i * 6 + 6]
1465+
refs[out_idx].append(values[in_idx][:, in_start : (in_start + length)])
1466+
refs = [torch.cat(ref, dim=1) for ref in refs]
1467+
outputs = torch.ops.fbgemm.permute_multi_embedding(
1468+
values, permutes, in_lengths, out_lengths
1469+
)
1470+
for out, ref in zip(outputs, refs):
1471+
self.assertTrue(torch.allclose(out, ref))
1472+
1473+
def test_multi_permute_backward_cpu(self) -> None:
1474+
batch_size = 5
1475+
keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]]
1476+
lengths = [[3, 4], [5, 6, 7], [8]]
1477+
groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]]
1478+
values = [
1479+
torch.randn(batch_size, sum(lens), device="cpu", requires_grad=True)
1480+
for lens in lengths
1481+
]
1482+
ref_values = [v.detach() for v in values]
1483+
for v in ref_values:
1484+
v.requires_grad = True
1485+
permutes, in_lengths, out_lengths = _multi_remap_to_groups(
1486+
keys, lengths, groups
1487+
)
1488+
refs = [[] for _ in groups]
1489+
for i in range(len(permutes) // 6):
1490+
in_idx, out_idx, in_start, _, length, _ = permutes[i * 6 : i * 6 + 6]
1491+
refs[out_idx].append(ref_values[in_idx][:, in_start : (in_start + length)])
1492+
refs = [torch.cat(ref, dim=1) for ref in refs]
1493+
outputs = torch.ops.fbgemm.permute_multi_embedding(
1494+
values, permutes, in_lengths, out_lengths
1495+
)
1496+
for out, ref in zip(outputs, refs):
1497+
self.assertTrue(torch.allclose(out, ref))
1498+
1499+
ref_loss, loss = refs[0].sum(), outputs[0].sum()
1500+
for i in range(1, len(refs)):
1501+
ref_loss += (i + 1.1) * refs[i].sum()
1502+
loss += (i + 1.1) * outputs[i].sum()
1503+
ref_loss.backward()
1504+
loss.backward()
1505+
for val, ref in zip(values, ref_values):
1506+
val_grad, ref_grad = val.grad, ref.grad
1507+
assert isinstance(val_grad, torch.Tensor)
1508+
self.assertTrue(torch.allclose(val_grad, ref_grad))
1509+
1510+
# pyre-ignore[56]
1511+
@unittest.skipIf(
1512+
torch.cuda.device_count() <= 0,
1513+
"CUDA is not available",
1514+
)
1515+
def test_multi_permute_backward_gpu(self) -> None:
1516+
batch_size = 2048
1517+
keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]]
1518+
lengths = [[96, 256], [512, 128, 768], [1024]]
1519+
groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]]
1520+
values = [
1521+
torch.randn(batch_size, sum(lens), device="cuda", requires_grad=True)
1522+
for lens in lengths
1523+
]
1524+
ref_values = [v.detach() for v in values]
1525+
for v in ref_values:
1526+
v.requires_grad = True
1527+
permutes, in_lengths, out_lengths = _multi_remap_to_groups(
1528+
keys, lengths, groups
1529+
)
1530+
refs = [[] for _ in groups]
1531+
for i in range(len(permutes) // 6):
1532+
in_idx, out_idx, in_start, _, length, _ = permutes[i * 6 : i * 6 + 6]
1533+
refs[out_idx].append(ref_values[in_idx][:, in_start : (in_start + length)])
1534+
refs = [torch.cat(ref, dim=1) for ref in refs]
1535+
outputs = torch.ops.fbgemm.permute_multi_embedding(
1536+
values, permutes, in_lengths, out_lengths
1537+
)
1538+
for out, ref in zip(outputs, refs):
1539+
self.assertTrue(torch.allclose(out, ref))
1540+
1541+
ref_loss, loss = refs[0].sum(), outputs[0].sum()
1542+
for i in range(1, len(refs)):
1543+
ref_loss += (i + 1.1) * refs[i].sum()
1544+
loss += (i + 1.1) * outputs[i].sum()
1545+
ref_loss.backward()
1546+
loss.backward()
1547+
for val, ref in zip(values, ref_values):
1548+
val_grad, ref_grad = val.grad, ref.grad
1549+
assert isinstance(val_grad, torch.Tensor)
1550+
self.assertTrue(torch.allclose(val_grad, ref_grad))
1551+
13771552
def test_permute_duplicates(self) -> None:
13781553
values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0])
13791554
lengths = torch.IntTensor([0, 2, 0, 1, 1, 1, 0, 3, 0])
@@ -1650,8 +1825,6 @@ def test_string_vb(self) -> None:
16501825
stride_per_key_per_rank=stride_per_key_per_rank,
16511826
)
16521827

1653-
print(str(jag_tensor))
1654-
16551828
self.assertEqual(
16561829
str(jag_tensor),
16571830
"""\

0 commit comments

Comments
 (0)