Skip to content

Commit 3ca8e8b

Browse files
Ivan Kobzarevfacebook-github-bot
authored andcommitted
Setting to use torch dynamo compiling path in eager (#2045)
Summary: Pull Request resolved: #2045 1/ Refactoring to use is_torchdynamo_compiling() from torchrec.pt2.checks instead of code duplication 2/ We have alternative path of logic is_torchdynamo_compiling(). Our tests are not testing it without compilation, so it is error prone to not catch some shape mismatch or etc. => We need a tool how to cover it with eager tests without compilation. => Introducing global setting to force using is_torchdynamo_compiling() path for eager for test coverage and debug. Enabling this path for test_pt2_multiprocess, that first eager iteration will be done on is_torchdynamo_compiling path. Reviewed By: PaulZhang12, gnahzg Differential Revision: D57860075 fbshipit-source-id: a033be81367b814afa47a7b22bd68d7eccf4f991
1 parent 9a50de2 commit 3ca8e8b

File tree

6 files changed

+43
-51
lines changed

6 files changed

+43
-51
lines changed

torchrec/distributed/dist_data.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from torchrec.distributed.global_settings import get_propogate_device
2929
from torchrec.distributed.types import Awaitable, QuantizedCommCodecs, rank_device
3030
from torchrec.fx.utils import fx_marker
31+
from torchrec.pt2.checks import is_torchdynamo_compiling
3132
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
3233

3334
try:
@@ -47,15 +48,6 @@
4748
pass
4849

4950

50-
try:
51-
from torch.compiler import is_dynamo_compiling as is_torchdynamo_compiling
52-
53-
except Exception:
54-
55-
def is_torchdynamo_compiling() -> bool: # type: ignore[misc]
56-
return False
57-
58-
5951
logger: logging.Logger = logging.getLogger()
6052

6153

torchrec/distributed/quant_embeddingbag.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -59,20 +59,14 @@
5959
)
6060
from torchrec.modules.embedding_modules import EmbeddingBagCollectionInterface
6161
from torchrec.modules.feature_processor_ import FeatureProcessorsCollection
62+
from torchrec.pt2.checks import is_torchdynamo_compiling
6263
from torchrec.quant.embedding_modules import (
6364
EmbeddingBagCollection as QuantEmbeddingBagCollection,
6465
FeatureProcessedEmbeddingBagCollection as QuantFeatureProcessedEmbeddingBagCollection,
6566
MODULE_ATTR_QUANT_STATE_DICT_SPLIT_SCALE_BIAS,
6667
)
6768
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor
6869

69-
try:
70-
from torch._dynamo import is_compiling as is_torchdynamo_compiling
71-
except Exception:
72-
73-
def is_torchdynamo_compiling() -> bool: # type: ignore[misc]
74-
return False
75-
7670

7771
def get_device_from_parameter_sharding(ps: ParameterSharding) -> str:
7872
# pyre-ignore

torchrec/distributed/tests/test_pt2_multiprocess.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import torch
1919
import torchrec
20+
import torchrec.pt2.checks
2021
from hypothesis import given, settings, strategies as st, Verbosity
2122
from torchrec.distributed.embedding import EmbeddingCollectionSharder
2223
from torchrec.distributed.embedding_types import EmbeddingComputeKernel
@@ -94,6 +95,11 @@ class _InputType(Enum):
9495
VARIABLE_BATCH = 2
9596

9697

98+
class _ConvertToVariableBatch(Enum):
99+
FALSE = 0
100+
TRUE = 1
101+
102+
97103
class EBCSharderFixedShardingType(EmbeddingBagCollectionSharder):
98104
def __init__(
99105
self,
@@ -333,6 +339,8 @@ def _test_compile_rank_fn(
333339
kjt_ft = kjt_for_pt2_tracing(kjt, convert_to_vb=convert_to_vb)
334340

335341
torchrec.distributed.comm_ops.set_use_sync_collectives(True)
342+
torchrec.pt2.checks.set_use_torchdynamo_compiling_path(True)
343+
336344
dmp.train(True)
337345

338346
eager_out = dmp(kjt_ft)
@@ -385,14 +393,14 @@ def disable_cuda_tf32(self) -> bool:
385393
_ModelType.EBC,
386394
ShardingType.TABLE_WISE.value,
387395
_InputType.SINGLE_BATCH,
388-
True,
396+
_ConvertToVariableBatch.TRUE,
389397
"eager",
390398
),
391399
(
392400
_ModelType.EBC,
393401
ShardingType.COLUMN_WISE.value,
394402
_InputType.SINGLE_BATCH,
395-
True,
403+
_ConvertToVariableBatch.TRUE,
396404
"eager",
397405
),
398406
]
@@ -406,7 +414,7 @@ def test_compile_multiprocess(
406414
_ModelType,
407415
str,
408416
_InputType,
409-
bool,
417+
_ConvertToVariableBatch,
410418
str,
411419
],
412420
) -> None:
@@ -421,6 +429,6 @@ def test_compile_multiprocess(
421429
sharding_type=sharding_type,
422430
kernel_type=kernel_type,
423431
input_type=input_type,
424-
convert_to_vb=tovb,
432+
convert_to_vb=tovb == _ConvertToVariableBatch.TRUE,
425433
torch_compile_backend=compile_backend,
426434
)

