diff --git a/onnxscript/function_libs/torch_lib/_flags.py b/onnxscript/function_libs/torch_lib/_flags.py index 209b34fa4e..690246da1d 100644 --- a/onnxscript/function_libs/torch_lib/_flags.py +++ b/onnxscript/function_libs/torch_lib/_flags.py @@ -39,3 +39,7 @@ def _load_boolean_flag( "TORCHLIB_EXPERIMENTAL_INITIALIZERS_AS_INPUTS", this_will="make initializers as inputs to the model graph", ) +EXPERIMENTAL_PREFER_TRACING: bool = _load_boolean_flag( + "TORCHLIB_EXPERIMENTAL_PREFER_TRACING", + this_will="trace all traceable functions to fold if branches and collapse constant expressions", +) diff --git a/onnxscript/function_libs/torch_lib/registration.py b/onnxscript/function_libs/torch_lib/registration.py index f9c9a9fc7a..05d8f62179 100644 --- a/onnxscript/function_libs/torch_lib/registration.py +++ b/onnxscript/function_libs/torch_lib/registration.py @@ -99,6 +99,7 @@ def torch_op( trace_only: bool = False, private: bool = False, complex: bool = False, + traceable: bool = False, ) -> Callable[[FunctionType], onnxscript.OnnxFunction | onnxscript.values.TracedOnnxFunction]: """Register a torch op. @@ -112,6 +113,18 @@ def torch_op( private: Whether the function is private (not directly exposed). It should be true for all functions with names starting with "_". complex: Whether the function expects complex-valued inputs. + traceable: Whether the function can also be traced. This is an **experimental** flag. + A function is traceable if it can both be scripted and traced to produce + the same result for a given input. Specifically: + + - A function _can_ be tagged with traceable if its if branches (if any) + can be statically evaluated. + - A function _should_ be tagged with traceable if it contains if branches + and/or CastLike nodes so that they can be evaluated away with the + EXPERIMENTAL_PREFER_TRACING on. + - A function without if branches or CastLike nodes _should not_ be tagged + with traceable because inlining will do the same thing. + - A function with `@graph` defined for a `Scan` op is not traceable yet. """ if registry is None: registry = default_registry @@ -128,6 +141,7 @@ def wrapper( else: assert isinstance(func, FunctionType) processed_func = onnxscript.script(opset=custom_opset)(func) + processed_func.experimental_traceable = traceable assert registry is not None for name_ in _check_and_normalize_names(name): diff --git a/onnxscript/values.py b/onnxscript/values.py index db52e9dc1b..23f1a48d75 100644 --- a/onnxscript/values.py +++ b/onnxscript/values.py @@ -479,6 +479,9 @@ def __init__( self._param_schemas: Optional[tuple[ParamSchema, ...]] = None self._op_schema: Optional[onnx.defs.OpSchema] = None + # Experimental fields + self.experimental_traceable = False + @property @deprecation.deprecated( since="0.1",