Skip to content

Commit c41116b

Browse files
authored
Allow Prefix Removal for Struct and Enums (#259)
This is a sub-issue from #247 . Previously we have prefix removal feature for standalone functions. This PR expands the feature to Struct and Enums. Assuming a prefix named `_prefix_` is to be removed, a struct named `_prefix_Foo` is named `Foo` in python. A enum defined below: ```c++ enum _prefix_Fruit { _prefix_Apple, _prefix_Banana }; ``` will be exported as: ```Python Fruit(IntEnum): Apple = 0 Banana = 1 ``` in python. This is to say that the prefix removal is applicable to both enum name as well as the members. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Configurable prefix removal for enums, structs, and functions so exported Python names omit configured prefixes. * Enum symbol list is now exposed in module exports for easier discovery. * **Refactor** * Centralized prefix-removal and unified Python-facing name handling across binding generation and symbol registration. * Rendering internals now propagate Python-visible names consistently. * **Tests** * Added/updated tests to verify prefix removal and symbol exposure; tests invoke kernels directly and reset renderer/function caches. * **Documentation** * Clarified enum registration docstring. <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: Michael Wang <[email protected]>
1 parent ae49180 commit c41116b

File tree

11 files changed

+311
-210
lines changed

11 files changed

+311
-210
lines changed

numbast/src/numbast/static/enum.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from numbast.static.renderer import BaseRenderer, get_rendered_imports
1111
from numbast.static.types import register_enum_type_str
12+
from numbast.utils import _apply_prefix_removal
1213

1314
file_logger = getLogger(f"{__name__}")
1415
logger_path = os.path.join(tempfile.gettempdir(), "test.py")
@@ -28,28 +29,38 @@ class {enum_name}(IntEnum):
2829
"""
2930
enumerator_template = " {enumerator} = {value}"
3031

31-
def __init__(self, decl: Enum):
32+
def __init__(
33+
self, decl: Enum, enum_prefix_removal: list[str] | None = None
34+
):
3235
self._decl = decl
36+
self._enum_prefix_removal = enum_prefix_removal or []
37+
38+
self._enum_name = _apply_prefix_removal(
39+
self._decl.name, self._enum_prefix_removal
40+
)
41+
42+
self._enum_symbols.append(self._enum_name)
3343

3444
def _render(self):
3545
self.Imports.add("from enum import IntEnum")
3646
self.Imports.add("from numba.types import IntEnumMember")
3747
self.Imports.add("from numba.types import int64")
3848

39-
register_enum_type_str(self._decl.name, self._decl.name)
49+
register_enum_type_str(self._decl.name, self._enum_name)
4050

4151
enumerators = []
4252
for enumerator, value in zip(
4353
self._decl.enumerators, self._decl.enumerator_values
4454
):
55+
py_name = _apply_prefix_removal(
56+
enumerator, self._enum_prefix_removal
57+
)
4558
enumerators.append(
46-
self.enumerator_template.format(
47-
enumerator=enumerator, value=value
48-
)
59+
self.enumerator_template.format(enumerator=py_name, value=value)
4960
)
5061

5162
self._python_rendered = self.enum_template.format(
52-
enum_name=self._decl.name, enumerators="\n".join(enumerators)
63+
enum_name=self._enum_name, enumerators="\n".join(enumerators)
5364
)
5465

5566

@@ -59,18 +70,21 @@ class StaticEnumsRenderer(BaseRenderer):
5970
Since enums creates a new C++ type. It should be invoked before making struct / function bindings.
6071
"""
6172

62-
def __init__(self, decls: list[Enum]):
73+
def __init__(
74+
self, decls: list[Enum], enum_prefix_removal: list[str] | None = None
75+
):
6376
super().__init__(decls)
6477
self._decls = decls
78+
self._enum_prefix_removal = enum_prefix_removal or []
6579

66-
self._python_rendered = []
80+
self._python_rendered: list[str] = []
6781

6882
def _render(self, with_imports):
6983
"""Render python bindings for enums."""
7084
self._python_str = ""
7185

7286
for decl in self._decls:
73-
SER = StaticEnumRenderer(decl)
87+
SER = StaticEnumRenderer(decl, self._enum_prefix_removal)
7488
SER._render()
7589
self._python_rendered.append(SER._python_rendered)
7690

numbast/src/numbast/static/function.py

Lines changed: 4 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
get_shim,
1717
)
1818
from numbast.static.types import to_numba_type_str
19-
from numbast.utils import make_function_shim
19+
from numbast.utils import make_function_shim, _apply_prefix_removal
2020
from numbast.errors import TypeNotFoundError, MangledFunctionNameConflictError
2121

