Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions numbast/src/numbast/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
deduplicate_overloads,
make_device_caller_with_nargs,
make_function_shim,
sanitize_param_names,
)
from numbast.shim_writer import MemoryShimWriter as ShimWriter

Expand Down Expand Up @@ -80,9 +81,10 @@ def bind_cxx_operator_overload_function(
assert py_op is not None

# Crossing C / C++ boundary, pass argument by pointers.
param_names = sanitize_param_names(func_decl.params)
arglist = ", ".join(
f"{arg.type_.unqualified_non_ref_type_name}* {arg.name}"
for arg in func_decl.params
f"{arg.type_.unqualified_non_ref_type_name}* {name}"
for name, arg in zip(param_names, func_decl.params)
)
if arglist:
arglist = ", " + arglist
Expand All @@ -91,7 +93,7 @@ def bind_cxx_operator_overload_function(
return_type=return_type_name,
arglist=arglist,
method_name=func_decl.name,
args=", ".join("*" + arg.name for arg in func_decl.params),
args=", ".join("*" + name for name in param_names),
)

# Typing
Expand Down
10 changes: 6 additions & 4 deletions numbast/src/numbast/static/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
get_shim,
)
from numbast.static.types import to_numba_type_str
from numbast.utils import make_function_shim
from numbast.utils import make_function_shim, sanitize_param_names
from numbast.errors import TypeNotFoundError, MangledFunctionNameConflictError

from ast_canopy.decl import Function
Expand Down Expand Up @@ -193,10 +193,12 @@ def wrap_pointer(typ):

self._lower_scope_name = f"_lower_{self._deduplicated_shim_name}"

self._param_names = sanitize_param_names(self._decl.params)

# Cache the list of parameter types in C++ pointer types
c_ptr_arglist = ", ".join(
f"{arg.type_.unqualified_non_ref_type_name}* {arg.name}"
for arg in self._decl.params
f"{arg.type_.unqualified_non_ref_type_name}* {name}"
for name, arg in zip(self._param_names, self._decl.params)
)
if c_ptr_arglist:
c_ptr_arglist = ", " + c_ptr_arglist
Expand All @@ -205,7 +207,7 @@ def wrap_pointer(typ):

# Cache the list of dereferenced arguments
self._deref_args_str = ", ".join(
"*" + arg.name for arg in self._decl.params
"*" + name for name in self._param_names
)

# Track the public symbols from a function binding
Expand Down
9 changes: 6 additions & 3 deletions numbast/src/numbast/static/struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
deduplicate_overloads,
make_struct_ctor_shim,
make_struct_conversion_operator_shim,
sanitize_param_names,
)
from numbast.errors import TypeNotFoundError

Expand Down Expand Up @@ -173,10 +174,12 @@ def wrap_pointer(typ):
_pointer_wrapped_param_types
)

self._param_names = sanitize_param_names(self._ctor_decl.params)

# Cache the list of parameter types in C++ pointer types
c_ptr_arglist = ", ".join(
f"{arg.type_.unqualified_non_ref_type_name}* {arg.name}"
for arg in self._ctor_decl.params
f"{arg.type_.unqualified_non_ref_type_name}* {name}"
for name, arg in zip(self._param_names, self._ctor_decl.params)
)
if c_ptr_arglist:
c_ptr_arglist = ", " + c_ptr_arglist
Expand All @@ -185,7 +188,7 @@ def wrap_pointer(typ):

# Cache the list of dereferenced arguments
self._deref_args_str = ", ".join(
"*" + arg.name for arg in self._ctor_decl.params
"*" + name for name in self._param_names
)

# Cache the unique shim name
Expand Down
8 changes: 5 additions & 3 deletions numbast/src/numbast/struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from numbast.utils import (
deduplicate_overloads,
make_device_caller_with_nargs,
sanitize_param_names,
)
from numbast.shim_writer import MemoryShimWriter as ShimWriter

