Skip to content

Commit d75d46b

Browse files
Ivan Kobzarevfacebook-github-bot
authored andcommitted
Setting to use torch dynamo compiling path in eager (#2045)
Summary: 1/ Refactoring to use is_torchdynamo_compiling() from torchrec.pt2.checks instead of code duplication 2/ We have alternative path of logic is_torchdynamo_compiling(). Our tests are not testing it without compilation, so it is error prone to not catch some shape mismatch or etc. => We need a tool how to cover it with eager tests without compilation. => Introducing global setting to force using is_torchdynamo_compiling() path for eager for test coverage and debug. Enabling this path for test_pt2_multiprocess, that first eager iteration will be done on is_torchdynamo_compiling path. Differential Revision: D57860075
1 parent 8c7fa2f commit d75d46b

File tree

6 files changed

+46
-53
lines changed

6 files changed

+46
-53
lines changed

torchrec/distributed/dist_data.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,12 @@
2727
from torchrec.distributed.embedding_types import KJTList
2828
from torchrec.distributed.types import Awaitable, QuantizedCommCodecs
2929
from torchrec.fx.utils import fx_marker
30+
from torchrec.pt2.checks import (
31+
is_non_strict_exporting,
32+
is_torchdynamo_compiling,
33+
pt2_checks_all_is_size,
34+
pt2_checks_tensor_slice,
35+
)
3036
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
3137

3238
try:
@@ -46,15 +52,6 @@
4652
pass
4753

4854

49-
try:
50-
from torch.compiler import is_dynamo_compiling as is_torchdynamo_compiling
51-
52-
except Exception:
53-
54-
def is_torchdynamo_compiling() -> bool: # type: ignore[misc]
55-
return False
56-
57-
5855
logger: logging.Logger = logging.getLogger()
5956

6057

torchrec/distributed/quant_embeddingbag.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -58,20 +58,14 @@
5858
)
5959
from torchrec.modules.embedding_modules import EmbeddingBagCollectionInterface
6060
from torchrec.modules.feature_processor_ import FeatureProcessorsCollection
61+
from torchrec.pt2.checks import is_torchdynamo_compiling
6162
from torchrec.quant.embedding_modules import (
6263
EmbeddingBagCollection as QuantEmbeddingBagCollection,
6364
FeatureProcessedEmbeddingBagCollection as QuantFeatureProcessedEmbeddingBagCollection,
6465
MODULE_ATTR_QUANT_STATE_DICT_SPLIT_SCALE_BIAS,
6566
)
6667
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor
6768

68-
try:
69-
from torch._dynamo import is_compiling as is_torchdynamo_compiling
70-
except Exception:
71-
72-
def is_torchdynamo_compiling() -> bool: # type: ignore[misc]
73-
return False
74-
7569

7670
def get_device_from_parameter_sharding(ps: ParameterSharding) -> str:
7771
# pyre-ignore

torchrec/distributed/tests/test_pt2_multiprocess.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import torch
1919
import torchrec
20+
import torchrec.pt2.checks
2021
from hypothesis import given, settings, strategies as st, Verbosity
2122
from torchrec.distributed.embedding import EmbeddingCollectionSharder
2223
from torchrec.distributed.embedding_types import EmbeddingComputeKernel
@@ -94,6 +95,11 @@ class _InputType(Enum):
9495
VARIABLE_BATCH = 2
9596

9697

98+
class _ConvertToVariableBatch(Enum):
99+
FALSE = 0
100+
TRUE = 1
101+
102+
97103
class EBCSharderFixedShardingType(EmbeddingBagCollectionSharder):
98104
def __init__(
99105
self,
@@ -333,6 +339,8 @@ def _test_compile_rank_fn(
333339
kjt_ft = kjt_for_pt2_tracing(kjt, convert_to_vb=convert_to_vb)
334340

335341
torchrec.distributed.comm_ops.set_use_sync_collectives(True)
342+
torchrec.pt2.checks.set_use_torchdynamo_compiling_path(True)
343+
336344
dmp.train(True)
337345

338346
eager_out = dmp(kjt_ft)
@@ -385,14 +393,14 @@ def disable_cuda_tf32(self) -> bool:
385393
_ModelType.EBC,
386394
ShardingType.TABLE_WISE.value,
387395
_InputType.SINGLE_BATCH,
388-
True,
396+
_ConvertToVariableBatch.TRUE,
389397
"eager",
390398
),
391399
(
392400
_ModelType.EBC,
393401
ShardingType.COLUMN_WISE.value,
394402
_InputType.SINGLE_BATCH,
395-
True,
403+
_ConvertToVariableBatch.TRUE,
396404
"eager",
397405
),
398406
]
@@ -406,7 +414,7 @@ def test_compile_multiprocess(
406414
_ModelType,
407415
str,
408416
_InputType,
409-
bool,
417+
_ConvertToVariableBatch,
410418
str,
411419
],
412420
) -> None:
@@ -421,6 +429,6 @@ def test_compile_multiprocess(
421429
sharding_type=sharding_type,
422430
kernel_type=kernel_type,
423431
input_type=input_type,
424-
convert_to_vb=tovb,
432+
convert_to_vb=tovb == _ConvertToVariableBatch.TRUE,
425433
torch_compile_backend=compile_backend,
426434
)

