Skip to content

Remove the RefAttr class #2328

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
May 22, 2025
Merged

Remove the RefAttr class #2328

merged 6 commits into from
May 22, 2025

Conversation

justinchuby
Copy link
Collaborator

@justinchuby justinchuby commented May 22, 2025

Rational

We defined the class RefAttr in the IR to represent reference attributes in ONNX. Node attributes can be Attr and RefAttr. However, since most of the time we are working with concrete attributes, the union of types creates a typing situation where we always need to assert the types before taking the values, even if we know a RefAttr cannot exist (outside of a function definition).

This additionally matches the definition of AttributeProto in ONNX.

Change

This change merged the two classes, and instead defines a is_ref() method for users to check the reference attribute.

The change is BC breaking for usage like isinstance(attr, ir.RefAttr). Fortunately all such usages exist in this code base and not in PyTorch, so we are safe to complete the change.

@@ -422,6 +423,8 @@
value: Any
doc_string: str | None

def is_ref(self) -> Literal[False]: ...

Check notice

Code scanning / CodeQL

Statement has no effect Note

This statement has no effect.

Copilot Autofix

AI about 1 month ago

To fix the issue, replace the ... placeholder in the is_ref method with a concrete implementation that returns False. This aligns the method's behavior with its type annotation (Literal[False]) and ensures that the protocol is clear and unambiguous for implementers.


Suggested changeset 1
onnxscript/ir/_protocols.py

Autofix patch

Autofix patch
Run the following command in your local git repository to apply this patch
cat << 'EOF' | git apply
diff --git a/onnxscript/ir/_protocols.py b/onnxscript/ir/_protocols.py
--- a/onnxscript/ir/_protocols.py
+++ b/onnxscript/ir/_protocols.py
@@ -425,3 +425,4 @@
 
-    def is_ref(self) -> Literal[False]: ...
+    def is_ref(self) -> Literal[False]:
+        return False
 
EOF
@@ -425,3 +425,4 @@

def is_ref(self) -> Literal[False]: ...
def is_ref(self) -> Literal[False]:
return False

Copilot is powered by AI and may make mistakes. Always verify output.
@@ -441,6 +444,8 @@
type: _enums.AttributeType
doc_string: str | None

def is_ref(self) -> Literal[True]: ...

Check notice

Code scanning / CodeQL

Statement has no effect Note

This statement has no effect.

Copilot Autofix

AI about 1 month ago

To fix the issue, the is_ref method in the ReferenceAttributeProtocol class should be updated to explicitly return True as a Literal[True]. This ensures that the method has a concrete implementation that matches its type annotation and intended behavior. The change should be made directly in the ReferenceAttributeProtocol class definition.

Suggested changeset 1
onnxscript/ir/_protocols.py

Autofix patch

Autofix patch
Run the following command in your local git repository to apply this patch
cat << 'EOF' | git apply
diff --git a/onnxscript/ir/_protocols.py b/onnxscript/ir/_protocols.py
--- a/onnxscript/ir/_protocols.py
+++ b/onnxscript/ir/_protocols.py
@@ -446,4 +446,4 @@
 
-    def is_ref(self) -> Literal[True]: ...
-
+    def is_ref(self) -> Literal[True]:
+        return True
 
EOF
@@ -446,4 +446,4 @@

def is_ref(self) -> Literal[True]: ...

def is_ref(self) -> Literal[True]:
return True

Copilot is powered by AI and may make mistakes. Always verify output.
Copy link

codecov bot commented May 22, 2025

❌ 3 Tests Failed:

Tests completed Failed Passed Skipped
16177 3 16174 1703
View the top 3 failed test(s) by shortest run time
onnxscript.backend.onnx_export_test.TestOnnxBackEnd::test_export2python_produces_correct_onnx_script_model_0026_test_ai_onnx_ml_tree_ensemble_set_membership
Stack Traces | 0.016s run time
onnxscript\converter.py:460: in _eval_constant_expr
    return eval(cpl, self.globals, locals)  # pylint: disable=eval-used
