Skip to content

Refactor converter to isolate translate_function_signature logic | feat(converter) #684

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 70 commits into from
Apr 28, 2023
Merged
Show file tree
Hide file tree
Changes from 60 commits
Commits
Show all changes
70 commits
Select commit Hold shift + click to select a range
c6505cf
Move version_utils to _internal | chore
justinchuby Apr 12, 2023
268ad8f
Merge `attrs` and `attr_protos` in `IRFunction` | chore(irbuilder)
justinchuby Apr 12, 2023
327a7d6
Auto generate OpSchema for functions | feat
justinchuby Apr 12, 2023
0883d6f
Update base for Update on "Auto generate OpSchema for functions | feat"
justinchuby Apr 12, 2023
95c4ba6
Update on "Auto generate OpSchema for functions | feat"
justinchuby Apr 12, 2023
962c13b
Update base for Update on "Auto generate OpSchema for functions | feat"
justinchuby Apr 12, 2023
a353d73
Update on "Auto generate OpSchema for functions | feat"
justinchuby Apr 12, 2023
821821c
Update base for Update on "Auto generate OpSchema for functions | feat"
justinchuby Apr 12, 2023
fa16ca5
Update on "Auto generate OpSchema for functions | feat"
justinchuby Apr 12, 2023
6b9106b
Update base for Update on "Auto generate OpSchema for functions | feat"
justinchuby Apr 13, 2023
46fa00f
Update on "Auto generate OpSchema for functions | feat"
justinchuby Apr 13, 2023
bbe8e7e
Update base for Update on "Auto generate OpSchema for functions | feat"
justinchuby Apr 13, 2023
43345e2
Update on "Auto generate OpSchema for functions | feat"
justinchuby Apr 13, 2023
26d5caa
Update base for Update on "Auto generate OpSchema for functions | feat"
justinchuby Apr 13, 2023
f835d9b
Update on "Auto generate OpSchema for functions | feat"
justinchuby Apr 13, 2023
b3a035d
Update base for Update on "Auto generate OpSchema for functions | feat"
justinchuby Apr 13, 2023
e953774
Update on "Auto generate OpSchema for functions | feat"
justinchuby Apr 13, 2023
03cc7f4
Update base for Update on "Auto generate OpSchema for functions | feat"
justinchuby Apr 13, 2023
7fda2d1
Update on "Auto generate OpSchema for functions | feat"
justinchuby Apr 13, 2023
7321b22
Update base for Update on "Auto generate OpSchema for functions | feat"
justinchuby Apr 13, 2023
b760abc
Update on "Auto generate OpSchema for functions | feat"
justinchuby Apr 13, 2023
efc7708
Update base for Update on "Auto generate OpSchema for functions | feat"
justinchuby Apr 14, 2023
606be97
Update on "Auto generate OpSchema for functions | feat"
justinchuby Apr 14, 2023
67a8ee0
Update base for Update on "Auto generate OpSchema for functions | fea…
justinchuby Apr 14, 2023
a3f9b50
Update on "Auto generate OpSchema for functions | feat(op_schema)"
justinchuby Apr 14, 2023
1d4a0b4
Update base for Update on "Auto generate OpSchema for functions | fea…
justinchuby Apr 14, 2023
08a27fa
Update on "Auto generate OpSchema for functions | feat(op_schema)"
justinchuby Apr 14, 2023
2d158ac
Update base for Update on "Auto generate OpSchema for functions | fea…
justinchuby Apr 17, 2023
28b4a48
Update on "Auto generate OpSchema for functions | feat(op_schema)"
justinchuby Apr 17, 2023
622b688
Update base for Update on "Auto generate OpSchema for functions | fea…
justinchuby Apr 18, 2023
19f9484
Update on "Auto generate OpSchema for functions | feat(op_schema)"
justinchuby Apr 18, 2023
31bb69c
Update base for Update on "Auto generate OpSchema for functions | fea…
justinchuby Apr 18, 2023
b61543a
Update on "Auto generate OpSchema for functions | feat(op_schema)"
justinchuby Apr 18, 2023
2c6be92
Update base for Update on "Auto generate OpSchema for functions | fea…
justinchuby Apr 18, 2023
6cfa67c
Update on "Auto generate OpSchema for functions | feat(op_schema)"
justinchuby Apr 18, 2023
a9a0845
Update base for Update on "Auto generate OpSchema for functions | fea…
justinchuby Apr 20, 2023
0502482
Update on "Auto generate OpSchema for functions | feat(op_schema)"
justinchuby Apr 20, 2023
79d3605
Update base for Update on "Auto generate OpSchema for functions | fea…
justinchuby Apr 20, 2023
ad57790
Update on "Auto generate OpSchema for functions | feat(op_schema)"
justinchuby Apr 20, 2023
1439aaf
Update base for Update on "Auto generate OpSchema for functions | fea…
justinchuby Apr 20, 2023
f2455dc
Update on "Auto generate OpSchema for functions | feat(op_schema)"
justinchuby Apr 20, 2023
ea41c8f
Update base for Update on "Auto generate OpSchema for functions | fea…
justinchuby Apr 21, 2023
a2fde87
Update on "Auto generate OpSchema for functions | feat(op_schema)"
justinchuby Apr 21, 2023
138c2ed
Update base for Update on "Auto generate OpSchema for functions | fea…
justinchuby Apr 22, 2023
b334880
Update on "Auto generate OpSchema for functions | feat(op_schema)"
justinchuby Apr 22, 2023
90208e1
Update base for Update on "Auto generate OpSchema for functions | fea…
justinchuby Apr 22, 2023
cef03af
Update on "Auto generate OpSchema for functions | feat(op_schema)"
justinchuby Apr 22, 2023
2d9627e
Update base for Update on "Auto generate OpSchema for functions | fea…
justinchuby Apr 22, 2023
ed79fce
Update on "Auto generate OpSchema for functions | feat(op_schema)"
justinchuby Apr 22, 2023
8f5f7ba
Update base for Update on "Auto generate OpSchema for functions | fea…
justinchuby Apr 22, 2023
6376c93
Update on "Auto generate OpSchema for functions | feat(op_schema)"
justinchuby Apr 22, 2023
49d8d0e
Update base for Update on "Auto generate OpSchema for functions | fea…
justinchuby Apr 25, 2023
dd80bff
Update on "Auto generate OpSchema for functions | feat(op_schema)"
justinchuby Apr 25, 2023
b3dbb7f
Update base for Update on "Auto generate OpSchema for functions | fea…
justinchuby Apr 27, 2023
14d2149
Update on "Auto generate OpSchema for functions | feat(op_schema)"
justinchuby Apr 27, 2023
dc0963b
Refactor converter to isolate translate_function_signature logic | fe…
justinchuby Apr 27, 2023
2ab1af3
Update base for Update on "Refactor converter to isolate translate_fu…
justinchuby Apr 27, 2023
dc608dc
Update on "Refactor converter to isolate translate_function_signature…
justinchuby Apr 27, 2023
b9f7a50
Update base for Update on "Refactor converter to isolate translate_fu…
justinchuby Apr 27, 2023
b60cb46
Update on "Refactor converter to isolate translate_function_signature…
justinchuby Apr 27, 2023
c77cdd8
Update base for Update on "Refactor converter to isolate translate_fu…
justinchuby Apr 27, 2023
e67f49e
Update on "Refactor converter to isolate translate_function_signature…
justinchuby Apr 27, 2023
a491eb8
Update base for Update on "Refactor converter to isolate translate_fu…
justinchuby Apr 28, 2023
59c41f1
Update on "Refactor converter to isolate translate_function_signature…
justinchuby Apr 28, 2023
88a14d0
Update base for Update on "Refactor converter to isolate translate_fu…
justinchuby Apr 28, 2023
adc8e38
Update on "Refactor converter to isolate translate_function_signature…
justinchuby Apr 28, 2023
3e60a78
Update base for Update on "Refactor converter to isolate translate_fu…
justinchuby Apr 28, 2023
c51159f
Update on "Refactor converter to isolate translate_function_signature…
justinchuby Apr 28, 2023
fe1e2d9
Update base for Update on "Refactor converter to isolate translate_fu…
justinchuby Apr 28, 2023
59f092e
Update on "Refactor converter to isolate translate_function_signature…
justinchuby Apr 28, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion onnxscript/autocast.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def cast(x, typeinfo) -> tensor.Tensor:
return cast_inputs(get_type_info, cast, op_schema, *args)