torchrec/distributed/train_pipeline/train_pipelines.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -49,17 +49,10 @@
4949
TrainPipelineContext,
5050
)
5151
from torchrec.distributed.types import Awaitable
52+
from torchrec.pt2.checks import is_torchdynamo_compiling
5253
from torchrec.streamable import Multistreamable
5354

5455

55-
try:
56-
from torch._dynamo import is_compiling as is_torchdynamo_compiling
57-
except Exception:
58-
59-
def is_torchdynamo_compiling() -> bool: # type: ignore[misc]
60-
return False
61-
62-
6356
logger: logging.Logger = logging.getLogger(__name__)
6457

6558

torchrec/pt2/checks.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,26 +8,42 @@
88
# pyre-strict
99

1010
from typing import List
11-
1211
import torch
1312

13+
USE_TORCHDYNAMO_COMPILING_PATH: bool = False
14+
15+
16+
def set_use_torchdynamo_compiling_path(val: bool) -> None:
17+
global USE_TORCHDYNAMO_COMPILING_PATH
18+
USE_TORCHDYNAMO_COMPILING_PATH = val
19+
20+
21+
def get_use_torchdynamo_compiling_path() -> bool:
22+
global USE_TORCHDYNAMO_COMPILING_PATH
23+
return USE_TORCHDYNAMO_COMPILING_PATH
24+
1425

1526
try:
1627
if torch.jit.is_scripting():
1728
raise Exception()
1829

1930
from torch.compiler import (
2031
is_compiling as is_compiler_compiling,
21-
is_dynamo_compiling as is_torchdynamo_compiling,
32+
is_dynamo_compiling as _is_torchdynamo_compiling,
2233
)
2334

35+
def is_torchdynamo_compiling() -> bool:
36+
global USE_TORCHDYNAMO_COMPILING_PATH
37+
return USE_TORCHDYNAMO_COMPILING_PATH or _is_torchdynamo_compiling()
38+
2439
def is_non_strict_exporting() -> bool:
2540
return not is_torchdynamo_compiling() and is_compiler_compiling()
2641

2742
except Exception:
2843
# BC for torch versions without compiler and torch deploy path
29-
def is_torchdynamo_compiling() -> bool: # type: ignore[misc]
30-
return False
44+
def is_torchdynamo_compiling() -> bool:
45+
global USE_TORCHDYNAMO_COMPILING_PATH
46+
return USE_TORCHDYNAMO_COMPILING_PATH
3147

3248
def is_non_strict_exporting() -> bool:
3349
return False

torchrec/sparse/jagged_tensor.py

Lines changed: 6 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,12 @@
1717
from torch.autograd.profiler import record_function
1818
from torch.fx._pytree import register_pytree_flatten_spec, TreeSpec
1919
from torch.utils._pytree import GetAttrKey, KeyEntry, register_pytree_node
20-
from torchrec.pt2.checks import pt2_checks_all_is_size, pt2_checks_tensor_slice
20+
from torchrec.pt2.checks import (
21+
is_non_strict_exporting,
22+
is_torchdynamo_compiling,
23+
pt2_checks_all_is_size,
24+
pt2_checks_tensor_slice,
25+
)
2126
from torchrec.streamable import Pipelineable
2227

2328
try:
@@ -38,26 +43,6 @@
3843
except ImportError:
3944
pass
4045

41-
try:
42-
if torch.jit.is_scripting():
43-
raise Exception()
44-
45-
from torch.compiler import (
46-
is_compiling as is_compiler_compiling,
47-
is_dynamo_compiling as is_torchdynamo_compiling,
48-
)
49-
50-
def is_non_strict_exporting() -> bool:
51-
return not is_torchdynamo_compiling() and is_compiler_compiling()
52-
53-
except Exception:
54-
# BC for torch versions without compiler and torch deploy path
55-
def is_torchdynamo_compiling() -> bool: # type: ignore[misc]
56-
return False
57-
58-
def is_non_strict_exporting() -> bool:
59-
return False
60-
6146

6247
def _pin_and_move(tensor: torch.Tensor, device: torch.device) -> torch.Tensor:
6348
if is_torchdynamo_compiling():

0 commit comments

Comments
 (0)