File tree 1 file changed +18
-7
lines changed
experimental/torch_xla2/torch_xla2
1 file changed +18
-7
lines changed Original file line number Diff line number Diff line change 10
10
'extract_jax' ,
11
11
]
12
12
13
-
13
+ from jax ._src import xla_bridge
14
+ os .environ .setdefault ('ENABLE_RUNTIME_UPTIME_TELEMETRY' , '1' )
14
15
jax .config .update ('jax_enable_x64' , True )
15
- jax .config .update (
16
- 'jax_pjrt_client_create_options' ,
17
- f'ml_framework_name:PyTorch/XLA2;ml_framework_version:{ "v0.0.1" } '
18
- )
16
+ old_pjrt_options = jax .config .jax_pjrt_client_create_options
17
+
18
+ try :
19
+ jax .config .update (
20
+ 'jax_pjrt_client_create_options' ,
21
+ f'ml_framework_name:PyTorch/XLA2;ml_framework_version:{ "v0.0.1" } '
22
+ )
23
+ xla_bridge ._clear_backends ()
24
+ jax .devices () # open PJRT to see if it opens
25
+ except RuntimeError :
26
+ jax .config .update (
27
+ 'jax_pjrt_client_create_options' , old_pjrt_options
28
+ )
29
+ xla_bridge ._clear_backends ()
30
+ jax .devices () # open PJRT to see if it opens
31
+
19
32
20
33
env = None
21
34
def default_env ():
22
35
global env
23
36
24
- os .environ .setdefault ('ENABLE_RUNTIME_UPTIME_TELEMETRY' , '1' )
25
-
26
37
if env is None :
27
38
env = tensor .Environment ()
28
39
return env
You can’t perform that action at this time.
0 commit comments