Expand Down Expand Up @@ -109,9 +110,10 @@ def bind_cxx_struct_ctor(
# FIXME: All params are passed by pointers, then dereferenced in shim.
# temporary solution for mismatching function prototype against definition.
# See above lowering for details.
param_names = sanitize_param_names(ctor.params)
arglist = ", ".join(
f"{arg.type_.unqualified_non_ref_type_name}* {arg.name}"
for arg in ctor.params
f"{arg.type_.unqualified_non_ref_type_name}* {name}"
for name, arg in zip(param_names, ctor.params)
)
if arglist:
arglist = ", " + arglist
Expand All @@ -120,7 +122,7 @@ def bind_cxx_struct_ctor(
func_name=func_name,
name=struct_name,
arglist=arglist,
args=", ".join("*" + arg.name for arg in ctor.params),
args=", ".join("*" + name for name in param_names),
)

@lower(S, *param_types)
Expand Down
43 changes: 34 additions & 9 deletions numbast/src/numbast/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

from typing import Callable
from typing import Callable, Iterable
from collections import defaultdict
import re

Expand Down Expand Up @@ -66,7 +66,26 @@ def deduplicate_overloads(func_name: str) -> str:
return func_name + f"_{OVERLOADS_CNT[func_name]}"


def paramvar_to_str(arg: pylibastcanopy.ParamVar):
def sanitize_param_names(params: Iterable[pylibastcanopy.ParamVar]) -> list[str]:
"""Return stable, unique parameter identifiers for use in generated code."""

names: list[str] = []
seen: set[str] = set()

for index, arg in enumerate(params):
candidate = (arg.name or "").strip() or f"arg{index}"
base = candidate
suffix = 1
while candidate in seen:
candidate = f"{base}_{suffix}"
suffix += 1
names.append(candidate)
seen.add(candidate)

return names


def paramvar_to_str(arg: pylibastcanopy.ParamVar, name: str):
"""Convert a ParamVar to a string type name.

Perform necessary downcasting of array type ParamVar to a pointer type.
Expand All @@ -82,13 +101,13 @@ def paramvar_to_str(arg: pylibastcanopy.ParamVar):
# Pointer to array type: int (*arr)[10]
loc = base_ty.rfind("*")
fml_arg = (
base_ty[: loc + 1] + f"*{arg.name}" + base_ty[loc + 1 :] + sizes
base_ty[: loc + 1] + f"*{name}" + base_ty[loc + 1 :] + sizes
)
else:
# Regular array type: int arr[10]
fml_arg = base_ty + f" (*{arg.name})" + sizes
fml_arg = base_ty + f" (*{name})" + sizes
else:
fml_arg = f"{arg.type_.unqualified_non_ref_type_name}* {arg.name}"
fml_arg = f"{arg.type_.unqualified_non_ref_type_name}* {name}"

return fml_arg

Expand Down Expand Up @@ -135,15 +154,18 @@ def make_function_shim(
else:
retval = "retval = "

formal_args = [paramvar_to_str(arg) for arg in params]
param_names = sanitize_param_names(params)
formal_args = [
paramvar_to_str(arg, name) for arg, name in zip(params, param_names)
]

formal_args_str = ", ".join(formal_args)
if formal_args_str:
# If there are formal arguments, add a comma before them
# otherwise it's an empty string.
formal_args_str = ", " + formal_args_str

acutal_args_str = ", ".join("*" + arg.name for arg in params)
acutal_args_str = ", ".join("*" + name for name in param_names)

include_str = "\n".join([f"#include <{include}>" for include in includes])

Expand Down Expand Up @@ -193,15 +215,18 @@ def make_struct_ctor_shim(
}}
"""

formal_args = [paramvar_to_str(arg) for arg in params]
param_names = sanitize_param_names(params)
formal_args = [
paramvar_to_str(arg, name) for arg, name in zip(params, param_names)
]

formal_args_str = ", ".join(formal_args)
if formal_args_str:
# If there are formal arguments, add a comma before them
# otherwise it's an empty string.
formal_args_str = ", " + formal_args_str

acutal_args_str = ", ".join("*" + arg.name for arg in params)
acutal_args_str = ", ".join("*" + name for name in param_names)

include_str = "\n".join([f"#include <{include}>" for include in includes])

Expand Down
36 changes: 36 additions & 0 deletions numbast/tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

from numbast.utils import make_function_shim


class _DummyType:
def __init__(self, name: str):
self.unqualified_non_ref_type_name = name


class _DummyParam:
def __init__(self, type_name: str, name: str = ""):
self.type_ = _DummyType(type_name)
self.name = name
self.unqualified_non_ref_type_name = type_name


def test_make_function_shim_names_unnamed_parameters():
params = [_DummyParam("Foo", "")]

shim = make_function_shim("shim", "useFoo", "bool", params)

assert "Foo* arg0" in shim
assert "useFoo(*arg0);" in shim


def test_make_function_shim_disambiguates_duplicate_names():
params = [_DummyParam("Foo", "x"), _DummyParam("Bar", "x")]

shim = make_function_shim("shim", "useFoo", "bool", params)

assert "Foo* x" in shim
# The second argument should be suffixed to avoid clashing with the first.
assert "Bar* x_1" in shim
assert "useFoo(*x, *x_1);" in shim