-
Notifications
You must be signed in to change notification settings - Fork 527
Wrap adding logging options in try / catch. #7307
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,19 +10,30 @@ | |
'extract_jax', | ||
] | ||
|
||
|
||
from jax._src import xla_bridge | ||
os.environ.setdefault('ENABLE_RUNTIME_UPTIME_TELEMETRY', '1') | ||
jax.config.update('jax_enable_x64', True) | ||
jax.config.update( | ||
'jax_pjrt_client_create_options', | ||
f'ml_framework_name:PyTorch/XLA2;ml_framework_version:{"v0.0.1"}' | ||
) | ||
old_pjrt_options = jax.config.jax_pjrt_client_create_options | ||
|
||
try: | ||
jax.config.update( | ||
'jax_pjrt_client_create_options', | ||
f'ml_framework_name:PyTorch/XLA2;ml_framework_version:{"v0.0.1"}' | ||
) | ||
xla_bridge._clear_backends() | ||
jax.devices() # open PJRT to see if it opens | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This could be a problem in the future. If you import Is JAX going to support this option on GPU (or at least not crash if this option is given on GPU) in the near future? If so, I think this is an okay hack for now and we can fix it later. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes. PJRT for CUDA need to take those args and ignore. PJRT for CPU already does that. |
||
except RuntimeError: | ||
jax.config.update( | ||
'jax_pjrt_client_create_options', old_pjrt_options | ||
) | ||
xla_bridge._clear_backends() | ||
jax.devices() # open PJRT to see if it opens | ||
|
||
|
||
env = None | ||
def default_env(): | ||
global env | ||
|
||
os.environ.setdefault('ENABLE_RUNTIME_UPTIME_TELEMETRY', '1') | ||
|
||
if env is None: | ||
env = tensor.Environment() | ||
return env | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this have a significant default value?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nop, it's empty.