Skip to content

Add lower_bounds in CPtrToPointer #1822

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions integration_tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,7 @@ RUN(NAME bindc_01 LABELS cpython llvm c)
RUN(NAME bindc_02 LABELS cpython llvm c)
RUN(NAME bindc_04 LABELS llvm c)
RUN(NAME bindc_07 LABELS cpython llvm c)
RUN(NAME bindc_08 LABELS cpython llvm c)
RUN(NAME exit_01 LABELS cpython llvm c)
RUN(NAME exit_02 FAIL LABELS cpython llvm c)
RUN(NAME exit_03 LABELS cpython llvm c wasm wasm_x86 wasm_x64)
Expand Down
26 changes: 26 additions & 0 deletions integration_tests/bindc_08.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# file: main.py
from lpython import CPtr, i32, dataclass, c_p_pointer, Pointer, empty_c_void_p, p_c_pointer

from numpy import empty

@dataclass
class Foo:
x: i32
y: i32

def init(foos_ptr: CPtr) -> None:
foos: Pointer[Foo[1]] = c_p_pointer(foos_ptr, Foo[1])
foos[0] = Foo(3, 2)

def main() -> None:
foos: Foo[1] = empty(1, dtype=Foo)
foos_ptr: CPtr = empty_c_void_p()
foos[0] = Foo(0, 1)
p_c_pointer(foos, foos_ptr)
init(foos_ptr)
print("foos[0].x = ", foos[0].x)
print("foos[0].y = ", foos[0].y)
assert foos[0].x == 3
assert foos[0].y == 2

