Skip to content

Commit c45140d

Browse files
authored
Merge pull request #1813 from Shaikh-Ubaid/struct_init_name_args
Support struct initialization with named arguments
2 parents 0407d30 + 462a731 commit c45140d

20 files changed

+299
-15
lines changed

integration_tests/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -478,6 +478,8 @@ RUN(NAME structs_20 LABELS cpython llvm c
478478
EXTRAFILES structs_20b.c)
479479
RUN(NAME structs_21 LABELS cpython llvm c)
480480
RUN(NAME structs_22 LABELS cpython llvm c)
481+
RUN(NAME structs_23 LABELS cpython llvm c)
482+
RUN(NAME structs_24 LABELS cpython llvm c)
481483
RUN(NAME sizeof_01 LABELS llvm c
482484
EXTRAFILES sizeof_01b.c)
483485
RUN(NAME sizeof_02 LABELS cpython llvm c)

integration_tests/structs_23.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
from lpython import dataclass, i32, u64, f64
2+
3+
@dataclass
4+
class A:
5+
a: i32
6+
b: i32
7+
8+
@dataclass
9+
class B:
10+
a: u64
11+
b: f64
12+
13+
def main0():
14+
s: A = A(b=-24, a=6)
15+
print(s.a)
16+
print(s.b)
17+
18+
assert s.a == 6
19+
assert s.b == -24
20+
21+
def main1():
22+
s: B = B(u64(22), b=3.14)
23+
print(s.a)
24+
print(s.b)
25+
26+
assert s.a == u64(22)
27+
assert abs(s.b - 3.14) <= 1e-12
28+
29+
main0()
30+
main1()

integration_tests/structs_24.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
from lpython import dataclass, i32, f64, u64
2+
from numpy import array
3+
4+
@dataclass
5+
class Foo:
6+
x: i32
7+
y: i32
8+
9+
def main0() -> None:
10+
foos: Foo[2] = array([Foo(y=2, x=1), Foo(x=3, y=4)])
11+
print(foos[0].x, foos[0].y, foos[1].x, foos[1].y)
12+
13+
assert foos[0].x == 1
14+
assert foos[0].y == 2
15+
assert foos[1].x == 3
16+
assert foos[1].y == 4
17+
18+
main0()

src/lpython/semantics/python_ast_to_asr.cpp

Lines changed: 85 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -893,6 +893,63 @@ class CommonVisitor : public AST::BaseVisitor<Struct> {
893893
return true;
894894
}
895895