def static_cast_inputs(converter, op_schema: OpSchema, *args):
def static_cast_inputs(converter, op_schema: Optional[OpSchema], *args):
"""Used for autocast during script-translation."""
if op_schema is None:
return args
Expand Down
103 changes: 59 additions & 44 deletions onnxscript/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,13 @@ def __init__(
self.this_module = opset
self.default_opset_ = default_opset

# States initialized by `init_function_translation`
self._outer = []
self._current_fn = None
self._nextvar = 0
self._used_vars = set()
self._locals: List[Dict[Any, Any]] = [{}]

@property
def default_opset(self):
if self.default_opset_ is None:
Expand Down Expand Up @@ -222,14 +229,14 @@ def find_onnx_opset(self, node: ast.AST) -> Optional[values.Opset]:

def init_function_translation(self):
"""Initialize self for translating a new (top-level) function."""
self.outer = []
self.current_fn = None
self.nextvar = 0
self.used_vars = set()
self.locals: List[Dict[Any, Any]] = [{}]
self._outer = []
self._current_fn = None
self._nextvar = 0
self._used_vars = set()
self._locals: List[Dict[Any, Any]] = [{}]

def source_of(self, node: ast.AST) -> sourceinfo.SourceInfo:
return sourceinfo.SourceInfo(node, self.source, self.current_fn.name)
return sourceinfo.SourceInfo(node, self.source, self._current_fn.name)

def message(self, node: ast.AST, error_msg: str) -> str:
"""Constructs an error message containing source information about an ast node."""
Expand All @@ -252,28 +259,28 @@ def enter_scope(self, name, parent_node):
"""Enter a control-flow block (a loop body or if-then-else branch).
The block is translated into a nested-scope in ONNX.
"""
self.outer.insert(0, self.current_fn)
self.current_fn = self.ir_builder.new_function(name)
self.locals.insert(0, {})
logger.debug("Converter:enter_scope:%d:node:%s", len(self.locals), type(parent_node))
self._outer.insert(0, self._current_fn)
self._current_fn = self.ir_builder.new_function(name)
self._locals.insert(0, {})
logger.debug("Converter:enter_scope:%d:node:%s", len(self._locals), type(parent_node))

def exit_scope(self):
"""Exit from a control-flow block (a loop body or if-then-else branch)."""
logger.debug("Converter:exit_scope:%d", len(self.locals))
graph = self.current_fn
self.current_fn = self.outer.pop(0)
self.locals.pop(0)
logger.debug("Converter:exit_scope:%d", len(self._locals))
graph = self._current_fn
self._current_fn = self._outer.pop(0)
self._locals.pop(0)
return graph

def current_scope(self):
return self.locals[0]
return self._locals[0]

def bind(self, name, val):
logger.debug("Converter:bind:%s", name)
self.locals[0][name] = val
self._locals[0][name] = val

def lookup(self, name, info, raise_exception=True):
for scope in self.locals:
for scope in self._locals:
if name in scope:
return scope[name]
if name in self.globals:
Expand All @@ -285,10 +292,10 @@ def lookup(self, name, info, raise_exception=True):
def generate_unique_name(self, candidate: str = "tmp") -> str:
# TODO(justinchuby): Can we reduce the O complexity of this function?
r = candidate
while r in self.used_vars:
r = f"{candidate}_{self.nextvar}"
self.nextvar = self.nextvar + 1
self.used_vars.add(r)
while r in self._used_vars:
r = f"{candidate}_{self._nextvar}"
self._nextvar = self._nextvar + 1
self._used_vars.add(r)
return r

def to_onnx_attr_ref(self, val: values.AttrRef, info: Optional[sourceinfo.SourceInfo]):
Expand Down Expand Up @@ -326,11 +333,11 @@ def py_var_to_onnx_var(self, py_var, info: sourceinfo.SourceInfo):
return self.to_onnx_var(self.lookup(py_var, info), target=py_var, info=info)

def emit_docstring(self, docstring):
self.ir_builder.add_docstring(self.current_fn, docstring)
self.ir_builder.add_docstring(self._current_fn, docstring)

def emit(self, outputs, callee, inputs, attrs, sub_functions=None):
self.ir_builder.add_stmt(
self.current_fn,
self._current_fn,
outputs,
callee,
inputs,
Expand Down Expand Up @@ -898,7 +905,7 @@ def translate_callee_expr(self, node) -> values.Op: # pylint: disable=R1710
function_name = node.id
found = self.lookup(function_name, self.source_of(node), raise_exception=False)
if isinstance(found, onnxscript.OnnxFunction):
self.current_fn.add_called_function(found)
self._current_fn.add_called_function(found)
return found
if isinstance(found, values.Op):
return found
Expand Down Expand Up @@ -1016,7 +1023,7 @@ def ret(exp, i, suffix):
t = None
else:
t = self.returntype[i]
self.ir_builder.add_output(self.current_fn, return_var, t, self.source_of(stmt))
self.ir_builder.add_output(self._current_fn, return_var, t, self.source_of(stmt))
return return_var

val = stmt.value
Expand Down Expand Up @@ -1123,7 +1130,7 @@ def translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]):
self.enter_scope("loop_body", loop_stmt)
o_loop_var = self.generate_unique_name(p_loop_var)
self.ir_builder.add_input(
self.current_fn,
self._current_fn,
o_loop_var,
onnx_types.INT64,
self.source_of(loop_stmt),
Expand All @@ -1134,7 +1141,7 @@ def translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]):
)