main()
2 changes: 1 addition & 1 deletion src/libasr/ASR.asdl
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ stmt
| ListAppend(expr a, expr ele)
| AssociateBlockCall(symbol m)
| SelectType(expr selector, type_stmt* body, stmt* default)
| CPtrToPointer(expr cptr, expr ptr, expr? shape)
| CPtrToPointer(expr cptr, expr ptr, expr? shape, expr? lower_bounds)
| BlockCall(int label, symbol m)
| SetInsert(expr a, expr ele)
| SetRemove(expr a, expr ele)
Expand Down
9 changes: 7 additions & 2 deletions src/libasr/codegen/asr_to_c.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1090,14 +1090,19 @@ R"(
std::string dest_src = std::move(src);
src = "";
std::string indent(indentation_level*indentation_spaces, ' ');
ASR::ArrayConstant_t* lower_bounds = nullptr;
if( x.m_lower_bounds ) {
LCOMPILERS_ASSERT(ASR::is_a<ASR::ArrayConstant_t>(*x.m_lower_bounds));
lower_bounds = ASR::down_cast<ASR::ArrayConstant_t>(x.m_lower_bounds);
}
if( ASRUtils::is_array(ASRUtils::expr_type(x.m_ptr)) ) {
std::string dim_set_code = "";
ASR::dimension_t* m_dims = nullptr;
int n_dims = ASRUtils::extract_dimensions_from_ttype(ASRUtils::expr_type(x.m_ptr), m_dims);
dim_set_code = indent + dest_src + "->n_dims = " + std::to_string(n_dims) + ";\n";
for( int i = 0; i < n_dims; i++ ) {
if( m_dims[i].m_start ) {
visit_expr(*m_dims[i].m_start);
if( lower_bounds ) {
visit_expr(*lower_bounds->m_args[i]);
} else {
src = "0";
}
Expand Down
58 changes: 18 additions & 40 deletions src/libasr/codegen/asr_to_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4400,6 +4400,12 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
llvm_cptr = builder->CreateBitCast(llvm_cptr, llvm_fptr_data_type->getPointerTo());
builder->CreateStore(llvm_cptr, fptr_data);
llvm::Value* prod = llvm::ConstantInt::get(context, llvm::APInt(32, 1));
ASR::ArrayConstant_t* lower_bounds = nullptr;
if( x.m_lower_bounds ) {
LCOMPILERS_ASSERT(ASR::is_a<ASR::ArrayConstant_t>(*x.m_lower_bounds));
lower_bounds = ASR::down_cast<ASR::ArrayConstant_t>(x.m_lower_bounds);
LCOMPILERS_ASSERT(fptr_rank == (int) lower_bounds->n_args);
}
for( int i = 0; i < fptr_rank; i++ ) {
llvm::Value* curr_dim = llvm::ConstantInt::get(context, llvm::APInt(32, i));
llvm::Value* desi = arr_descr->get_pointer_to_dimension_descriptor(fptr_des, curr_dim);
Expand All @@ -4409,6 +4415,13 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
builder->CreateStore(prod, desi_stride);
llvm::Value* i32_one = llvm::ConstantInt::get(context, llvm::APInt(32, 1));
llvm::Value* new_lb = i32_one;
if( lower_bounds ) {
int ptr_loads_copy = ptr_loads;
ptr_loads = 2;
this->visit_expr_wrapper(lower_bounds->m_args[i], true);
ptr_loads = ptr_loads_copy;
new_lb = tmp;
}
llvm::Value* new_ub = shape_data ? CreateLoad(llvm_utils->create_ptr_gep(shape_data, i)) : i32_one;
builder->CreateStore(new_lb, desi_lb);
llvm::Value* new_size = builder->CreateAdd(builder->CreateSub(new_ub, new_lb), i32_one);
Expand Down Expand Up @@ -6048,46 +6061,11 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
for (size_t i=0; i < x.n_args; i++) {
llvm::Value *llvm_el = llvm_utils->create_gep(p_fxn, i);
ASR::expr_t *el = x.m_args[i];
llvm::Value *llvm_val;
if (ASR::is_a<ASR::Integer_t>(*x.m_type)) {
ASR::IntegerConstant_t *ci = ASR::down_cast<ASR::IntegerConstant_t>(el);
switch (ASR::down_cast<ASR::Integer_t>(x.m_type)->m_kind) {
case (4) : {
int32_t el_value = ci->m_n;
llvm_val = llvm::ConstantInt::get(context, llvm::APInt(32, static_cast<int32_t>(el_value), true));
break;
}
case (8) : {
int64_t el_value = ci->m_n;
llvm_val = llvm::ConstantInt::get(context, llvm::APInt(32, el_value, true));
break;
}
default :
throw CodeGenError("ConstArray integer kind not supported yet");
}
} else if (ASR::is_a<ASR::Real_t>(*x.m_type)) {
ASR::RealConstant_t *cr = ASR::down_cast<ASR::RealConstant_t>(el);
switch (ASR::down_cast<ASR::Real_t>(x.m_type)->m_kind) {
case (4) : {
float el_value = cr->m_r;
llvm_val = llvm::ConstantFP::get(context, llvm::APFloat(el_value));
break;
}
case (8) : {
double el_value = cr->m_r;
llvm_val = llvm::ConstantFP::get(context, llvm::APFloat(el_value));
break;
}
default :
throw CodeGenError("ConstArray real kind not supported yet");
}
} else if (ASR::is_a<ASR::Logical_t>(*x.m_type)) {
ASR::LogicalConstant_t *cr = ASR::down_cast<ASR::LogicalConstant_t>(el);
llvm_val = llvm::ConstantInt::get(context, llvm::APInt(1, cr->m_value));
} else {
throw CodeGenError("ConstArray type not supported yet");
}
builder->CreateStore(llvm_val, llvm_el);
int64_t ptr_loads_copy = ptr_loads;
ptr_loads = 2;
this->visit_expr_wrapper(el, true);
ptr_loads = ptr_loads_copy;
builder->CreateStore(tmp, llvm_el);
}
// Return the vector as float* type:
tmp = llvm_utils->create_gep(p_fxn, 0);
Expand Down
39 changes: 35 additions & 4 deletions src/lpython/semantics/python_ast_to_asr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2490,6 +2490,35 @@ class CommonVisitor : public AST::BaseVisitor<Struct> {
current_scope->add_or_overwrite_symbol(var_name, v_sym);
}

#define fill_shape_and_lower_bound_for_CPtrToPointer() ASR::dimension_t* target_dims = nullptr; \
int target_n_dims = ASRUtils::extract_dimensions_from_ttype(target_type, target_dims); \
ASR::expr_t* target_shape = nullptr; \
ASR::expr_t* lower_bounds = nullptr; \
if( target_n_dims > 0 ) { \
Vec<ASR::expr_t*> sizes, lbs; \
sizes.reserve(al, target_n_dims); \
lbs.reserve(al, target_n_dims); \
bool success = true; \
for( int i = 0; i < target_n_dims; i++ ) { \
if( target_dims->m_length == nullptr ) { \
success = false; \
break; \
} \
sizes.push_back(al, target_dims->m_length); \
lbs.push_back(al, ASRUtils::EXPR(ASR::make_IntegerConstant_t( \
al, loc, 0, ASRUtils::TYPE( \
ASR::make_Integer_t(al, loc, 4, nullptr, 0))))); \
} \
if( success ) { \
target_shape = ASRUtils::EXPR(ASR::make_ArrayConstant_t(al, \
loc, sizes.p, sizes.size(), ASRUtils::expr_type(target_dims[0].m_length), \
ASR::arraystorageType::RowMajor)); \
lower_bounds = ASRUtils::EXPR(ASR::make_ArrayConstant_t(al, \
loc, lbs.p, lbs.size(), ASRUtils::expr_type(lbs[0]), \
ASR::arraystorageType::RowMajor)); \
} \
} \

ASR::asr_t* create_CPtrToPointerFromArgs(AST::expr_t* ast_cptr, AST::expr_t* ast_pptr,
AST::expr_t* ast_type_expr, const Location& loc) {
this->visit_expr(*ast_cptr);
Expand All @@ -2509,8 +2538,8 @@ class CommonVisitor : public AST::BaseVisitor<Struct> {
);
throw SemanticAbort();
}
return ASR::make_CPtrToPointer_t(al, loc, cptr,
pptr, nullptr);
fill_shape_and_lower_bound_for_CPtrToPointer();
return ASR::make_CPtrToPointer_t(al, loc, cptr, pptr, target_shape, lower_bounds);
}

void visit_AnnAssignUtil(const AST::AnnAssign_t& x, std::string& var_name,
Expand Down Expand Up @@ -6125,8 +6154,10 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
);
throw SemanticAbort();
}
return ASR::make_CPtrToPointer_t(al, x.base.base.loc, cptr,
pptr, nullptr);
const Location& loc = x.base.base.loc;
fill_shape_and_lower_bound_for_CPtrToPointer();
return ASR::make_CPtrToPointer_t(al, loc, cptr,
pptr, target_shape, lower_bounds);
}

ASR::asr_t* create_PointerToCPtr(const AST::Call_t& x) {
Expand Down
9 changes: 8 additions & 1 deletion src/runtime/lpython/lpython.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,10 @@ def __call__(self, arg):
return self._convert(arg)

def dataclass(arg):
arg.__class_getitem__ = lambda self: None
def __class_getitem__(key):
return Array(arg, key)
arg.__class_getitem__ = __class_getitem__

return py_dataclass(arg)

def is_dataclass(obj):
Expand Down Expand Up @@ -254,6 +257,8 @@ def convert_type_to_ctype(arg):
elif arg is None:
raise NotImplementedError("Type cannot be None")
elif isinstance(arg, Array):
if is_dataclass(arg._type):
return arg
type = convert_type_to_ctype(arg._type)
return ctypes.POINTER(type)
elif is_dataclass(arg):
Expand Down Expand Up @@ -523,6 +528,8 @@ def __setattr__(self, name: str, value):
def c_p_pointer(cptr, targettype):
targettype_ptr = convert_type_to_ctype(targettype)
if isinstance(targettype, Array):
if py_is_dataclass(targettype._type):
return ctypes.cast(cptr.value, ctypes.py_object).value
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cptr.value is nothing but the address of foos: Foo[1] = empty(1, dtype=Foo). And ctypes.cast(cptr.value, ctypes.py_object).value just gives the original numpy array back, so effectively it becomes a pointer and you can easily make changes to it in other functions.

That's the only way forward if we want to keep using numpy arrays for arrays of dataclasses. If we can avoid using numpy arrays for dataclasses then we can define our own empty in lpython.py or in our own version of numpy.py. In our empty we can just create a ctypes.Structure array and return it via empty. For all the types like i32, etc we can keep calling numpy.empty. But I would not go for it unless extremely necessary.

newa = ctypes.cast(cptr, targettype_ptr)
return newa
else:
Expand Down
2 changes: 1 addition & 1 deletion tests/reference/asr-bindc_01-6d521a9.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"outfile": null,
"outfile_hash": null,
"stdout": "asr-bindc_01-6d521a9.stdout",
"stdout_hash": "d02c57ff6ddb41568c291b11a31301870bf2bc3a970461a71ec23a9d",
"stdout_hash": "26baf870cb5fc3b568cb3d4bf45527713699c39609b6598186843330",
"stderr": null,
"stderr_hash": null,
"returncode": 0
Expand Down
1 change: 1 addition & 0 deletions tests/reference/asr-bindc_01-6d521a9.stdout
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
(Var 5 queries)
(Var 5 x)
()
()
)
(Print
()
Expand Down
2 changes: 1 addition & 1 deletion tests/reference/asr-bindc_02-bc1a7ea.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"outfile": null,
"outfile_hash": null,
"stdout": "asr-bindc_02-bc1a7ea.stdout",
"stdout_hash": "a74aa56cff206d4ef8fb0766f1cf596c122255882a7df3f5e4fcf4e7",
"stdout_hash": "dcf32037db14152961ae628f9687bfe7dfba3a34bb5a0e8da314fed1",
"stderr": null,
"stderr_hash": null,
"returncode": 0
Expand Down
2 changes: 2 additions & 0 deletions tests/reference/asr-bindc_02-bc1a7ea.stdout
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
(Var 196 queries)
(Var 196 x)
()
()
)
(Print
()
Expand Down Expand Up @@ -271,6 +272,7 @@
(Var 193 yq)
(Var 193 yptr1)
()
()
)
(Print
()
Expand Down
2 changes: 1 addition & 1 deletion tests/reference/asr-structs_02-2ab459a.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"outfile": null,
"outfile_hash": null,
"stdout": "asr-structs_02-2ab459a.stdout",
"stdout_hash": "ae8e8d2163b51eb20e19e6257618899aa4fbe78452e760a38608651d",
"stdout_hash": "c2089c22d28d4a379936dcc63db58a39094875021b2a12bd51e70583",
"stderr": null,
"stderr_hash": null,
"returncode": 0
Expand Down
1 change: 1 addition & 0 deletions tests/reference/asr-structs_02-2ab459a.stdout
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,7 @@
(Var 3 a)
(Var 3 a2)
()
()
)
(Print
()
Expand Down