diff --git a/integration_tests/CMakeLists.txt b/integration_tests/CMakeLists.txt index c671342acc..5f71d5649c 100644 --- a/integration_tests/CMakeLists.txt +++ b/integration_tests/CMakeLists.txt @@ -559,7 +559,7 @@ RUN(NAME enum_06 LABELS cpython llvm c) RUN(NAME enum_07 IMPORT_PATH .. LABELS cpython llvm c) RUN(NAME union_01 LABELS cpython llvm c) -RUN(NAME union_02 LABELS llvm c) +RUN(NAME union_02 LABELS cpython llvm c) RUN(NAME union_03 LABELS cpython llvm c) RUN(NAME union_04 IMPORT_PATH .. LABELS cpython llvm c) diff --git a/integration_tests/union_02.py b/integration_tests/union_02.py index c64f08463a..67cb4a3254 100644 --- a/integration_tests/union_02.py +++ b/integration_tests/union_02.py @@ -19,9 +19,9 @@ class C: @ccall @union class D(Union): - a: A = A() - b: B = B() - c: C = C() + a: A = A(0, 3.0) + b: B = B(i64(0), 2.0) + c: C = C(i64(0), 0.0, 1.0) def test_struct_union(): d: D = D() diff --git a/src/runtime/lpython/lpython.py b/src/runtime/lpython/lpython.py index 103966f77d..2e01b885b0 100644 --- a/src/runtime/lpython/lpython.py +++ b/src/runtime/lpython/lpython.py @@ -4,6 +4,7 @@ import platform from dataclasses import dataclass as py_dataclass, is_dataclass as py_is_dataclass + # TODO: this does not seem to restrict other imports __slots__ = ["i8", "i16", "i32", "i64", "u8", "u16", "u32", "u64", "f32", "f64", "c32", "c64", "CPtr", "overload", "ccall", "TypeVar", "pointer", "c_p_pointer", "Pointer", @@ -89,10 +90,19 @@ def __init__(self, type, dims): Const = ConstType("Const") Callable = Type("Callable") Allocatable = Type("Allocatable") -Union = ctypes.Union Pointer = PointerType("Pointer") +class Union: + def __init__(self): + pass + + def __setattr__(self, name: str, value): + self.__dict__[name] = value + + def __getattr__(self, name: str): + return self.__dict__[name] + class Intent: def __init__(self, type): self._type = type @@ -381,7 +391,6 @@ def convert_to_ctypes_Union(f): for name in f.__annotations__: ltype_ = f.__annotations__[name] fields.append((name, convert_type_to_ctype(ltype_))) - f._fields_ = fields f.__annotations__ = {} @@ -467,12 +476,16 @@ def inner(fn): def union(f): fields = [] + fa = {} for name in f.__annotations__: ltype_ = f.__annotations__[name] - fields.append((name, convert_type_to_ctype(ltype_))) + ltype_ = convert_type_to_ctype(ltype_) + fa[name] = ltype_ + + fields.append((name, ltype_)) f._fields_ = fields - f.__annotations__ = {} + f.__annotations__ = fa return f def pointer(x, type_=None):