Skip to content

Commit 3a8c033

Browse files
authored
Merge pull request pytorch#11 from ynimmaga/portable_kernels
Get openvino backend device from compile specs
2 parents 1a20242 + e9bac9a commit 3a8c033

File tree

7 files changed

+129
-119
lines changed

7 files changed

+129
-119
lines changed

backends/openvino/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,4 +69,3 @@ target_link_options(openvino_backend PRIVATE -Wl,-rpath=${OPENVINO_LIB_PATH})
6969

7070
# Install OpenVINO backend library to the lib directory
7171
install(TARGETS openvino_backend DESTINATION lib)
72-

backends/openvino/partitioner.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from typing import Callable, final, List, Optional, Tuple
88

99
import torch
10+
import torch.fx as fx
1011
from executorch.backends.openvino.preprocess import OpenvinoBackend
1112
from executorch.exir.backend.backend_details import CompileSpec
1213
from executorch.exir.backend.partitioner import (
@@ -15,12 +16,12 @@
1516
PartitionResult,
1617
)
1718
from executorch.exir.backend.utils import tag_constant_data
19+
from openvino.frontend.pytorch.torchdynamo.op_support import OperatorSupport
1820

1921
from torch.export.exported_program import ExportedProgram
2022
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
2123
from torch.fx.passes.operator_support import OperatorSupportBase
22-
import torch.fx as fx
23-
from openvino.frontend.pytorch.torchdynamo.op_support import OperatorSupport
24+
2425

2526
class OpenvinoOperatorsSupport(OperatorSupportBase):
2627

@@ -44,10 +45,10 @@ def is_node_supported(self, _, node: torch.fx.Node) -> bool:
4445
options = []
4546
op_type = node.target.__name__
4647
supported_ops = OperatorSupport(options)._support_dict
47-
if (op_type == "getitem"):
48+
if op_type == "getitem":
4849
return True
4950

50-
if ("torch.ops." + str(op_type) in supported_ops):
51+
if "torch.ops." + str(op_type) in supported_ops:
5152
return True
5253
else:
5354
print("Op not supported: ", "torch.ops." + str(op_type))
@@ -58,7 +59,7 @@ def is_node_supported(self, _, node: torch.fx.Node) -> bool:
5859
)
5960
return False
6061

61-
return False
62+
return False
6263

6364