self.ir_builder.add_input(
self.current_fn,
self._current_fn,
i_cond_var,
onnx_types.BOOL,
self.source_of(loop_stmt),
Expand All @@ -1145,7 +1152,9 @@ def translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]):
# TODO: retrieve the annotation for variable pv is any is specified.
# typeinfo = self.eval_constant_expr(pv.annotation)
typeinfo = None
self.ir_builder.add_input(self.current_fn, ov, typeinfo, self.source_of(loop_stmt))
self.ir_builder.add_input(
self._current_fn, ov, typeinfo, self.source_of(loop_stmt)
)
self.bind(
pv,
values.Dynamic(ov, values.DynamicKind.Loop, self.source_of(loop_stmt)),
Expand Down Expand Up @@ -1201,14 +1210,14 @@ def translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]):
)

self.ir_builder.add_output(
self.current_fn,
self._current_fn,
o_cond_out,
onnx_types.BOOL,
self.source_of(loop_stmt),
)
for pv in loop_state_vars:
ov = self.py_var_to_onnx_var(pv, self.source_of(loop_stmt))
if ov not in self.current_fn.assigned_names:
if ov not in self._current_fn.assigned_names:
# When converting the loop-body into a graph, we need to handle
# identity assignments of the form "x = y" inside the loop body
# specially if y represents a value computed outside the loop body.
Expand All @@ -1218,7 +1227,7 @@ def translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]):
# TODO: retrieve variable type for the annotation if any.
typeinfo = None
self.ir_builder.add_output(
self.current_fn, ov, typeinfo, self.source_of(loop_stmt)
self._current_fn, ov, typeinfo, self.source_of(loop_stmt)
)
body = self.exit_scope()
inputs = [o_loop_bound, o_true] + [
Expand All @@ -1245,27 +1254,27 @@ def translate_block(self, stmts, name, live_defs, parent_stmt=None):
if pvar in self.current_scope():
pv_val = self.current_scope()[pvar]
output = self.to_onnx_var(pv_val, pvar)
if output not in self.current_fn.assigned_names:
if output not in self._current_fn.assigned_names:
# To return an outer-scope variable, an ONNX Graph has to
# use an explicit copy via Identity.
output = self.emit_copy(output, pvar)
self.ir_builder.add_output(
self.current_fn,
self._current_fn,
output,
pv_val.typeinfo,
self.source_of(info_stmt),
)
else:
pv_val = None
for scope in self.locals: # TODO: skip current_scope
for scope in self._locals: # TODO: skip current_scope
if pvar in scope:
pv_val = scope[pvar]
break
if pv_val is None:
self.fail(
stmts[0],
f"Variable {pvar} is not assigned a value along a conditional "
f"branch, known variables: {list(self.locals)}.",
f"branch, known variables: {list(self._locals)}.",
)
# introduce a copy
ovar = self.generate_unique_name(pvar)
Expand All @@ -1278,7 +1287,7 @@ def translate_block(self, stmts, name, live_defs, parent_stmt=None):
# TODO: retrieve the annotation if any.
typeinfo = None
self.ir_builder.add_output(
self.current_fn, ovar, typeinfo, self.source_of(info_stmt)
self._current_fn, ovar, typeinfo, self.source_of(info_stmt)
)
graph = self.exit_scope()
return graph.to_graph_and_functions()
Expand All @@ -1294,15 +1303,15 @@ def translate_nested_function_def(self, fn: ast.FunctionDef):
]
self.bind(fn.name, function_ir)
# TODO: Does not yet handle nested functions within nested functions.
self.current_fn.add_nested_function(function_ir)
self._current_fn.add_nested_function(function_ir)

