Skip to content

Commit a838e88

Browse files
authored
Merge pull request #2301 from Smit-create/flip_sign
PASS: Update FlipSign pass to use Intrinsic Function
2 parents ab5d201 + 1cb7340 commit a838e88

File tree

6 files changed

+130
-21
lines changed

6 files changed

+130
-21
lines changed

integration_tests/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -494,6 +494,7 @@ RUN(NAME expr_19 LABELS cpython llvm c)
494494
RUN(NAME expr_20 LABELS cpython llvm c)
495495
RUN(NAME expr_21 LABELS cpython llvm c)
496496
RUN(NAME expr_22 LABELS cpython llvm c)
497+
RUN(NAME expr_23 LABELS cpython llvm c)
497498

498499
RUN(NAME expr_01u LABELS cpython llvm c NOFAST)
499500
RUN(NAME expr_02u LABELS cpython llvm c NOFAST)

integration_tests/expr_23.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
from lpython import f32, i32
2+
3+
def flip_sign_check():
4+
x: f32
5+
eps: f32 = f32(1e-5)
6+
7+
number: i32 = 123
8+
x = f32(5.5)
9+
10+
if (number%2 == 1):
11+
x = -x
12+
13+
assert abs(x - f32(-5.5)) < eps
14+
15+
number = 124
16+
x = f32(5.5)
17+
18+
if (number%2 == 1):
19+
x = -x
20+
21+
assert abs(x - f32(5.5)) < eps
22+
23+
flip_sign_check()

src/libasr/pass/flip_sign.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -99,10 +99,10 @@ class FlipSignVisitor : public PassUtils::SkipOptimizationFunctionVisitor<FlipSi
9999
// xi = xor(shiftl(int(Nd),63), xi)
100100
LCOMPILERS_ASSERT(flip_sign_signal_variable);
101101
LCOMPILERS_ASSERT(flip_sign_variable);
102-
ASR::stmt_t* flip_sign_call = PassUtils::get_flipsign(flip_sign_signal_variable,
103-
flip_sign_variable, al, unit, pass_options, current_scope,
104-
[&](const std::string &msg, const Location &) { throw LCompilersException(msg); });
105-
pass_result.push_back(al, flip_sign_call);
102+
ASR::expr_t* flip_sign_result = PassUtils::get_flipsign(flip_sign_signal_variable,
103+
flip_sign_variable, al, unit, x.base.base.loc);
104+
pass_result.push_back(al, ASRUtils::STMT(ASR::make_Assignment_t(al, x.base.base.loc,
105+
flip_sign_variable, flip_sign_result, nullptr)));
106106
}
107107
}
108108

@@ -212,6 +212,8 @@ void pass_replace_flip_sign(Allocator &al, ASR::TranslationUnit_t &unit,
212212
const LCompilers::PassOptions& pass_options) {
213213
FlipSignVisitor v(al, unit, pass_options);
214214
v.visit_TranslationUnit(unit);
215+
PassUtils::UpdateDependenciesVisitor u(al);
216+
u.visit_TranslationUnit(unit);
215217
}
216218

217219

