Skip to content

Commit b294330

Browse files
committed
Merge remote-tracking branch 'upstream/main' into xiaowu/addOp(BatchNorm)
2 parents aa6df76 + 0298154 commit b294330

File tree

3 files changed

+48
-13
lines changed

3 files changed

+48
-13
lines changed

onnxscript/evaluator.py

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import onnx
1515
import onnx.defs
1616
import onnx.helper
17-
from typing_extensions import TypeAlias
17+
from typing_extensions import Protocol, TypeAlias, runtime_checkable
1818

1919
from onnxscript import autocast, irbuilder, onnx_opset, tensor, utils, values
2020
from onnxscript._internal import param_manipulation
@@ -117,7 +117,40 @@ def _unwrap_tensors_in_kwargs(kwargs: Mapping[str, Any]) -> dict[str, Any]:
117117
return new_kwargs
118118

119119

120-
class Evaluator(abc.ABC):
120+
@runtime_checkable
121+
class Evaluator(Protocol):
122+
"""Protocol for evaluating ONNX ops."""
123+
124+
def eval(
125+
self,
126+
schema: onnx.defs.OpSchema,
127+
inputs: Sequence[ExtendedModeValue],
128+
attributes: Mapping[str, Any],
129+
):
130+
"""Evaluates an ONNX op.
131+
132+
Args:
133+
schema: The OpSchema of the operator to evaluate.
134+
inputs: The ONNX inputs to the op.
135+
attributes: The ONNX attributes to the op.
136+
"""
137+
138+
def eval_function(
139+
self,
140+
function: values.OnnxFunction,
141+
args: Sequence[ExtendedModeValue],
142+
kwargs: Mapping[str, ExtendedModeValue],
143+
):
144+
"""Evaluates an OnnxFunction.
145+
146+
Args:
147+
function: The OnnxFunction to evaluate.
148+
args: The positional arguments to the function.
149+
kwargs: The keyword arguments to the function.
150+
"""
151+
152+
153+
class BaseEvaluator(Evaluator, abc.ABC):
121154
"""Base class for evaluation of ONNX ops.
122155
123156
The execution of onnxscript functions in eager-mode is dispatched to an Evaluator
@@ -400,7 +433,7 @@ def _schema_id(schema: onnx.defs.OpSchema) -> tuple[str, str, int]:
400433
return schema.name, schema.domain, schema.since_version
401434

402435

403-
class ORTEvaluator(Evaluator):
436+
class ORTEvaluator(BaseEvaluator):
404437
"""Evaluates ONNX ops using ONNX Runtime."""
405438

406439
def _eval(self, schema, inputs, attributes, closure):

onnxscript/function_libs/torch_aten/graph_building.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,9 @@ def __init__(self, graph: TorchScriptGraph):
204204
def graph(self) -> TorchScriptGraph:
205205
return self._graph
206206

207+
def eval(self, schema, inputs, attributes):
208+
return self._graph.add_op_call(schema, inputs, attributes)
209+
207210
@beartype
208211
def eval_function( # type: ignore[override]
209212
self,
@@ -225,14 +228,6 @@ def eval_function( # type: ignore[override]
225228
attributes[name] = float(value)
226229
return self._graph.add_function_call(function, inputs, attributes)
227230

228-
def _eval(self, schema: onnx.defs.OpSchema, inputs, attributes, closure: Any):
229-
del closure # Unused
230-
231-
return self._graph.add_op_call(schema, inputs, attributes)
232-
233-
def eval(self, schema, inputs, attributes):
234-
return self._eval(schema, inputs, attributes, closure=None)
235-
236231

237232
@beartype
238233
def _add_attribute_to_torchscript_node(

onnxscript/function_libs/torch_aten/ops/core.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -769,6 +769,12 @@ def aten_bitwise_not(self: TInt) -> TInt:
769769
return op.BitwiseNot(self)
770770

771771

772+
@torch_op("aten::bitwise_not", overload=True)
773+
def aten_bitwise_not_bool(self: BOOL) -> BOOL:
774+
# bitwise_not(Tensor self) -> Tensor
775+
return op.Not(self)
776+
777+
772778
@torch_op("aten::bitwise_or")
773779
def aten_bitwise_or(self: TInt, other: TInt) -> TInt:
774780
# bitwise_or.Tensor(Tensor self, Tensor other) -> Tensor
@@ -3162,9 +3168,10 @@ def aten_margin_ranking_loss(
31623168
@torch_op("aten::masked_fill")
31633169
def aten_masked_fill(self: TTensor, mask: BOOL, value: TTensor) -> TTensor:
31643170
# masked_fill.Tensor(Tensor self, Tensor mask, Tensor value) -> Tensor
3165-
mask_cast = op.Cast(mask, to=BOOL.dtype)
3171+
# NOTE: Do not attempt to cast `mask` to BOOL because mask should not take any other types.
3172+
# `mask` coming in as other types is often an error and should fail the model.
31663173
value_cast = op.CastLike(value, self)
3167-
return op.Where(mask_cast, value_cast, self)
3174+
return op.Where(mask, value_cast, self)
31683175

31693176

31703177
def aten_masked_scatter(self: TensorType, mask: TensorType, source: TensorType) -> TensorType:

0 commit comments

Comments
 (0)