Skip to content

Commit 9e7822a

Browse files
committed
Updated runtime device location code to run regardless of whether a switch is required or not.
1 parent a8fbc5e commit 9e7822a

File tree

1 file changed

+15
-15
lines changed

1 file changed

+15
-15
lines changed

core/runtime/execute_engine.cpp

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -80,22 +80,22 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
8080
} else {
8181
// Target device is current device
8282
target_device += std::to_string(curr_device.id);
83+
}
84+
85+
// For each input, ensure its current device is the desired target device
86+
for (size_t i = 0; i < inputs.size(); i++) {
87+
at::Tensor* in = &inputs[i];
88+
std::string current_tensor_device = in->device().str();
8389

84-
// For each input, ensure its current device is the desired target device
85-
for (size_t i = 0; i < inputs.size(); i++) {
86-
at::Tensor* in = &inputs[i];
87-
std::string current_tensor_device = in->device().str();
88-
89-
// If current device string does not match target device, display warning and move tensor accordingly
90-
if (current_tensor_device != target_device) {
91-
LOG_WARNING(
92-
"Input " << i << " of engine " << compiled_engine->name << " was found to be on " << current_tensor_device
93-
<< " but should be on " << target_device << ". This tensor is being moved by the runtime but "
94-
<< "for performance considerations, ensure your inputs are all on GPU "
95-
<< "and open an issue here (https://github.com/pytorch/TensorRT/issues) if this "
96-
<< "warning persists.");
97-
*in = in->to(torch::Device(target_device));
98-
}
90+
// If current device string does not match target device, display warning and move tensor accordingly
91+
if (current_tensor_device != target_device) {
92+
LOG_WARNING(
93+
"Input " << i << " of engine " << compiled_engine->name << " was found to be on " << current_tensor_device
94+
<< " but should be on " << target_device << ". This tensor is being moved by the runtime but "
95+
<< "for performance considerations, ensure your inputs are all on GPU "
96+
<< "and open an issue here (https://github.com/pytorch/TensorRT/issues) if this "
97+
<< "warning persists.");
98+
*in = in->to(torch::Device(target_device));
9999
}
100100
}
101101

0 commit comments

Comments
 (0)