Skip to content

[FRONTEND][NFC] Fix type checking, conditional logic, and loop structures for improved readability and performance #4208

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 26 additions & 29 deletions python/triton/compiler/code_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def mangle_ty(ty):
return f'{elt}S{shape}S'
if ty.is_void():
return 'V'
assert False, "Unsupported type"
raise TypeError(f'Unsupported type {ty}')


def mangle_fn(name, arg_tys, constants):
Expand Down Expand Up @@ -121,10 +121,7 @@ def __init__(self, gscope):
self.gscope = gscope

def _visit_stmts(self, body) -> bool:
for s in body:
if self.visit(s):
return True
return False
return any(self.visit(s) for s in body)

def _visit_function(self, fn) -> bool:
# Currently we only support JITFunctions defined in the global scope
Expand Down Expand Up @@ -160,7 +157,7 @@ def visit_Attribute(self, node: ast.Attribute) -> bool:
return self.visit(node.value)

def visit_Name(self, node: ast.Name) -> bool:
if type(node.ctx) == ast.Store:
if type(node.ctx) is ast.Store:
return False
if node.id in self.gscope:
fn = self.gscope[node.id]
Expand Down Expand Up @@ -226,7 +223,7 @@ def __init__(self, context, prototype, gscope, attributes, constants, function_n
self.function_ret_types = {} if function_types is None else function_types
self.prototype = prototype
self.gscope = gscope
self.lscope = dict()
self.lscope = {}
self.attributes = attributes
self.constants = constants
self.jit_fn = jit_fn
Expand Down Expand Up @@ -281,19 +278,20 @@ def global_lookup(name: str, absent):
# The high-level rule is that only constexpr globals are allowed.
# But actually a bunch of other things, such as module imports, are
# technically Python globals. We have to allow these too!
if (val is absent #
or name in self.builtin_namespace #
or type(val) == ModuleType #
or isinstance(val, JITFunction) #
or getattr(val, "__triton_builtin__", False) #
or getattr(val, "__module__", "").startswith("triton.language") #
or isinstance(val, language.dtype) #
or self._is_constexpr_global(name) #
if any([
val is absent, name in self.builtin_namespace, #
type(val) is ModuleType, #
isinstance(val, JITFunction), #
getattr(val, "__triton_builtin__", False), #
getattr(val, "__module__", "").startswith("triton.language"), #
isinstance(val, language.dtype), #
self._is_constexpr_global(name), #
# Allow accesses to globals while visiting an ast.arg
# because you should be able to do
# @triton.jit def fn(x: tl.constexpr = GLOBAL): ...
or self.visiting_arg_default_value #
or os.environ.get("TRITON_ALLOW_NON_CONSTEXPR_GLOBALS", "0") == "1"):
self.visiting_arg_default_value, #
os.environ.get("TRITON_ALLOW_NON_CONSTEXPR_GLOBALS", "0") == "1"
]):
return val
raise NameError(
textwrap.dedent(f"""\
Expand Down Expand Up @@ -418,7 +416,7 @@ def visit_FunctionDef(self, node):
entry = self.fn.add_entry_block()
arg_values = []
idx = 0
for i, arg_name in enumerate(arg_names):
for i in range(len(arg_names)):
if i in self.constants:
cst = self.constants[i]
if not _is_constexpr(cst):
Expand Down Expand Up @@ -514,7 +512,7 @@ def visit_AugAssign(self, node):
return self.dereference_name(name)

def visit_Name(self, node):
if type(node.ctx) == ast.Store:
if type(node.ctx) is ast.Store:
return node.id
return self.dereference_name(node.id)

Expand Down Expand Up @@ -770,9 +768,9 @@ def visit_Compare(self, node):
rhs = self.visit(node.comparators[0])
lhs_value = _unwrap_if_constexpr(lhs)
rhs_value = _unwrap_if_constexpr(rhs)
if type(node.ops[0]) == ast.Is:
if type(node.ops[0]) is ast.Is:
return constexpr(lhs_value is rhs_value)
if type(node.ops[0]) == ast.IsNot:
if type(node.ops[0]) is ast.IsNot:
return constexpr(lhs_value is not rhs_value)
method_name = self._method_name_for_comp_op.get(type(node.ops[0]))
if method_name is None:
Expand Down Expand Up @@ -1048,7 +1046,7 @@ def call_JitFunction(self, fn: JITFunction, args, kwargs):
args = [args[name] for name in fn.arg_names]
args = [arg if _is_triton_tensor(arg) else constexpr(arg) for arg in args]
# generate function def
attributes = dict()
attributes = {}
constexprs = [i for i, arg in enumerate(args) if _is_constexpr(arg)]
constants = {i: args[i] for i in constexprs}
# generate call
Expand Down Expand Up @@ -1098,14 +1096,14 @@ def visit_Call(self, node):

kws = dict(self.visit(keyword) for keyword in node.keywords)
args = [self.visit(arg) for arg in node.args]
if fn is language.core.device_assert: # TODO: this should not be so hardcoded
if not self.debug:
return
# TODO: this should not be so hardcoded
if fn is language.core.device_assert and not self.debug:
return
if isinstance(fn, JITFunction):
_check_fn_args(node, fn, args)
return self.call_JitFunction(fn, args, kws)
if (hasattr(fn, '__self__') and _is_triton_tensor(fn.__self__)) or language.core.is_builtin(fn):
extra_kwargs = dict(_builder=self.builder)
extra_kwargs = {"_builder": self.builder}
sig = inspect.signature(fn)
if '_generator' in sig.parameters:
extra_kwargs['_generator'] = self
Expand Down Expand Up @@ -1154,9 +1152,8 @@ def visit_Str(self, node):

def visit_Attribute(self, node):
lhs = self.visit(node.value)
if _is_triton_tensor(lhs):
if node.attr == "T":
return language.semantic.permute(lhs, (1, 0), builder=self.builder)
if _is_triton_tensor(lhs) and node.attr == "T":
return language.semantic.permute(lhs, (1, 0), builder=self.builder)
return getattr(lhs, node.attr)

def visit_Expr(self, node):
Expand Down
9 changes: 4 additions & 5 deletions python/triton/runtime/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def _update_hash(self, func):
self.hasher.update(func_key.encode("utf-8"))

def visit_Name(self, node):
if type(node.ctx) == ast.Store:
if type(node.ctx) is ast.Store:
return node.id

if node.id in self.local_names:
Expand All @@ -117,12 +117,11 @@ def visit_Name(self, node):
and not self.visiting_arg_default_value
# It would be pretty evil if someone did `import x` and then
# `x = blah`.
and type(val) != ModuleType
and type(val) is not ModuleType
# It would be pretty evil if we used function `foo` inside of
# `bar` and then someone did `foo = baz`.
and not isinstance(val, JITFunction) and not getattr(val, "__triton_builtin__", False) #
and node.id not in self.supported_python_builtins #
):
and node.id not in self.supported_python_builtins):
self.used_global_vals[(node.id, id(self.globals))] = (val, self.globals)

self._update_hash(val)
Expand Down Expand Up @@ -650,7 +649,7 @@ def run(self, *args, grid, warmup, **kwargs):

# Check that used global values have not changed.
not_present = object()
for (name, globals_dict_id), (val, globals_dict) in self.used_global_vals.items():
for (name, _), (val, globals_dict) in self.used_global_vals.items():
if (newVal := globals_dict.get(name, not_present)) != val:
raise RuntimeError(
f"Global variable {name} has changed since we compiled this kernel, from {val} to {newVal}")
Expand Down
Loading