From c96559542e445653801a8d44bd8460f211d55402 Mon Sep 17 00:00:00 2001 From: Ivan Kobzarev Date: Sun, 23 Jun 2024 15:38:58 -0700 Subject: [PATCH] Fwd-Bwd correctness tests for TBEs, kernels (#2152) Summary: Pull Request resolved: https://github.com/pytorch/torchrec/pull/2152 Adding more tests for kernels coverage, testing inductor compilation and forward-backward numerical correctness. Reviewed By: TroyGarden, gnahzg Differential Revision: D58869080 --- torchrec/distributed/tests/test_pt2.py | 367 ++++++++++++++++++++++++- 1 file changed, 362 insertions(+), 5 deletions(-) diff --git a/torchrec/distributed/tests/test_pt2.py b/torchrec/distributed/tests/test_pt2.py index 5440dba18..032a9a0fb 100644 --- a/torchrec/distributed/tests/test_pt2.py +++ b/torchrec/distributed/tests/test_pt2.py @@ -7,6 +7,7 @@ # pyre-ignore-all-errors +import copy import itertools import sys import unittest @@ -14,14 +15,24 @@ from typing import Any, Dict, List, Tuple import torch +from fbgemm_gpu.permute_pooled_embedding_modules import PermutePooledEmbeddings +from fbgemm_gpu.permute_pooled_embedding_modules_split import ( + PermutePooledEmbeddingsSplit, +) +from fbgemm_gpu.split_embedding_utils import get_table_batched_offsets_from_dense +from fbgemm_gpu.split_table_batched_embeddings_ops_training import ( + ComputeDevice, + EmbeddingLocation, + SplitTableBatchedEmbeddingBagsCodegen, +) from hypothesis import given, settings, strategies as st +from torch._dynamo.testing import reduce_to_scalar_loss from torchrec.distributed.test_utils.infer_utils import ( KJTInputExportDynamicShapeWrapper, KJTInputExportWrapperWithStrides, TestQuantFPEBCSharder, ) from torchrec.pt2.utils import kjt_for_pt2_tracing -from torchrec.sparse.jagged_tensor import KeyedTensor try: # pyre-ignore @@ -50,13 +61,16 @@ ComputeKJTToJTDict, JaggedTensor, KeyedJaggedTensor, + KeyedTensor, ) -def make_kjt(values: List[int], lengths: List[int]) -> KeyedJaggedTensor: - values_tensor = torch.tensor(values, dtype=torch.int32) - lengths_tensor = torch.tensor(lengths, dtype=torch.int32) - weights_tensor = torch.randn(len(values), dtype=torch.float32) +def make_kjt( + values: List[int], lengths: List[int], device: str = "cpu" +) -> KeyedJaggedTensor: + values_tensor = torch.tensor(values, dtype=torch.int32, device=device) + lengths_tensor = torch.tensor(lengths, dtype=torch.int32, device=device) + weights_tensor = torch.randn(len(values), dtype=torch.float32, device=device) torch._check(torch.sum(lengths_tensor).item() == values_tensor.size(0)) kjt = KeyedJaggedTensor( keys=[f"key{i}" for i in range(len(lengths))], @@ -149,6 +163,102 @@ class _TestType(Enum): DYNAMO_COMPILE = auto() +# pyre-ignore +def _copy_input_tensors(t, device): + if isinstance(t, torch.Tensor): + ret = t.detach().clone().to(device) + if ret.dtype in [torch.float, torch.double]: + ret.requires_grad = True + ret.retain_grad() + return ret + elif isinstance(t, (list, tuple)): + return [_copy_input_tensors(_t, device) for _t in t] + elif isinstance(t, int): + return t + else: + raise ValueError(f"Unsupported type {type(t)}") + + +# pyre-ignore +def _grad_detach_clone(t): + if isinstance(t, torch.Tensor): + # pyre-ignore + if t.grad is None: + return None + return t.grad.detach().clone() + elif isinstance(t, (list, tuple)): + return [_grad_detach_clone(_t) for _t in t] + elif isinstance(t, int): + return t + else: + raise ValueError(f"Unsupported type {type(t)}") + + +# pyre-ignore +def _assert_close(actual, expected) -> None: + if actual is None and expected is None: + return + + if isinstance(expected, torch.Tensor): + assert isinstance(actual, torch.Tensor) + torch.testing.assert_close(actual, expected, atol=1e-3, rtol=1e-3) + elif isinstance(expected, (list, tuple)): + assert type(expected) is type(actual) + for _a, _e in zip(actual, expected): + _assert_close(_a, _e) + elif isinstance(expected, int): + assert type(expected) is type(actual) + assert expected == actual + else: + raise ValueError(f"Unsupported type {type(expected)}") + + +def test_compile_fwd_bwd( + fn, + inp, + device: torch.device, + unpack_inp: bool = False, + backend: str = "inductor", + fullgraph: bool = True, + skip_backward: bool = False, + *args, + **kwargs, +): + eager_input = _copy_input_tensors(inp, device) + compile_input = _copy_input_tensors(inp, device) + + if unpack_inp: + eager_out = fn(*eager_input, *args, **kwargs) + else: + eager_out = fn(eager_input, *args, **kwargs) + + if not skip_backward: + eager_loss = reduce_to_scalar_loss(eager_out) + eager_loss.backward() + eager_bwd_out = _grad_detach_clone(eager_input) + + with dynamo_skipfiles_allow("torchrec"): + torch._dynamo.config.capture_scalar_outputs = True + torch._dynamo.config.capture_dynamic_output_shape_ops = True + + if unpack_inp: + compile_out = torch.compile(fn, backend=backend, fullgraph=fullgraph)( + *compile_input + ) + else: + compile_out = torch.compile(fn, backend=backend, fullgraph=fullgraph)( + compile_input + ) + + if not skip_backward: + reduce_to_scalar_loss(compile_out).backward() + compile_bwd_out = _grad_detach_clone(compile_input) + + _assert_close(compile_out, eager_out) + if not skip_backward: + _assert_close(compile_bwd_out, eager_bwd_out) + + class TestPt2(unittest.TestCase): def setUp(self): super().setUp() @@ -393,6 +503,26 @@ def forward(self, inputs: List[KeyedTensor]) -> Dict[str, torch.Tensor]: compile_output = opt_fn(inputs) torch.testing.assert_close(eager_output, compile_output) + def test_kjt_permute_dynamo_compile(self) -> None: + class M(torch.nn.Module): + def forward(self, kjt: KeyedJaggedTensor, indices: List[int]): + return kjt.permute(indices) + + device = "cuda" + kjt: KeyedJaggedTensor = make_kjt([2, 3, 4, 5, 6], [1, 2, 1, 1], device=device) + indices: List[int] = [1, 0, 3, 2] + # pyre-ignore + inputs_fn = lambda kjt: ( + *kjt_module_kjt_inputs_with_strides(kjt), + indices, + ) + self._test_kjt_input_module_dynamo_compile( + M(), + kjt.keys(), + inputs_fn(kjt_for_pt2_tracing(kjt)), + backend="inductor", + ) + def test_kjt_length_per_key(self) -> None: class M(torch.nn.Module): def forward(self, kjt: KeyedJaggedTensor): @@ -666,3 +796,230 @@ def f(kjt): ).sync() f(kjt_for_pt2_tracing(kjt1)) self.assertEqual(counter.frame_count, 1) + + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Not enough GPUs available", + ) + @settings(deadline=None) + def test_ebc_vb_reindex(self) -> None: + device = "cuda" + + def fn( + embs: torch.Tensor, + indices: torch.Tensor, + input_num_indices: List[int], + input_rows: List[int], + input_columns: List[int], + ): + reindex_output = torch.ops.fbgemm.batch_index_select_dim0_tensor( + inputs=embs, + indices=indices.view(-1), + input_num_indices=torch.tensor(input_num_indices, dtype=torch.int64), + input_rows=torch.tensor(input_rows, dtype=torch.int64), + input_columns=torch.tensor(input_columns, dtype=torch.int64), + permute_output_dim_0_1=True, + ) + return reindex_output + + N = 5 + batch_size = 10 + emb_dim = 12 + embs: torch.Tensor = torch.randn( + [N * batch_size * emb_dim], device=device, requires_grad=True + ) + torch._dynamo.mark_dynamic(embs, 0) + input_num_indices = [batch_size] * N + input_rows = [batch_size] * N + input_columns = [emb_dim] * N + indices: torch.Tensor = ( + torch.arange(batch_size) + .expand(N, batch_size) + .contiguous() + .to(device=device) + ) + + ins = (embs, indices, input_num_indices, input_rows, input_columns) + test_compile_fwd_bwd(fn, ins, device, unpack_inp=True) + + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Not enough GPUs, this test requires at least two GPUs", + ) + @settings(deadline=None) + def test_permute_pooled_embs(self) -> None: + device = "cuda" + m = PermutePooledEmbeddings( + embs_dims=[12, 12, 12], + permute=[2, 1, 0], + ) + inp = torch.randn(12, 3) + test_compile_fwd_bwd(m, inp, device, backend="aot_eager") + + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Not enough GPUs, this test requires at least two GPUs", + ) + @settings(deadline=None) + def test_permute_pooled_embs_split(self) -> None: + device = "cuda" + m = PermutePooledEmbeddingsSplit( + embs_dims=[12, 12, 12], + permute=[2, 1, 0], + ) + inp = torch.randn(12, 3) + test_compile_fwd_bwd(m, inp, device) + + @settings(deadline=None) + def test_tbe_compile(self) -> None: + D = 4 + T = 2 + E = 10 + Ds = [D] * T + Es = [E] * T + + device = "cuda" + tbe = SplitTableBatchedEmbeddingBagsCodegen( + embedding_specs=[ + ( + E, + D, + ( + EmbeddingLocation.MANAGED + if device == "cuda" + else EmbeddingLocation.HOST + ), + ComputeDevice.CUDA if device == "cuda" else ComputeDevice.CPU, + ) + for (E, D) in zip(Es, Ds) + ], + ) + tbe.init_embedding_weights_uniform(0, 1) + + class M(torch.nn.Module): + def __init__(self, tbe) -> None: + super().__init__() + self.tbe = tbe + + def forward(self, indices, offsets, f) -> torch.Tensor: + e = self.tbe(indices, offsets) + return torch.mul(torch.mean(e, dim=1), f) + + m = M(tbe) + m.train(True) + m_compile = copy.deepcopy(m) + m_compile.train(True) + + def get_weights(m): + return m.tbe.weights_uvm.clone().detach() + + original_weights = get_weights(m) + + x = torch.Tensor( + [ + [ + [1], + [1], + ], + [[3], [4]], + ] + ).to(dtype=torch.int64, device=device) + (indices, offsets) = get_table_batched_offsets_from_dense( + x, use_cpu=device == "cpu" + ) + inp_f = torch.randn(T, requires_grad=True, device=device) + + # EAGER + out = m(indices, offsets, inp_f.clone()) + reduce_to_scalar_loss(out).backward() + eager_weights_diff = get_weights(m) - original_weights + + # COMPILE + orig_compile_weights = get_weights(m_compile) + with dynamo_skipfiles_allow("torchrec"): + torch._dynamo.config.capture_scalar_outputs = True + torch._dynamo.config.capture_dynamic_output_shape_ops = True + + compile_out = torch.compile(m_compile, backend="aot_eager", fullgraph=True)( + indices, offsets, inp_f.clone() + ) + reduce_to_scalar_loss(compile_out).backward() + compile_weights_diff = get_weights(m_compile) - orig_compile_weights + + assert_close(eager_weights_diff, compile_weights_diff) + + @settings(deadline=None) + def test_tbe_compile_vb(self) -> None: + D = 4 + T = 2 + E = 10 + Ds = [D] * T + Es = [E] * T + + device = "cuda" + tbe = SplitTableBatchedEmbeddingBagsCodegen( + embedding_specs=[ + ( + E, + D, + ( + EmbeddingLocation.MANAGED + if device == "cuda" + else EmbeddingLocation.HOST + ), + ComputeDevice.CUDA if device == "cuda" else ComputeDevice.CPU, + ) + for (E, D) in zip(Es, Ds) + ], + ) + tbe.init_embedding_weights_uniform(0, 1) + + class M(torch.nn.Module): + def __init__(self, tbe) -> None: + super().__init__() + self.tbe = tbe + + def forward( + self, indices, offsets, batch_size_per_feature_per_rank, f + ) -> torch.Tensor: + e = self.tbe( + indices, + offsets, + batch_size_per_feature_per_rank=batch_size_per_feature_per_rank, + ) + return torch.mul(torch.mean(e, dim=0), f) + + m = M(tbe) + m.train(True) + m_compile = copy.deepcopy(m) + m_compile.train(True) + + def get_weights(m): + return m.tbe.weights_uvm.clone().detach() + + original_weights = get_weights(m) + + indices = torch.Tensor([1, 2, 0, 1, 2]).to(dtype=torch.int64, device=device) + lengths = torch.Tensor([2, 3]).to(dtype=torch.int64, device=device) + offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(lengths) + batch_size_per_feature_per_rank = [[1], [2]] + inp_f = torch.randn(1, requires_grad=True, device=device) + + # EAGER + out = m(indices, offsets, batch_size_per_feature_per_rank, inp_f.clone()) + reduce_to_scalar_loss(out).backward() + eager_weights_diff = get_weights(m) - original_weights + + # COMPILE + orig_compile_weights = get_weights(m_compile) + with dynamo_skipfiles_allow("torchrec"): + torch._dynamo.config.capture_scalar_outputs = True + torch._dynamo.config.capture_dynamic_output_shape_ops = True + + compile_out = torch.compile(m_compile, backend="aot_eager", fullgraph=True)( + indices, offsets, batch_size_per_feature_per_rank, inp_f.clone() + ) + reduce_to_scalar_loss(compile_out).backward() + compile_weights_diff = get_weights(m_compile) - orig_compile_weights + + assert_close(eager_weights_diff, compile_weights_diff)