Skip to content

Commit 3db28b3

Browse files
TroyGardenfacebook-github-bot
authored andcommitted
implementation of fbgemm op - permute_multi_embedding (#2120)
Summary: X-link: pytorch/FBGEMM#2738 Pull Request resolved: #2120 # context * current we have a working function `permute_pooled_embs_auto_grad` to do a full permute of KTs, including forward and backward * it has several limitations: a) it has to be a full permute, duplicates are not supported; b) in the main [use case](https://fburl.com/code/89od0rqm) there has to be a torch.concat on the input KTs, which is not very efficient; c) the function output a single KT which requires a split operation * there is some attempt to support duplicated outputs, but the backward doesn't work * this diff is trying to create a new kernel (named `permute_multi_embedding`) to support a multiple-KT to multiple-KT mapping operation with backward support # notes * this diff focuses on the implemenation and test of the operator * performance analysis and benchmark are in the next diff # operator example usage * used in python ``` # test inputs: 3 KTs with batch_size=2048 batch_size = 2048 keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]] lengths = [[96, 256], [512, 128, 768], [1024]] values = [ torch.randn(batch_size, sum(lens), device="cuda", requires_grad=True) for lens in lengths ] # target outputs: 4 KTs with re-arranged keys (features), duplicates are allowed groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]] # accessorial arguments to the op/kernel permutes, in_lengths, out_lengths = _multi_remap_to_groups( keys, lengths, groups ) # arguments outputs = torch.ops.fbgemm.permute_multi_embedding_internal_testing( values, permutes, in_lengths, out_lengths ) ``` * permutes ``` # each row represents a key (feature) permute move, which consists of the following parameters: # [input_tensor_idx, output_tensor_idx, input_key_idx, output_key_idx, key_length, magic_jump] permutes = tensor( [ [0, 0, 0, 0, 3, 4], # f1 [1, 0, 0, 3, 5, 0], # f3 [0, 1, 3, 0, 4, 0], # f2 [1, 2, 5, 0, 6, 0], # f4 [0, 2, 0, 6, 3, -6], # f1 [2, 2, 0, 9, 8, 0], # f6 [0, 3, 0, 0, 3, -8], # f1 [1, 3, 11, 3, 7, 0], # f5 ] ) ``` # details 1. from the above example usage, we can clearly see that the operatior takes in the following: a) values: List[torch.Tensor], which represents the input KTs b) permutes: torch.Tensor, which contains the permute information, will be explained later c) output_lengths_list: List[int], the lengths of the output tensors (KTs), which is needed to allocate memory on device ahead d) in_lengths: torch.Tensor, lengths of input tensors, which is on device e) out_lengths: torch.Tensor, lengths of output tensors, which is on device 2. the operator returns a list of tensors, which represents the permuted KTs 3. `permute` is the most critical argument in this operator: a) 2-D tensor b) each row represents a key (feature) permute move c) a permute move = [input_tensor_id, output_tensor_id, input_start_idx, output_start_idx, feature_length, jump] d) jump is used in backward when a key (feature) from the input tensor is mapped to multiple places in the output tensors 4. The magic_jump a) It's only used in the backward computation b) it's usually 0, means no jump c) it's non-zero when there is a duplicate in the permute, e.g., the same feature appears more than once in the output d) the `magic_jump` is the next index of the very same feature in the permute sequence with some modifications e) modification-1: `magic_jump` is positive when it's the first of its kind [Start] f) modification-2: `magic_jump` is negative when it's not the first of its kind [Continue] g) modification-3: `magic_jump` is the negative value of the length of the permute sequence when it's the last of its kind. [Stop] Reviewed By: sryap Differential Revision: D57055616 fbshipit-source-id: 16673d3a2eafab93b08d4ff3c43d54366966064a
1 parent 90f3054 commit 3db28b3

File tree

2 files changed

+295
-2
lines changed

2 files changed

+295
-2
lines changed

torchrec/sparse/jagged_tensor.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,12 @@
3636
torch.ops.load_library(
3737
"//deeplearning/fbgemm/fbgemm_gpu:permute_pooled_embedding_ops_cpu"
3838
)
39+
torch.ops.load_library(
40+
"//deeplearning/fbgemm/fbgemm_gpu:permute_multi_embedding_ops_cpu"
41+
)
42+
torch.ops.load_library(
43+
"//deeplearning/fbgemm/fbgemm_gpu:permute_multi_embedding_ops_gpu"
44+
)
3945
except OSError:
4046
pass
4147

@@ -164,6 +170,24 @@ def _all_keys_used_once(
164170
return len(key_set) == len(group_set) == len(flat_keys) == len(flat_groups)
165171

166172

173+
@torch.fx.wrap
174+
def permute_multi_embedding(
175+
keyed_tensors: List["KeyedTensor"], groups: List[List["str"]]
176+
) -> List[torch.Tensor]:
177+
keys, lengths, values = _desugar_keyed_tensors(keyed_tensors)
178+
permutes, in_shape, out_shape, out_lengths = _kt_regroup_permutes(
179+
values[0], keys, lengths, groups
180+
)
181+
permuted_values = torch.ops.fbgemm.permute_multi_embedding(
182+
values,
183+
permutes,
184+
in_shape,
185+
out_shape,
186+
out_lengths,
187+
)
188+
return permuted_values
189+
190+
167191
@torch.fx.wrap
168192
def _fbgemm_permute_pooled_embs(
169193
keyed_tensors: List["KeyedTensor"], groups: List[List["str"]]
@@ -240,6 +264,90 @@ def _remap_to_groups(
240264
return permute, inv_permute, offsets, inv_offsets, splits
241265

242266

267+
def _kt_regroup_permutes(
268+
value: torch.Tensor,
269+
keys: List[List[str]],
270+
key_lengths: List[List[int]],
271+
groups: List[List[str]],
272+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List[int]]:
273+
"""
274+
returns: permutes, in_shapes, out_shapes, out_lengths
275+
"""
276+
# key => (tensor_idx, key_index)
277+
key_map: Dict[str, Tuple[int, int]] = {
278+
key: (tensor_idx, key_idx)
279+
for tensor_idx, tensor in enumerate(keys)
280+
for key_idx, key in enumerate(tensor)
281+
}
282+
283+
# [offsets per tensor]
284+
in_offsets: List[List[int]] = [[] for _ in key_lengths]
285+
for i, tensor in enumerate(key_lengths):
286+
in_offsets[i] = _cumsum(tensor)
287+
in_lengths: List[int] = [sum(lengths) for lengths in key_lengths]
288+
289+
# set total_permutes as the jump stop sign
290+
total_permutes: int = sum(len(tensor) for tensor in groups)
291+
out_lengths: List[int] = [0] * len(groups)
292+
293+
# [input_tensor_idx, output_tensor_idx, input_start, output_start, length, jump]
294+
permute_param = 6
295+
permutes: List[List[int]] = [[0] * permute_param for _ in range(total_permutes)]
296+
297+
# record the last seen index, so that can make the jump from last_seen to current
298+
last_seen: Dict[str, int] = {}
299+
permute_idx = 0
300+
for output_tensor_idx, output_tenser in enumerate(groups):
301+
output_start = 0
302+
for output_key in output_tenser:
303+
input_tensor_idx, input_key_idx = key_map[output_key]
304+
input_start = in_offsets[input_tensor_idx][input_key_idx]
305+
length = key_lengths[input_tensor_idx][input_key_idx]
306+
307+
# add jump data
308+
if output_key not in last_seen:
309+
jump = 0 # don't need to jump yet
310+
# positive as a potential jump start
311+
last_seen[output_key] = permute_idx
312+
else:
313+
prev = last_seen[output_key]
314+
if prev >= 0: # positive ==> it's a jump start
315+
# jump to current idx, positive as the jump start
316+
permutes[prev][5] = permute_idx
317+
else: # it's already in a jump sequence, mark as negative
318+
permutes[-prev][5] = -permute_idx
319+
# mark last_seen negative since it's already in jump
320+
last_seen[output_key] = -permute_idx
321+
# it's a potential jump stop
322+
jump = -total_permutes
323+
324+
permutes[permute_idx][:] = [
325+
input_tensor_idx,
326+
output_tensor_idx,
327+
input_start,
328+
output_start,
329+
length,
330+
jump,
331+
]
332+
permute_idx += 1
333+
output_start += length
334+
out_lengths[output_tensor_idx] = output_start
335+
336+
permute_tensor = torch.tensor(permutes, dtype=torch.int32)
337+
in_shapes = torch.tensor(in_lengths, dtype=torch.int32)
338+
out_shapes = torch.tensor(out_lengths, dtype=torch.int32)
339+
device = value.device
340+
permute_tensor = _pin_and_move(permute_tensor, device)
341+
in_shapes = _pin_and_move(in_shapes, device)
342+
out_shapes = _pin_and_move(out_shapes, device)
343+
return (
344+
permute_tensor,
345+
in_shapes,
346+
out_shapes,
347+
out_lengths,
348+
)
349+
350+
243351
def _values_string(values: torch.Tensor, start: int, end: int) -> str:
244352
size = values.size()
245353
if len(size) == 1:

torchrec/sparse/tests/test_jagged_tensor.py

Lines changed: 187 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from torch.testing import FileCheck
1717
from torchrec.fx import symbolic_trace
1818
from torchrec.sparse.jagged_tensor import (
19+
_kt_regroup_permutes,
1920
_regroup_keyed_tensors,
2021
ComputeJTDictToKJT,
2122
ComputeKJTToJTDict,
@@ -1397,6 +1398,192 @@ def test_permute_vb(self) -> None:
13971398
)
13981399
self.assertEqual(permuted_jag_tensor.weights_or_none(), None)
13991400

1401+
def test_kt_regroup_permutes(self) -> None:
1402+
keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]]
1403+
lengths = [[3, 4], [5, 6, 7], [8]]
1404+
groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]]
1405+
for device in ["cpu", "meta", "cuda"]:
1406+
if device == "cuda" and not torch.cuda.is_available():
1407+
continue
1408+
device = torch.device(device)
1409+
permutes, in_shapes, out_shapes, out_lengths = _kt_regroup_permutes(
1410+
torch.empty(0, device=device), keys, lengths, groups
1411+
)
1412+
ref_permutes = [
1413+
[0, 0, 0, 0, 3, 4], # f1, jump to 4, as a start
1414+
[1, 0, 0, 3, 5, 0], # f3
1415+
[0, 1, 3, 0, 4, 0], # f2
1416+
[1, 2, 5, 0, 6, 0], # f4
1417+
[0, 2, 0, 6, 3, -6], # f1 jump to 6, as in a jump sequence
1418+
[2, 2, 0, 9, 8, 0], # f6
1419+
[0, 3, 0, 0, 3, -8], # f1 jump stop, as out of boundary
1420+
[1, 3, 11, 3, 7, 0], # f5
1421+
]
1422+
if device.type == "meta":
1423+
self.assertEqual(
1424+
permutes.shape, (len(ref_permutes), len(ref_permutes[0]))
1425+
)
1426+
self.assertEqual(in_shapes.shape, (3,))
1427+
self.assertEqual(out_shapes.shape, (4,))
1428+
else:
1429+
self.assertTrue(
1430+
torch.equal(
1431+
permutes,
1432+
torch.tensor(ref_permutes, dtype=torch.int32, device=device),
1433+
)
1434+
)
1435+
self.assertEqual(in_shapes.tolist(), [7, 18, 8])
1436+
self.assertEqual(out_shapes.tolist(), [8, 4, 17, 10])
1437+
self.assertEqual(out_lengths, [8, 4, 17, 10])
1438+
1439+
def test_multi_permute_forward_cpu(self) -> None:
1440+
batch_size = 32
1441+
keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]]
1442+
lengths = [[3, 4], [5, 6, 7], [8]]
1443+
groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]]
1444+
values = [
1445+
torch.randn(batch_size, sum(lens), device="cpu", requires_grad=True)
1446+
for lens in lengths
1447+
]
1448+
permutes, in_shapes, out_shapes, out_lengths = _kt_regroup_permutes(
1449+
values[0], keys, lengths, groups
1450+
)
1451+
refs = [[] for _ in groups]
1452+
for i in range(permutes.size(0)):
1453+
in_idx, out_idx, in_start, _, length, _ = permutes[i].tolist()
1454+
refs[out_idx].append(values[in_idx][:, in_start : (in_start + length)])
1455+
refs = [torch.cat(ref, dim=1) for ref in refs]
1456+
outputs = torch.ops.fbgemm.permute_multi_embedding(
1457+
values, permutes, in_shapes, out_shapes, out_lengths
1458+
)
1459+
for out, ref in zip(outputs, refs):
1460+
self.assertTrue(torch.allclose(out, ref))
1461+
1462+
def test_multi_permute_forward_meta(self) -> None:
1463+
batch_size = 32
1464+
keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]]
1465+
lengths = [[3, 4], [5, 6, 7], [8]]
1466+
groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]]
1467+
values = [
1468+
torch.randn(batch_size, sum(lens), device="meta", requires_grad=True)
1469+
for lens in lengths
1470+
]
1471+
permutes, in_shapes, out_shapes, out_lengths = _kt_regroup_permutes(
1472+
values[0], keys, lengths, groups
1473+
)
1474+
outputs = torch.ops.fbgemm.permute_multi_embedding(
1475+
values, permutes, in_shapes, out_shapes, out_lengths
1476+
)
1477+
for out, ref in zip(outputs, out_lengths):
1478+
self.assertEqual(out.shape, (batch_size, ref))
1479+
1480+
# pyre-ignore[56]
1481+
@unittest.skipIf(
1482+
torch.cuda.device_count() <= 0,
1483+
"CUDA is not available",
1484+
)
1485+
def test_multi_permute_forward_gpu(self) -> None:
1486+
batch_size = 1024
1487+
keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]]
1488+
lengths = [[96, 256], [512, 128, 768], [1024]]
1489+
groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]]
1490+
values = [
1491+
torch.randn(batch_size, sum(lens), device="cuda", requires_grad=True)
1492+
for lens in lengths
1493+
]
1494+
permutes, in_shapes, out_shapes, out_lengths = _kt_regroup_permutes(
1495+
values[0], keys, lengths, groups
1496+
)
1497+
refs = [[] for _ in groups]
1498+
for i in range(permutes.size(0)):
1499+
in_idx, out_idx, in_start, _, length, _ = permutes[i].tolist()
1500+
refs[out_idx].append(values[in_idx][:, in_start : (in_start + length)])
1501+
refs = [torch.cat(ref, dim=1) for ref in refs]
1502+
outputs = torch.ops.fbgemm.permute_multi_embedding(
1503+
values, permutes, in_shapes, out_shapes, out_lengths
1504+
)
1505+
for out, ref in zip(outputs, refs):
1506+
self.assertTrue(torch.allclose(out, ref))
1507+
1508+
def test_multi_permute_backward_cpu(self) -> None:
1509+
batch_size = 32
1510+
keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]]
1511+
lengths = [[3, 4], [5, 6, 7], [8]]
1512+
groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]]
1513+
values = [
1514+
torch.randn(batch_size, sum(lens), device="cpu", requires_grad=True)
1515+
for lens in lengths
1516+
]
1517+
ref_values = [v.detach() for v in values]
1518+
for v in ref_values:
1519+
v.requires_grad = True
1520+
permutes, in_shapes, out_shapes, out_lengths = _kt_regroup_permutes(
1521+
values[0], keys, lengths, groups
1522+
)
1523+
refs = [[] for _ in groups]
1524+
for i in range(permutes.size(0)):
1525+
in_idx, out_idx, in_start, _, length, _ = permutes[i].tolist()
1526+
refs[out_idx].append(ref_values[in_idx][:, in_start : (in_start + length)])
1527+
refs = [torch.cat(ref, dim=1) for ref in refs]
1528+
outputs = torch.ops.fbgemm.permute_multi_embedding(
1529+
values, permutes, in_shapes, out_shapes, out_lengths
1530+
)
1531+
for out, ref in zip(outputs, refs):
1532+
self.assertTrue(torch.allclose(out, ref))
1533+
1534+
ref_loss, loss = refs[0].sum(), outputs[0].sum()
1535+
for i in range(1, len(refs)):
1536+
ref_loss += (i + 1.1) * refs[i].sum()
1537+
loss += (i + 1.1) * outputs[i].sum()
1538+
ref_loss.backward()
1539+
loss.backward()
1540+
for val, ref in zip(values, ref_values):
1541+
val_grad, ref_grad = val.grad, ref.grad
1542+
assert isinstance(val_grad, torch.Tensor)
1543+
self.assertTrue(torch.allclose(val_grad, ref_grad))
1544+
1545+
# pyre-ignore[56]
1546+
@unittest.skipIf(
1547+
torch.cuda.device_count() <= 0,
1548+
"CUDA is not available",
1549+
)
1550+
def test_multi_permute_backward_gpu(self) -> None:
1551+
batch_size = 2048
1552+
keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]]
1553+
lengths = [[96, 256], [512, 128, 768], [1024]]
1554+
groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]]
1555+
values = [
1556+
torch.randn(batch_size, sum(lens), device="cuda", requires_grad=True)
1557+
for lens in lengths
1558+
]
1559+
ref_values = [v.detach() for v in values]
1560+
for v in ref_values:
1561+
v.requires_grad = True
1562+
permutes, in_shapes, out_shapes, out_lengths = _kt_regroup_permutes(
1563+
values[0], keys, lengths, groups
1564+
)
1565+
refs = [[] for _ in groups]
1566+
for i in range(permutes.size(0)):
1567+
in_idx, out_idx, in_start, _, length, _ = permutes[i].tolist()
1568+
refs[out_idx].append(ref_values[in_idx][:, in_start : (in_start + length)])
1569+
refs = [torch.cat(ref, dim=1) for ref in refs]
1570+
outputs = torch.ops.fbgemm.permute_multi_embedding(
1571+
values, permutes, in_shapes, out_shapes, out_lengths
1572+
)
1573+
for out, ref in zip(outputs, refs):
1574+
self.assertTrue(torch.allclose(out, ref))
1575+
1576+
ref_loss, loss = refs[0].sum(), outputs[0].sum()
1577+
for i in range(1, len(refs)):
1578+
ref_loss += (i + 1.1) * refs[i].sum()
1579+
loss += (i + 1.1) * outputs[i].sum()
1580+
ref_loss.backward()
1581+
loss.backward()
1582+
for val, ref in zip(values, ref_values):
1583+
val_grad, ref_grad = val.grad, ref.grad
1584+
assert isinstance(val_grad, torch.Tensor)
1585+
self.assertTrue(torch.allclose(val_grad, ref_grad))
1586+
14001587
def test_permute_duplicates(self) -> None:
14011588
values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0])
14021589
lengths = torch.IntTensor([0, 2, 0, 1, 1, 1, 0, 3, 0])
@@ -1673,8 +1860,6 @@ def test_string_vb(self) -> None:
16731860
stride_per_key_per_rank=stride_per_key_per_rank,
16741861
)
16751862

1676-
print(str(jag_tensor))
1677-
16781863
self.assertEqual(
16791864
str(jag_tensor),
16801865
"""\

0 commit comments

Comments
 (0)