896+
int64_t find_argument_position_from_name(ASR::StructType_t* orig_struct, std::string arg_name) {
897+
int64_t arg_position = -1;
898+
for( size_t i = 0; i < orig_struct->n_members; i++ ) {
899+
std::string original_arg_name = std::string(orig_struct->m_members[i]);
900+
if( original_arg_name == arg_name ) {
901+
return i;
902+
}
903+
}
904+
return arg_position;
905+
}
906+
907+
void visit_expr_list(AST::expr_t** pos_args, size_t n_pos_args,
908+
AST::keyword_t* kwargs, size_t n_kwargs,
909+
Vec<ASR::call_arg_t>& call_args_vec,
910+
ASR::StructType_t* orig_struct, const Location &loc) {
911+
LCOMPILERS_ASSERT(call_args_vec.reserve_called);
912+
913+
// Fill the whole call_args_vec with nullptr
914+
// This is for error handling later on.
915+
for( size_t i = 0; i < n_pos_args + n_kwargs; i++ ) {
916+
ASR::call_arg_t call_arg;
917+
Location loc;
918+
loc.first = loc.last = 1;
919+
call_arg.m_value = nullptr;
920+
call_arg.loc = loc;
921+
call_args_vec.push_back(al, call_arg);
922+
}
923+
924+
// Now handle positional arguments in the following loop
925+
for( size_t i = 0; i < n_pos_args; i++ ) {
926+
this->visit_expr(*pos_args[i]);
927+
ASR::expr_t* expr = ASRUtils::EXPR(tmp);
928+
call_args_vec.p[i].loc = expr->base.loc;
929+
call_args_vec.p[i].m_value = expr;
930+
}
931+
932+
// Now handle keyword arguments in the following loop
933+
for( size_t i = 0; i < n_kwargs; i++ ) {
934+
this->visit_expr(*kwargs[i].m_value);
935+
ASR::expr_t* expr = ASRUtils::EXPR(tmp);
936+
std::string arg_name = std::string(kwargs[i].m_arg);
937+
int64_t arg_pos = find_argument_position_from_name(orig_struct, arg_name);
938+
if( arg_pos == -1 ) {
939+
throw SemanticError("Member '" + arg_name + "' not found in struct", kwargs[i].loc);
940+
} else if (arg_pos >= (int64_t)call_args_vec.size()) {
941+
throw SemanticError("Not enough arguments to " + std::string(orig_struct->m_name)
942+
+ "(), expected " + std::to_string(orig_struct->n_members), loc);
943+
}
944+
if( call_args_vec[arg_pos].m_value != nullptr ) {
945+
throw SemanticError(std::string(orig_struct->m_name) + "() got multiple values for argument '"
946+
+ arg_name + "'", kwargs[i].loc);
947+
}
948+
call_args_vec.p[arg_pos].loc = expr->base.loc;
949+
call_args_vec.p[arg_pos].m_value = expr;
950+
}
951+
}
952+
896953
void visit_expr_list_with_cast(ASR::expr_t** m_args, size_t n_args,
897954
Vec<ASR::call_arg_t>& call_args_vec,
898955
Vec<ASR::call_arg_t>& args,
@@ -1195,7 +1252,17 @@ class CommonVisitor : public AST::BaseVisitor<Struct> {
11951252
}
11961253
} else if(ASR::is_a<ASR::StructType_t>(*s)) {
11971254
ASR::StructType_t* StructType = ASR::down_cast<ASR::StructType_t>(s);
1198-
for( size_t i = 0; i < std::min(args.size(), StructType->n_members); i++ ) {
1255+
if (n_kwargs > 0) {
1256+
args.reserve(al, n_pos_args + n_kwargs);
1257+
visit_expr_list(pos_args, n_pos_args, kwargs, n_kwargs,
1258+
args, StructType, loc);
1259+
}
1260+
1261+
if (args.size() > 0 && args.size() != StructType->n_members) {
1262+
throw SemanticError("StructConstructor arguments do not match the number of struct members", loc);
1263+
}
1264+
1265+
for( size_t i = 0; i < args.size(); i++ ) {
11991266
std::string member_name = StructType->m_members[i];
12001267
ASR::Variable_t* member_var = ASR::down_cast<ASR::Variable_t>(
12011268
StructType->m_symtab->resolve_symbol(member_name));
@@ -6599,6 +6666,14 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
65996666
tmp = ASR::make_StringConstant_t(al, loc, s2c(al, s_var), str_type);
66006667
}
66016668

6669+
void parse_args(const AST::Call_t &x, Vec<ASR::call_arg_t> &args) {
6670+
// Keyword arguments handled in make_call_helper()
6671+
if( x.n_keywords == 0 ) {
6672+
args.reserve(al, x.n_args);
6673+
visit_expr_list(x.m_args, x.n_args, args);
6674+
}
6675+
}
6676+
66026677
void visit_Call(const AST::Call_t &x) {
66036678
std::string call_name = "";
66046679
Vec<ASR::call_arg_t> args;
@@ -6612,14 +6687,9 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
66126687
tmp = nullptr;
66136688
return ;
66146689
}
6615-
// Keyword arguments handled in make_call_helper
6616-
#define parse_args() if( x.n_keywords == 0 ) { \
6617-
args.reserve(al, x.n_args); \
6618-
visit_expr_list(x.m_args, x.n_args, args); \
6619-
} \
66206690

66216691
if (AST::is_a<AST::Attribute_t>(*x.m_func)) {
6622-
parse_args()
6692+
parse_args(x, args);
66236693
AST::Attribute_t *at = AST::down_cast<AST::Attribute_t>(x.m_func);
66246694
if (AST::is_a<AST::Name_t>(*at->m_value)) {
66256695
AST::Name_t *n = AST::down_cast<AST::Name_t>(at->m_value);
@@ -6788,7 +6858,7 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
67886858
// This will all be removed once we port it to intrinsic functions
67896859
// Intrinsic functions
67906860
if (call_name == "size") {
6791-
parse_args();
6861+
parse_args(x, args);;
67926862
if( args.size() < 1 || args.size() > 2 ) {
67936863
throw SemanticError("array accepts only 1 (arr) or 2 (arr, axis) arguments, got " +
67946864
std::to_string(args.size()) + " arguments instead.",
@@ -6820,7 +6890,7 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
68206890
tmp = nullptr;
68216891
return;
68226892
} else if (call_name == "callable") {
6823-
parse_args()
6893+
parse_args(x, args);
68246894
if (args.size() != 1) {
68256895
throw SemanticError(call_name + "() takes exactly one argument (" +
68266896
std::to_string(args.size()) + " given)", x.base.base.loc);
@@ -6836,13 +6906,13 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
68366906
tmp = ASR::make_LogicalConstant_t(al, x.base.base.loc, result, type);
68376907
return;
68386908
} else if( call_name == "pointer" ) {
6839-
parse_args()
6909+
parse_args(x, args);
68406910
ASR::ttype_t *type = ASRUtils::TYPE(ASR::make_Pointer_t(al, x.base.base.loc,
68416911
ASRUtils::expr_type(args[0].m_value)));
68426912
tmp = ASR::make_GetPointer_t(al, x.base.base.loc, args[0].m_value, type, nullptr);
68436913
return ;
68446914
} else if( call_name == "array" ) {
6845-
parse_args()
6915+
parse_args(x, args);
68466916
if( args.size() != 1 ) {
68476917
throw SemanticError("array accepts only 1 argument for now, got " +
68486918
std::to_string(args.size()) + " arguments instead.",
@@ -6862,7 +6932,7 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
68626932
}
68636933
return;
68646934
} else if( call_name == "deepcopy" ) {
6865-
parse_args()
6935+
parse_args(x, args);
68666936
if( args.size() != 1 ) {
68676937
throw SemanticError("deepcopy only accepts one argument, found " +
68686938
std::to_string(args.size()) + " instead.",
@@ -6921,7 +6991,7 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
69216991
call_name == "c32" ||
69226992
call_name == "c64"
69236993
) {
6924-
parse_args()
6994+
parse_args(x, args);
69256995
ASR::ttype_t* target_type = nullptr;
69266996
if( call_name == "i8" ) {
69276997
target_type = ASRUtils::TYPE(ASR::make_Integer_t(al, x.base.base.loc, 1, nullptr, 0));
@@ -6953,7 +7023,7 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
69537023
tmp = (ASR::asr_t*) arg;
69547024
return ;
69557025
} else if (intrinsic_node_handler.is_present(call_name)) {
6956-
parse_args()
7026+
parse_args(x, args);
69577027
tmp = intrinsic_node_handler.get_intrinsic_node(call_name, al,
69587028
x.base.base.loc, args);
69597029
return;
@@ -6965,7 +7035,7 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
69657035
} // end of "comment"
69667036
}
69677037

6968-
parse_args()
7038+
parse_args(x, args);
69697039
tmp = make_call_helper(al, s, current_scope, args, call_name, x.base.base.loc,
69707040
false, x.m_args, x.n_args, x.m_keywords, x.n_keywords);
69717041
}

tests/errors/structs_03.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
from lpython import i32, dataclass
2+
3+
@dataclass
4+
class S:
5+
x: i32
6+
7+
def main0():
8+
s: S = S(y=2)
9+
10+
main0()

tests/errors/structs_04.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from lpython import i32, dataclass
2+
3+
@dataclass
4+
class S:
5+
x: i32
6+
y: i32
7+
8+
def main0():
9+
s: S = S(24, x=2)
10+
11+
main0()

tests/errors/structs_05.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from lpython import i32, dataclass
2+
3+
@dataclass
4+
class S:
5+
x: i32
6+
y: i32
7+
8+
def main0():
9+
s: S = S(2)
10+
11+
main0()

tests/errors/structs_06.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from lpython import i32, dataclass
2+
3+
@dataclass
4+
class S:
5+
x: i32
6+
y: i32
7+
8+
def main0():
9+
s: S = S(2, 3, 4, 5)
10+
11+
main0()

tests/errors/structs_07.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from lpython import i32, dataclass
2+
3+
@dataclass
4+
class S:
5+
x: i32
6+
y: i32
7+
8+
def main0():
9+
s: S = S(y=2)
10+
11+
main0()
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
{
2+
"basename": "asr-structs_03-754fb64",
3+
"cmd": "lpython --show-asr --no-color {infile} -o {outfile}",
4+
"infile": "tests/errors/structs_03.py",
5+
"infile_hash": "19180d0a7a22141e74e61452cc6cc185f1dd1c4f4315446450ce98db",
6+
"outfile": null,
7+
"outfile_hash": null,
8+
"stdout": null,
9+
"stdout_hash": null,
10+
"stderr": "asr-structs_03-754fb64.stderr",
11+
"stderr_hash": "c6410f9948863d922cb0a0cd36613c529ad45fdf556d393d36e2df07",
12+
"returncode": 2
13+
}
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
semantic error: Member 'y' not found in struct
2+
--> tests/errors/structs_03.py:8:14
3+
|
4+
8 | s: S = S(y=2)
5+
| ^^^
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
{
2+
"basename": "asr-structs_04-7b864bc",
3+
"cmd": "lpython --show-asr --no-color {infile} -o {outfile}",
4+
"infile": "tests/errors/structs_04.py",
5+
"infile_hash": "5951c49d2d7f143bbe3d67b982770ceb6d709939eb2d5ed544888f16",
6+
"outfile": null,
7+
"outfile_hash": null,
8+
"stdout": null,
9+
"stdout_hash": null,
10+
"stderr": "asr-structs_04-7b864bc.stderr",
11+
"stderr_hash": "e4e04a1a30ae38b6587c4c3ad12a7e83839c63938c025a3884f62ef8",
12+
"returncode": 2
13+
}
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
semantic error: S() got multiple values for argument 'x'
2+
--> tests/errors/structs_04.py:9:18
3+
|
4+
9 | s: S = S(24, x=2)
5+
| ^^^
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
{
2+
"basename": "asr-structs_05-a89315d",
3+
"cmd": "lpython --show-asr --no-color {infile} -o {outfile}",
4+
"infile": "tests/errors/structs_05.py",
5+
"infile_hash": "3b94e692a074b226736f068daf39c876f113277a73468bd21c01d3cc",
6+
"outfile": null,
7+
"outfile_hash": null,
8+
"stdout": null,
9+
"stdout_hash": null,
10+
"stderr": "asr-structs_05-a89315d.stderr",
11+
"stderr_hash": "227decb39171becb34a42cbdd93d96bcdd4d8c9dc5151706a74d7074",
12+
"returncode": 2
13+
}
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
semantic error: StructConstructor arguments do not match the number of struct members
2+
--> tests/errors/structs_05.py:9:12
3+
|
4+
9 | s: S = S(2)
5+
| ^^^^
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
{
2+
"basename": "asr-structs_06-6e14537",
3+
"cmd": "lpython --show-asr --no-color {infile} -o {outfile}",
4+
"infile": "tests/errors/structs_06.py",
5+
"infile_hash": "9f4273c5fb4469837f65003255dcdca067c5c17735d0642757fd069c",
6+
"outfile": null,
7+
"outfile_hash": null,
8+
"stdout": null,
9+
"stdout_hash": null,
10+
"stderr": "asr-structs_06-6e14537.stderr",
11+
"stderr_hash": "21e94af3d6a631d4871d9bc2a86200c3c3c3b661964a079105721dde",
12+
"returncode": 2
13+
}
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
semantic error: StructConstructor arguments do not match the number of struct members
2+
--> tests/errors/structs_06.py:9:12
3+
|
4+
9 | s: S = S(2, 3, 4, 5)
5+
| ^^^^^^^^^^^^^

0 commit comments

Comments
 (0)