Skip to content

Commit b46a93d

Browse files
authored
Introducing Symbolic Binary operators and i32 to S Casting function (#1964)
1 parent f2ba0ef commit b46a93d

File tree

9 files changed

+288
-85
lines changed

9 files changed

+288
-85
lines changed

integration_tests/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -580,6 +580,8 @@ RUN(NAME structs_26 LABELS cpython llvm c)
580580
RUN(NAME structs_27 LABELS cpython llvm c)
581581

582582
RUN(NAME symbolics_01 LABELS cpython_sym c_sym)
583+
RUN(NAME symbolics_02 LABELS cpython_sym c_sym)
584+
RUN(NAME symbolics_03 LABELS cpython_sym c_sym)
583585

584586
RUN(NAME sizeof_01 LABELS llvm c
585587
EXTRAFILES sizeof_01b.c)

integration_tests/symbolics_02.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
from sympy import Symbol
2+
from lpython import S
3+
4+
def test_symbolic_operations():
5+
x: S = Symbol('x')
6+
y: S = Symbol('y')
7+
8+
# Addition
9+
z: S = x + y
10+
print(z) # Expected: x + y
11+
12+
# Subtraction
13+
w: S = x - y
14+
print(w) # Expected: x - y
15+
16+
# Multiplication
17+
u: S = x * y
18+
print(u) # Expected: x*y
19+
20+
# Division
21+
v: S = x / y
22+
print(v) # Expected: x/y
23+
24+
# Power
25+
p: S = x ** y
26+
print(p) # Expected: x**y
27+
28+
# Casting
29+
a: S = S(100)
30+
b: S = S(-100)
31+
c: S = a + b
32+
print(c) # Expected: 0
33+
34+
test_symbolic_operations()

integration_tests/symbolics_03.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
from sympy import Symbol, pi
2+
from lpython import S
3+
4+
def test_operator_chaining():
5+
w: S = S(2)
6+
x: S = Symbol('x')
7+
y: S = Symbol('y')
8+
z: S = Symbol('z')
9+
Pi: S = Symbol('pi')
10+
11+
a: S = x * w
12+
b: S = a + Pi
13+
c: S = b / z
14+
d: S = c ** w
15+
16+
print(a) # Expected: 2*x
17+
print(b) # Expected: pi + 2*x
18+
print(c) # Expected: (pi + 2*x)/z
19+
print(d) # Expected: (pi + 2*x)**2/z**2
20+
21+
test_operator_chaining()

src/libasr/ASR.asdl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -432,6 +432,7 @@ cast_kind
432432
| RealToUnsignedInteger
433433
| CPtrToUnsignedInteger
434434
| UnsignedIntegerToCPtr
435+
| IntegerToSymbolicExpression
435436

436437
dimension = (expr? start, expr? length)
437438

src/libasr/asr_utils.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include <libasr/utils.h>
99
#include <libasr/modfile.h>
1010
#include <libasr/pass/pass_utils.h>
11+
#include <libasr/pass/intrinsic_function_registry.h>
1112

1213
namespace LCompilers {
1314

@@ -1196,6 +1197,15 @@ ASR::asr_t* make_Cast_t_value(Allocator &al, const Location &a_loc,
11961197
double real = value_complex->m_re;
11971198
value = ASR::down_cast<ASR::expr_t>(
11981199
ASR::make_RealConstant_t(al, a_loc, real, a_type));
1200+
} else if (a_kind == ASR::cast_kindType::IntegerToSymbolicExpression) {
1201+
Vec<ASR::expr_t*> args;
1202+
args.reserve(al, 1);
1203+
args.push_back(al, a_arg);
1204+
LCompilers::ASRUtils::create_intrinsic_function create_function =
1205+
LCompilers::ASRUtils::IntrinsicFunctionRegistry::get_create_function("SymbolicInteger");
1206+
value = ASR::down_cast<ASR::expr_t>(create_function(al, a_loc, args,
1207+
[](const std::string&, const Location&) {
1208+
}));
11991209
}
12001210
}
12011211

src/libasr/codegen/asr_to_c_cpp.h

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1634,6 +1634,9 @@ R"(#include <stdio.h>
16341634
last_expr_precedence = 2;
16351635
break;
16361636
}
1637+
case (ASR::cast_kindType::IntegerToSymbolicExpression): {
1638+
break;
1639+
}
16371640
default : throw CodeGenError("Cast kind " + std::to_string(x.m_kind) + " not implemented",
16381641
x.base.base.loc);
16391642
}
@@ -2382,8 +2385,13 @@ R"(#include <stdio.h>
23822385
SET_INTRINSIC_NAME(Exp2, "exp2");
23832386
SET_INTRINSIC_NAME(Expm1, "expm1");
23842387
SET_INTRINSIC_NAME(SymbolicSymbol, "Symbol");
2388+
SET_INTRINSIC_NAME(SymbolicInteger, "Integer");
23852389
SET_INTRINSIC_NAME(SymbolicPi, "pi");
2386-
case (static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicAdd)): {
2390+
case (static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicAdd)):
2391+
case (static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicSub)):
2392+
case (static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicMul)):
2393+
case (static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicDiv)):
2394+
case (static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicPow)): {
23872395
LCOMPILERS_ASSERT(x.n_args == 2);
23882396
this->visit_expr(*x.m_args[0]);
23892397
std::string arg1 = src;
@@ -2404,7 +2412,8 @@ R"(#include <stdio.h>
24042412
src = out;
24052413
} else if (x.n_args == 1) {
24062414
this->visit_expr(*x.m_args[0]);
2407-
if (x.m_intrinsic_id != static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicSymbol)) {
2415+
if ((x.m_intrinsic_id != static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicSymbol)) &&
2416+
(x.m_intrinsic_id != static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicInteger))) {
24082417
out += "(" + src + ")";
24092418
src = out;
24102419
}