2222
from ast_canopy.decl import Function
@@ -413,8 +413,9 @@ def __init__(
413413
function_prefix_removal: list[str] = [],
414414
):
415415
super().__init__(decl, header_path, use_cooperative)
416-
self._function_prefix_removal = function_prefix_removal
417-
self._python_func_name = self._apply_prefix_removal(self._decl.name)
416+
self._python_func_name = _apply_prefix_removal(
417+
decl.name, function_prefix_removal
418+
)
418419

419420
# Override the base class symbol tracking to use the Python function name
420421
# Remove the original name that was added by the base class
@@ -423,25 +424,6 @@ def __init__(
423424
# Add the Python function name (with prefix removal applied)
424425
self._function_symbols.append(self._python_func_name)
425426

426-
def _apply_prefix_removal(self, name: str) -> str:
427-
"""Apply prefix removal to a function name based on the configuration.
428-
429-
Parameters
430-
----------
431-
name : str
432-
The original function name
433-
434-
Returns
435-
-------
436-
str
437-
The function name with prefixes removed
438-
"""
439-
for prefix in self._function_prefix_removal:
440-
if name.startswith(prefix):
441-
return name[len(prefix) :]
442-
443-
return name
444-
445427
@property
446428
def func_name_python(self):
447429
"""The name of the function in python with prefix removal applied."""

numbast/src/numbast/static/renderer.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,9 @@ def reset(self):
8383
_function_symbols: list[str] = []
8484
"""List of new function handles to expose."""
8585

86+
_enum_symbols: list[str] = []
87+
"""List of new enum handles to expose."""
88+
8689
def __init__(self, decl):
8790
self.Imports.add("import numba")
8891
self.Imports.add("import io")
@@ -128,6 +131,7 @@ def clear_base_renderer_cache():
128131
BaseRenderer._nbtype_symbols.clear()
129132
BaseRenderer._record_symbols.clear()
130133
BaseRenderer._function_symbols.clear()
134+
BaseRenderer._enum_symbols.clear()
131135

132136

133137
def get_reproducible_info(
@@ -245,6 +249,18 @@ def _get_function_symbols() -> str:
245249
return code
246250

247251

252+
def _get_enum_symbols() -> str:
253+
template = """
254+
_ENUM_SYMBOLS = [{enum_symbols}]
255+
"""
256+
257+
symbols = BaseRenderer._enum_symbols
258+
quote_wrapped = [f'"{s}"' for s in symbols]
259+
concat = ",".join(quote_wrapped)
260+
code = template.format(enum_symbols=concat)
261+
return code
262+
263+
248264
def get_all_exposed_symbols() -> str:
249265
"""Return the definition of all exposed symbols via `__all__`.
250266
@@ -257,13 +273,14 @@ def get_all_exposed_symbols() -> str:
257273
nbtype_symbols = _get_nbtype_symbols()
258274
record_symbols = _get_record_symbols()
259275
function_symbols = _get_function_symbols()
276+
enum_symbols = _get_enum_symbols()
260277

261278
all_symbols = f"""
262279
{nbtype_symbols}
263280
{record_symbols}
264281
{function_symbols}
265-
266-
__all__ = _NBTYPE_SYMBOLS + _RECORD_SYMBOLS + _FUNCTION_SYMBOLS
282+
{enum_symbols}
283+
__all__ = _NBTYPE_SYMBOLS + _RECORD_SYMBOLS + _FUNCTION_SYMBOLS + _ENUM_SYMBOLS
267284
"""
268285

269286
return all_symbols

0 commit comments

Comments
 (0)