Skip to content

EBC backward numeric correctness tests #2162

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
188 changes: 136 additions & 52 deletions torchrec/distributed/tests/test_pt2_multiprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#!/usr/bin/env python3

import unittest
from dataclasses import dataclass
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple

Expand All @@ -18,7 +19,11 @@
import torch
import torchrec
import torchrec.pt2.checks
from fbgemm_gpu.split_table_batched_embeddings_ops_training import (
SplitTableBatchedEmbeddingBagsCodegen,
)
from hypothesis import given, settings, strategies as st, Verbosity
from torch import distributed as dist
from torch._dynamo.testing import reduce_to_scalar_loss
from torchrec.distributed.embedding import EmbeddingCollectionSharder
from torchrec.distributed.embedding_types import EmbeddingComputeKernel
Expand Down Expand Up @@ -102,6 +107,11 @@ class _ConvertToVariableBatch(Enum):
TRUE = 1


@dataclass
class _TestConfig:
n_extra_numerics_checks_inputs: int = 2


class EBCSharderFixedShardingType(EmbeddingBagCollectionSharder):
def __init__(
self,
Expand Down Expand Up @@ -227,6 +237,7 @@ def _test_compile_rank_fn(
kernel_type: str,
input_type: _InputType,
convert_to_vb: bool,
config: _TestConfig,
torch_compile_backend: Optional[str] = None,
local_size: Optional[int] = None,
) -> None:
Expand All @@ -240,8 +251,9 @@ def _test_compile_rank_fn(
num_float_features: int = 8
num_weighted_features: int = 1

device = torch.device("cuda")
pg = ctx.pg
# pyre-ignore
device: torch.Device = torch.device("cuda")
pg: Optional[dist.ProcessGroup] = ctx.pg
assert pg is not None

topology: Topology = Topology(world_size=world_size, compute_device="cuda")
Expand Down Expand Up @@ -316,58 +328,96 @@ def _test_compile_rank_fn(
pg,
)

dmp = DistributedModelParallel(
model,
env=ShardingEnv.from_process_group(pg),
plan=plan,
# pyre-ignore
sharders=sharders,
device=device,
init_data_parallel=False,
)

if input_type == _InputType.VARIABLE_BATCH:
(
global_model_input,
local_model_inputs,
) = ModelInput.generate_variable_batch_input(
average_batch_size=batch_size,
world_size=world_size,
num_float_features=num_float_features,
# pyre-ignore
def _dmp(m: torch.nn.Module) -> DistributedModelParallel:
return DistributedModelParallel(
m,
# pyre-ignore
tables=mi.tables,
)
else:
(
global_model_input,
local_model_inputs,
) = ModelInput.generate(
batch_size=batch_size,
world_size=world_size,
num_float_features=num_float_features,
tables=mi.tables,
weighted_tables=mi.weighted_tables,
variable_batch_size=False,
env=ShardingEnv.from_process_group(pg),
plan=plan,
sharders=sharders,
device=device,
init_data_parallel=False,
)

local_model_input = local_model_inputs[0].to(device)
dmp = _dmp(model)
dmp_compile = _dmp(model)

# TODO: Fix some data dependent failures on subsequent inputs
n_extra_numerics_checks = config.n_extra_numerics_checks_inputs
ins = []

for _ in range(1 + n_extra_numerics_checks):
if input_type == _InputType.VARIABLE_BATCH:
(
_,
local_model_inputs,
) = ModelInput.generate_variable_batch_input(
average_batch_size=batch_size,
world_size=world_size,
num_float_features=num_float_features,
# pyre-ignore
tables=mi.tables,
)
else:
(
_,
local_model_inputs,
) = ModelInput.generate(
batch_size=batch_size,
world_size=world_size,
num_float_features=num_float_features,
tables=mi.tables,
weighted_tables=mi.weighted_tables,
variable_batch_size=False,
)
ins.append(local_model_inputs)

local_model_input = ins[0][rank].to(device)

kjt = local_model_input.idlist_features
ff = local_model_input.float_features
ff.requires_grad = True
kjt_ft = kjt_for_pt2_tracing(kjt, convert_to_vb=convert_to_vb)

compile_input_ff = ff.clone().detach()
compile_input_ff.requires_grad = True

torchrec.distributed.comm_ops.set_use_sync_collectives(True)
torchrec.pt2.checks.set_use_torchdynamo_compiling_path(True)

dmp.train(True)
dmp_compile.train(True)

def get_weights(dmp: DistributedModelParallel) -> torch.Tensor:
tbe = None
if test_model_type == _ModelType.EBC:
tbe = (
dmp._dmp_wrapped_module._ebc._lookups[0]._emb_modules[0]._emb_module
)
elif test_model_type == _ModelType.FPEBC:
tbe = (
dmp._dmp_wrapped_module._fpebc._lookups[0]
._emb_modules[0]
._emb_module
)
elif test_model_type == _ModelType.EC:
tbe = (
dmp._dmp_wrapped_module._ec._lookups[0]._emb_modules[0]._emb_module
)
assert isinstance(tbe, SplitTableBatchedEmbeddingBagsCodegen)

return tbe.weights_dev.clone().detach()

original_weights = get_weights(dmp)
original_weights.zero_()
original_compile_weights = get_weights(dmp_compile)
original_compile_weights.zero_()

eager_out = dmp(kjt_ft, ff)

eager_loss = reduce_to_scalar_loss(eager_out)
eager_loss.backward()
reduce_to_scalar_loss(eager_out).backward()
eager_weights_diff = get_weights(dmp) - original_weights

if torch_compile_backend is None:
return
Expand All @@ -378,7 +428,7 @@ def _test_compile_rank_fn(
torch._dynamo.config.capture_scalar_outputs = True
torch._dynamo.config.capture_dynamic_output_shape_ops = True
opt_fn = torch.compile(
dmp,
dmp_compile,
backend=torch_compile_backend,
fullgraph=True,
)
Expand All @@ -387,32 +437,58 @@ def _test_compile_rank_fn(
)
torch.testing.assert_close(eager_out, compile_out, atol=1e-3, rtol=1e-3)
if run_compile_backward:
loss = reduce_to_scalar_loss(compile_out)
loss.backward()
reduce_to_scalar_loss(compile_out).backward()
compile_weights_diff = (
get_weights(dmp_compile) - original_compile_weights
)
# Checking TBE weights updated inplace
torch.testing.assert_close(
eager_weights_diff, compile_weights_diff, atol=1e-3, rtol=1e-3
)
# Check float inputs gradients
torch.testing.assert_close(
ff.grad, compile_input_ff.grad, atol=1e-3, rtol=1e-3
)

##### COMPILE END #####

##### NUMERIC CHECK #####
with dynamo_skipfiles_allow("torchrec"):
n = len(local_model_inputs)
for i in range(n - 1):
local_model_input = local_model_inputs[1 + i].to(device)
for i in range(n_extra_numerics_checks):
local_model_input = ins[1 + i][rank].to(device)
kjt = local_model_input.idlist_features
kjt_ft = kjt_for_pt2_tracing(kjt, convert_to_vb=convert_to_vb)
ff = local_model_input.float_features
ff.requires_grad = True
eager_out_i = dmp(kjt_ft, ff)
eager_loss_i = reduce_to_scalar_loss(eager_out_i)
eager_loss_i.backward()
reduce_to_scalar_loss(eager_out_i).backward()
eager_weights_diff = get_weights(dmp) - original_weights

compile_input_ff = ff.detach().clone()
compile_out_i = opt_fn(kjt_ft, ff)
compile_input_ff = ff.clone().detach()
compile_input_ff.requires_grad = True
compile_out_i = opt_fn(kjt_ft, compile_input_ff)
torch.testing.assert_close(
eager_out_i, compile_out_i, atol=1e-3, rtol=1e-3
)
if run_compile_backward:
loss_i = torch._dynamo.testing.reduce_to_scalar_loss(compile_out_i)
loss_i.backward()
torch._dynamo.testing.reduce_to_scalar_loss(
compile_out_i
).backward()
compile_weights_diff = (
get_weights(dmp_compile) - original_compile_weights
)
# Checking TBE weights updated inplace
torch.testing.assert_close(
eager_weights_diff,
compile_weights_diff,
atol=1e-3,
rtol=1e-3,
)
# Check float inputs gradients
torch.testing.assert_close(
ff.grad, compile_input_ff.grad, atol=1e-3, rtol=1e-3
)

##### NUMERIC CHECK END #####


Expand All @@ -431,35 +507,41 @@ def disable_cuda_tf32(self) -> bool:
EmbeddingComputeKernel.DENSE.value,
],
),
model_type_sharding_type_input_type_tovb_backend=st.sampled_from(
given_config_tuple=st.sampled_from(
[
(
_ModelType.EBC,
ShardingType.TABLE_WISE.value,
_InputType.SINGLE_BATCH,
_ConvertToVariableBatch.TRUE,
"inductor",
# TODO: Debug data dependent numerics checks failures on subsequent inputs for VB backward
_TestConfig(n_extra_numerics_checks_inputs=0),
),
(
_ModelType.EBC,
ShardingType.COLUMN_WISE.value,
_InputType.SINGLE_BATCH,
_ConvertToVariableBatch.TRUE,
"inductor",
# TODO: Debug data dependent numerics checks failures on subsequent inputs for VB backward
_TestConfig(n_extra_numerics_checks_inputs=0),
),
(
_ModelType.EBC,
ShardingType.TABLE_WISE.value,
_InputType.SINGLE_BATCH,
_ConvertToVariableBatch.FALSE,
"eager",
_TestConfig(),
),
(
_ModelType.EBC,
ShardingType.COLUMN_WISE.value,
_InputType.SINGLE_BATCH,
_ConvertToVariableBatch.FALSE,
"eager",
_TestConfig(),
),
]
),
Expand All @@ -468,16 +550,17 @@ def disable_cuda_tf32(self) -> bool:
def test_compile_multiprocess(
self,
kernel_type: str,
model_type_sharding_type_input_type_tovb_backend: Tuple[
given_config_tuple: Tuple[
_ModelType,
str,
_InputType,
_ConvertToVariableBatch,
Optional[str],
_TestConfig,
],
) -> None:
model_type, sharding_type, input_type, tovb, compile_backend = (
model_type_sharding_type_input_type_tovb_backend
model_type, sharding_type, input_type, tovb, compile_backend, config = (
given_config_tuple
)
self._run_multi_process_test(
callable=_test_compile_rank_fn,
Expand All @@ -488,5 +571,6 @@ def test_compile_multiprocess(
kernel_type=kernel_type,
input_type=input_type,
convert_to_vb=tovb == _ConvertToVariableBatch.TRUE,
config=config,
torch_compile_backend=compile_backend,
)
Loading