|
16 | 16 | from torch.testing import FileCheck
|
17 | 17 | from torchrec.fx import symbolic_trace
|
18 | 18 | from torchrec.sparse.jagged_tensor import (
|
| 19 | + _kt_regroup_permutes, |
19 | 20 | _regroup_keyed_tensors,
|
20 | 21 | ComputeJTDictToKJT,
|
21 | 22 | ComputeKJTToJTDict,
|
@@ -1397,6 +1398,192 @@ def test_permute_vb(self) -> None:
|
1397 | 1398 | )
|
1398 | 1399 | self.assertEqual(permuted_jag_tensor.weights_or_none(), None)
|
1399 | 1400 |
|
| 1401 | + def test_kt_regroup_permutes(self) -> None: |
| 1402 | + keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]] |
| 1403 | + lengths = [[3, 4], [5, 6, 7], [8]] |
| 1404 | + groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]] |
| 1405 | + for device in ["cpu", "meta", "cuda"]: |
| 1406 | + if device == "cuda" and not torch.cuda.is_available(): |
| 1407 | + continue |
| 1408 | + device = torch.device(device) |
| 1409 | + permutes, in_shapes, out_shapes, out_lengths = _kt_regroup_permutes( |
| 1410 | + torch.empty(0, device=device), keys, lengths, groups |
| 1411 | + ) |
| 1412 | + ref_permutes = [ |
| 1413 | + [0, 0, 0, 0, 3, 4], # f1, jump to 4, as a start |
| 1414 | + [1, 0, 0, 3, 5, 0], # f3 |
| 1415 | + [0, 1, 3, 0, 4, 0], # f2 |
| 1416 | + [1, 2, 5, 0, 6, 0], # f4 |
| 1417 | + [0, 2, 0, 6, 3, -6], # f1 jump to 6, as in a jump sequence |
| 1418 | + [2, 2, 0, 9, 8, 0], # f6 |
| 1419 | + [0, 3, 0, 0, 3, -8], # f1 jump stop, as out of boundary |
| 1420 | + [1, 3, 11, 3, 7, 0], # f5 |
| 1421 | + ] |
| 1422 | + if device.type == "meta": |
| 1423 | + self.assertEqual( |
| 1424 | + permutes.shape, (len(ref_permutes), len(ref_permutes[0])) |
| 1425 | + ) |
| 1426 | + self.assertEqual(in_shapes.shape, (3,)) |
| 1427 | + self.assertEqual(out_shapes.shape, (4,)) |
| 1428 | + else: |
| 1429 | + self.assertTrue( |
| 1430 | + torch.equal( |
| 1431 | + permutes, |
| 1432 | + torch.tensor(ref_permutes, dtype=torch.int32, device=device), |
| 1433 | + ) |
| 1434 | + ) |
| 1435 | + self.assertEqual(in_shapes.tolist(), [7, 18, 8]) |
| 1436 | + self.assertEqual(out_shapes.tolist(), [8, 4, 17, 10]) |
| 1437 | + self.assertEqual(out_lengths, [8, 4, 17, 10]) |
| 1438 | + |
| 1439 | + def test_multi_permute_forward_cpu(self) -> None: |
| 1440 | + batch_size = 32 |
| 1441 | + keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]] |
| 1442 | + lengths = [[3, 4], [5, 6, 7], [8]] |
| 1443 | + groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]] |
| 1444 | + values = [ |
| 1445 | + torch.randn(batch_size, sum(lens), device="cpu", requires_grad=True) |
| 1446 | + for lens in lengths |
| 1447 | + ] |
| 1448 | + permutes, in_shapes, out_shapes, out_lengths = _kt_regroup_permutes( |
| 1449 | + values[0], keys, lengths, groups |
| 1450 | + ) |
| 1451 | + refs = [[] for _ in groups] |
| 1452 | + for i in range(permutes.size(0)): |
| 1453 | + in_idx, out_idx, in_start, _, length, _ = permutes[i].tolist() |
| 1454 | + refs[out_idx].append(values[in_idx][:, in_start : (in_start + length)]) |
| 1455 | + refs = [torch.cat(ref, dim=1) for ref in refs] |
| 1456 | + outputs = torch.ops.fbgemm.permute_multi_embedding( |
| 1457 | + values, permutes, in_shapes, out_shapes, out_lengths |
| 1458 | + ) |
| 1459 | + for out, ref in zip(outputs, refs): |
| 1460 | + self.assertTrue(torch.allclose(out, ref)) |
| 1461 | + |
| 1462 | + def test_multi_permute_forward_meta(self) -> None: |
| 1463 | + batch_size = 32 |
| 1464 | + keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]] |
| 1465 | + lengths = [[3, 4], [5, 6, 7], [8]] |
| 1466 | + groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]] |
| 1467 | + values = [ |
| 1468 | + torch.randn(batch_size, sum(lens), device="meta", requires_grad=True) |
| 1469 | + for lens in lengths |
| 1470 | + ] |
| 1471 | + permutes, in_shapes, out_shapes, out_lengths = _kt_regroup_permutes( |
| 1472 | + values[0], keys, lengths, groups |
| 1473 | + ) |
| 1474 | + outputs = torch.ops.fbgemm.permute_multi_embedding( |
| 1475 | + values, permutes, in_shapes, out_shapes, out_lengths |
| 1476 | + ) |
| 1477 | + for out, ref in zip(outputs, out_lengths): |
| 1478 | + self.assertEqual(out.shape, (batch_size, ref)) |
| 1479 | + |
| 1480 | + # pyre-ignore[56] |
| 1481 | + @unittest.skipIf( |
| 1482 | + torch.cuda.device_count() <= 0, |
| 1483 | + "CUDA is not available", |
| 1484 | + ) |
| 1485 | + def test_multi_permute_forward_gpu(self) -> None: |
| 1486 | + batch_size = 1024 |
| 1487 | + keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]] |
| 1488 | + lengths = [[96, 256], [512, 128, 768], [1024]] |
| 1489 | + groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]] |
| 1490 | + values = [ |
| 1491 | + torch.randn(batch_size, sum(lens), device="cuda", requires_grad=True) |
| 1492 | + for lens in lengths |
| 1493 | + ] |
| 1494 | + permutes, in_shapes, out_shapes, out_lengths = _kt_regroup_permutes( |
| 1495 | + values[0], keys, lengths, groups |
| 1496 | + ) |
| 1497 | + refs = [[] for _ in groups] |
| 1498 | + for i in range(permutes.size(0)): |
| 1499 | + in_idx, out_idx, in_start, _, length, _ = permutes[i].tolist() |
| 1500 | + refs[out_idx].append(values[in_idx][:, in_start : (in_start + length)]) |
| 1501 | + refs = [torch.cat(ref, dim=1) for ref in refs] |
| 1502 | + outputs = torch.ops.fbgemm.permute_multi_embedding( |
| 1503 | + values, permutes, in_shapes, out_shapes, out_lengths |
| 1504 | + ) |
| 1505 | + for out, ref in zip(outputs, refs): |
| 1506 | + self.assertTrue(torch.allclose(out, ref)) |
| 1507 | + |
| 1508 | + def test_multi_permute_backward_cpu(self) -> None: |
| 1509 | + batch_size = 32 |
| 1510 | + keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]] |
| 1511 | + lengths = [[3, 4], [5, 6, 7], [8]] |
| 1512 | + groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]] |
| 1513 | + values = [ |
| 1514 | + torch.randn(batch_size, sum(lens), device="cpu", requires_grad=True) |
| 1515 | + for lens in lengths |
| 1516 | + ] |
| 1517 | + ref_values = [v.detach() for v in values] |
| 1518 | + for v in ref_values: |
| 1519 | + v.requires_grad = True |
| 1520 | + permutes, in_shapes, out_shapes, out_lengths = _kt_regroup_permutes( |
| 1521 | + values[0], keys, lengths, groups |
| 1522 | + ) |
| 1523 | + refs = [[] for _ in groups] |
| 1524 | + for i in range(permutes.size(0)): |
| 1525 | + in_idx, out_idx, in_start, _, length, _ = permutes[i].tolist() |
| 1526 | + refs[out_idx].append(ref_values[in_idx][:, in_start : (in_start + length)]) |
| 1527 | + refs = [torch.cat(ref, dim=1) for ref in refs] |
| 1528 | + outputs = torch.ops.fbgemm.permute_multi_embedding( |
| 1529 | + values, permutes, in_shapes, out_shapes, out_lengths |
| 1530 | + ) |
| 1531 | + for out, ref in zip(outputs, refs): |
| 1532 | + self.assertTrue(torch.allclose(out, ref)) |
| 1533 | + |
| 1534 | + ref_loss, loss = refs[0].sum(), outputs[0].sum() |
| 1535 | + for i in range(1, len(refs)): |
| 1536 | + ref_loss += (i + 1.1) * refs[i].sum() |
| 1537 | + loss += (i + 1.1) * outputs[i].sum() |
| 1538 | + ref_loss.backward() |
| 1539 | + loss.backward() |
| 1540 | + for val, ref in zip(values, ref_values): |
| 1541 | + val_grad, ref_grad = val.grad, ref.grad |
| 1542 | + assert isinstance(val_grad, torch.Tensor) |
| 1543 | + self.assertTrue(torch.allclose(val_grad, ref_grad)) |
| 1544 | + |
| 1545 | + # pyre-ignore[56] |
| 1546 | + @unittest.skipIf( |
| 1547 | + torch.cuda.device_count() <= 0, |
| 1548 | + "CUDA is not available", |
| 1549 | + ) |
| 1550 | + def test_multi_permute_backward_gpu(self) -> None: |
| 1551 | + batch_size = 2048 |
| 1552 | + keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]] |
| 1553 | + lengths = [[96, 256], [512, 128, 768], [1024]] |
| 1554 | + groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]] |
| 1555 | + values = [ |
| 1556 | + torch.randn(batch_size, sum(lens), device="cuda", requires_grad=True) |
| 1557 | + for lens in lengths |
| 1558 | + ] |
| 1559 | + ref_values = [v.detach() for v in values] |
| 1560 | + for v in ref_values: |
| 1561 | + v.requires_grad = True |
| 1562 | + permutes, in_shapes, out_shapes, out_lengths = _kt_regroup_permutes( |
| 1563 | + values[0], keys, lengths, groups |
| 1564 | + ) |
| 1565 | + refs = [[] for _ in groups] |
| 1566 | + for i in range(permutes.size(0)): |
| 1567 | + in_idx, out_idx, in_start, _, length, _ = permutes[i].tolist() |
| 1568 | + refs[out_idx].append(ref_values[in_idx][:, in_start : (in_start + length)]) |
| 1569 | + refs = [torch.cat(ref, dim=1) for ref in refs] |
| 1570 | + outputs = torch.ops.fbgemm.permute_multi_embedding( |
| 1571 | + values, permutes, in_shapes, out_shapes, out_lengths |
| 1572 | + ) |
| 1573 | + for out, ref in zip(outputs, refs): |
| 1574 | + self.assertTrue(torch.allclose(out, ref)) |
| 1575 | + |
| 1576 | + ref_loss, loss = refs[0].sum(), outputs[0].sum() |
| 1577 | + for i in range(1, len(refs)): |
| 1578 | + ref_loss += (i + 1.1) * refs[i].sum() |
| 1579 | + loss += (i + 1.1) * outputs[i].sum() |
| 1580 | + ref_loss.backward() |
| 1581 | + loss.backward() |
| 1582 | + for val, ref in zip(values, ref_values): |
| 1583 | + val_grad, ref_grad = val.grad, ref.grad |
| 1584 | + assert isinstance(val_grad, torch.Tensor) |
| 1585 | + self.assertTrue(torch.allclose(val_grad, ref_grad)) |
| 1586 | + |
1400 | 1587 | def test_permute_duplicates(self) -> None:
|
1401 | 1588 | values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0])
|
1402 | 1589 | lengths = torch.IntTensor([0, 2, 0, 1, 1, 1, 0, 3, 0])
|
@@ -1673,8 +1860,6 @@ def test_string_vb(self) -> None:
|
1673 | 1860 | stride_per_key_per_rank=stride_per_key_per_rank,
|
1674 | 1861 | )
|
1675 | 1862 |
|
1676 |
| - print(str(jag_tensor)) |
1677 |
| - |
1678 | 1863 | self.assertEqual(
|
1679 | 1864 | str(jag_tensor),
|
1680 | 1865 | """\
|
|
0 commit comments