-
Notifications
You must be signed in to change notification settings - Fork 2.6k
[PyOV] __eq__ for parameter and result #26613
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unfortunately, the changes are not fully backward compatible, because some of the basic python operators including "==" are overriten for Node by ov ops here:
openvino/src/bindings/python/src/openvino/runtime/__init__.py
Lines 69 to 85 in 68f7d33
# Extend Node class to support binary operators | |
Node.__add__ = opset13.add | |
Node.__sub__ = opset13.subtract | |
Node.__mul__ = opset13.multiply | |
Node.__div__ = opset13.divide | |
Node.__truediv__ = opset13.divide | |
Node.__radd__ = lambda left, right: opset13.add(right, left) | |
Node.__rsub__ = lambda left, right: opset13.subtract(right, left) | |
Node.__rmul__ = lambda left, right: opset13.multiply(right, left) | |
Node.__rdiv__ = lambda left, right: opset13.divide(right, left) | |
Node.__rtruediv__ = lambda left, right: opset13.divide(right, left) | |
Node.__eq__ = opset13.equal | |
Node.__ne__ = opset13.not_equal | |
Node.__lt__ = opset13.less | |
Node.__le__ = opset13.less_equal | |
Node.__gt__ = opset13.greater | |
Node.__ge__ = opset13.greater_equal |
operator.eq
case is failing:def test_binary_operators_with_scalar(operator, expected_type): |
The motivation to align .index()
behavior is reasonable, but the impact and consistency should be verified.
Also there are existing methods that can be used to retrieve the index model.get_parameter_index(parameter)
or model.get_result_index(parameter)
.
@@ -61,6 +61,7 @@ def test_graph_api(): | |||
assert list(model.get_output_shape(0)) == [2, 2] | |||
assert (model.get_parameters()[1].get_partial_shape()) == PartialShape([3, 4, 5]) | |||
assert len(model.get_parameters()) == 3 | |||
assert model.get_parameters().index(parameter_c) == 2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is also model.get_parameter_index(parameter)
available in Python and C++ OV API.
openvino/src/bindings/python/src/pyopenvino/graph/model.cpp
Lines 1094 to 1099 in 68f7d33
model.def( | |
"get_parameter_index", | |
(int64_t(ov::Model::*)(const std::shared_ptr<ov::op::v0::Parameter>&) const) & ov::Model::get_parameter_index, | |
py::arg("parameter"), | |
R"( | |
Return the index position of `parameter` |
@@ -748,6 +748,7 @@ def test_model_add_remove_result_parameter_sink(): | |||
|
|||
results = model.get_results() | |||
assert len(results) == 2 | |||
assert results.index(result2) == 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is also model.get_result_index(parameter)
available in Python and C++ OV API.
openvino/src/bindings/python/src/pyopenvino/graph/model.cpp
Lines 848 to 852 in 68f7d33
model.def("get_result_index", | |
(int64_t(ov::Model::*)(const ov::Output<ov::Node>&) const) & ov::Model::get_result_index, | |
py::arg("value"), | |
R"( | |
Return index of result. |
parameter.def( | ||
"__eq__", | ||
[](const ov::op::v0::Parameter& a, const ov::op::v0::Parameter& b) { | ||
return a.get_instance_id() == b.get_instance_id(); | ||
}, | ||
py::is_operator()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It changes behaviour of ==
, but !=
is still overwritten by ov.op.not_equal
, which can lead to inconsistency:
openvino/src/bindings/python/src/openvino/runtime/__init__.py
Lines 80 to 81 in 68f7d33
Node.__eq__ = opset13.equal | |
Node.__ne__ = opset13.not_equal |
I close it for right now. |
### Details: - #26613 caused backward incompatible changes, so decided to deprecate comparison overloads for comparison operations as the usage of them confuse users - suppress deprecation warnings in tests - added bindings for `get_instance_id()` to Node class ### Tickets: - result of discussion [CVS-131039](https://jira.devtools.intel.com/browse/CVS-131039)
Details:
__eq__
and__hash__
Tickets: