@@ -797,9 +797,11 @@ class ExprStmtDuplicatorVisitor(ASDLVisitor):
797
797
def __init__ (self , stream , data ):
798
798
self .duplicate_stmt = []
799
799
self .duplicate_expr = []
800
+ self .duplicate_ttype = []
800
801
self .duplicate_case_stmt = []
801
802
self .is_stmt = False
802
803
self .is_expr = False
804
+ self .is_ttype = False
803
805
self .is_case_stmt = False
804
806
self .is_product = False
805
807
super (ExprStmtDuplicatorVisitor , self ).__init__ (stream , data )
@@ -834,6 +836,13 @@ def visitModule(self, mod):
834
836
self .duplicate_expr .append (("" , 0 ))
835
837
self .duplicate_expr .append ((" switch(x->type) {" , 1 ))
836
838
839
+ self .duplicate_ttype .append ((" ASR::ttype_t* duplicate_ttype(ASR::ttype_t* x) {" , 0 ))
840
+ self .duplicate_ttype .append ((" if( !x ) {" , 1 ))
841
+ self .duplicate_ttype .append ((" return nullptr;" , 2 ))
842
+ self .duplicate_ttype .append ((" }" , 1 ))
843
+ self .duplicate_ttype .append (("" , 0 ))
844
+ self .duplicate_ttype .append ((" switch(x->type) {" , 1 ))
845
+
837
846
self .duplicate_case_stmt .append ((" ASR::case_stmt_t* duplicate_case_stmt(ASR::case_stmt_t* x) {" , 0 ))
838
847
self .duplicate_case_stmt .append ((" if( !x ) {" , 1 ))
839
848
self .duplicate_case_stmt .append ((" return nullptr;" , 2 ))
@@ -858,6 +867,14 @@ def visitModule(self, mod):
858
867
self .duplicate_expr .append ((" return nullptr;" , 1 ))
859
868
self .duplicate_expr .append ((" }" , 0 ))
860
869
870
+ self .duplicate_ttype .append ((" default: {" , 2 ))
871
+ self .duplicate_ttype .append ((' LCOMPILERS_ASSERT_MSG(false, "Duplication of " + std::to_string(x->type) + " type is not supported yet.");' , 3 ))
872
+ self .duplicate_ttype .append ((" }" , 2 ))
873
+ self .duplicate_ttype .append ((" }" , 1 ))
874
+ self .duplicate_ttype .append (("" , 0 ))
875
+ self .duplicate_ttype .append ((" return nullptr;" , 1 ))
876
+ self .duplicate_ttype .append ((" }" , 0 ))
877
+
861
878
self .duplicate_case_stmt .append ((" default: {" , 2 ))
862
879
self .duplicate_case_stmt .append ((' LCOMPILERS_ASSERT_MSG(false, "Duplication of " + std::to_string(x->type) + " case statement is not supported yet.");' , 3 ))
863
880
self .duplicate_case_stmt .append ((" }" , 2 ))
@@ -872,6 +889,9 @@ def visitModule(self, mod):
872
889
for line , level in self .duplicate_expr :
873
890
self .emit (line , level = level )
874
891
self .emit ("" )
892
+ for line , level in self .duplicate_ttype :
893
+ self .emit (line , level = level )
894
+ self .emit ("" )
875
895
for line , level in self .duplicate_case_stmt :
876
896
self .emit (line , level = level )
877
897
self .emit ("" )
@@ -885,8 +905,9 @@ def visitType(self, tp):
885
905
def visitSum (self , sum , * args ):
886
906
self .is_stmt = args [0 ] == 'stmt'
887
907
self .is_expr = args [0 ] == 'expr'
908
+ self .is_ttype = args [0 ] == "ttype"
888
909
self .is_case_stmt = args [0 ] == 'case_stmt'
889
- if self .is_stmt or self .is_expr or self .is_case_stmt :
910
+ if self .is_stmt or self .is_expr or self .is_case_stmt or self . is_ttype :
890
911
for tp in sum .types :
891
912
self .visit (tp , * args )
892
913
@@ -933,6 +954,10 @@ def make_visitor(self, name, fields):
933
954
self .duplicate_expr .append ((" }" , 3 ))
934
955
self .duplicate_expr .append ((" return down_cast<ASR::expr_t>(self().duplicate_%s(down_cast<ASR::%s_t>(x)));" % (name , name ), 3 ))
935
956
self .duplicate_expr .append ((" }" , 2 ))
957
+ elif self .is_ttype :
958
+ self .duplicate_ttype .append ((" case ASR::ttypeType::%s: {" % name , 2 ))
959
+ self .duplicate_ttype .append ((" return down_cast<ASR::ttype_t>(self().duplicate_%s(down_cast<ASR::%s_t>(x)));" % (name , name ), 3 ))
960
+ self .duplicate_ttype .append ((" }" , 2 ))
936
961
elif self .is_case_stmt :
937
962
self .duplicate_case_stmt .append ((" case ASR::case_stmtType::%s: {" % name , 2 ))
938
963
self .duplicate_case_stmt .append ((" return down_cast<ASR::case_stmt_t>(self().duplicate_%s(down_cast<ASR::%s_t>(x)));" % (name , name ), 3 ))
@@ -949,7 +974,8 @@ def visitField(self, field):
949
974
field .type == "do_loop_head" or
950
975
field .type == "array_index" or
951
976
field .type == "alloc_arg" or
952
- field .type == "case_stmt" ):
977
+ field .type == "case_stmt" or
978
+ field .type == "ttype" ):
953
979
level = 2
954
980
if field .seq :
955
981
self .used = True
@@ -1107,10 +1133,12 @@ def visitField(self, field):
1107
1133
self .used = True
1108
1134
self .emit ("for (size_t i = 0; i < x->n_%s; i++) {" % field .name , level )
1109
1135
if field .type == "call_arg" :
1110
- self .emit (" ASR::expr_t** current_expr_copy_%d = current_expr;" % (self .current_expr_copy_variable_count ), level )
1111
- self .emit (" current_expr = &(x->m_%s[i].m_value);" % (field .name ), level )
1112
- self .emit (" self().replace_expr(x->m_%s[i].m_value);" % (field .name ), level )
1113
- self .emit (" current_expr = current_expr_copy_%d;" % (self .current_expr_copy_variable_count ), level )
1136
+ self .emit (" if (x->m_%s[i].m_value != nullptr) {" % (field .name ), level )
1137
+ self .emit (" ASR::expr_t** current_expr_copy_%d = current_expr;" % (self .current_expr_copy_variable_count ), level + 1 )
1138
+ self .emit (" current_expr = &(x->m_%s[i].m_value);" % (field .name ), level + 1 )
1139
+ self .emit (" self().replace_expr(x->m_%s[i].m_value);" % (field .name ), level + 1 )
1140
+ self .emit (" current_expr = current_expr_copy_%d;" % (self .current_expr_copy_variable_count ), level + 1 )
1141
+ self .emit (" }" , level )
1114
1142
self .current_expr_copy_variable_count += 1
1115
1143
self .emit ("}" , level )
1116
1144
else :
@@ -2310,6 +2338,8 @@ def make_visitor(self, name, fields):
2310
2338
LCOMPILERS_ASSERT(e->m_external);
2311
2339
LCOMPILERS_ASSERT(!ASR::is_a<ASR::ExternalSymbol_t>(*e->m_external));
2312
2340
s = e->m_external;
2341
+ } else if (s->type == ASR::symbolType::Function) {
2342
+ return ASR::down_cast<ASR::Function_t>(s)->m_function_signature;
2313
2343
}
2314
2344
return ASR::down_cast<ASR::Variable_t>(s)->m_type;
2315
2345
}""" \
@@ -2529,6 +2559,9 @@ def main(argv):
2529
2559
subs ["MOD" ] = "LPython::AST"
2530
2560
subs ["mod" ] = "ast"
2531
2561
subs ["lcompiler" ] = "lpython"
2562
+ elif subs ["MOD" ] == "AST" :
2563
+ subs ["MOD" ] = "LFortran::AST"
2564
+ subs ["lcompiler" ] = "lfortran"
2532
2565
else :
2533
2566
subs ["lcompiler" ] = "lfortran"
2534
2567
is_asr = (mod .name .upper () == "ASR" )
0 commit comments