def translate_function_def(self, fn: ast.FunctionDef) -> irbuilder.IRFunction:
logger.debug("Converter:translate_function_def:%s", fn.name)
def translate_function_signature(self, fn: ast.FunctionDef) -> irbuilder.IRFunction:
"""Translate a function signature."""
args = fn.args
if args.vararg or args.kwonlyargs or args.kw_defaults or args.kwarg:
warn(f"{fn.name}: Unsupported feature in function signature.")
domain = self.this_module.domain
self.current_fn = self.ir_builder.new_function(fn.name, domain, True)
self._current_fn = self.ir_builder.new_function(fn.name, domain, True)
for i, x in enumerate(args.args):
arg_with_default_start_index = len(args.args) - len(args.defaults)
if args.defaults and i >= arg_with_default_start_index:
Expand All @@ -1324,15 +1333,15 @@ def translate_function_def(self, fn: ast.FunctionDef) -> irbuilder.IRFunction:
typeinfo = None
if typeinfo and ta.is_attr_type(typeinfo):
self.ir_builder.add_attr_parameter(
self.current_fn,
self._current_fn,
x.arg,
ta.pytype_to_attrtype(typeinfo),
default_value,
)
self.bind(x.arg, values.AttrRef(x.arg, typeinfo, self.source_of(x)))
else:
self.ir_builder.add_input(self.current_fn, x.arg, typeinfo, self.source_of(x))
self.used_vars.add(x.arg)
self.ir_builder.add_input(self._current_fn, x.arg, typeinfo, self.source_of(x))
self._used_vars.add(x.arg)
self.bind(
x.arg,
values.Dynamic(x.arg, values.DynamicKind.Input, self.source_of(x)),
Expand All @@ -1352,9 +1361,15 @@ def translate_function_def(self, fn: ast.FunctionDef) -> irbuilder.IRFunction:
self.returntype = None
else:
self.returntype = None

return self._current_fn

def translate_function_def(self, fn: ast.FunctionDef) -> irbuilder.IRFunction:
logger.debug("Converter:translate_function_def:%s", fn.name)
_ = self.translate_function_signature(fn)
for i, s in enumerate(fn.body):
self.translate_stmt(s, index_of_stmt=i)
return self.current_fn
return self._current_fn

def top_level_stmt(self, stmt: ast.FunctionDef) -> irbuilder.IRFunction:
if isinstance(stmt, ast.FunctionDef):
Expand Down
4 changes: 3 additions & 1 deletion onnxscript/irbuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,5 +526,7 @@ def make_attr_ref(self, attrname: str, refname: str, pytype: type) -> IRAttribut
proto = onnx.AttributeProto()
proto.name = attrname
proto.ref_attr_name = refname
proto.type = ta.pytype_to_attrtype(pytype)
attr_type = ta.pytype_to_attrtype(pytype)
assert attr_type is not None
proto.type = attr_type
return IRAttributeValue(proto)
4 changes: 4 additions & 0 deletions onnxscript/onnx_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,10 @@ def to_type_proto(cls) -> onnx.TypeProto:
shape = [cls.shape] # example: "FLOAT[10]"
return onnx.helper.make_tensor_type_proto(cls.dtype, shape)

@classmethod
def to_string(cls) -> str:
return f"tensor({cls.__name__.lower()})"


class FLOAT(TensorType, dtype=onnx.TensorProto.FLOAT):
pass
Expand Down
Loading