|
16 | 16 | from torch.testing import FileCheck
|
17 | 17 | from torchrec.fx import symbolic_trace
|
18 | 18 | from torchrec.sparse.jagged_tensor import (
|
| 19 | + _multi_remap_to_groups, |
19 | 20 | _regroup_keyed_tensors,
|
20 | 21 | ComputeJTDictToKJT,
|
21 | 22 | ComputeKJTToJTDict,
|
@@ -1374,6 +1375,180 @@ def test_permute_vb(self) -> None:
|
1374 | 1375 | )
|
1375 | 1376 | self.assertEqual(permuted_jag_tensor.weights_or_none(), None)
|
1376 | 1377 |
|
| 1378 | + def test_multi_remap_to_group(self) -> None: |
| 1379 | + keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]] |
| 1380 | + lengths = [[3, 4], [5, 6, 7], [8]] |
| 1381 | + groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]] |
| 1382 | + permutes, in_lengths, out_lengths = _multi_remap_to_groups( |
| 1383 | + keys, lengths, groups |
| 1384 | + ) |
| 1385 | + ref_permutes = [ |
| 1386 | + [0, 0, 0, 0, 3, 4], # f1, jump to 4, as a start |
| 1387 | + [1, 0, 0, 3, 5, 0], # f3 |
| 1388 | + [0, 1, 3, 0, 4, 0], # f2 |
| 1389 | + [1, 2, 5, 0, 6, 0], # f4 |
| 1390 | + [0, 2, 0, 6, 3, -6], # f1 jump to 6, as in a jump sequence |
| 1391 | + [2, 2, 0, 9, 8, 0], # f6 |
| 1392 | + [0, 3, 0, 0, 3, -8], # f1 jump stop, as out of boundary |
| 1393 | + [1, 3, 11, 3, 7, 0], # f5 |
| 1394 | + ] |
| 1395 | + self.assertEqual(permutes, [i for p in ref_permutes for i in p]) |
| 1396 | + self.assertEqual(in_lengths, [7, 18, 8]) |
| 1397 | + self.assertEqual(out_lengths, [8, 4, 17, 10]) |
| 1398 | + |
| 1399 | + def test_multi_permute_forward_cpu(self) -> None: |
| 1400 | + batch_size = 5 |
| 1401 | + keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]] |
| 1402 | + lengths = [[3, 4], [5, 6, 7], [8]] |
| 1403 | + groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]] |
| 1404 | + values = [ |
| 1405 | + torch.randn(batch_size, sum(lens), device="cpu", requires_grad=True) |
| 1406 | + for lens in lengths |
| 1407 | + ] |
| 1408 | + permutes, in_lengths, out_lengths = _multi_remap_to_groups( |
| 1409 | + keys, lengths, groups |
| 1410 | + ) |
| 1411 | + refs = [[] for _ in groups] |
| 1412 | + for i in range(len(permutes) // 6): |
| 1413 | + in_idx, out_idx, in_start, _, length, _ = permutes[i * 6 : i * 6 + 6] |
| 1414 | + refs[out_idx].append(values[in_idx][:, in_start : (in_start + length)]) |
| 1415 | + refs = [torch.cat(ref, dim=1) for ref in refs] |
| 1416 | + outputs = torch.ops.fbgemm.permute_multi_embedding( |
| 1417 | + values, permutes, in_lengths, out_lengths |
| 1418 | + ) |
| 1419 | + for out, ref in zip(outputs, refs): |
| 1420 | + self.assertTrue(torch.allclose(out, ref)) |
| 1421 | + |
| 1422 | + def test_multi_permute_forward_meta(self) -> None: |
| 1423 | + batch_size = 5 |
| 1424 | + keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]] |
| 1425 | + lengths = [[3, 4], [5, 6, 7], [8]] |
| 1426 | + groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]] |
| 1427 | + values = [ |
| 1428 | + torch.randn(batch_size, sum(lens), device="meta", requires_grad=True) |
| 1429 | + for lens in lengths |
| 1430 | + ] |
| 1431 | + permutes, in_lengths, out_lengths = _multi_remap_to_groups( |
| 1432 | + keys, lengths, groups |
| 1433 | + ) |
| 1434 | + refs = [[] for _ in groups] |
| 1435 | + for i in range(len(permutes) // 6): |
| 1436 | + in_idx, out_idx, in_start, _, length, _ = permutes[i * 6 : i * 6 + 6] |
| 1437 | + refs[out_idx].append(values[in_idx][:, in_start : (in_start + length)]) |
| 1438 | + refs = [torch.cat(ref, dim=1) for ref in refs] |
| 1439 | + outputs = torch.ops.fbgemm.permute_multi_embedding( |
| 1440 | + values, permutes, in_lengths, out_lengths |
| 1441 | + ) |
| 1442 | + for out, ref in zip(outputs, refs): |
| 1443 | + self.assertEqual(out.shape, ref.shape) |
| 1444 | + |
| 1445 | + # pyre-ignore[56] |
| 1446 | + @unittest.skipIf( |
| 1447 | + torch.cuda.device_count() <= 0, |
| 1448 | + "CUDA is not available", |
| 1449 | + ) |
| 1450 | + def test_multi_permute_forward_gpu(self) -> None: |
| 1451 | + batch_size = 5 |
| 1452 | + keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]] |
| 1453 | + lengths = [[3, 4], [5, 6, 7], [8]] |
| 1454 | + groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]] |
| 1455 | + values = [ |
| 1456 | + torch.randn(batch_size, sum(lens), device="cuda", requires_grad=True) |
| 1457 | + for lens in lengths |
| 1458 | + ] |
| 1459 | + permutes, in_lengths, out_lengths = _multi_remap_to_groups( |
| 1460 | + keys, lengths, groups |
| 1461 | + ) |
| 1462 | + refs = [[] for _ in groups] |
| 1463 | + for i in range(len(permutes) // 6): |
| 1464 | + in_idx, out_idx, in_start, _, length, _ = permutes[i * 6 : i * 6 + 6] |
| 1465 | + refs[out_idx].append(values[in_idx][:, in_start : (in_start + length)]) |
| 1466 | + refs = [torch.cat(ref, dim=1) for ref in refs] |
| 1467 | + outputs = torch.ops.fbgemm.permute_multi_embedding( |
| 1468 | + values, permutes, in_lengths, out_lengths |
| 1469 | + ) |
| 1470 | + for out, ref in zip(outputs, refs): |
| 1471 | + self.assertTrue(torch.allclose(out, ref)) |
| 1472 | + |
| 1473 | + def test_multi_permute_backward_cpu(self) -> None: |
| 1474 | + batch_size = 5 |
| 1475 | + keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]] |
| 1476 | + lengths = [[3, 4], [5, 6, 7], [8]] |
| 1477 | + groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]] |
| 1478 | + values = [ |
| 1479 | + torch.randn(batch_size, sum(lens), device="cpu", requires_grad=True) |
| 1480 | + for lens in lengths |
| 1481 | + ] |
| 1482 | + ref_values = [v.detach() for v in values] |
| 1483 | + for v in ref_values: |
| 1484 | + v.requires_grad = True |
| 1485 | + permutes, in_lengths, out_lengths = _multi_remap_to_groups( |
| 1486 | + keys, lengths, groups |
| 1487 | + ) |
| 1488 | + refs = [[] for _ in groups] |
| 1489 | + for i in range(len(permutes) // 6): |
| 1490 | + in_idx, out_idx, in_start, _, length, _ = permutes[i * 6 : i * 6 + 6] |
| 1491 | + refs[out_idx].append(ref_values[in_idx][:, in_start : (in_start + length)]) |
| 1492 | + refs = [torch.cat(ref, dim=1) for ref in refs] |
| 1493 | + outputs = torch.ops.fbgemm.permute_multi_embedding( |
| 1494 | + values, permutes, in_lengths, out_lengths |
| 1495 | + ) |
| 1496 | + for out, ref in zip(outputs, refs): |
| 1497 | + self.assertTrue(torch.allclose(out, ref)) |
| 1498 | + |
| 1499 | + ref_loss, loss = refs[0].sum(), outputs[0].sum() |
| 1500 | + for i in range(1, len(refs)): |
| 1501 | + ref_loss += (i + 1.1) * refs[i].sum() |
| 1502 | + loss += (i + 1.1) * outputs[i].sum() |
| 1503 | + ref_loss.backward() |
| 1504 | + loss.backward() |
| 1505 | + for val, ref in zip(values, ref_values): |
| 1506 | + val_grad, ref_grad = val.grad, ref.grad |
| 1507 | + assert isinstance(val_grad, torch.Tensor) |
| 1508 | + self.assertTrue(torch.allclose(val_grad, ref_grad)) |
| 1509 | + |
| 1510 | + # pyre-ignore[56] |
| 1511 | + @unittest.skipIf( |
| 1512 | + torch.cuda.device_count() <= 0, |
| 1513 | + "CUDA is not available", |
| 1514 | + ) |
| 1515 | + def test_multi_permute_backward_gpu(self) -> None: |
| 1516 | + batch_size = 2048 |
| 1517 | + keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]] |
| 1518 | + lengths = [[96, 256], [512, 128, 768], [1024]] |
| 1519 | + groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]] |
| 1520 | + values = [ |
| 1521 | + torch.randn(batch_size, sum(lens), device="cuda", requires_grad=True) |
| 1522 | + for lens in lengths |
| 1523 | + ] |
| 1524 | + ref_values = [v.detach() for v in values] |
| 1525 | + for v in ref_values: |
| 1526 | + v.requires_grad = True |
| 1527 | + permutes, in_lengths, out_lengths = _multi_remap_to_groups( |
| 1528 | + keys, lengths, groups |
| 1529 | + ) |
| 1530 | + refs = [[] for _ in groups] |
| 1531 | + for i in range(len(permutes) // 6): |
| 1532 | + in_idx, out_idx, in_start, _, length, _ = permutes[i * 6 : i * 6 + 6] |
| 1533 | + refs[out_idx].append(ref_values[in_idx][:, in_start : (in_start + length)]) |
| 1534 | + refs = [torch.cat(ref, dim=1) for ref in refs] |
| 1535 | + outputs = torch.ops.fbgemm.permute_multi_embedding( |
| 1536 | + values, permutes, in_lengths, out_lengths |
| 1537 | + ) |
| 1538 | + for out, ref in zip(outputs, refs): |
| 1539 | + self.assertTrue(torch.allclose(out, ref)) |
| 1540 | + |
| 1541 | + ref_loss, loss = refs[0].sum(), outputs[0].sum() |
| 1542 | + for i in range(1, len(refs)): |
| 1543 | + ref_loss += (i + 1.1) * refs[i].sum() |
| 1544 | + loss += (i + 1.1) * outputs[i].sum() |
| 1545 | + ref_loss.backward() |
| 1546 | + loss.backward() |
| 1547 | + for val, ref in zip(values, ref_values): |
| 1548 | + val_grad, ref_grad = val.grad, ref.grad |
| 1549 | + assert isinstance(val_grad, torch.Tensor) |
| 1550 | + self.assertTrue(torch.allclose(val_grad, ref_grad)) |
| 1551 | + |
1377 | 1552 | def test_permute_duplicates(self) -> None:
|
1378 | 1553 | values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0])
|
1379 | 1554 | lengths = torch.IntTensor([0, 2, 0, 1, 1, 1, 0, 3, 0])
|
@@ -1650,8 +1825,6 @@ def test_string_vb(self) -> None:
|
1650 | 1825 | stride_per_key_per_rank=stride_per_key_per_rank,
|
1651 | 1826 | )
|
1652 | 1827 |
|
1653 |
| - print(str(jag_tensor)) |
1654 |
| - |
1655 | 1828 | self.assertEqual(
|
1656 | 1829 | str(jag_tensor),
|
1657 | 1830 | """\
|
|
0 commit comments