torchrec/distributed/train_pipeline/train_pipelines.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -49,17 +49,10 @@
4949
TrainPipelineContext,
5050
)
5151
from torchrec.distributed.types import Awaitable
52+
from torchrec.pt2.checks import is_torchdynamo_compiling
5253
from torchrec.streamable import Multistreamable
5354

5455

55-
try:
56-
from torch._dynamo import is_compiling as is_torchdynamo_compiling
57-
except Exception:
58-
59-
def is_torchdynamo_compiling() -> bool: # type: ignore[misc]
60-
return False
61-
62-
6356
logger: logging.Logger = logging.getLogger(__name__)
6457

6558

torchrec/pt2/checks.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,22 +11,42 @@
1111

1212
import torch
1313

14+
USE_TORCHDYNAMO_COMPILING_PATH: bool = False
15+
16+
17+
def set_use_torchdynamo_compiling_path(val: bool) -> None:
18+
global USE_TORCHDYNAMO_COMPILING_PATH
19+
USE_TORCHDYNAMO_COMPILING_PATH = val
20+
21+
22+
def get_use_torchdynamo_compiling_path() -> bool:
23+
global USE_TORCHDYNAMO_COMPILING_PATH
24+
return USE_TORCHDYNAMO_COMPILING_PATH
25+
1426

1527
try:
1628
if torch.jit.is_scripting():
1729
raise Exception()
1830

1931
from torch.compiler import (
2032
is_compiling as is_compiler_compiling,
21-
is_dynamo_compiling as is_torchdynamo_compiling,
33+
is_dynamo_compiling as _is_torchdynamo_compiling,
2234
)
2335

36+
def is_torchdynamo_compiling() -> bool:
37+
if torch.jit.is_scripting():
38+
return False
39+
40+
# Can not use global variable here, as it is not supported in TorchScript
41+
# (It parses full method src even there is a guard torch.jit.is_scripting())
42+
return get_use_torchdynamo_compiling_path() or _is_torchdynamo_compiling()
43+
2444
def is_non_strict_exporting() -> bool:
2545
return not is_torchdynamo_compiling() and is_compiler_compiling()
2646

2747
except Exception:
2848
# BC for torch versions without compiler and torch deploy path
29-
def is_torchdynamo_compiling() -> bool: # type: ignore[misc]
49+
def is_torchdynamo_compiling() -> bool:
3050
return False
3151

3252
def is_non_strict_exporting() -> bool:

torchrec/sparse/jagged_tensor.py

Lines changed: 6 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,12 @@
1717
from torch.autograd.profiler import record_function
1818
from torch.fx._pytree import register_pytree_flatten_spec, TreeSpec
1919
from torch.utils._pytree import GetAttrKey, KeyEntry, register_pytree_node
20-
from torchrec.pt2.checks import pt2_checks_all_is_size, pt2_checks_tensor_slice
20+
from torchrec.pt2.checks import (
21+
is_non_strict_exporting,
22+
is_torchdynamo_compiling,
23+
pt2_checks_all_is_size,
24+
pt2_checks_tensor_slice,
25+
)
2126
from torchrec.streamable import Pipelineable
2227

2328
try:
@@ -38,26 +43,6 @@
3843
except ImportError:
3944
pass
4045

41-
try:
42-
if torch.jit.is_scripting():
43-
raise Exception()
44-
45-
from torch.compiler import (
46-
is_compiling as is_compiler_compiling,
47-
is_dynamo_compiling as is_torchdynamo_compiling,
48-
)
49-
50-
def is_non_strict_exporting() -> bool:
51-
return not is_torchdynamo_compiling() and is_compiler_compiling()
52-
53-
except Exception:
54-
# BC for torch versions without compiler and torch deploy path
55-
def is_torchdynamo_compiling() -> bool: # type: ignore[misc]
56-
return False
57-
58-
def is_non_strict_exporting() -> bool:
59-
return False
60-
6146

6247
def _pin_and_move(tensor: torch.Tensor, device: torch.device) -> torch.Tensor:
6348
if is_torchdynamo_compiling():

0 commit comments

Comments
 (0)