6465
@final
@@ -88,13 +89,12 @@ def ops_to_not_decompose(
8889
return (ops_not_decompose, None)
8990

9091
def partition(self, exported_program: ExportedProgram) -> PartitionResult:
91-
options = {}
9292
gm = fx.symbolic_trace(exported_program.graph_module)
9393

9494
partitioner = CapabilityBasedPartitioner(
9595
exported_program.graph_module,
9696
OpenvinoOperatorsSupport(self._op_types_to_skip, self._op_names_to_skip),
97-
allows_single_node_partition=True
97+
allows_single_node_partition=True,
9898
)
9999
partition_list = partitioner.propose_partitions()
100100

backends/openvino/preprocess.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import contextlib
88
import struct
99

10-
from typing import final, List, cast
10+
from typing import cast, final, List
1111

1212
import torch
1313
from executorch.exir.backend.backend_details import (
@@ -18,8 +18,6 @@
1818
from executorch.exir.backend.compile_spec_schema import CompileSpec
1919
from openvino.frontend.pytorch.torchdynamo.compile import openvino_compile
2020

21-
SKIP_COMPILE_SPEC_KEYS = {"ImportForever"}
22-
2321

2422
@final
2523
class OpenvinoBackend(BackendDetails):
@@ -34,8 +32,8 @@ def preprocess(
3432
output_names = edge_program.graph_signature.user_outputs
3533
args = []
3634
for node in edge_program.graph.nodes:
37-
if (node.target in input_names):
38-
args.append( node.meta["val"])
35+
if node.target in input_names:
36+
args.append(node.meta["val"])
3937

4038
input_shapes = []
4139
output_shapes = []

backends/openvino/runtime/OpenvinoBackend.cpp

Lines changed: 115 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
*/
88

99
#include <cstring>
10-
#include <memory>
1110
#include <iostream>
11+
#include <memory>
1212

1313
#include <openvino/openvino.hpp>
1414

@@ -39,143 +39,159 @@ namespace backends {
3939
namespace openvino {
4040

4141
OpenvinoBackend::OpenvinoBackend() {
42-
if (!is_available()) {
43-
//ET_LOG(Error, "OpenVINO runtime is not available. Initialization failed.");
44-
throw std::runtime_error("OpenVINO runtime not available");
45-
}
42+
if (!is_available()) {
43+
// ET_LOG(Error, "OpenVINO runtime is not available. Initialization
44+
// failed.");
45+
throw std::runtime_error("OpenVINO runtime not available");
46+
}
4647

47-
//ET_LOG(Info, "OpenVINO runtime successfully verified and initialized.");
48+
// ET_LOG(Info, "OpenVINO runtime successfully verified and initialized.");
4849
}
4950

5051
bool OpenvinoBackend::is_available() const {
51-
try {
52-
// Create an OpenVINO Core object to verify runtime availability
53-
ov::Core core;
54-
55-
// Check if at least one device is available
56-
auto devices = core.get_available_devices();
57-
if (!devices.empty()) {
58-
return true; // OpenVINO is available
59-
}
60-
} catch (const std::exception& e) {
61-
// Log the exception if OpenVINO runtime is not available
62-
ET_LOG(Error, "OpenVINO is not available: %s", e.what());
63-
} catch (...) {
64-
// Handle any unexpected errors
65-
ET_LOG(Error, "OpenVINO availability check failed due to an unknown error.");
66-
}
52+
try {
53+
// Create an OpenVINO Core object to verify runtime availability
54+
ov::Core core;
6755

68-
return false; // OpenVINO is not available
56+
// Check if at least one device is available
57+
auto devices = core.get_available_devices();
58+
if (!devices.empty()) {
59+
return true; // OpenVINO is available
60+
}
61+
} catch (const std::exception& e) {
62+
// Log the exception if OpenVINO runtime is not available
63+
ET_LOG(Error, "OpenVINO is not available: %s", e.what());
64+
} catch (...) {
65+
// Handle any unexpected errors
66+
ET_LOG(
67+
Error, "OpenVINO availability check failed due to an unknown error.");
68+
}
69+
70+
return false; // OpenVINO is not available
6971
}
7072

7173
Result<DelegateHandle*> OpenvinoBackend::init(
7274
BackendInitContext& context,
7375
FreeableBuffer* processed,
7476
ArrayRef<CompileSpec> compile_specs) const {
77+
ET_LOG(Info, "OpenvinoBackend::init %p", processed->data());
7578

76-
ET_LOG(Info, "OpenvinoBackend::init %p", processed->data());
79+
ov::Core core;
80+
const char* data_ptr = static_cast<const char*>(processed->data());
81+
size_t data_size = processed->size();
7782

78-
ov::Core core;
79-
const char* data_ptr = static_cast<const char*>(processed->data());
80-
size_t data_size = processed->size();
83+
// Copy data to a string or vector
84+
std::string data_string(data_ptr, data_size);
8185

82-
// Copy data to a string or vector
83-
std::string data_string(data_ptr, data_size);
86+
// Wrap the data in a stream
87+
std::istringstream compiled_stream(data_string);
8488

85-
// Wrap the data in a stream
86-
std::istringstream compiled_stream(data_string);
89+
auto device = "CPU";
90+
// Get the device value, if provided in compile sepcs
91+
for (auto& compile_spec : compile_specs) {
92+
if (std::strcmp(compile_spec.key, "device") == 0)
93+
device = static_cast<char*>(compile_spec.value.buffer);
94+
}
8795

88-
// Import the model
89-
auto compiled_model = core.import_model(compiled_stream, "CPU");
96+
// Import the model
97+
auto compiled_model = core.import_model(compiled_stream, device);
9098

91-
// Allocate an infer request
92-
std::shared_ptr<ov::InferRequest> infer_request = std::make_shared<ov::InferRequest>(compiled_model.create_infer_request());
99+
// Allocate an infer request
100+
std::shared_ptr<ov::InferRequest> infer_request =
101+
std::make_shared<ov::InferRequest>(compiled_model.create_infer_request());
93102

94-
// Allocate execution handle
95-
MemoryAllocator* allocator = context.get_runtime_allocator();
96-
ExecutionHandle* handle = ET_ALLOCATE_INSTANCE_OR_RETURN_ERROR(allocator, ExecutionHandle);
97-
handle->compiled_model = std::make_shared<ov::CompiledModel>(compiled_model);
98-
handle->infer_request = infer_request;
103+
// Allocate execution handle
104+
MemoryAllocator* allocator = context.get_runtime_allocator();
105+
ExecutionHandle* handle =
106+
ET_ALLOCATE_INSTANCE_OR_RETURN_ERROR(allocator, ExecutionHandle);
107+
handle->compiled_model = std::make_shared<ov::CompiledModel>(compiled_model);
108+
handle->infer_request = infer_request;
99109

100-
return handle;
110+
return handle;
101111
}
102112

103113
Error OpenvinoBackend::execute(
104114
BackendExecutionContext& context,
105115
DelegateHandle* input_handle,
106116
EValue** args) const {
117+
ExecutionHandle* execution_handle = (ExecutionHandle*)input_handle;
107118

108-
ExecutionHandle* execution_handle = (ExecutionHandle*)input_handle;
119+
auto infer_request = execution_handle->infer_request;
109120

110-
auto infer_request = execution_handle->infer_request;
121+
size_t num_inputs = infer_request->get_compiled_model().inputs().size();
122+
size_t num_outputs = infer_request->get_compiled_model().outputs().size();
111123

112-
size_t num_inputs = infer_request->get_compiled_model().inputs().size();
113-
size_t num_outputs = infer_request->get_compiled_model().outputs().size();
124+
// Set inputs
125+
for (size_t i = 0; i < num_inputs; i++) {
126+
auto input_tensor = args[i]->toTensor();
127+
ov::Shape input_shape(
128+
input_tensor.sizes().begin(), input_tensor.sizes().end());
114129

115-
// Set inputs
116-
for (size_t i = 0; i < num_inputs; i++) {
117-
auto input_tensor = args[i]->toTensor();
118-
ov::Shape input_shape(input_tensor.sizes().begin(), input_tensor.sizes().end());
130+
// Convert input tensor to OpenVINO tensor
131+
ov::element::Type ov_type =
132+
convert_to_openvino_type(input_tensor.scalar_type());
133+
ov::Tensor ov_input_tensor(
134+
ov_type, input_shape, input_tensor.mutable_data_ptr());
119135

120-
// Convert input tensor to OpenVINO tensor
121-
ov::element::Type ov_type = convert_to_openvino_type(input_tensor.scalar_type());
122-
ov::Tensor ov_input_tensor(ov_type, input_shape, input_tensor.mutable_data_ptr());
136+
infer_request->set_input_tensor(i, ov_input_tensor);
137+
}
123138

124-
infer_request->set_input_tensor(i, ov_input_tensor);
125-
}
139+
// Set outputs
140+
for (size_t i = 0; i < num_outputs; i++) {
141+
auto output_tensor = args[num_inputs + i]->toTensor();
142+
ov::Shape output_shape(
143+
output_tensor.sizes().begin(), output_tensor.sizes().end());
126144

127-
// Set outputs
128-
for (size_t i = 0; i < num_outputs; i++) {
129-
auto output_tensor = args[num_inputs+i]->toTensor();
130-
ov::Shape output_shape(output_tensor.sizes().begin(), output_tensor.sizes().end());
145+
// Convert input tensor to OpenVINO tensor
146+
ov::element::Type ov_type =
147+
convert_to_openvino_type(output_tensor.scalar_type());
148+
ov::Tensor ov_output_tensor(
149+
ov_type, output_shape, output_tensor.mutable_data_ptr());
131150

132-
// Convert input tensor to OpenVINO tensor
133-
ov::element::Type ov_type = convert_to_openvino_type(output_tensor.scalar_type());
134-
ov::Tensor ov_output_tensor(ov_type, output_shape, output_tensor.mutable_data_ptr());
135-
136-
infer_request->set_output_tensor(i, ov_output_tensor);
137-
}
151+
infer_request->set_output_tensor(i, ov_output_tensor);
152+
}
138153

139-
// Execute the inference
140-
infer_request->infer();
154+
// Execute the inference
155+
infer_request->infer();
141156

142-
return Error::Ok;
157+
return Error::Ok;
143158
}
144159

145160
void OpenvinoBackend::destroy(DelegateHandle* handle) const {
146-
if (!handle) {
147-
ET_LOG(Info, "Attempted to destroy a null handle.");
148-
return;
149-
}
150-
151-
// Cast the handle to the appropriate type
152-
ExecutionHandle* execution_handle = static_cast<ExecutionHandle*>(handle);
153-
154-
// Clean up resources
155-
if (execution_handle->infer_request) {
156-
execution_handle->infer_request.reset(); // Release the infer request
157-
ET_LOG(Info, "Infer request successfully destroyed.");
158-
}
159-
160-
if (execution_handle->compiled_model) {
161-
execution_handle->compiled_model.reset(); // Release the compiled model
162-
ET_LOG(Info, "Compiled model successfully destroyed.");
163-
}
164-
165-
ET_LOG(Info, "Delegate handle destroyed successfully.");
161+
if (!handle) {
162+
ET_LOG(Info, "Attempted to destroy a null handle.");
163+
return;
164+
}
165+
166+
// Cast the handle to the appropriate type
167+
ExecutionHandle* execution_handle = static_cast<ExecutionHandle*>(handle);
168+
169+
// Clean up resources
170+
if (execution_handle->infer_request) {
171+
execution_handle->infer_request.reset(); // Release the infer request
172+
ET_LOG(Info, "Infer request successfully destroyed.");
173+
}
174+
175+
if (execution_handle->compiled_model) {
176+
execution_handle->compiled_model.reset(); // Release the compiled model
177+
ET_LOG(Info, "Compiled model successfully destroyed.");
178+
}
179+
180+
ET_LOG(Info, "Delegate handle destroyed successfully.");
166181
}
167182

168-
ov::element::Type OpenvinoBackend::convert_to_openvino_type(ScalarType scalar_type) const {
169-
switch (scalar_type) {
170-
case ScalarType::Float:
171-
return ov::element::f32;
172-
case ScalarType::Int:
173-
return ov::element::i32;
174-
case ScalarType::Char:
175-
return ov::element::i8;
176-
default:
177-
throw std::runtime_error("Unsupported scalar type");
178-
}
183+
ov::element::Type OpenvinoBackend::convert_to_openvino_type(
184+
ScalarType scalar_type) const {
185+
switch (scalar_type) {
186+
case ScalarType::Float:
187+
return ov::element::f32;
188+
case ScalarType::Int:
189+
return ov::element::i32;
190+
case ScalarType::Char:
191+
return ov::element::i8;
192+
default:
193+
throw std::runtime_error("Unsupported scalar type");
194+
}
179195
}
180196

181197
} // namespace openvino
@@ -185,7 +201,5 @@ ov::element::Type OpenvinoBackend::convert_to_openvino_type(ScalarType scalar_ty
185201
namespace {
186202
auto backend = executorch::backends::openvino::OpenvinoBackend();
187203
executorch::runtime::Backend backend_id{"OpenvinoBackend", &backend};
188-
static auto registered = executorch::runtime::register_backend(backend_id);
204+
static auto registered = executorch::runtime::register_backend(backend_id);
189205
} // namespace
190-
191-

examples/openvino/CMakeLists.txt

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,5 +94,3 @@ set_target_properties(openvino_executor_runner PROPERTIES INSTALL_RPATH "$ORIGIN
9494
get_filename_component(
9595
EXECUTORCH_SOURCE_DIR "${CMAKE_CURRENT_LIST_DIR}/../.." ABSOLUTE
9696
)
97-
98-

examples/openvino/aot/README.md

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,5 +113,3 @@ python aot_openvino_compiler.py --suite timm --model vgg16 --input_shape [1, 3,
113113

114114
- **Unsupported Input Shape**:
115115
Ensure `--input_shape` is provided as a valid list or tuple.
116-
117-

0 commit comments

Comments
 (0)