Skip to content

Commit 490ac84

Browse files
committed
Auto generate OpSchema for functions | feat
This change adds the capability to auto generate `OpSchema`. ### Changes - Implement the `opschema` property in `OnnxFunction` ghstack-source-id: 07f9122 Pull Request resolved: #626
1 parent 84a348e commit 490ac84

File tree

4 files changed

+248
-27
lines changed

4 files changed

+248
-27
lines changed

onnxscript/autocast.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def cast(x, typeinfo) -> tensor.Tensor:
8686
return cast_inputs(get_type_info, cast, op_schema, *args)
8787

8888

89-
def static_cast_inputs(converter, op_schema: OpSchema, *args):
89+
def static_cast_inputs(converter, op_schema: Optional[OpSchema], *args):
9090
"""Used for autocast during script-translation."""
9191
if op_schema is None:
9292
return args

onnxscript/onnx_types.py

Lines changed: 71 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -105,69 +105,105 @@ def to_type_proto(cls) -> onnx.TypeProto:
105105
shape = [cls.shape] # example: "FLOAT[10]"
106106
return onnx.helper.make_tensor_type_proto(cls.dtype, shape)
107107

108+
@classmethod
109+
def to_string(cls) -> str:
110+
raise NotImplementedError()
111+
108112

109113
class FLOAT(TensorType, dtype=onnx.TensorProto.FLOAT):
110-
pass
114+
@classmethod
115+
def to_string(cls):
116+
return "tensor(float)"
111117

112118

113119
class UINT8(TensorType, dtype=onnx.TensorProto.UINT8):
114-
pass
120+
@classmethod
121+
def to_string(cls):
122+
return "tensor(uint8)"
115123

116124

117125
class INT8(TensorType, dtype=onnx.TensorProto.INT8):
118-
pass
126+
@classmethod
127+
def to_string(cls):
128+
return "tensor(int8)"
119129

120130

121131
class UINT16(TensorType, dtype=onnx.TensorProto.UINT16):
122-
pass
132+
@classmethod
133+
def to_string(cls):
134+
return "tensor(uint16)"
123135

124136

125137
class INT16(TensorType, dtype=onnx.TensorProto.INT16):
126-
pass
138+
@classmethod
139+
def to_string(cls):
140+
return "tensor(int16)"
127141

128142

129143
class INT32(TensorType, dtype=onnx.TensorProto.INT32):
130-
pass
144+
@classmethod
145+
def to_string(cls):
146+
return "tensor(int32)"
131147

132148

133149
class INT64(TensorType, dtype=onnx.TensorProto.INT64):
134-
pass
150+
@classmethod
151+
def to_string(cls):
152+
return "tensor(int64)"
135153

136154

137155
class STRING(TensorType, dtype=onnx.TensorProto.STRING):
138-
pass
156+
@classmethod
157+
def to_string(cls):
158+
return "tensor(string)"
139159

140160

141161
class BOOL(TensorType, dtype=onnx.TensorProto.BOOL):
142-
pass
162+
@classmethod
163+
def to_string(cls):
164+
return "tensor(bool)"
143165

144166

145167
class FLOAT16(TensorType, dtype=onnx.TensorProto.FLOAT16):
146-
pass
168+
@classmethod
169+
def to_string(cls):
170+
return "tensor(float16)"
147171

148172

149173
class DOUBLE(TensorType, dtype=onnx.TensorProto.DOUBLE):
150-
pass
174+
@classmethod
175+
def to_string(cls):
176+
return "tensor(double)"
151177

152178

153179
class UINT32(TensorType, dtype=onnx.TensorProto.UINT32):
154-
pass
180+
@classmethod
181+
def to_string(cls):
182+
return "tensor(uint32)"
155183

156184

157185
class UINT64(TensorType, dtype=onnx.TensorProto.UINT64):
158-
pass
186+
@classmethod
187+
def to_string(cls):
188+
return "tensor(uint64)"
159189

160190

161191
class COMPLEX64(TensorType, dtype=onnx.TensorProto.COMPLEX64):
162-
pass
192+
@classmethod
193+
def to_string(cls):
194+
return "tensor(complex64)"
163195

164196

165197
class COMPLEX128(TensorType, dtype=onnx.TensorProto.COMPLEX128):
166-
pass
198+
@classmethod
199+
def to_string(cls):
200+
return "tensor(complex128)"
167201

168202

169203
class BFLOAT16(TensorType, dtype=onnx.TensorProto.BFLOAT16):
170-
pass
204+
@classmethod
205+
def to_string(cls):
206+
return "tensor(bfloat16)"
171207

172208

173209
def onnx_type_to_onnxscript_repr(onnx_type: onnx.TypeProto) -> str:
@@ -203,3 +239,22 @@ def onnx_type_to_onnxscript_repr(onnx_type: onnx.TypeProto) -> str:
203239

