|
16 | 16 | from torch.testing import FileCheck
|
17 | 17 | from torchrec.fx import symbolic_trace
|
18 | 18 | from torchrec.sparse.jagged_tensor import (
|
| 19 | + _kt_regroup_permutes, |
19 | 20 | _regroup_keyed_tensors,
|
20 | 21 | ComputeJTDictToKJT,
|
21 | 22 | ComputeKJTToJTDict,
|
@@ -1374,6 +1375,192 @@ def test_permute_vb(self) -> None:
|
1374 | 1375 | )
|
1375 | 1376 | self.assertEqual(permuted_jag_tensor.weights_or_none(), None)
|
1376 | 1377 |
|
| 1378 | + def test_kt_regroup_permutes(self) -> None: |
| 1379 | + keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]] |
| 1380 | + lengths = [[3, 4], [5, 6, 7], [8]] |
| 1381 | + groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]] |
| 1382 | + for device in ["cpu", "meta", "cuda"]: |
| 1383 | + if device == "cuda" and not torch.cuda.is_available(): |
| 1384 | + continue |
| 1385 | + device = torch.device(device) |
| 1386 | + permutes, in_shapes, out_shapes, out_lengths = _kt_regroup_permutes( |
| 1387 | + torch.empty(0, device=device), keys, lengths, groups |
| 1388 | + ) |
| 1389 | + ref_permutes = [ |
| 1390 | + [0, 0, 0, 0, 3, 4], # f1, jump to 4, as a start |
| 1391 | + [1, 0, 0, 3, 5, 0], # f3 |
| 1392 | + [0, 1, 3, 0, 4, 0], # f2 |
| 1393 | + [1, 2, 5, 0, 6, 0], # f4 |
| 1394 | + [0, 2, 0, 6, 3, -6], # f1 jump to 6, as in a jump sequence |
| 1395 | + [2, 2, 0, 9, 8, 0], # f6 |
| 1396 | + [0, 3, 0, 0, 3, -8], # f1 jump stop, as out of boundary |
| 1397 | + [1, 3, 11, 3, 7, 0], # f5 |
| 1398 | + ] |
| 1399 | + if device.type == "meta": |
| 1400 | + self.assertEqual( |
| 1401 | + permutes.shape, (len(ref_permutes), len(ref_permutes[0])) |
| 1402 | + ) |
| 1403 | + self.assertEqual(in_shapes.shape, (3,)) |
| 1404 | + self.assertEqual(out_shapes.shape, (4,)) |
| 1405 | + else: |
| 1406 | + self.assertTrue( |
| 1407 | + torch.equal( |
| 1408 | + permutes, |
| 1409 | + torch.tensor(ref_permutes, dtype=torch.int32, device=device), |
| 1410 | + ) |
| 1411 | + ) |
| 1412 | + self.assertEqual(in_shapes.tolist(), [7, 18, 8]) |
| 1413 | + self.assertEqual(out_shapes.tolist(), [8, 4, 17, 10]) |
| 1414 | + self.assertEqual(out_lengths, [8, 4, 17, 10]) |
| 1415 | + |
| 1416 | + def test_multi_permute_forward_cpu(self) -> None: |
| 1417 | + batch_size = 32 |
| 1418 | + keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]] |
| 1419 | + lengths = [[3, 4], [5, 6, 7], [8]] |
| 1420 | + groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]] |
| 1421 | + values = [ |
| 1422 | + torch.randn(batch_size, sum(lens), device="cpu", requires_grad=True) |
| 1423 | + for lens in lengths |
| 1424 | + ] |
| 1425 | + permutes, in_shapes, out_shapes, out_lengths = _kt_regroup_permutes( |
| 1426 | + values[0], keys, lengths, groups |
| 1427 | + ) |
| 1428 | + refs = [[] for _ in groups] |
| 1429 | + for i in range(permutes.size(0)): |
| 1430 | + in_idx, out_idx, in_start, _, length, _ = permutes[i].tolist() |
| 1431 | + refs[out_idx].append(values[in_idx][:, in_start : (in_start + length)]) |
| 1432 | + refs = [torch.cat(ref, dim=1) for ref in refs] |
| 1433 | + outputs = torch.ops.fbgemm.permute_multi_embedding( |
| 1434 | + values, permutes, in_shapes, out_shapes, out_lengths |
| 1435 | + ) |
| 1436 | + for out, ref in zip(outputs, refs): |
| 1437 | + self.assertTrue(torch.allclose(out, ref)) |
| 1438 | + |
| 1439 | + def test_multi_permute_forward_meta(self) -> None: |
| 1440 | + batch_size = 32 |
| 1441 | + keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]] |
| 1442 | + lengths = [[3, 4], [5, 6, 7], [8]] |
| 1443 | + groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]] |
| 1444 | + values = [ |
| 1445 | + torch.randn(batch_size, sum(lens), device="meta", requires_grad=True) |
| 1446 | + for lens in lengths |
| 1447 | + ] |
| 1448 | + permutes, in_shapes, out_shapes, out_lengths = _kt_regroup_permutes( |
| 1449 | + values[0], keys, lengths, groups |
| 1450 | + ) |
| 1451 | + outputs = torch.ops.fbgemm.permute_multi_embedding( |
| 1452 | + values, permutes, in_shapes, out_shapes, out_lengths |
| 1453 | + ) |
| 1454 | + for out, ref in zip(outputs, out_lengths): |
| 1455 | + self.assertEqual(out.shape, (batch_size, ref)) |
| 1456 | + |
| 1457 | + # pyre-ignore[56] |
| 1458 | + @unittest.skipIf( |
| 1459 | + torch.cuda.device_count() <= 0, |
| 1460 | + "CUDA is not available", |
| 1461 | + ) |
| 1462 | + def test_multi_permute_forward_gpu(self) -> None: |
| 1463 | + batch_size = 1024 |
| 1464 | + keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]] |
| 1465 | + lengths = [[96, 256], [512, 128, 768], [1024]] |
| 1466 | + groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]] |
| 1467 | + values = [ |
| 1468 | + torch.randn(batch_size, sum(lens), device="cuda", requires_grad=True) |
| 1469 | + for lens in lengths |
| 1470 | + ] |
| 1471 | + permutes, in_shapes, out_shapes, out_lengths = _kt_regroup_permutes( |
| 1472 | + values[0], keys, lengths, groups |
| 1473 | + ) |
| 1474 | + refs = [[] for _ in groups] |
| 1475 | + for i in range(permutes.size(0)): |
| 1476 | + in_idx, out_idx, in_start, _, length, _ = permutes[i].tolist() |
| 1477 | + refs[out_idx].append(values[in_idx][:, in_start : (in_start + length)]) |
| 1478 | + refs = [torch.cat(ref, dim=1) for ref in refs] |
| 1479 | + outputs = torch.ops.fbgemm.permute_multi_embedding( |
| 1480 | + values, permutes, in_shapes, out_shapes, out_lengths |
| 1481 | + ) |
| 1482 | + for out, ref in zip(outputs, refs): |
| 1483 | + self.assertTrue(torch.allclose(out, ref)) |
| 1484 | + |
| 1485 | + def test_multi_permute_backward_cpu(self) -> None: |
| 1486 | + batch_size = 32 |
| 1487 | + keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]] |
| 1488 | + lengths = [[3, 4], [5, 6, 7], [8]] |
| 1489 | + groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]] |
| 1490 | + values = [ |
| 1491 | + torch.randn(batch_size, sum(lens), device="cpu", requires_grad=True) |
| 1492 | + for lens in lengths |
| 1493 | + ] |
| 1494 | + ref_values = [v.detach() for v in values] |
| 1495 | + for v in ref_values: |
| 1496 | + v.requires_grad = True |
| 1497 | + permutes, in_shapes, out_shapes, out_lengths = _kt_regroup_permutes( |
| 1498 | + values[0], keys, lengths, groups |
| 1499 | + ) |
| 1500 | + refs = [[] for _ in groups] |
| 1501 | + for i in range(permutes.size(0)): |
| 1502 | + in_idx, out_idx, in_start, _, length, _ = permutes[i].tolist() |
| 1503 | + refs[out_idx].append(ref_values[in_idx][:, in_start : (in_start + length)]) |
| 1504 | + refs = [torch.cat(ref, dim=1) for ref in refs] |
| 1505 | + outputs = torch.ops.fbgemm.permute_multi_embedding( |
| 1506 | + values, permutes, in_shapes, out_shapes, out_lengths |
| 1507 | + ) |
| 1508 | + for out, ref in zip(outputs, refs): |
| 1509 | + self.assertTrue(torch.allclose(out, ref)) |
| 1510 | + |
| 1511 | + ref_loss, loss = refs[0].sum(), outputs[0].sum() |
| 1512 | + for i in range(1, len(refs)): |
| 1513 | + ref_loss += (i + 1.1) * refs[i].sum() |
| 1514 | + loss += (i + 1.1) * outputs[i].sum() |
| 1515 | + ref_loss.backward() |
| 1516 | + loss.backward() |
| 1517 | + for val, ref in zip(values, ref_values): |
| 1518 | + val_grad, ref_grad = val.grad, ref.grad |
| 1519 | + assert isinstance(val_grad, torch.Tensor) |
| 1520 | + self.assertTrue(torch.allclose(val_grad, ref_grad)) |
| 1521 | + |
| 1522 | + # pyre-ignore[56] |
| 1523 | + @unittest.skipIf( |
| 1524 | + torch.cuda.device_count() <= 0, |
| 1525 | + "CUDA is not available", |
| 1526 | + ) |
| 1527 | + def test_multi_permute_backward_gpu(self) -> None: |
| 1528 | + batch_size = 2048 |
| 1529 | + keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]] |
| 1530 | + lengths = [[96, 256], [512, 128, 768], [1024]] |
| 1531 | + groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]] |
| 1532 | + values = [ |
| 1533 | + torch.randn(batch_size, sum(lens), device="cuda", requires_grad=True) |
| 1534 | + for lens in lengths |
| 1535 | + ] |
| 1536 | + ref_values = [v.detach() for v in values] |
| 1537 | + for v in ref_values: |
| 1538 | + v.requires_grad = True |
| 1539 | + permutes, in_shapes, out_shapes, out_lengths = _kt_regroup_permutes( |
| 1540 | + values[0], keys, lengths, groups |
| 1541 | + ) |
| 1542 | + refs = [[] for _ in groups] |
| 1543 | + for i in range(permutes.size(0)): |
| 1544 | + in_idx, out_idx, in_start, _, length, _ = permutes[i].tolist() |
| 1545 | + refs[out_idx].append(ref_values[in_idx][:, in_start : (in_start + length)]) |
| 1546 | + refs = [torch.cat(ref, dim=1) for ref in refs] |
| 1547 | + outputs = torch.ops.fbgemm.permute_multi_embedding( |
| 1548 | + values, permutes, in_shapes, out_shapes, out_lengths |
| 1549 | + ) |
| 1550 | + for out, ref in zip(outputs, refs): |
| 1551 | + self.assertTrue(torch.allclose(out, ref)) |
| 1552 | + |
| 1553 | + ref_loss, loss = refs[0].sum(), outputs[0].sum() |
| 1554 | + for i in range(1, len(refs)): |
| 1555 | + ref_loss += (i + 1.1) * refs[i].sum() |
| 1556 | + loss += (i + 1.1) * outputs[i].sum() |
| 1557 | + ref_loss.backward() |
| 1558 | + loss.backward() |
| 1559 | + for val, ref in zip(values, ref_values): |
| 1560 | + val_grad, ref_grad = val.grad, ref.grad |
| 1561 | + assert isinstance(val_grad, torch.Tensor) |
| 1562 | + self.assertTrue(torch.allclose(val_grad, ref_grad)) |
| 1563 | + |
1377 | 1564 | def test_permute_duplicates(self) -> None:
|
1378 | 1565 | values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0])
|
1379 | 1566 | lengths = torch.IntTensor([0, 2, 0, 1, 1, 1, 0, 3, 0])
|
@@ -1650,8 +1837,6 @@ def test_string_vb(self) -> None:
|
1650 | 1837 | stride_per_key_per_rank=stride_per_key_per_rank,
|
1651 | 1838 | )
|
1652 | 1839 |
|
1653 |
| - print(str(jag_tensor)) |
1654 |
| - |
1655 | 1840 | self.assertEqual(
|
1656 | 1841 | str(jag_tensor),
|
1657 | 1842 | """\
|
|
0 commit comments