src/libasr/codegen/c_utils.h

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -632,6 +632,14 @@ class CCPPDSUtils {
632632
return result;
633633
}
634634

635+
std::string generate_binary_operator_code(std::string value, std::string target, std::string operatorName) {
636+
size_t delimiterPos = value.find(",");
637+
std::string leftPart = value.substr(0, delimiterPos);
638+
std::string rightPart = value.substr(delimiterPos + 1);
639+
std::string result = operatorName + "(" + target + ", " + leftPart + ", " + rightPart + ");";
640+
return result;
641+
}
642+
635643
std::string get_deepcopy_symbolic(ASR::expr_t *value_expr, std::string value, std::string target) {
636644
std::string result;
637645
if (ASR::is_a<ASR::Var_t>(*value_expr)) {
@@ -645,22 +653,43 @@ class CCPPDSUtils {
645653
break;
646654
}
647655
case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicAdd: {
648-
size_t delimiterPos = value.find(",");
649-
std::string leftPart = value.substr(0, delimiterPos);
650-
std::string rightPart = value.substr(delimiterPos + 1);
651-
result = "basic_add(" + target + ", " + leftPart + ", " + rightPart + ");";
656+
result = generate_binary_operator_code(value, target, "basic_add");
657+
break;
658+
}
659+
case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicSub: {
660+
result = generate_binary_operator_code(value, target, "basic_sub");
661+
break;
662+
}
663+
case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicMul: {
664+
result = generate_binary_operator_code(value, target, "basic_mul");
665+
break;
666+
}
667+
case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicDiv: {
668+
result = generate_binary_operator_code(value, target, "basic_div");
669+
break;
670+
}
671+
case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicPow: {
672+
result = generate_binary_operator_code(value, target, "basic_pow");
652673
break;
653674
}
654675
case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicPi: {
655676
result = "basic_const_pi(" + target + ");";
656677
break;
657678
}
679+
case LCompilers::ASRUtils::IntrinsicFunctions::SymbolicInteger: {
680+
result = "integer_set_si(" + target + ", " + value + ");";
681+
break;
682+
}
658683
default: {
659684
throw LCompilersException("IntrinsicFunction: `"
660685
+ LCompilers::ASRUtils::get_intrinsic_name(intrinsic_id)
661686
+ "` is not implemented");
662687
}
663688
}
689+
} else if (ASR::is_a<ASR::Cast_t>(*value_expr)) {
690+
ASR::Cast_t* cast_expr = ASR::down_cast<ASR::Cast_t>(value_expr);
691+
std::string cast_value_expr = get_deepcopy_symbolic(cast_expr->m_value, value, target);
692+
return cast_value_expr;
664693
}
665694
return result;
666695
}

0 commit comments

Comments
 (0)