-
Notifications
You must be signed in to change notification settings - Fork 2.9k
[ONNX] Fix model_dir when loading external weights #33494
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -121,8 +121,8 @@ class GraphIteratorProto : public ov::frontend::onnx::GraphIterator { | |
|
|
||
| std::map<std::string, std::string> get_metadata() const override; | ||
|
|
||
| std::string get_model_dir() const { | ||
| return *m_model_dir; | ||
| std::string get_model_dir() const override { | ||
| return m_model_dir ? *m_model_dir : std::string{}; | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Or, more safely: Rationale: null pointers are inherently bad and will eventually be dereferenced by accident. Having an invariant saying: "m_model_dir is always valid" is more reliable.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Question: is an empty model directory one of use cases (and has a meaning in our domain) or is returning an empty string just a safeguard for null pointers? |
||
| } | ||
|
|
||
| GraphIteratorProtoMemoryManagementMode get_memory_management_mode() const { | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -792,8 +792,10 @@ Tensor Node::get_attribute_value(const std::string& name) const { | |
| auto tensor_decoder = std::dynamic_pointer_cast<ov::frontend::onnx::DecoderBaseTensor>( | ||
| m_decoder->get_attribute(name).as<ov::frontend::onnx::DecoderBase::Ptr>()); | ||
| const auto& tensor_meta_info = tensor_decoder->get_tensor_info(); | ||
| auto input_model = m_translate_session ? m_translate_session->get_input_model() : nullptr; | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Question: can a [nitpick] If no, then how about moving validation to ? |
||
| FRONT_END_GENERAL_CHECK(input_model != nullptr, "InputModel is not available for tensor attributes"); | ||
mvafin marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| auto tensor_place = std::make_shared<ov::frontend::onnx::TensorONNXPlace>( | ||
| *m_translate_session->get_input_model().get(), | ||
| *input_model, | ||
| tensor_meta_info.m_partial_shape, | ||
| tensor_meta_info.m_element_type, | ||
| std::vector<std::string>{*tensor_meta_info.m_tensor_name}, | ||
|
|
@@ -816,11 +818,14 @@ SparseTensor Node::get_attribute_value(const std::string& name) const { | |
| FRONT_END_GENERAL_CHECK(sparse_tensor_info.m_indices && sparse_tensor_info.m_values, | ||
| "Incomplete sparse tensors are not supported"); | ||
|
|
||
| auto input_model = m_translate_session ? m_translate_session->get_input_model() : nullptr; | ||
| FRONT_END_GENERAL_CHECK(input_model != nullptr, "InputModel is not available for sparse tensor attributes"); | ||
|
|
||
| auto values_decoder = | ||
| std::dynamic_pointer_cast<ov::frontend::onnx::DecoderBaseTensor>(sparse_tensor_info.m_values); | ||
| const auto& values_meta_info = values_decoder->get_tensor_info(); | ||
| auto values_place = std::make_shared<ov::frontend::onnx::TensorONNXPlace>( | ||
| *m_translate_session->get_input_model().get(), | ||
| *input_model, | ||
| values_meta_info.m_partial_shape, | ||
| values_meta_info.m_element_type, | ||
| std::vector<std::string>{*values_meta_info.m_tensor_name}, | ||
|
|
@@ -834,7 +839,7 @@ SparseTensor Node::get_attribute_value(const std::string& name) const { | |
| std::dynamic_pointer_cast<ov::frontend::onnx::DecoderBaseTensor>(sparse_tensor_info.m_indices); | ||
| const auto& indices_meta_info = indices_decoder->get_tensor_info(); | ||
| auto indices_place = std::make_shared<ov::frontend::onnx::TensorONNXPlace>( | ||
| *m_translate_session->get_input_model().get(), | ||
| *input_model, | ||
| indices_meta_info.m_partial_shape, | ||
| indices_meta_info.m_element_type, | ||
| std::vector<std::string>{*indices_meta_info.m_tensor_name}, | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5,6 +5,7 @@ | |
| #include "core/tensor.hpp" | ||
|
|
||
| #include "input_model.hpp" | ||
| #include "openvino/util/file_util.hpp" | ||
|
|
||
| namespace ov { | ||
| namespace frontend { | ||
|
|
@@ -19,10 +20,18 @@ detail::LocalStreamHandles TensorONNXPlace::get_stream_cache() { | |
| return model_onnx->get_stream_cache(); | ||
| } | ||
|
|
||
| std::string TensorONNXPlace::get_model_dir() const { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [nitpick] Or, in a more idiomatic C++: |
||
| const auto model_onnx = dynamic_cast<const unify::InputModel*>(&m_input_model); | ||
| if (!model_onnx) { | ||
| return {}; | ||
| } | ||
| return model_onnx->get_model_dir(); | ||
| } | ||
|
|
||
| Tensor::Tensor(const std::shared_ptr<TensorONNXPlace>& tensor_place) { | ||
| m_tensor_proto = nullptr; | ||
| m_shape = tensor_place->get_partial_shape().get_shape(); | ||
| m_model_dir = ""; | ||
| m_model_dir = tensor_place->get_model_dir(); | ||
| m_mmap_cache = tensor_place->get_mmap_cache(); | ||
| m_tensor_place = tensor_place; | ||
| } | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -129,7 +129,10 @@ set_property(TEST ov_onnx_frontend_tests PROPERTY LABELS OV UNIT ONNX_FE) | |
| add_dependencies(ov_onnx_frontend_tests openvino_template_extension) | ||
|
|
||
|
|
||
| target_include_directories(ov_onnx_frontend_tests PRIVATE "${CMAKE_CURRENT_SOURCE_DIR}") | ||
| target_include_directories(ov_onnx_frontend_tests PRIVATE | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess the idiomatic way of doing this is: Not tested, however the general rule (in CMake 4) is to:
Typically a relative path indicates something fishy happening. |
||
| "${CMAKE_CURRENT_SOURCE_DIR}" | ||
| "${CMAKE_CURRENT_SOURCE_DIR}/../frontend/src" | ||
| "${CMAKE_CURRENT_SOURCE_DIR}/../onnx_common/include") | ||
|
|
||
| target_compile_definitions(ov_onnx_frontend_tests | ||
| PRIVATE | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -5,46 +5,68 @@ | |||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| #include <onnx/onnx_pb.h> | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| #include <algorithm> | ||||||||||||||||||||||||||||||||||||||
| #include <filesystem> | ||||||||||||||||||||||||||||||||||||||
| #include <fstream> | ||||||||||||||||||||||||||||||||||||||
| #include <iostream> | ||||||||||||||||||||||||||||||||||||||
| #include <map> | ||||||||||||||||||||||||||||||||||||||
| #include <openvino/frontend/exception.hpp> | ||||||||||||||||||||||||||||||||||||||
| #include <openvino/frontend/graph_iterator.hpp> | ||||||||||||||||||||||||||||||||||||||
| #include <openvino/frontend/input_model.hpp> | ||||||||||||||||||||||||||||||||||||||
| #include <openvino/openvino.hpp> | ||||||||||||||||||||||||||||||||||||||
| #include <system_error> | ||||||||||||||||||||||||||||||||||||||
| #include <unordered_map> | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| #include "../frontend/src/core/graph_iterator_proto.hpp" | ||||||||||||||||||||||||||||||||||||||
| #include "common_test_utils/common_utils.hpp" | ||||||||||||||||||||||||||||||||||||||
| #include "load_from.hpp" | ||||||||||||||||||||||||||||||||||||||
| #include "onnx_utils.hpp" | ||||||||||||||||||||||||||||||||||||||
| #include "utils.hpp" | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| using ::ONNX_NAMESPACE::ModelProto; | ||||||||||||||||||||||||||||||||||||||
| using ::ONNX_NAMESPACE::Version; | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| TEST_P(FrontEndLoadFromTest, testLoadUsingSimpleGraphIterator) { | ||||||||||||||||||||||||||||||||||||||
| ov::frontend::FrontEnd::Ptr fe; | ||||||||||||||||||||||||||||||||||||||
| class SimpleIterator : public ov::frontend::onnx::GraphIterator { | ||||||||||||||||||||||||||||||||||||||
| public: | ||||||||||||||||||||||||||||||||||||||
| mutable size_t get_model_dir_call_count = 0; | ||||||||||||||||||||||||||||||||||||||
| mutable std::string last_returned_dir; | ||||||||||||||||||||||||||||||||||||||
| std::string model_dir; | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| class SimpleIterator : public ov::frontend::onnx::GraphIterator { | ||||||||||||||||||||||||||||||||||||||
| public: | ||||||||||||||||||||||||||||||||||||||
| size_t size() const override { | ||||||||||||||||||||||||||||||||||||||
| return 0; | ||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||
| void reset() override {}; | ||||||||||||||||||||||||||||||||||||||
| void next() override {}; | ||||||||||||||||||||||||||||||||||||||
| bool is_end() const override { | ||||||||||||||||||||||||||||||||||||||
| return true; | ||||||||||||||||||||||||||||||||||||||
| }; | ||||||||||||||||||||||||||||||||||||||
| std::shared_ptr<ov::frontend::onnx::DecoderBase> get_decoder() const override { | ||||||||||||||||||||||||||||||||||||||
| return nullptr; | ||||||||||||||||||||||||||||||||||||||
| }; | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| int64_t get_opset_version(const std::string& domain) const override { | ||||||||||||||||||||||||||||||||||||||
| return 1; | ||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||
| SimpleIterator() = default; | ||||||||||||||||||||||||||||||||||||||
| explicit SimpleIterator(const std::string& dir) : model_dir(dir) {} | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| std::map<std::string, std::string> get_metadata() const override { | ||||||||||||||||||||||||||||||||||||||
| return {}; | ||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| ~SimpleIterator() override {}; | ||||||||||||||||||||||||||||||||||||||
| size_t size() const override { | ||||||||||||||||||||||||||||||||||||||
| return 0; | ||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||
| void reset() override {}; | ||||||||||||||||||||||||||||||||||||||
| void next() override {}; | ||||||||||||||||||||||||||||||||||||||
| bool is_end() const override { | ||||||||||||||||||||||||||||||||||||||
| return true; | ||||||||||||||||||||||||||||||||||||||
| }; | ||||||||||||||||||||||||||||||||||||||
| std::shared_ptr<ov::frontend::onnx::DecoderBase> get_decoder() const override { | ||||||||||||||||||||||||||||||||||||||
| return nullptr; | ||||||||||||||||||||||||||||||||||||||
| }; | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| int64_t get_opset_version(const std::string& domain) const override { | ||||||||||||||||||||||||||||||||||||||
| return 1; | ||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| std::map<std::string, std::string> get_metadata() const override { | ||||||||||||||||||||||||||||||||||||||
| return {}; | ||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| std::string get_model_dir() const override { | ||||||||||||||||||||||||||||||||||||||
| ++get_model_dir_call_count; | ||||||||||||||||||||||||||||||||||||||
| last_returned_dir = model_dir; | ||||||||||||||||||||||||||||||||||||||
| return model_dir; | ||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| ~SimpleIterator() override {}; | ||||||||||||||||||||||||||||||||||||||
| }; | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| TEST_P(FrontEndLoadFromTest, testLoadUsingSimpleGraphIterator) { | ||||||||||||||||||||||||||||||||||||||
| ov::frontend::FrontEnd::Ptr fe; | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| auto iter = std::make_shared<SimpleIterator>(); | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||||||
|
|
@@ -128,3 +150,36 @@ TEST_P(FrontEndLoadFromTest, testLoadUsingGraphIteratorExternalMMAP) { | |||||||||||||||||||||||||||||||||||||
| ASSERT_EQ(iter->get_mmap_cache()->size(), 1); // MMAP handle must be in cache after work finished | ||||||||||||||||||||||||||||||||||||||
| ASSERT_EQ(model->get_ordered_ops().size(), 6); | ||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| TEST_P(FrontEndLoadFromTest, tensor_place_uses_model_dir_for_external_data) { | ||||||||||||||||||||||||||||||||||||||
| const std::string model_name = "external_data/external_data.onnx"; | ||||||||||||||||||||||||||||||||||||||
| const auto path = | ||||||||||||||||||||||||||||||||||||||
| ov::util::path_join({ov::test::utils::getExecutableDirectory(), TEST_ONNX_MODELS_DIRNAME, model_name}).string(); | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| const auto expected_model_dir = std::filesystem::path(path).parent_path().string(); | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| auto iter = std::make_shared<SimpleIterator>(expected_model_dir); | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| auto graph_iter = std::dynamic_pointer_cast<ov::frontend::onnx::GraphIterator>(iter); | ||||||||||||||||||||||||||||||||||||||
| ASSERT_NO_THROW(m_frontEnd = m_fem.load_by_framework("onnx")); | ||||||||||||||||||||||||||||||||||||||
| ASSERT_NE(m_frontEnd, nullptr); | ||||||||||||||||||||||||||||||||||||||
| ASSERT_TRUE(m_frontEnd->supported(graph_iter)); | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| ASSERT_NO_THROW(m_inputModel = m_frontEnd->load(graph_iter)); | ||||||||||||||||||||||||||||||||||||||
| ASSERT_NE(m_inputModel, nullptr); | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| ASSERT_NO_THROW({ | ||||||||||||||||||||||||||||||||||||||
| try { | ||||||||||||||||||||||||||||||||||||||
| auto model = m_frontEnd->convert(m_inputModel); | ||||||||||||||||||||||||||||||||||||||
| ASSERT_NE(model, nullptr); | ||||||||||||||||||||||||||||||||||||||
| } catch (const std::exception& ex) { | ||||||||||||||||||||||||||||||||||||||
| std::cerr << "convert failed: " << ex.what() << std::endl; | ||||||||||||||||||||||||||||||||||||||
| throw; | ||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||
| }); | ||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+171
to
+179
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How about:
Suggested change
Rationale: you probably want to print the failure consistently with the rest of googletest. Assuming, that googletest (always) prints to |
||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| ASSERT_GT(iter->get_model_dir_call_count, 0) << "get_model_dir() was never called"; | ||||||||||||||||||||||||||||||||||||||
| ASSERT_EQ(iter->last_returned_dir, expected_model_dir) | ||||||||||||||||||||||||||||||||||||||
| << "get_model_dir() returned unexpected path: " << iter->last_returned_dir | ||||||||||||||||||||||||||||||||||||||
| << " (expected: " << expected_model_dir << ")"; | ||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] Since it's virtual, it will never be inlined (so having it in the header only increases the compilation time). Consider having all bodies of virtual functions in the cpp file.