Skip to content

[PyOV] __eq__ for parameter and result #26613

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed

Conversation

akuporos
Copy link
Contributor

Details:

  • add overloads for __eq__ and __hash__

Tickets:

@akuporos akuporos requested a review from a team as a code owner September 16, 2024 11:22
@github-actions github-actions bot added the category: Python API OpenVINO Python bindings label Sep 16, 2024
@akuporos akuporos requested a review from praasz September 16, 2024 11:25
@akuporos akuporos added this to the 2024.5 milestone Sep 16, 2024
@akuporos akuporos requested a review from mitruska September 16, 2024 17:47
Copy link
Contributor

@mitruska mitruska left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately, the changes are not fully backward compatible, because some of the basic python operators including "==" are overriten for Node by ov ops here:

# Extend Node class to support binary operators
Node.__add__ = opset13.add
Node.__sub__ = opset13.subtract
Node.__mul__ = opset13.multiply
Node.__div__ = opset13.divide
Node.__truediv__ = opset13.divide
Node.__radd__ = lambda left, right: opset13.add(right, left)
Node.__rsub__ = lambda left, right: opset13.subtract(right, left)
Node.__rmul__ = lambda left, right: opset13.multiply(right, left)
Node.__rdiv__ = lambda left, right: opset13.divide(right, left)
Node.__rtruediv__ = lambda left, right: opset13.divide(right, left)
Node.__eq__ = opset13.equal
Node.__ne__ = opset13.not_equal
Node.__lt__ = opset13.less
Node.__le__ = opset13.less_equal
Node.__gt__ = opset13.greater
Node.__ge__ = opset13.greater_equal
This is why the test for operator.eq case is failing:
def test_binary_operators_with_scalar(operator, expected_type):

The motivation to align .index() behavior is reasonable, but the impact and consistency should be verified.

Also there are existing methods that can be used to retrieve the index model.get_parameter_index(parameter) or model.get_result_index(parameter).

@@ -61,6 +61,7 @@ def test_graph_api():
assert list(model.get_output_shape(0)) == [2, 2]
assert (model.get_parameters()[1].get_partial_shape()) == PartialShape([3, 4, 5])
assert len(model.get_parameters()) == 3
assert model.get_parameters().index(parameter_c) == 2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is also model.get_parameter_index(parameter) available in Python and C++ OV API.

model.def(
"get_parameter_index",
(int64_t(ov::Model::*)(const std::shared_ptr<ov::op::v0::Parameter>&) const) & ov::Model::get_parameter_index,
py::arg("parameter"),
R"(
Return the index position of `parameter`

@@ -748,6 +748,7 @@ def test_model_add_remove_result_parameter_sink():

results = model.get_results()
assert len(results) == 2
assert results.index(result2) == 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is also model.get_result_index(parameter) available in Python and C++ OV API.

model.def("get_result_index",
(int64_t(ov::Model::*)(const ov::Output<ov::Node>&) const) & ov::Model::get_result_index,
py::arg("value"),
R"(
Return index of result.

Comment on lines +58 to +63
parameter.def(
"__eq__",
[](const ov::op::v0::Parameter& a, const ov::op::v0::Parameter& b) {
return a.get_instance_id() == b.get_instance_id();
},
py::is_operator());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It changes behaviour of ==, but != is still overwritten by ov.op.not_equal, which can lead to inconsistency:

Node.__eq__ = opset13.equal
Node.__ne__ = opset13.not_equal

@akuporos akuporos removed this from the 2024.5 milestone Sep 18, 2024
@akuporos akuporos added the WIP work in progress label Sep 18, 2024
@akuporos
Copy link
Contributor Author

I close it for right now.
Will restore when the right time come

@akuporos akuporos closed this Sep 24, 2024
github-merge-queue bot pushed a commit that referenced this pull request Sep 27, 2024
### Details:
- #26613 caused backward
incompatible changes, so decided to deprecate comparison overloads for
comparison operations as the usage of them confuse users
 - suppress deprecation warnings in tests
 - added bindings for `get_instance_id()` to Node class 
 
### Tickets:
- result of discussion
[CVS-131039](https://jira.devtools.intel.com/browse/CVS-131039)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
category: Python API OpenVINO Python bindings WIP work in progress
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants