diff --git a/integration_tests/CMakeLists.txt b/integration_tests/CMakeLists.txt index 3fb00bc1d5..9c20af5421 100644 --- a/integration_tests/CMakeLists.txt +++ b/integration_tests/CMakeLists.txt @@ -289,6 +289,7 @@ RUN(NAME bindc_02 LABELS cpython llvm c) RUN(NAME bindc_04 LABELS llvm c) RUN(NAME bindc_07 LABELS cpython llvm c) RUN(NAME bindc_08 LABELS cpython llvm c) +RUN(NAME bindc_09 LABELS cpython llvm c) RUN(NAME exit_01 LABELS cpython llvm c) RUN(NAME exit_02 FAIL LABELS cpython llvm c) RUN(NAME exit_03 LABELS cpython llvm c wasm wasm_x86 wasm_x64) diff --git a/integration_tests/bindc_09.py b/integration_tests/bindc_09.py new file mode 100644 index 0000000000..18fdaad29e --- /dev/null +++ b/integration_tests/bindc_09.py @@ -0,0 +1,43 @@ +from enum import Enum + +from lpython import CPtr, c_p_pointer, p_c_pointer, dataclass, empty_c_void_p, pointer, Pointer, i32, ccallable + +class Value(Enum): + TEN: i32 = 10 + TWO: i32 = 2 + ONE: i32 = 1 + FIVE: i32 = 5 + +@dataclass +class Foo: + value: Value + +@ccallable +@dataclass +class FooC: + value: Value + +def bar(foo_ptr: CPtr) -> None: + foo: Pointer[Foo] = c_p_pointer(foo_ptr, Foo) + foo.value = Value.FIVE + +def barc(foo_ptr: CPtr) -> None: + foo: Pointer[FooC] = c_p_pointer(foo_ptr, FooC) + foo.value = Value.ONE + +def main() -> None: + foo: Foo = Foo(Value.TEN) + fooc: FooC = FooC(Value.TWO) + foo_ptr: CPtr = empty_c_void_p() + + p_c_pointer(pointer(foo), foo_ptr) + bar(foo_ptr) + print(foo.value, foo.value.name) + assert foo.value == Value.FIVE + + p_c_pointer(pointer(fooc), foo_ptr) + barc(foo_ptr) + print(fooc.value) + assert fooc.value == Value.ONE.value + +main() diff --git a/integration_tests/structs_15.py b/integration_tests/structs_15.py index 4918cf15fd..7341d42d60 100644 --- a/integration_tests/structs_15.py +++ b/integration_tests/structs_15.py @@ -1,5 +1,6 @@ -from lpython import i32, i16, i8, i64, CPtr, dataclass, ccall, Pointer, c_p_pointer, sizeof +from lpython import i32, i16, i8, CPtr, dataclass, ccall, Pointer, c_p_pointer, sizeof, ccallable +@ccallable @dataclass class A: x: i16 diff --git a/src/runtime/lpython/lpython.py b/src/runtime/lpython/lpython.py index 9d8fd0d85a..23dfdf87e2 100644 --- a/src/runtime/lpython/lpython.py +++ b/src/runtime/lpython/lpython.py @@ -50,6 +50,9 @@ def __class_getitem__(key): return py_dataclass(arg) +def is_ctypes_Structure(obj): + return (isclass(obj) and issubclass(obj, ctypes.Structure)) + def is_dataclass(obj): return ((isclass(obj) and issubclass(obj, ctypes.Structure)) or py_is_dataclass(obj)) @@ -236,6 +239,7 @@ class c_double_complex(c_complex): _fields_ = [("real", ctypes.c_double), ("imag", ctypes.c_double)] def convert_type_to_ctype(arg): + from enum import Enum if arg == f64: return ctypes.c_double elif arg == f32: @@ -275,6 +279,9 @@ def convert_type_to_ctype(arg): return ctypes.POINTER(type) elif is_dataclass(arg): return convert_to_ctypes_Structure(arg) + elif issubclass(arg, Enum): + # TODO: store enum in ctypes.Structure with name and value as fields. + return ctypes.c_int64 else: raise NotImplementedError("Type %r not implemented" % arg) @@ -422,6 +429,7 @@ def __init__(self, *args): super().__init__(*args) for field, arg in zip(self._fields_, args): + from enum import Enum member = self.__getattribute__(field[0]) value = arg if isinstance(member, ctypes.Array): @@ -434,6 +442,8 @@ def __init__(self, *args): value = value.flatten().tolist() value = [c_double_complex(val.real, val.imag) for val in value] value = type(member)(*value) + elif isinstance(value, Enum): + value = value.value self.__setattr__(field[0], value) ctypes_Structure.__name__ = f.__name__ @@ -515,6 +525,7 @@ def __getattr__(self, name: str): def __setattr__(self, name: str, value): name_ = self.ctypes_ptr.contents.__getattribute__(name) + from enum import Enum if isinstance(name_, c_float_complex): if isinstance(value, complex): value = c_float_complex(value.real, value.imag) @@ -535,6 +546,8 @@ def __setattr__(self, name: str, value): value = value.flatten().tolist() value = [c_double_complex(val.real, val.imag) for val in value] value = type(name_)(*value) + elif isinstance(value, Enum): + value = value.value self.ctypes_ptr.contents.__setattr__(name, value) def c_p_pointer(cptr, targettype): @@ -545,9 +558,14 @@ def c_p_pointer(cptr, targettype): newa = ctypes.cast(cptr, targettype_ptr) return newa else: + if py_is_dataclass(targettype): + if cptr.value is None: + return None + return ctypes.cast(cptr, ctypes.py_object).value + targettype_ptr = ctypes.POINTER(targettype_ptr) newa = ctypes.cast(cptr, targettype_ptr) - if is_dataclass(targettype): + if is_ctypes_Structure(targettype): # return after wrapping newa inside PointerToStruct return PointerToStruct(newa) return newa