E   NameError: name 'nan' is not defined

The above exception was the direct cause of the following exception:
.nox\test_ort_nightly\Lib\site-packages\parameterized\parameterized.py:620: in standalone_func
    return func(*(a + p.args), **p.kwargs, **kw)
onnxscript\backend\onnx_export_test.py:271: in test_export2python_produces_correct_onnx_script_model
    functions = extract_functions(backend_test.name, code, self.test_folder)
onnxscript\backend\onnx_export_test.py:137: in extract_functions
    mod = importlib.import_module(import_name)
C:\hostedtoolcache\windows\Python\3.11.9\x64\Lib\importlib\__init__.py:126: in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
<frozen importlib._bootstrap>:1204: in _gcd_import
    ???
<frozen importlib._bootstrap>:1176: in _find_and_load
    ???
<frozen importlib._bootstrap>:1147: in _find_and_load_unlocked
    ???
<frozen importlib._bootstrap>:690: in _load_unlocked
    ???
.nox\test_ort_nightly\Lib\site-packages\_pytest\assertion\rewrite.py:185: in exec_module
    exec(co, module.__dict__)
tests\onnx_backend_test_code\test_ai_onnx_ml_tree_ensemble_set_membership.py:9: in <module>
    @script()
onnxscript\main.py:94: in transform
    result = script_check(f_ast, opset, env, src, default_opset=default_opset)
onnxscript\main.py:38: in script_check
    return convert.translate_function_def(f)
onnxscript\converter.py:1452: in translate_function_def
    fn_ir = self._translate_function_def_common(stmt)
onnxscript\converter.py:1439: in _translate_function_def_common
    self._translate_stmt(s, index_of_stmt=i)
onnxscript\converter.py:961: in _translate_stmt
    return self._translate_assign_stmt(node)
onnxscript\converter.py:1048: in _translate_assign_stmt
    assign(lhs, rhs)
onnxscript\converter.py:992: in assign
    t = self._translate_expr(rhs, lhs).name
onnxscript\converter.py:546: in _translate_expr
    r = self._translate_call_expr(node)
