Skip to content

Commit d0e857a

Browse files
authored
Merge pull request #2313 from Smit-create/lf_2248
Update FMA/flip_sign pass
2 parents f6a4606 + 588ef41 commit d0e857a

File tree

7 files changed

+73
-11
lines changed

7 files changed

+73
-11
lines changed

src/libasr/codegen/asr_to_llvm.cpp

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1884,6 +1884,30 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
18841884
}
18851885
break ;
18861886
}
1887+
case ASRUtils::IntrinsicScalarFunctions::FlipSign: {
1888+
Vec<ASR::call_arg_t> args;
1889+
args.reserve(al, 2);
1890+
ASR::call_arg_t arg0_, arg1_;
1891+
arg0_.loc = x.m_args[0]->base.loc, arg0_.m_value = x.m_args[0];
1892+
args.push_back(al, arg0_);
1893+
arg1_.loc = x.m_args[1]->base.loc, arg1_.m_value = x.m_args[1];
1894+
args.push_back(al, arg1_);
1895+
generate_flip_sign(args.p);
1896+
break;
1897+
}
1898+
case ASRUtils::IntrinsicScalarFunctions::FMA: {
1899+
Vec<ASR::call_arg_t> args;
1900+
args.reserve(al, 3);
1901+
ASR::call_arg_t arg0_, arg1_, arg2_;
1902+
arg0_.loc = x.m_args[0]->base.loc, arg0_.m_value = x.m_args[0];
1903+
args.push_back(al, arg0_);
1904+
arg1_.loc = x.m_args[1]->base.loc, arg1_.m_value = x.m_args[1];
1905+
args.push_back(al, arg1_);
1906+
arg2_.loc = x.m_args[2]->base.loc, arg2_.m_value = x.m_args[2];
1907+
args.push_back(al, arg2_);
1908+
generate_fma(args.p);
1909+
break;
1910+
}
18871911
default: {
18881912
throw CodeGenError( ASRUtils::IntrinsicScalarFunctionRegistry::
18891913
get_intrinsic_function_name(x.m_intrinsic_id) +
@@ -7372,7 +7396,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
73727396
llvm::Value* int_var = builder->CreateBitCast(CreateLoad(variable), shifted_signal->getType());
73737397
tmp = builder->CreateXor(shifted_signal, int_var);
73747398
llvm::Type* variable_type = llvm_utils->get_type_from_ttype_t_util(asr_variable->m_type, module.get());
7375-
builder->CreateStore(builder->CreateBitCast(tmp, variable_type->getPointerTo()), variable);
7399+
tmp = builder->CreateBitCast(tmp, variable_type);
73767400
}
73777401

73787402
void generate_fma(ASR::call_arg_t* m_args) {
@@ -8300,7 +8324,12 @@ Result<std::unique_ptr<LLVMModule>> asr_to_llvm(ASR::TranslationUnit_t &asr,
83008324
pass_options.run_fun = run_fn;
83018325
pass_options.always_run = false;
83028326
pass_options.verbose = co.verbose;
8327+
std::vector<int64_t> skip_optimization_func_instantiation;
8328+
skip_optimization_func_instantiation.push_back(static_cast<int64_t>(ASRUtils::IntrinsicScalarFunctions::FlipSign));
8329+
skip_optimization_func_instantiation.push_back(static_cast<int64_t>(ASRUtils::IntrinsicScalarFunctions::FMA));
8330+
pass_options.skip_optimization_func_instantiation = skip_optimization_func_instantiation;
83038331
pass_manager.rtlib = co.rtlib;
8332+
83048333
pass_manager.apply_passes(al, &asr, pass_options, diagnostics);
83058334

83068335
// Uncomment for debugging the ASR after the transformation

src/libasr/pass/flip_sign.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ class FlipSignVisitor : public PassUtils::SkipOptimizationFunctionVisitor<FlipSi
100100
LCOMPILERS_ASSERT(flip_sign_signal_variable);
101101
LCOMPILERS_ASSERT(flip_sign_variable);
102102
ASR::expr_t* flip_sign_result = PassUtils::get_flipsign(flip_sign_signal_variable,
103-
flip_sign_variable, al, unit, x.base.base.loc);
103+
flip_sign_variable, al, unit, x.base.base.loc, pass_options);
104104
pass_result.push_back(al, ASRUtils::STMT(ASR::make_Assignment_t(al, x.base.base.loc,
105105
flip_sign_variable, flip_sign_result, nullptr)));
106106
}

src/libasr/pass/fma.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ class FMAVisitor : public PassUtils::SkipOptimizationFunctionVisitor<FMAVisitor>
118118
}
119119

120120
fma_var = PassUtils::get_fma(other_expr, first_arg, second_arg,
121-
al, unit, x.base.base.loc);
121+
al, unit, x.base.base.loc, pass_options);
122122
from_fma = false;
123123
}
124124

src/libasr/pass/intrinsic_function_registry.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2451,7 +2451,7 @@ namespace IntrinsicScalarFunctionRegistry {
24512451
{static_cast<int64_t>(IntrinsicScalarFunctions::FMA),
24522452
{&FMA::instantiate_FMA, &FMA::verify_args}},
24532453
{static_cast<int64_t>(IntrinsicScalarFunctions::FlipSign),
2454-
{&FlipSign::instantiate_FlipSign, &FMA::verify_args}},
2454+
{&FlipSign::instantiate_FlipSign, &FlipSign::verify_args}},
24552455
{static_cast<int64_t>(IntrinsicScalarFunctions::Abs),
24562456
{&Abs::instantiate_Abs, &Abs::verify_args}},
24572457
{static_cast<int64_t>(IntrinsicScalarFunctions::Partition),

src/libasr/pass/pass_utils.cpp

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -587,14 +587,34 @@ namespace LCompilers {
587587
int32_type, bound_type, nullptr));
588588
}
589589

