diff --git a/experimental/torch_xla2/torch_xla2/__init__.py b/experimental/torch_xla2/torch_xla2/__init__.py index 5e7715973270..c737bf21980d 100644 --- a/experimental/torch_xla2/torch_xla2/__init__.py +++ b/experimental/torch_xla2/torch_xla2/__init__.py @@ -10,19 +10,30 @@ 'extract_jax', ] - +from jax._src import xla_bridge +os.environ.setdefault('ENABLE_RUNTIME_UPTIME_TELEMETRY', '1') jax.config.update('jax_enable_x64', True) -jax.config.update( - 'jax_pjrt_client_create_options', - f'ml_framework_name:PyTorch/XLA2;ml_framework_version:{"v0.0.1"}' -) +old_pjrt_options = jax.config.jax_pjrt_client_create_options + +try: + jax.config.update( + 'jax_pjrt_client_create_options', + f'ml_framework_name:PyTorch/XLA2;ml_framework_version:{"v0.0.1"}' + ) + xla_bridge._clear_backends() + jax.devices() # open PJRT to see if it opens +except RuntimeError: + jax.config.update( + 'jax_pjrt_client_create_options', old_pjrt_options + ) + xla_bridge._clear_backends() + jax.devices() # open PJRT to see if it opens + env = None def default_env(): global env - os.environ.setdefault('ENABLE_RUNTIME_UPTIME_TELEMETRY', '1') - if env is None: env = tensor.Environment() return env