|
1 | 1 | #include "torch_xla/csrc/xla_graph_executor.h"
|
2 | 2 |
|
| 3 | +#include <Python.h> |
| 4 | + |
3 | 5 | #include <algorithm>
|
4 | 6 | #include <atomic>
|
5 | 7 | #include <cmath>
|
@@ -635,9 +637,28 @@ std::vector<at::Tensor> XLAGraphExecutor::GetTensorsFused(
|
635 | 637 | *tensors, async != nullptr ? async->indices : absl::Span<const size_t>(),
|
636 | 638 | async != nullptr ? async->tensors_data
|
637 | 639 | : absl::Span<const torch::lazy::BackendDataPtr>());
|
| 640 | + |
| 641 | + // Execution is async in PJRT, so TransferFromServer may block until execution |
| 642 | + // completes. Release the GIL so other threads can proceed and unblock any |
| 643 | + // collective computations. |
| 644 | + // HACK: This method may be called outside of python (mainly in C++ tests) or |
| 645 | + // when the GIL is already released, so we must check both cases here. If |
| 646 | + // possible, prefer to release the GIL in the python bindings before copying |
| 647 | + // this pattern. |
| 648 | + PyThreadState* save = nullptr; |
| 649 | + // TODO(wcromar): Remove this setting when we are more confident |
| 650 | + static const bool release_gil = |
| 651 | + xla::sys_util::GetEnvBool("XLA_RELEASE_GIL_DURING_TRANSFER", true); |
| 652 | + if (release_gil && Py_IsInitialized() && PyGILState_Check()) { |
| 653 | + save = PyEval_SaveThread(); |
| 654 | + } |
638 | 655 | std::vector<xla::Literal> literals =
|
639 | 656 | xla::ComputationClient::Get()->TransferFromServer(
|
640 | 657 | UnwrapXlaData(tensors_data));
|
| 658 | + if (save) { |
| 659 | + PyEval_RestoreThread(save); |
| 660 | + } |
| 661 | + |
641 | 662 | return FetchTensors(tensors, literals,
|
642 | 663 | async != nullptr ? &async->indices : nullptr);
|
643 | 664 | }
|
|
0 commit comments