src/libasr/pass/intrinsic_function_registry.h

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ enum class IntrinsicScalarFunctions : int64_t {
4242
Exp2,
4343
Expm1,
4444
FMA,
45+
FlipSign,
4546
ListIndex,
4647
Partition,
4748
ListReverse,
@@ -95,6 +96,7 @@ inline std::string get_intrinsic_name(int x) {
9596
INTRINSIC_NAME_CASE(Exp2)
9697
INTRINSIC_NAME_CASE(Expm1)
9798
INTRINSIC_NAME_CASE(FMA)
99+
INTRINSIC_NAME_CASE(FlipSign)
98100
INTRINSIC_NAME_CASE(ListIndex)
99101
INTRINSIC_NAME_CASE(Partition)
100102
INTRINSIC_NAME_CASE(ListReverse)
@@ -1343,6 +1345,86 @@ namespace FMA {
13431345

13441346
} // namespace FMA
13451347

1348+
namespace FlipSign {
1349+
1350+
static inline void verify_args(const ASR::IntrinsicScalarFunction_t& x, diag::Diagnostics& diagnostics) {
1351+
ASRUtils::require_impl(x.n_args == 2,
1352+
"ASR Verify: Call to FlipSign must have exactly 2 arguments",
1353+
x.base.base.loc, diagnostics);
1354+
ASR::ttype_t *type1 = ASRUtils::expr_type(x.m_args[0]);
1355+
ASR::ttype_t *type2 = ASRUtils::expr_type(x.m_args[1]);
1356+
ASRUtils::require_impl((is_integer(*type1) && is_real(*type2)),
1357+
"ASR Verify: Arguments to FlipSign must be of int and real type respectively",
1358+
x.base.base.loc, diagnostics);
1359+
}
1360+
1361+
static ASR::expr_t *eval_FlipSign(Allocator &al, const Location &loc,
1362+
ASR::ttype_t* t1, Vec<ASR::expr_t*> &args) {
1363+
int a = ASR::down_cast<ASR::IntegerConstant_t>(args[0])->m_n;
1364+
double b = ASR::down_cast<ASR::RealConstant_t>(args[1])->m_r;
1365+
if (a % 2 == 1) b = -b;
1366+
return make_ConstantWithType(make_RealConstant_t, b, t1, loc);
1367+
}
1368+
1369+
static inline ASR::asr_t* create_FlipSign(Allocator& al, const Location& loc,
1370+
Vec<ASR::expr_t*>& args,
1371+
const std::function<void (const std::string &, const Location &)> err) {
1372+
if (args.size() != 2) {
1373+
err("Intrinsic FlipSign function accepts exactly 2 arguments", loc);
1374+
}
1375+
ASR::ttype_t *type1 = ASRUtils::expr_type(args[0]);
1376+
ASR::ttype_t *type2 = ASRUtils::expr_type(args[1]);
1377+
if (!ASRUtils::is_integer(*type1) || !ASRUtils::is_real(*type2)) {
1378+
err("Argument of the FlipSign function must be int and real respectively",
1379+
args[0]->base.loc);
1380+
}
1381+
ASR::expr_t *m_value = nullptr;
1382+
if (all_args_evaluated(args)) {
1383+
Vec<ASR::expr_t*> arg_values; arg_values.reserve(al, 2);
1384+
arg_values.push_back(al, expr_value(args[0]));
1385+
arg_values.push_back(al, expr_value(args[1]));
1386+
m_value = eval_FlipSign(al, loc, expr_type(args[1]), arg_values);
1387+
}
1388+
return ASR::make_IntrinsicScalarFunction_t(al, loc,
1389+
static_cast<int64_t>(IntrinsicScalarFunctions::FlipSign),
1390+
args.p, args.n, 0, ASRUtils::expr_type(args[1]), m_value);
1391+
}
1392+
1393+
static inline ASR::expr_t* instantiate_FlipSign(Allocator &al, const Location &loc,
1394+
SymbolTable *scope, Vec<ASR::ttype_t*>& arg_types, ASR::ttype_t *return_type,
1395+
Vec<ASR::call_arg_t>& new_args, int64_t /*overload_id*/) {
1396+
declare_basic_variables("_lcompilers_optimization_flipsign_" + type_to_str_python(arg_types[1]));
1397+
fill_func_arg("signal", arg_types[0]);
1398+
fill_func_arg("variable", arg_types[1]);
1399+
auto result = declare(fn_name, return_type, ReturnVar);
1400+
/*
1401+
real(real32) function flipsigni32r32(signal, variable)
1402+
integer(int32), intent(in) :: signal
1403+
real(real32), intent(out) :: variable
1404+
integer(int32) :: q
1405+
q = signal/2
1406+
flipsigni32r32 = variable
1407+
if (signal - 2*q == 1 ) flipsigni32r32 = -variable
1408+
end subroutine
1409+
*/
1410+
1411+
ASR::expr_t *two = i(2, arg_types[0]);
1412+
ASR::expr_t *q = iDiv(args[0], two);
1413+
ASR::expr_t *cond = iSub(args[0], iMul(two, q));
1414+
body.push_back(al, b.If(iEq(cond, i(1, arg_types[0])), {
1415+
b.Assignment(result, f32_neg(args[1], arg_types[1]))
1416+
}, {
1417+
b.Assignment(result, args[1])
1418+
}));
1419+
1420+
ASR::symbol_t *f_sym = make_Function_t(fn_name, fn_symtab, dep, args,
1421+
body, result, Source, Implementation, nullptr);
1422+
scope->add_symbol(fn_name, f_sym);
1423+
return b.Call(f_sym, new_args, return_type, nullptr);
1424+
}
1425+
1426+
} // namespace FlipSign
1427+
13461428
#define create_exp_macro(X, stdeval) \
13471429
namespace X { \
13481430
static inline ASR::expr_t* eval_##X(Allocator &al, const Location &loc, \
@@ -2368,6 +2450,8 @@ namespace IntrinsicScalarFunctionRegistry {
23682450
{nullptr, &UnaryIntrinsicFunction::verify_args}},
23692451
{static_cast<int64_t>(IntrinsicScalarFunctions::FMA),
23702452
{&FMA::instantiate_FMA, &FMA::verify_args}},
2453+
{static_cast<int64_t>(IntrinsicScalarFunctions::FlipSign),
2454+
{&FlipSign::instantiate_FlipSign, &FMA::verify_args}},
23712455
{static_cast<int64_t>(IntrinsicScalarFunctions::Abs),
23722456
{&Abs::instantiate_Abs, &Abs::verify_args}},
23732457
{static_cast<int64_t>(IntrinsicScalarFunctions::Partition),
@@ -2456,6 +2540,8 @@ namespace IntrinsicScalarFunctionRegistry {
24562540
"exp2"},
24572541
{static_cast<int64_t>(IntrinsicScalarFunctions::FMA),
24582542
"fma"},
2543+
{static_cast<int64_t>(IntrinsicScalarFunctions::FlipSign),
2544+
"flipsign"},
24592545
{static_cast<int64_t>(IntrinsicScalarFunctions::Expm1),
24602546
"expm1"},
24612547
{static_cast<int64_t>(IntrinsicScalarFunctions::ListIndex),

src/libasr/pass/pass_utils.cpp

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -588,24 +588,25 @@ namespace LCompilers {
588588
}
589589

590590

591-
ASR::stmt_t* get_flipsign(ASR::expr_t* arg0, ASR::expr_t* arg1,
592-
Allocator& al, ASR::TranslationUnit_t& unit,
593-
LCompilers::PassOptions& pass_options,
594-
SymbolTable*& current_scope,
595-
const std::function<void (const std::string &, const Location &)> err) {
596-
ASR::symbol_t *v = import_generic_procedure("flipsign", "lfortran_intrinsic_optimization",
597-
al, unit, pass_options, current_scope, arg0->base.loc);
591+
ASR::expr_t* get_flipsign(ASR::expr_t* arg0, ASR::expr_t* arg1,
592+
Allocator& al, ASR::TranslationUnit_t& unit, const Location& loc){
593+
ASRUtils::impl_function instantiate_function =
594+
ASRUtils::IntrinsicScalarFunctionRegistry::get_instantiate_function(
595+
static_cast<int64_t>(ASRUtils::IntrinsicScalarFunctions::FlipSign));
596+
Vec<ASR::ttype_t*> arg_types;
597+
ASR::ttype_t* type = ASRUtils::expr_type(arg1);
598+
arg_types.reserve(al, 2);
599+
arg_types.push_back(al, ASRUtils::expr_type(arg0));
600+
arg_types.push_back(al, ASRUtils::expr_type(arg1));
598601
Vec<ASR::call_arg_t> args;
599602
args.reserve(al, 2);
600603
ASR::call_arg_t arg0_, arg1_;
601604
arg0_.loc = arg0->base.loc, arg0_.m_value = arg0;
602605
args.push_back(al, arg0_);
603606
arg1_.loc = arg1->base.loc, arg1_.m_value = arg1;
604607
args.push_back(al, arg1_);
605-
return ASRUtils::STMT(
606-
ASRUtils::symbol_resolve_external_generic_procedure_without_eval(
607-
arg0->base.loc, v, args, current_scope, al,
608-
err));
608+
return instantiate_function(al, loc,
609+
unit.m_global_scope, arg_types, type, args, 0);
609610
}
610611

611612
ASR::expr_t* to_int32(ASR::expr_t* x, ASR::ttype_t* int64type, Allocator& al) {

src/libasr/pass/pass_utils.h

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -73,12 +73,8 @@ namespace LCompilers {
7373
ASR::expr_t* get_bound(ASR::expr_t* arr_expr, int dim, std::string bound,
7474
Allocator& al);
7575

76-
77-
ASR::stmt_t* get_flipsign(ASR::expr_t* arg0, ASR::expr_t* arg1,
78-
Allocator& al, ASR::TranslationUnit_t& unit,
79-
LCompilers::PassOptions& pass_options,
80-
SymbolTable*& current_scope,
81-
const std::function<void (const std::string &, const Location &)> err);
76+
ASR::expr_t* get_flipsign(ASR::expr_t* arg0, ASR::expr_t* arg1,
77+
Allocator& al, ASR::TranslationUnit_t& unit, const Location& loc);
8278

8379
ASR::expr_t* to_int32(ASR::expr_t* x, ASR::ttype_t* int32type, Allocator& al);
8480

0 commit comments

Comments
 (0)