Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 38 additions & 9 deletions ast_canopy/ast_canopy/decl.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,12 @@
}


class Function:
class Decl:
def __init__(self, source_location: bindings.SourceLocation):
self.source_location = source_location


class Function(Decl):
"""
Represents a C++ function.

Expand All @@ -75,7 +80,9 @@ def __init__(
is_constexpr: bool,
mangled_name: str,
parse_entry_point: str,
source_location: bindings.SourceLocation,
):
super().__init__(source_location)
self.name = name
self.return_type = return_type
self.params = params
Expand Down Expand Up @@ -163,6 +170,7 @@ def from_c_obj(cls, c_obj: bindings.Function, parse_entry_point: str):
c_obj.is_constexpr,
c_obj.mangled_name,
parse_entry_point,
c_obj.source_location,
)


Expand All @@ -172,15 +180,17 @@ def __init__(self, template_parameters, num_min_required_args):
self.num_min_required_args = num_min_required_args


class FunctionTemplate(Template):
class FunctionTemplate(Decl, Template):
def __init__(
self,
template_parameters: list[bindings.TemplateParam],
num_min_required_args: int,
function: Function,
parse_entry_point: str,
source_location: bindings.SourceLocation,
):
super().__init__(template_parameters, num_min_required_args)
Decl.__init__(self, source_location)
Template.__init__(self, template_parameters, num_min_required_args)
self.function = function

self.parse_entry_point = parse_entry_point
Expand All @@ -194,6 +204,7 @@ def from_c_obj(
c_obj.num_min_required_args,
Function.from_c_obj(c_obj.function, parse_entry_point),
parse_entry_point,
c_obj.source_location,
)

def instantiate(self, **kwargs):
Expand All @@ -213,6 +224,7 @@ def __init__(
is_move_constructor: bool,
mangled_name: str,
parse_entry_point: str,
source_location: bindings.SourceLocation,
):
super().__init__(
name,
Expand All @@ -222,6 +234,7 @@ def __init__(
is_constexpr,
mangled_name,
parse_entry_point,
source_location,
)
self.kind = kind
self.is_move_constructor = is_move_constructor
Expand Down Expand Up @@ -252,6 +265,7 @@ def from_c_obj(cls, c_obj: bindings.Method, parse_entry_point: str):
c_obj.is_move_constructor(),
c_obj.mangled_name,
parse_entry_point,
c_obj.source_location,
)


Expand All @@ -273,7 +287,7 @@ def decl_name(self):
return self.name


class Struct:
class Struct(Decl):
def __init__(
self,
name: str,
Expand All @@ -285,7 +299,9 @@ def __init__(
sizeof_: int,
alignof_: int,
parse_entry_point: str,
source_location: bindings.SourceLocation,
):
super().__init__(source_location)
self.name = name
self.fields = fields
self.methods = methods
Expand Down Expand Up @@ -330,6 +346,7 @@ def from_c_obj(cls, c_obj: bindings.Record, parse_entry_point: str):
c_obj.sizeof_,
c_obj.alignof_,
parse_entry_point,
c_obj.source_location,
)


Expand All @@ -354,18 +371,21 @@ def from_c_obj(cls, c_obj: bindings.Record, parse_entry_point: str):
c_obj.sizeof_,
c_obj.alignof_,
parse_entry_point,
c_obj.source_location,
)


class ClassTemplate(Template):
class ClassTemplate(Decl, Template):
def __init__(
self,
record: TemplatedStruct,
template_parameters: list[bindings.TemplateParam],
num_min_required_args: int,
parse_entry_point: str,
source_location: bindings.SourceLocation,
):
super().__init__(template_parameters, num_min_required_args)
Decl.__init__(self, source_location)
Template.__init__(self, template_parameters, num_min_required_args)
self.record = record

self.parse_entry_point = parse_entry_point
Expand All @@ -377,22 +397,31 @@ def from_c_obj(cls, c_obj: bindings.ClassTemplate, parse_entry_point: str):
c_obj.template_parameters,
c_obj.num_min_required_args,
parse_entry_point,
c_obj.source_location,
)

def instantiate(self, **kwargs):
tstruct = ClassInstantiation(self)
return tstruct.instantiate(**kwargs)


class ConstExprVar:
def __init__(self, name: str, type_: bindings.Type, value_serialized: str):
class ConstExprVar(Decl):
def __init__(
self,
name: str,
type_: bindings.Type,
value_serialized: str,
source_location: bindings.SourceLocation,
):
super().__init__(source_location)

self.name = name
self.type_ = type_
self.value_serialized = value_serialized

@classmethod
def from_c_obj(cls, c_obj: bindings.ConstExprVar):
return cls(c_obj.name, c_obj.type_, c_obj.value)
return cls(c_obj.name, c_obj.type_, c_obj.value, c_obj.source_location)

@property
def value(self):
Expand Down
34 changes: 24 additions & 10 deletions ast_canopy/ast_canopy/pylibastcanopy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,21 @@ PYBIND11_MODULE(pylibastcanopy, m) {
.value("protected_", access_kind::protected_)
.value("private_", access_kind::private_);

py::class_<Enum>(m, "Enum")
py::class_<SourceLocation>(m, "SourceLocation")
.def(py::init<>())
.def(py::init<const std::string &, const unsigned int, const unsigned int,
const bool>())
.def_property_readonly("file_name", &SourceLocation::file_name)
.def_property_readonly("line", &SourceLocation::line)
.def_property_readonly("column", &SourceLocation::column)
.def_property_readonly("is_valid", &SourceLocation::is_valid);

py::class_<Decl>(m, "Decl")
.def(py::init())
.def(py::init<const clang::Decl *>())
.def_readwrite("source_location", &Decl::source_location);

py::class_<Enum, Decl>(m, "Enum")
.def(py::init<const clang::EnumDecl *>())
.def_readwrite("name", &Enum::name)
.def_readwrite("enumerators", &Enum::enumerators)
Expand Down Expand Up @@ -82,13 +96,13 @@ PYBIND11_MODULE(pylibastcanopy, m) {
t[2].cast<bool>(), t[3].cast<bool>()};
}));

py::class_<ConstExprVar>(m, "ConstExprVar")
py::class_<ConstExprVar, Decl>(m, "ConstExprVar")
.def(py::init<>())
.def_readwrite("type_", &ConstExprVar::type_)
.def_readwrite("name", &ConstExprVar::name)
.def_readwrite("value", &ConstExprVar::value);

py::class_<Field>(m, "Field")
py::class_<Field, Decl>(m, "Field")
.def_readwrite("name", &Field::name)
.def_readwrite("type_", &Field::type)
.def_readwrite("access", &Field::access)
Expand All @@ -107,7 +121,7 @@ PYBIND11_MODULE(pylibastcanopy, m) {
t[2].cast<access_kind>()};
}));

py::class_<ParamVar>(m, "ParamVar")
py::class_<ParamVar, Decl>(m, "ParamVar")
.def(py::init<std::string, Type>())
.def_readwrite("name", &ParamVar::name)
.def_readwrite("type_", &ParamVar::type)
Expand All @@ -124,7 +138,7 @@ PYBIND11_MODULE(pylibastcanopy, m) {
return ParamVar{t[0].cast<std::string>(), t[1].cast<Type>()};
}));

py::class_<TemplateParam>(m, "TemplateParam")
py::class_<TemplateParam, Decl>(m, "TemplateParam")
.def_readwrite("name", &TemplateParam::name)
.def_readwrite("type_", &TemplateParam::type)
.def_readwrite("kind", &TemplateParam::kind)
Expand All @@ -145,7 +159,7 @@ PYBIND11_MODULE(pylibastcanopy, m) {
t[2].cast<Type>()};
}));

py::class_<Function>(m, "Function")
py::class_<Function, Decl>(m, "Function")
.def_readwrite("name", &Function::name)
.def_readwrite("return_type", &Function::return_type)
.def_readwrite("params", &Function::params)
Expand Down Expand Up @@ -183,7 +197,7 @@ PYBIND11_MODULE(pylibastcanopy, m) {
t[1].cast<std::size_t>()};
}));

py::class_<FunctionTemplate, Template>(m, "FunctionTemplate")
py::class_<FunctionTemplate, Decl, Template>(m, "FunctionTemplate")
.def_readwrite("function", &FunctionTemplate::function)
.def_readwrite("num_min_required_args",
&FunctionTemplate::num_min_required_args)
Expand All @@ -202,7 +216,7 @@ PYBIND11_MODULE(pylibastcanopy, m) {
t[1].cast<Function>()};
}));

py::class_<ClassTemplate, Template>(m, "ClassTemplate")
py::class_<ClassTemplate, Decl, Template>(m, "ClassTemplate")
.def_readwrite("num_min_required_args",
&ClassTemplate::num_min_required_args)
.def_readwrite("record", &ClassTemplate::record);
Expand All @@ -224,7 +238,7 @@ PYBIND11_MODULE(pylibastcanopy, m) {
t[1].cast<method_kind>()};
}));

py::class_<Record>(m, "Record")
py::class_<Record, Decl>(m, "Record")
.def_readwrite("name", &Record::name)
.def_readwrite("fields", &Record::fields)
.def_readwrite("methods", &Record::methods)
Expand Down Expand Up @@ -254,7 +268,7 @@ PYBIND11_MODULE(pylibastcanopy, m) {
t[8].cast<std::string>()};
}));

py::class_<Typedef>(m, "Typedef")
py::class_<Typedef, Decl>(m, "Typedef")
.def_readwrite("name", &Typedef::name)
.def_readwrite("underlying_name", &Typedef::underlying_name)
.def(py::pickle(
Expand Down
45 changes: 35 additions & 10 deletions ast_canopy/ast_canopy/pylibastcanopy.pyi
Original file line number Diff line number Diff line change
@@ -1,17 +1,24 @@
from _typeshed import Incomplete
from typing import ClassVar, overload

class ClassTemplate(Template):
class ClassTemplate(Decl, Template):
num_min_required_args: int
record: Incomplete
def __init__(self, *args, **kwargs) -> None: ...

class ConstExprVar:
class ConstExprVar(Decl):
name: str
type_: Type
value: str
def __init__(self) -> None: ...

class Decl:
source_location: SourceLocation
@overload
def __init__(self) -> None: ...
@overload
def __init__(self, arg0) -> None: ...

class Declarations:
class_templates: list[ClassTemplate]
enums: list[Enum]
Expand All @@ -21,26 +28,28 @@ class Declarations:
typedefs: list[Typedef]
def __init__(self, *args, **kwargs) -> None: ...

class Enum:
class Enum(Decl):
enumerator_values: list[str]
enumerators: list[str]
name: str
def __init__(self, arg0) -> None: ...

class Field:
class Field(Decl):
access: access_kind
name: str
type_: Type
def __init__(self, *args, **kwargs) -> None: ...

class Function:
class Function(Decl):
exec_space: execution_space
is_constexpr: bool
mangled_name: str
name: str
params: list[ParamVar]
return_type: Type
def __init__(self, *args, **kwargs) -> None: ...

class FunctionTemplate(Template):
class FunctionTemplate(Decl, Template):
function: Function
num_min_required_args: int
def __init__(self, *args, **kwargs) -> None: ...
Expand All @@ -50,12 +59,14 @@ class Method(Function):
def __init__(self, *args, **kwargs) -> None: ...
def is_move_constructor(self) -> bool: ...

class ParamVar:
class ParamVar(Decl):
name: str
type_: Type
def __init__(self, arg0: str, arg1: Type) -> None: ...

class Record:
class ParseError(Exception): ...

class Record(Decl):
alignof_: int
fields: list[Field]
methods: list[Method]
Expand All @@ -66,12 +77,26 @@ class Record:
templated_methods: list[FunctionTemplate]
def __init__(self, *args, **kwargs) -> None: ...

class SourceLocation:
@overload
def __init__(self) -> None: ...
@overload
def __init__(self, arg0: str, arg1: int, arg2: int, arg3: bool) -> None: ...
@property
def column(self) -> int: ...
@property
def file_name(self) -> str: ...
@property
def is_valid(self) -> bool: ...
@property
def line(self) -> int: ...

class Template:
num_min_required_args: int
template_parameters: list[TemplateParam]
def __init__(self, arg0: list[TemplateParam], arg1: int) -> None: ...

class TemplateParam:
class TemplateParam(Decl):
kind: template_param_kind
name: str
type_: Type
Expand All @@ -89,7 +114,7 @@ class Type:
def is_left_reference(self) -> bool: ...
def is_right_reference(self) -> bool: ...

class Typedef:
class Typedef(Decl):
name: str
underlying_name: str
def __init__(self, *args, **kwargs) -> None: ...
Expand Down
Loading
Loading