Skip to content

Commit d41c5ba

Browse files
authored
[PJRT] Release the GIL during TransferFromServer (#4504)
* Release the GIL during `TransferFromServer` * Check for GIL before releasing it * Add comment * formatting * Check for Python intitialization before GIL status * Add explanation and warning * Add temporary switch to hold onto GIL * Formatting
1 parent efb5353 commit d41c5ba

File tree

2 files changed

+26
-0
lines changed

2 files changed

+26
-0
lines changed

configuration.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -497,3 +497,8 @@ variables:
497497
- List of metrics percentiles to record.
498498
type: string
499499
default_value: "0.01:0.05:0.1:0.2:0.5:0.8:0.9:0.95:0.99"
500+
XLA_RELEASE_GIL_DURING_TRANSFER:
501+
descripton:
502+
- Release Python's GIL when transferring data from the runtime.
503+
type: bool
504+
default_value: true

torch_xla/csrc/xla_graph_executor.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
#include "torch_xla/csrc/xla_graph_executor.h"
22

3+
#include <Python.h>
4+
35
#include <algorithm>
46
#include <atomic>
57
#include <cmath>
@@ -635,9 +637,28 @@ std::vector<at::Tensor> XLAGraphExecutor::GetTensorsFused(
635637
*tensors, async != nullptr ? async->indices : absl::Span<const size_t>(),
636638
async != nullptr ? async->tensors_data
637639
: absl::Span<const torch::lazy::BackendDataPtr>());
640+
641+
// Execution is async in PJRT, so TransferFromServer may block until execution
642+
// completes. Release the GIL so other threads can proceed and unblock any
643+
// collective computations.
644+
// HACK: This method may be called outside of python (mainly in C++ tests) or
645+
// when the GIL is already released, so we must check both cases here. If
646+
// possible, prefer to release the GIL in the python bindings before copying
647+
// this pattern.
648+
PyThreadState* save = nullptr;
649+
// TODO(wcromar): Remove this setting when we are more confident
650+
static const bool release_gil =
651+
xla::sys_util::GetEnvBool("XLA_RELEASE_GIL_DURING_TRANSFER", true);
652+
if (release_gil && Py_IsInitialized() && PyGILState_Check()) {
653+
save = PyEval_SaveThread();
654+
}
638655
std::vector<xla::Literal> literals =
639656
xla::ComputationClient::Get()->TransferFromServer(
640657
UnwrapXlaData(tensors_data));
658+
if (save) {
659+
PyEval_RestoreThread(save);
660+
}
661+
641662
return FetchTensors(tensors, literals,
642663
async != nullptr ? &async->indices : nullptr);
643664
}

0 commit comments

Comments
 (0)