204240
# Currently, only tensor types are supported. Need to expand support for other ONNX types.
205241
ONNXType = TensorType
242+
243+
ALL_TENSOR_TYPES = (
244+
BFLOAT16,
245+
BOOL,
246+
COMPLEX128,
247+
COMPLEX64,
248+
DOUBLE,
249+
FLOAT,
250+
FLOAT16,
251+
INT16,
252+
INT32,
253+
INT64,
254+
INT8,
255+
STRING,
256+
UINT16,
257+
UINT32,
258+
UINT64,
259+
UINT8,
260+
)

onnxscript/type_annotation.py

Lines changed: 55 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,12 @@
77
import collections
88
import inspect
99
import typing
10+
from typing import Any, TypeVar, Union
1011

1112
import onnx
1213
from typing_extensions import get_args, get_origin
1314

14-
from onnxscript.onnx_types import TensorType
15+
from onnxscript import onnx_types
1516

1617
# TypeAnnotationValue represents the (value of) valid type-annotations recognized
1718
# by ONNX Script. TODO: Flesh out a formal definition. Currently, it supports
@@ -59,7 +60,7 @@ def _is_primitive_attr_type(typeinfo: TypeAnnotationValue) -> bool:
5960

6061
def pytype_to_attrtype(
6162
pytype: TypeAnnotationValue,
62-
) -> onnx.AttributeProto.AttributeType:
63+
) -> typing.Optional[onnx.AttributeProto.AttributeType]:
6364
pytype = _remove_annotation(pytype)
6465
if pytype in _PYTYPE_TO_ATTRTYPE_MAP:
6566
return _PYTYPE_TO_ATTRTYPE_MAP[pytype]
@@ -74,13 +75,13 @@ def pytype_to_attrtype(
7475
elt_type = get_args(pytype)[0]
7576
if elt_type in _LISTTYPE_TO_ATTRTYPE_MAP:
7677
return _LISTTYPE_TO_ATTRTYPE_MAP[elt_type]
77-
return onnx.AttributeProto.UNDEFINED
78+
return None
7879

7980

8081
def _is_tensor_type(typeinfo: TypeAnnotationValue) -> bool:
81-
if isinstance(typeinfo, TensorType):
82+
if isinstance(typeinfo, onnx_types.TensorType):
8283
return True
83-
if inspect.isclass(typeinfo) and issubclass(typeinfo, TensorType):
84+
if inspect.isclass(typeinfo) and issubclass(typeinfo, onnx_types.TensorType):
8485
return True
8586
return False
8687

@@ -146,3 +147,52 @@ def get_return_types(typeinfo: type | typing.Sequence[type]) -> typing.Sequence[
146147
if get_origin(typeinfo) is tuple:
147148
return get_args(typeinfo)
148149
return (typeinfo,)
150+
151+
152+
def _reduce_type_var_to_union(hint: typing.TypeVar):
153+
"""Reduce a TypeVar to a Union type on which we can use issubclass to check membership."""
154+
assert isinstance(hint, TypeVar)
155+
156+
# If the TypeVar has a bound, use that.
157+
if hint.__bound__ is not None:
158+
return hint.__bound__
159+
160+
# If the TypeVar has no bound, use the first constraint.
161+
if hint.__constraints__:
162+
return Union.__getitem__(hint.__constraints__)
163+
164+
return Any
165+
166+
167+
def get_supported_input_types(pytype) -> list[str]:
168+
"""Returns a list of all supported input types for a given type annotation.
169+
170+
Args:
171+
pytype: A type annotation.
172+
173+
Returns:
174+
A list of all supported input types for the given type annotation.
175+
"""
176+
supported_types: list[str] = []
177+
if typing.get_origin(pytype) is Union and isinstance(typing.get_args(pytype)[0], TypeVar):
178+
# Recursively unpack TypeVars inside an Optional
179+
for arg in typing.get_args(pytype):
180+
supported_types.extend(get_supported_input_types(arg))
181+
return supported_types
182+
183+
if isinstance(pytype, TypeVar):
184+
pytype = _reduce_type_var_to_union(pytype)
185+
186+
for tensor_type in onnx_types.ALL_TENSOR_TYPES:
187+
if pytype is None:
188+
# The same as Any
189+
supported_types.append(tensor_type.to_string())
190+
elif pytype == onnx_types.TensorType:
191+
supported_types.append(tensor_type.to_string())
192+
elif isinstance(pytype, tensor_type):
193+
supported_types.append(tensor_type.to_string())
194+
elif issubclass(tensor_type, pytype):
195+
supported_types.append(tensor_type.to_string())
196+
# TODO(justinchuby): Handle sequence types
197+
198+
return supported_types

0 commit comments

Comments
 (0)