590+
bool skip_instantiation(PassOptions pass_options, int64_t id) {
591+
if (!pass_options.skip_optimization_func_instantiation.empty()) {
592+
for (size_t i=0; i<pass_options.skip_optimization_func_instantiation.size(); i++) {
593+
if (pass_options.skip_optimization_func_instantiation[i] == id) {
594+
return true;
595+
}
596+
}
597+
}
598+
return false;
599+
}
590600

591601
ASR::expr_t* get_flipsign(ASR::expr_t* arg0, ASR::expr_t* arg1,
592-
Allocator& al, ASR::TranslationUnit_t& unit, const Location& loc){
602+
Allocator& al, ASR::TranslationUnit_t& unit, const Location& loc,
603+
PassOptions pass_options){
604+
ASR::ttype_t* type = ASRUtils::expr_type(arg1);
605+
int64_t fp_s = static_cast<int64_t>(ASRUtils::IntrinsicScalarFunctions::FlipSign);
606+
if (skip_instantiation(pass_options, fp_s)) {
607+
Vec<ASR::expr_t*> args;
608+
args.reserve(al, 2);
609+
args.push_back(al, arg0);
610+
args.push_back(al, arg1);
611+
return ASRUtils::EXPR(ASRUtils::make_IntrinsicScalarFunction_t_util(al, loc, fp_s,
612+
args.p, args.n, 0, type, nullptr));
613+
}
593614
ASRUtils::impl_function instantiate_function =
594615
ASRUtils::IntrinsicScalarFunctionRegistry::get_instantiate_function(
595616
static_cast<int64_t>(ASRUtils::IntrinsicScalarFunctions::FlipSign));
596617
Vec<ASR::ttype_t*> arg_types;
597-
ASR::ttype_t* type = ASRUtils::expr_type(arg1);
598618
arg_types.reserve(al, 2);
599619
arg_types.push_back(al, ASRUtils::expr_type(arg0));
600620
arg_types.push_back(al, ASRUtils::expr_type(arg1));
@@ -667,13 +687,23 @@ namespace LCompilers {
667687
}
668688

669689
ASR::expr_t* get_fma(ASR::expr_t* arg0, ASR::expr_t* arg1, ASR::expr_t* arg2,
670-
Allocator& al, ASR::TranslationUnit_t& unit, Location& loc){
671-
690+
Allocator& al, ASR::TranslationUnit_t& unit, Location& loc,
691+
PassOptions pass_options){
692+
int64_t fma_id = static_cast<int64_t>(ASRUtils::IntrinsicScalarFunctions::FMA);
693+
ASR::ttype_t* type = ASRUtils::expr_type(arg0);
694+
if (skip_instantiation(pass_options, fma_id)) {
695+
Vec<ASR::expr_t*> args;
696+
args.reserve(al, 3);
697+
args.push_back(al, arg0);
698+
args.push_back(al, arg1);
699+
args.push_back(al, arg2);
700+
return ASRUtils::EXPR(ASRUtils::make_IntrinsicScalarFunction_t_util(al, loc, fma_id,
701+
args.p, args.n, 0, type, nullptr));
702+
}
672703
ASRUtils::impl_function instantiate_function =
673704
ASRUtils::IntrinsicScalarFunctionRegistry::get_instantiate_function(
674705
static_cast<int64_t>(ASRUtils::IntrinsicScalarFunctions::FMA));
675706
Vec<ASR::ttype_t*> arg_types;
676-
ASR::ttype_t* type = ASRUtils::expr_type(arg0);
677707
arg_types.reserve(al, 3);
678708
arg_types.push_back(al, ASRUtils::expr_type(arg0));
679709
arg_types.push_back(al, ASRUtils::expr_type(arg1));

src/libasr/pass/pass_utils.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,8 @@ namespace LCompilers {
7474
Allocator& al);
7575

7676
ASR::expr_t* get_flipsign(ASR::expr_t* arg0, ASR::expr_t* arg1,
77-
Allocator& al, ASR::TranslationUnit_t& unit, const Location& loc);
77+
Allocator& al, ASR::TranslationUnit_t& unit, const Location& loc,
78+
PassOptions pass_options);
7879

7980
ASR::expr_t* to_int32(ASR::expr_t* x, ASR::ttype_t* int32type, Allocator& al);
8081

@@ -86,7 +87,8 @@ namespace LCompilers {
8687
ASR::intentType var_intent=ASR::intentType::Local);
8788

8889
ASR::expr_t* get_fma(ASR::expr_t* arg0, ASR::expr_t* arg1, ASR::expr_t* arg2,
89-
Allocator& al, ASR::TranslationUnit_t& unit, Location& loc);
90+
Allocator& al, ASR::TranslationUnit_t& unit, Location& loc,
91+
PassOptions pass_options);
9092

9193
ASR::expr_t* get_sign_from_value(ASR::expr_t* arg0, ASR::expr_t* arg1,
9294
Allocator& al, ASR::TranslationUnit_t& unit,

src/libasr/utils.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ namespace LCompilers {
9797
bool verbose = false; // For developer debugging
9898
bool pass_cumulative = false; // Apply passes cumulatively
9999
bool disable_main = false;
100+
std::vector<int64_t> skip_optimization_func_instantiation;
100101
bool module_name_mangling = false;
101102
bool global_symbols_mangling = false;
102103
bool intrinsic_symbols_mangling = false;

0 commit comments

Comments
 (0)