onnxscript\converter.py:825: in _translate_call_expr
    attrs = [
onnxscript\converter.py:826: in <listcomp>
    self._translate_attr(x, y, callee.op_schema.attributes[x])
onnxscript\converter.py:510: in _translate_attr
    val = self._eval_constant_expr(expr)
onnxscript\converter.py:462: in _eval_constant_expr
    raise NameError(
E   NameError: ERROR: Missing names, globals contains ['__name__', '__doc__', '__package__', '__loader__', '__spec__', '__file__', '__cached__', '__builtins__', '@py_builtins', '@pytest_ar', 'numpy', 'TensorProto', 'make_tensor', 'script', 'external_tensor', 'Opset', 'FLOAT', 'ai_onnx_ml5'], locals [].
E   at: Function 'bck_test_ai_onnx_ml_tree_ensemble_set_membership', line 3
E       Y = ai_onnx_ml5.TreeEnsemble(X, aggregate_function=1, leaf_targetids=[0, 1, 2, 3], leaf_weights=make_tensor("value", 1, dims=[4], vals=[1.0, 10.0, 1000.0, 100.0]), membership_values=make_tensor("value", 1, dims=[8], vals=[1.2000000476837158, 3.700000047683716, 8.0, 9.0, nan, 12.0, 7.0, nan]), n_targets=4, nodes_falseleafs=[1, 0, 1], nodes_falsenodeids=[2, 2, 3], nodes_featureids=[0, 0, 0], nodes_modes=make_tensor("value", 2, dims=[3], vals=[0, 6, 6]), nodes_splits=make_tensor("value", 1, dims=[3], vals=[11.0, 232344.0, nan]), nodes_trueleafs=[0, 1, 1], nodes_truenodeids=[1, 0, 1], post_transform=0, tree_roots=[0])
E                                                                                                                                                                                             ^
onnxscript.backend.onnx_export_test.TestOnnxBackEnd::test_export2python_produces_correct_onnx_script_model_0815_test_ai_onnx_ml_tree_ensemble_set_membership
Stack Traces | 0.019s run time
onnxscript/converter.py:460: in _eval_constant_expr
    return eval(cpl, self.globals, locals)  # pylint: disable=eval-used
E   NameError: name 'nan' is not defined

The above exception was the direct cause of the following exception:
..../test_ort_nightly/lib/python3.11.../site-packages/parameterized/parameterized.py:620: in standalone_func
    return func(*(a + p.args), **p.kwargs, **kw)
onnxscript/backend/onnx_export_test.py:271: in test_export2python_produces_correct_onnx_script_model
    functions = extract_functions(backend_test.name, code, self.test_folder)
onnxscript/backend/onnx_export_test.py:137: in extract_functions
    mod = importlib.import_module(import_name)
.../hostedtoolcache/Python/3.11.12.../x64/lib/python3.11/importlib/__init__.py:126: in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
<frozen importlib._bootstrap>:1204: in _gcd_import
    ???
<frozen importlib._bootstrap>:1176: in _find_and_load
    ???
<frozen importlib._bootstrap>:1147: in _find_and_load_unlocked
    ???
<frozen importlib._bootstrap>:690: in _load_unlocked
    ???
..../test_ort_nightly/lib/python3.11.../_pytest/assertion/rewrite.py:185: in exec_module
    exec(co, module.__dict__)
tests/onnx_backend_test_code/test_ai_onnx_ml_tree_ensemble_set_membership.py:9: in <module>
    @script()
onnxscript/main.py:94: in transform
    result = script_check(f_ast, opset, env, src, default_opset=default_opset)
onnxscript/main.py:38: in script_check
    return convert.translate_function_def(f)
onnxscript/converter.py:1452: in translate_function_def
    fn_ir = self._translate_function_def_common(stmt)
onnxscript/converter.py:1439: in _translate_function_def_common
    self._translate_stmt(s, index_of_stmt=i)
onnxscript/converter.py:961: in _translate_stmt
    return self._translate_assign_stmt(node)
onnxscript/converter.py:1048: in _translate_assign_stmt
    assign(lhs, rhs)
onnxscript/converter.py:992: in assign
    t = self._translate_expr(rhs, lhs).name
onnxscript/converter.py:546: in _translate_expr
    r = self._translate_call_expr(node)
onnxscript/converter.py:825: in _translate_call_expr
    attrs = [
onnxscript/converter.py:826: in <listcomp>
    self._translate_attr(x, y, callee.op_schema.attributes[x])
onnxscript/converter.py:510: in _translate_attr
    val = self._eval_constant_expr(expr)
onnxscript/converter.py:462: in _eval_constant_expr
    raise NameError(
E   NameError: ERROR: Missing names, globals contains ['__name__', '__doc__', '__package__', '__loader__', '__spec__', '__file__', '__cached__', '__builtins__', '@py_builtins', '@pytest_ar', 'numpy', 'TensorProto', 'make_tensor', 'script', 'external_tensor', 'Opset', 'FLOAT', 'ai_onnx_ml5'], locals [].
E   at: Function 'bck_test_ai_onnx_ml_tree_ensemble_set_membership', line 3
E       Y = ai_onnx_ml5.TreeEnsemble(X, aggregate_function=1, leaf_targetids=[0, 1, 2, 3], leaf_weights=make_tensor("value", 1, dims=[4], vals=[1.0, 10.0, 1000.0, 100.0]), membership_values=make_tensor("value", 1, dims=[8], vals=[1.2000000476837158, 3.700000047683716, 8.0, 9.0, nan, 12.0, 7.0, nan]), n_targets=4, nodes_falseleafs=[1, 0, 1], nodes_falsenodeids=[2, 2, 3], nodes_featureids=[0, 0, 0], nodes_modes=make_tensor("value", 2, dims=[3], vals=[0, 6, 6]), nodes_splits=make_tensor("value", 1, dims=[3], vals=[11.0, 232344.0, nan]), nodes_trueleafs=[0, 1, 1], nodes_truenodeids=[1, 0, 1], post_transform=0, tree_roots=[0])
E                                                                                                                                                                                             ^
onnxscript.backend.onnx_export_test.TestOnnxBackEnd::test_export2python_produces_correct_onnx_script_model_0125_test_ai_onnx_ml_tree_ensemble_set_membership
Stack Traces | 0.025s run time
onnxscript/converter.py:460: in _eval_constant_expr
    return eval(cpl, self.globals, locals)  # pylint: disable=eval-used
E   NameError: name 'nan' is not defined

The above exception was the direct cause of the following exception:
..../test_ort_nightly/lib/python3.11.../site-packages/parameterized/parameterized.py:620: in standalone_func
    return func(*(a + p.args), **p.kwargs, **kw)
onnxscript/backend/onnx_export_test.py:271: in test_export2python_produces_correct_onnx_script_model
    functions = extract_functions(backend_test.name, code, self.test_folder)
onnxscript/backend/onnx_export_test.py:137: in extract_functions
    mod = importlib.import_module(import_name)
.../Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/importlib/__init__.py:126: in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
<frozen importlib._bootstrap>:1204: in _gcd_import
    ???
<frozen importlib._bootstrap>:1176: in _find_and_load
    ???
<frozen importlib._bootstrap>:1147: in _find_and_load_unlocked
    ???
<frozen importlib._bootstrap>:690: in _load_unlocked
    ???
..../test_ort_nightly/lib/python3.11.../_pytest/assertion/rewrite.py:185: in exec_module
    exec(co, module.__dict__)
tests/onnx_backend_test_code/test_ai_onnx_ml_tree_ensemble_set_membership.py:9: in <module>
    @script()
onnxscript/main.py:94: in transform
    result = script_check(f_ast, opset, env, src, default_opset=default_opset)
onnxscript/main.py:38: in script_check
    return convert.translate_function_def(f)
onnxscript/converter.py:1452: in translate_function_def
    fn_ir = self._translate_function_def_common(stmt)
onnxscript/converter.py:1439: in _translate_function_def_common
    self._translate_stmt(s, index_of_stmt=i)
onnxscript/converter.py:961: in _translate_stmt
    return self._translate_assign_stmt(node)
onnxscript/converter.py:1048: in _translate_assign_stmt
    assign(lhs, rhs)
onnxscript/converter.py:992: in assign
    t = self._translate_expr(rhs, lhs).name
onnxscript/converter.py:546: in _translate_expr
    r = self._translate_call_expr(node)
onnxscript/converter.py:825: in _translate_call_expr
    attrs = [
onnxscript/converter.py:826: in <listcomp>
    self._translate_attr(x, y, callee.op_schema.attributes[x])
onnxscript/converter.py:510: in _translate_attr
    val = self._eval_constant_expr(expr)
onnxscript/converter.py:462: in _eval_constant_expr
    raise NameError(
E   NameError: ERROR: Missing names, globals contains ['__name__', '__doc__', '__package__', '__loader__', '__spec__', '__file__', '__cached__', '__builtins__', '@py_builtins', '@pytest_ar', 'numpy', 'TensorProto', 'make_tensor', 'script', 'external_tensor', 'Opset', 'FLOAT', 'ai_onnx_ml5'], locals [].
E   at: Function 'bck_test_ai_onnx_ml_tree_ensemble_set_membership', line 3
E       Y = ai_onnx_ml5.TreeEnsemble(X, aggregate_function=1, leaf_targetids=[0, 1, 2, 3], leaf_weights=make_tensor("value", 1, dims=[4], vals=[1.0, 10.0, 1000.0, 100.0]), membership_values=make_tensor("value", 1, dims=[8], vals=[1.2000000476837158, 3.700000047683716, 8.0, 9.0, nan, 12.0, 7.0, nan]), n_targets=4, nodes_falseleafs=[1, 0, 1], nodes_falsenodeids=[2, 2, 3], nodes_featureids=[0, 0, 0], nodes_modes=make_tensor("value", 2, dims=[3], vals=[0, 6, 6]), nodes_splits=make_tensor("value", 1, dims=[3], vals=[11.0, 232344.0, nan]), nodes_trueleafs=[0, 1, 1], nodes_truenodeids=[1, 0, 1], post_transform=0, tree_roots=[0])
E                                                                                                                                                                                             ^

To view more test analytics, go to the Test Analytics Dashboard
📋 Got 3 mins? Take this short survey to help us improve Test Analytics.

@justinchuby justinchuby requested review from Copilot May 22, 2025 15:26
Copy link
Contributor

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR removes the separate RefAttr class by merging its functionality into the existing Attr class and introduces an is_ref() method for reference checks.
Key changes:

  • Consolidated attribute classes and removed RefAttr across the IR.
  • Updated all type annotations and isinstance(attr, RefAttr) checks to use Attr and attr.is_ref().
  • Adjusted serialization/deserialization and protocol definitions to handle reference attributes via is_ref().

Reviewed Changes

Copilot reviewed 13 out of 13 changed files in this pull request and generated no comments.

Show a summary per file
File Description
onnxscript/version_converter/_version_converter.py Changed visit_attribute signature to ir.Attr.
onnxscript/rewriter/llama_rule_sets.py Replaced isinstance(..., RefAttr) with is_ref().
onnxscript/rewriter/_rewrite_rule.py Simplified check signatures; updated copy_attr_value.
onnxscript/rewriter/_pattern_ir.py Narrowed AttrPattern to ir.Attr only.
onnxscript/optimizer/_constant_folding.py Changed visit_attribute signature to ir.Attr.
onnxscript/ir/serde.py Split serialization paths using is_ref().
onnxscript/ir/passes/common/inliner.py Updated clone_attr and attr_map types.
onnxscript/ir/external_data.py Replaced isinstance(..., RefAttr) with is_ref().
onnxscript/ir/_tape.py Updated attribute sequences to ir.Attr.
onnxscript/ir/_protocols.py Added Literal import and is_ref() to protocols.
onnxscript/ir/_core.py Removed RefAttr class; merged into Attr.
onnxscript/ir/_convenience/_constructors.py Updated attribute sequences to ir.Attr.
onnxscript/ir/_convenience/init.py Removed RefAttr import and tightened types.
Comments suppressed due to low confidence (3)

onnxscript/version_converter/_version_converter.py:265

  • Reference attributes are no longer skipped here—since all attributes are now Attr, isinstance(attr, ir.Attr) will also match refs. Add a guard to skip ref attrs, e.g., if not attr.is_ref(): before processing.
if isinstance(attr, ir.Attr):

onnxscript/optimizer/_constant_folding.py:1063

  • Same as above, reference attributes will match ir.Attr and may cause visit_graph on a ref. Add not attr.is_ref() to skip processing reference attrs.
if isinstance(attr, ir.Attr):

onnxscript/rewriter/_rewrite_rule.py:371

  • The exception message still references the removed RefAttr class. Update it to something like "Reference attributes are not supported." for clarity.
raise NotImplementedError("RefAttr not supported.")

Copy link
Contributor

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copilot encountered an error and was unable to review this pull request. You can try again by re-requesting a review.

@justinchuby justinchuby added this to the 0.2.7 milestone May 22, 2025
@justinchuby justinchuby requested a review from Copilot May 22, 2025 16:45
Copy link
Contributor

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copilot encountered an error and was unable to review this pull request. You can try again by re-requesting a review.

@justinchuby justinchuby merged commit f46004e into main May 22, 2025
26 of 29 checks passed
@justinchuby justinchuby deleted the justinchu/combine-ref branch May 22, 2025 17:43
@justinchuby justinchuby mentioned this pull request May 21, 2025
9 tasks
@justinchuby justinchuby modified the milestones: 0.2.7, 0.2.6 May 22, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
Development

Successfully merging this pull request may close these issues.

2 participants