-
Notifications
You must be signed in to change notification settings - Fork 2.9k
[ONNX] Fix model_dir when loading external weights #33494
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR fixes the propagation of model_dir when loading ONNX models with external weights. Previously, the model_dir was not properly passed from GraphIteratorProto to Tensor objects, which could cause issues when loading external data files.
Key changes:
- Added
get_model_dir()method toInputModelandTensorONNXPlaceto expose the model directory - Modified
Tensorconstructor to retrievemodel_dirfromTensorONNXPlaceinstead of using an empty string - Updated
GraphIteratorProto::get_model_dir()to safely handle nullm_model_dir
Reviewed changes
Copilot reviewed 9 out of 9 changed files in this pull request and generated 4 comments.
Show a summary per file
| File | Description |
|---|---|
| src/frontends/onnx/tests/graph_iterator.cpp | Added comprehensive test for external data loading with model directory |
| src/frontends/onnx/tests/CMakeLists.txt | Updated include directories for test compilation |
| src/frontends/onnx/frontend/src/input_model.hpp | Added get_model_dir() method declaration |
| src/frontends/onnx/frontend/src/input_model.cpp | Implemented model directory retrieval and storage |
| src/frontends/onnx/frontend/src/core/tensor.hpp | Added get_model_dir() method declaration to TensorONNXPlace |
| src/frontends/onnx/frontend/src/core/tensor.cpp | Modified Tensor constructor to use model_dir from TensorONNXPlace |
| src/frontends/onnx/frontend/src/core/node.cpp | Refactored to use local variable for input_model access |
| src/frontends/onnx/frontend/src/core/graph_iterator_proto.hpp | Added null-safety check for m_model_dir |
| src/frontends/onnx/frontend/src/core/graph_iterator_proto.cpp | Removed initialization of m_model_dir to nullptr |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Signed-off-by: Maxim Vafin <[email protected]>
Signed-off-by: Maxim Vafin <[email protected]>
Signed-off-by: Maxim Vafin <[email protected]>
bumbosiepsak
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please consider my remarks and let me know once you're done and happy.
|
|
||
| std::string get_model_dir() const { | ||
| return *m_model_dir; | ||
| std::string get_model_dir() const override { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] Since it's virtual, it will never be inlined (so having it in the header only increases the compilation time). Consider having all bodies of virtual functions in the cpp file.
| std::string get_model_dir() const { | ||
| return *m_model_dir; | ||
| std::string get_model_dir() const override { | ||
| return m_model_dir ? *m_model_dir : std::string{}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Or, more safely:
namespace {
std::string const model_dir_unset; // Empty, always instantiated string
}
GraphIteratorProto::GraphIteratorProto(const GraphIteratorProtoMemoryManagementMode mode)
: m_graph(nullptr),
m_model_dir(&model_dir_unset),
std::string GraphIteratorProto::get_model_dir() const {
return *m_model_dir;
}
Rationale: null pointers are inherently bad and will eventually be dereferenced by accident. Having an invariant saying: "m_model_dir is always valid" is more reliable.
| std::string get_model_dir() const { | ||
| return *m_model_dir; | ||
| std::string get_model_dir() const override { | ||
| return m_model_dir ? *m_model_dir : std::string{}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Question: is an empty model directory one of use cases (and has a meaning in our domain) or is returning an empty string just a safeguard for null pointers?
| auto tensor_decoder = std::dynamic_pointer_cast<ov::frontend::onnx::DecoderBaseTensor>( | ||
| m_decoder->get_attribute(name).as<ov::frontend::onnx::DecoderBase::Ptr>()); | ||
| const auto& tensor_meta_info = tensor_decoder->get_tensor_info(); | ||
| auto input_model = m_translate_session ? m_translate_session->get_input_model() : nullptr; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Question: can a session have no input model?
[nitpick] If no, then how about moving validation to get_input_model() and having this here:
FRONT_END_GENERAL_CHECK(m_translate_session != nullptr, "InputModel is not available for tensor attributes");
auto input_model = m_translate_session->get_input_model();
?
| return model_onnx->get_stream_cache(); | ||
| } | ||
|
|
||
| std::string TensorONNXPlace::get_model_dir() const { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] Or, in a more idiomatic C++:
try {
using Model = const unify::InputModel&;
return dynamic_cast<Model>(m_input_model)->get_model_dir();
}
catch (const std::bad_cast &) {
return {};
}
|
|
||
|
|
||
| target_include_directories(ov_onnx_frontend_tests PRIVATE "${CMAKE_CURRENT_SOURCE_DIR}") | ||
| target_include_directories(ov_onnx_frontend_tests PRIVATE |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess the idiomatic way of doing this is:
# in src/frontends/onnx/frontend/CMakeLists.txt
ov_add_frontend(NAME onnx
...
)
target_include_directories(mylib INTERFACE
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/src>
)
# in src/frontends/onnx/onnx_common/CMakeLists.txt
set(TARGET_NAME "openvino_onnx_common")
# target_include_directories all good already
# in src/frontends/onnx/tests/CMakeLists.txt
target_link_libraries(ov_onnx_frontend_tests
PRIVATE
onnx
openvino_onnx_common
)
Not tested, however the general rule (in CMake 4) is to:
- have all targets export their (export) include directories
- link the final targets with targets from point 1 and inherit their exports
Typically a relative path indicates something fishy happening.
| ASSERT_NO_THROW({ | ||
| try { | ||
| auto model = m_frontEnd->convert(m_inputModel); | ||
| ASSERT_NE(model, nullptr); | ||
| } catch (const std::exception& ex) { | ||
| std::cerr << "convert failed: " << ex.what() << std::endl; | ||
| throw; | ||
| } | ||
| }); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about:
| ASSERT_NO_THROW({ | |
| try { | |
| auto model = m_frontEnd->convert(m_inputModel); | |
| ASSERT_NE(model, nullptr); | |
| } catch (const std::exception& ex) { | |
| std::cerr << "convert failed: " << ex.what() << std::endl; | |
| throw; | |
| } | |
| }); | |
| try { | |
| auto model = m_frontEnd->convert(m_inputModel); | |
| ASSERT_NE(model, nullptr); | |
| } catch (const std::exception& ex) { | |
| FAIL() << "convert failed: " << ex.what(); | |
| } catch (...) { | |
| FAIL() << "convert failed: reason unknown" << ex.what(); | |
| } | |
Rationale: you probably want to print the failure consistently with the rest of googletest. Assuming, that googletest (always) prints to std::cerr might be a bit of a stretch.
Details:
model_dirvalue fromGraphIteratorPrototoTensorTickets: