Skip to content

Commit c275371

Browse files
authored
Wrap adding logging options in try / catch. (#7307)
1 parent 91389c7 commit c275371

File tree

1 file changed

+18
-7
lines changed

1 file changed

+18
-7
lines changed

experimental/torch_xla2/torch_xla2/__init__.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,19 +10,30 @@
1010
'extract_jax',
1111
]
1212

13-
13+
from jax._src import xla_bridge
14+
os.environ.setdefault('ENABLE_RUNTIME_UPTIME_TELEMETRY', '1')
1415
jax.config.update('jax_enable_x64', True)
15-
jax.config.update(
16-
'jax_pjrt_client_create_options',
17-
f'ml_framework_name:PyTorch/XLA2;ml_framework_version:{"v0.0.1"}'
18-
)
16+
old_pjrt_options = jax.config.jax_pjrt_client_create_options
17+
18+
try:
19+
jax.config.update(
20+
'jax_pjrt_client_create_options',
21+
f'ml_framework_name:PyTorch/XLA2;ml_framework_version:{"v0.0.1"}'
22+
)
23+
xla_bridge._clear_backends()
24+
jax.devices() # open PJRT to see if it opens
25+
except RuntimeError:
26+
jax.config.update(
27+
'jax_pjrt_client_create_options', old_pjrt_options
28+
)
29+
xla_bridge._clear_backends()
30+
jax.devices() # open PJRT to see if it opens
31+
1932

2033
env = None
2134
def default_env():
2235
global env
2336

24-
os.environ.setdefault('ENABLE_RUNTIME_UPTIME_TELEMETRY', '1')
25-
2637
if env is None:
2738
env = tensor.Environment()
2839
return env

0 commit comments

Comments
 (0)