Skip to content

Commit e8f63b0

Browse files
authored
Merge pull request #2077 from anutosh491/GSoC_PR5
Added support for symbolic Expand & Differentiation
2 parents a18690b + 315e020 commit e8f63b0

File tree

6 files changed

+164
-8
lines changed

6 files changed

+164
-8
lines changed

integration_tests/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -602,6 +602,7 @@ RUN(NAME symbolics_01 LABELS cpython_sym c_sym)
602602
RUN(NAME symbolics_02 LABELS cpython_sym c_sym)
603603
RUN(NAME symbolics_03 LABELS cpython_sym c_sym)
604604
RUN(NAME symbolics_04 LABELS cpython_sym c_sym)
605+
RUN(NAME symbolics_05 LABELS cpython_sym c_sym)
605606

606607
RUN(NAME sizeof_01 LABELS llvm c
607608
EXTRAFILES sizeof_01b.c)

integration_tests/symbolics_05.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
from sympy import Symbol, expand, diff
2+
from lpython import S
3+
4+
def test_operations():
5+
x: S = Symbol('x')
6+
y: S = Symbol('y')
7+
z: S = Symbol('z')
8+
a: S = (x + y)**S(2)
9+
b: S = (x + y + z)**S(3)
10+
11+
# test expand
12+
assert(a.expand() == S(2)*x*y + x**S(2) + y**S(2))
13+
assert(expand(b) == S(3)*x*y**S(2) + S(3)*x*z**S(2) + S(3)*x**S(2)*y + S(3)*x**S(2)*z +\
14+
S(3)*y*z**S(2) + S(3)*y**S(2)*z + S(6)*x*y*z + x**S(3) + y**S(3) + z**S(3))
15+
print(a.expand())
16+
print(expand(b))
17+
18+
# test diff
19+
assert(a.diff(x) == S(2)*(x + y))
20+
assert(diff(b, x) == S(3)*(x + y + z)**S(2))
21+
print(a.diff(x))
22+
print(diff(b, x))
23+
24+
test_operations()

src/libasr/codegen/asr_to_c_cpp.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2755,6 +2755,10 @@ PyMODINIT_FUNC PyInit_lpython_module_)" + fn_name + R"((void) {
27552755
src = performSymbolicOperation("basic_pow", x);
27562756
return;
27572757
}
2758+
case (static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicDiff)): {
2759+
src = performSymbolicOperation("basic_diff", x);
2760+
return;
2761+
}
27582762
case (static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicPi)): {
27592763
headers.insert("symengine/cwrapper.h");
27602764
LCOMPILERS_ASSERT(x.n_args == 0);
@@ -2781,6 +2785,22 @@ PyMODINIT_FUNC PyInit_lpython_module_)" + fn_name + R"((void) {
27812785
src = target;
27822786
return;
27832787
}
2788+
case (static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicExpand)): {
2789+
headers.insert("symengine/cwrapper.h");
2790+
LCOMPILERS_ASSERT(x.n_args == 1);
2791+
std::string target = symengine_queue.push();
2792+
std::string target_src = symengine_src;
2793+
this->visit_expr(*x.m_args[0]);
2794+
std::string arg1 = src;
2795+
std::string arg1_src = symengine_src;
2796+
if (ASR::is_a<ASR::Var_t>(*x.m_args[0])) {
2797+
symengine_queue.pop();
2798+
}
2799+
symengine_src = target_src + arg1_src;
2800+
symengine_src += indent + "basic_expand(" + target + ", " + arg1 + ");\n";
2801+
src = target;
2802+
return;
2803+
}
27842804
default : {
27852805
throw LCompilersException("IntrinsicFunction: `"
27862806
+ ASRUtils::get_intrinsic_name(x.m_intrinsic_id)

src/libasr/pass/intrinsic_function_registry.h

Lines changed: 61 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,8 @@ enum class IntrinsicFunctions : int64_t {
7272
SymbolicPow,
7373
SymbolicPi,
7474
SymbolicInteger,
75+
SymbolicDiff,
76+
SymbolicExpand,
7577
Sum,
7678
// ...
7779
};
@@ -2056,7 +2058,7 @@ namespace SymbolicSymbol {
20562058

20572059
} // namespace SymbolicSymbol
20582060

2059-
#define create_symbolic_binop_macro(X) \
2061+
#define create_symbolic_binary_macro(X) \
20602062
namespace X{ \
20612063
\
20622064
static inline void verify_args(const ASR::IntrinsicFunction_t& x, \
@@ -2107,11 +2109,12 @@ namespace X{
21072109
} \
21082110
} // namespace X
21092111

2110-
create_symbolic_binop_macro(SymbolicAdd)
2111-
create_symbolic_binop_macro(SymbolicSub)
2112-
create_symbolic_binop_macro(SymbolicMul)
2113-
create_symbolic_binop_macro(SymbolicDiv)
2114-
create_symbolic_binop_macro(SymbolicPow)
2112+
create_symbolic_binary_macro(SymbolicAdd)
2113+
create_symbolic_binary_macro(SymbolicSub)
2114+
create_symbolic_binary_macro(SymbolicMul)
2115+
create_symbolic_binary_macro(SymbolicDiv)
2116+
create_symbolic_binary_macro(SymbolicPow)
2117+
create_symbolic_binary_macro(SymbolicDiff)
21152118

21162119
namespace SymbolicPi {
21172120

@@ -2166,6 +2169,46 @@ namespace SymbolicInteger {
21662169
}
21672170
} // namespace SymbolicInteger
21682171

2172+
namespace SymbolicExpand {
2173+
2174+
static inline void verify_args(const ASR::IntrinsicFunction_t& x, diag::Diagnostics& diagnostics) {
2175+
const Location& loc = x.base.base.loc;
2176+
ASRUtils::require_impl(x.n_args == 1,
2177+
"SymbolicExpand must have exactly 1 input argument",
2178+
loc, diagnostics);
2179+
2180+
ASR::ttype_t* input_type = ASRUtils::expr_type(x.m_args[0]);
2181+
ASRUtils::require_impl(ASR::is_a<ASR::SymbolicExpression_t>(*input_type),
2182+
"SymbolicExpand expects an argument of type SymbolicExpression",
2183+
x.base.base.loc, diagnostics);
2184+
}
2185+
2186+
static inline ASR::expr_t *eval_SymbolicExpand(Allocator &/*al*/,
2187+
const Location &/*loc*/, Vec<ASR::expr_t*>& /*args*/) {
2188+
// TODO
2189+
return nullptr;
2190+
}
2191+
2192+
static inline ASR::asr_t* create_SymbolicExpand(Allocator& al, const Location& loc,
2193+
Vec<ASR::expr_t*>& args,
2194+
const std::function<void (const std::string &, const Location &)> err) {
2195+
if (args.size() != 1) {
2196+
err("Intrinsic expand function accepts exactly 1 argument", loc);
2197+
}
2198+
2199+
ASR::ttype_t* argtype = ASRUtils::expr_type(args[0]);
2200+
if(!ASR::is_a<ASR::SymbolicExpression_t>(*argtype)) {
2201+
err("Argument of SymbolicExpand function must be of type SymbolicExpression",
2202+
args[0]->base.loc);
2203+
}
2204+
2205+
ASR::ttype_t *to_type = ASRUtils::TYPE(ASR::make_SymbolicExpression_t(al, loc));
2206+
return UnaryIntrinsicFunction::create_UnaryFunction(al, loc, args, eval_SymbolicExpand,
2207+
static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicExpand), 0, to_type);
2208+
}
2209+
2210+
} // namespace SymbolicExpand
2211+
21692212
namespace IntrinsicFunctionRegistry {
21702213

21712214
static const std::map<int64_t,
@@ -2228,6 +2271,10 @@ namespace IntrinsicFunctionRegistry {
22282271
{nullptr, &SymbolicPi::verify_args}},
22292272
{static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicInteger),
22302273
{nullptr, &SymbolicInteger::verify_args}},
2274+
{static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicDiff),
2275+
{nullptr, &SymbolicDiff::verify_args}},
2276+
{static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicExpand),
2277+
{nullptr, &SymbolicExpand::verify_args}},
22312278
};
22322279

22332280
static const std::map<int64_t, std::string>& intrinsic_function_id_to_name = {
@@ -2282,6 +2329,10 @@ namespace IntrinsicFunctionRegistry {
22822329
"pi"},
22832330
{static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicInteger),
22842331
"SymbolicInteger"},
2332+
{static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicDiff),
2333+
"SymbolicDiff"},
2334+
{static_cast<int64_t>(ASRUtils::IntrinsicFunctions::SymbolicExpand),
2335+
"SymbolicExpand"},
22852336
{static_cast<int64_t>(ASRUtils::IntrinsicFunctions::Any),
22862337
"any"},
22872338
{static_cast<int64_t>(ASRUtils::IntrinsicFunctions::Sum),
@@ -2319,6 +2370,8 @@ namespace IntrinsicFunctionRegistry {
23192370
{"SymbolicPow", {&SymbolicPow::create_SymbolicPow, &SymbolicPow::eval_SymbolicPow}},
23202371
{"pi", {&SymbolicPi::create_SymbolicPi, &SymbolicPi::eval_SymbolicPi}},
23212372
{"SymbolicInteger", {&SymbolicInteger::create_SymbolicInteger, &SymbolicInteger::eval_SymbolicInteger}},
2373+
{"diff", {&SymbolicDiff::create_SymbolicDiff, &SymbolicDiff::eval_SymbolicDiff}},
2374+
{"expand", {&SymbolicExpand::create_SymbolicExpand, &SymbolicExpand::eval_SymbolicExpand}},
23222375
};
23232376

23242377
static inline bool is_intrinsic_function(const std::string& name) {
@@ -2433,6 +2486,8 @@ inline std::string get_intrinsic_name(int x) {
24332486
INTRINSIC_NAME_CASE(SymbolicPow)
24342487
INTRINSIC_NAME_CASE(SymbolicPi)
24352488
INTRINSIC_NAME_CASE(SymbolicInteger)
2489+
INTRINSIC_NAME_CASE(SymbolicDiff)
2490+
INTRINSIC_NAME_CASE(SymbolicExpand)
24362491
INTRINSIC_NAME_CASE(Sum)
24372492
default : {
24382493
throw LCompilersException("pickle: intrinsic_id not implemented");

src/lpython/semantics/python_ast_to_asr.cpp

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -694,6 +694,11 @@ class CommonVisitor : public AST::BaseVisitor<Struct> {
694694
return;
695695
}
696696

697+
void handle_symbolic_attribute(ASR::expr_t *s, std::string attr_name,
698+
const Location &loc, Vec<ASR::expr_t*> &args) {
699+
tmp = attr_handler.get_symbolic_attribute(s, attr_name, al, loc, args, diag);
700+
return;
701+
}
697702

698703
void fill_expr_in_ttype_t(std::vector<ASR::expr_t*>& exprs, ASR::dimension_t* dims, size_t n_dims) {
699704
for( size_t i = 0; i < n_dims; i++ ) {
@@ -7113,6 +7118,11 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
71137118
handle_string_attributes(se, args, at->m_attr, loc);
71147119
return;
71157120
}
7121+
ASR::ttype_t *type = ASRUtils::expr_type(se);
7122+
if (ASR::is_a<ASR::SymbolicExpression_t>(*type)) {
7123+
handle_symbolic_attribute(se, at->m_attr, loc, eles);
7124+
return;
7125+
}
71167126
handle_builtin_attribute(se, at->m_attr, loc, eles);
71177127
return;
71187128
}
@@ -7231,7 +7241,7 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
72317241

72327242
if (!s) {
72337243
std::set<std::string> not_cpython_builtin = {
7234-
"sin", "cos", "gamma", "tan", "asin", "acos", "atan", "sinh", "cosh", "tanh", "exp", "exp2", "expm1", "Symbol",
7244+
"sin", "cos", "gamma", "tan", "asin", "acos", "atan", "sinh", "cosh", "tanh", "exp", "exp2", "expm1", "Symbol", "diff", "expand",
72357245
"sum" // For sum called over lists
72367246
};
72377247
if (ASRUtils::IntrinsicFunctionRegistry::is_intrinsic_function(call_name) &&

src/lpython/semantics/python_attribute_eval.h

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ struct AttributeHandler {
1515
typedef ASR::asr_t* (*attribute_eval_callback)(ASR::expr_t*, Allocator &,
1616
const Location &, Vec<ASR::expr_t*> &, diag::Diagnostics &);
1717

18-
std::map<std::string, attribute_eval_callback> attribute_map;
18+
std::map<std::string, attribute_eval_callback> attribute_map, symbolic_attribute_map;
1919
std::set<std::string> modify_attr_set;
2020

2121
AttributeHandler() {
@@ -40,6 +40,11 @@ struct AttributeHandler {
4040
modify_attr_set = {"list@append", "list@remove",
4141
"list@reverse", "list@clear", "list@insert", "list@pop",
4242
"set@pop", "set@add", "set@remove", "dict@pop"};
43+
44+
symbolic_attribute_map = {
45+
{"diff", &eval_symbolic_diff},
46+
{"expand", &eval_symbolic_expand}
47+
};
4348
}
4449

4550
std::string get_type_name(ASR::ttype_t *t) {
@@ -82,6 +87,19 @@ struct AttributeHandler {
8287
}
8388
}
8489

90+
ASR::asr_t* get_symbolic_attribute(ASR::expr_t *e, std::string attr_name,
91+
Allocator &al, const Location &loc, Vec<ASR::expr_t*> &args, diag::Diagnostics &diag) {
92+
std::string key = attr_name;
93+
auto search = symbolic_attribute_map.find(key);
94+
if (search != symbolic_attribute_map.end()) {
95+
attribute_eval_callback cb = search->second;
96+
return cb(e, al, loc, args, diag);
97+
} else {
98+
throw SemanticError("S." + attr_name + " is not implemented yet",
99+
loc);
100+
}
101+
}
102+
85103
static ASR::asr_t* eval_int_bit_length(ASR::expr_t *s, Allocator &al, const Location &loc,
86104
Vec<ASR::expr_t*> &args, diag::Diagnostics &/*diag*/) {
87105
if (args.size() != 0) {
@@ -388,6 +406,34 @@ struct AttributeHandler {
388406
return make_DictPop_t(al, loc, s, args[0], value_type, nullptr);
389407
}
390408

409+
static ASR::asr_t* eval_symbolic_diff(ASR::expr_t *s, Allocator &al, const Location &loc,
410+
Vec<ASR::expr_t*> &args, diag::Diagnostics &/*diag*/) {
411+
Vec<ASR::expr_t*> args_with_list;
412+
args_with_list.reserve(al, args.size() + 1);
413+
args_with_list.push_back(al, s);
414+
for(size_t i = 0; i < args.size(); i++) {
415+
args_with_list.push_back(al, args[i]);
416+
}
417+
ASRUtils::create_intrinsic_function create_function =
418+
ASRUtils::IntrinsicFunctionRegistry::get_create_function("diff");
419+
return create_function(al, loc, args_with_list, [&](const std::string &msg, const Location &loc)
420+
{ throw SemanticError(msg, loc); });
421+
}
422+
423+
static ASR::asr_t* eval_symbolic_expand(ASR::expr_t *s, Allocator &al, const Location &loc,
424+
Vec<ASR::expr_t*> &args, diag::Diagnostics &/*diag*/) {
425+
Vec<ASR::expr_t*> args_with_list;
426+
args_with_list.reserve(al, args.size() + 1);
427+
args_with_list.push_back(al, s);
428+
for(size_t i = 0; i < args.size(); i++) {
429+
args_with_list.push_back(al, args[i]);
430+
}
431+
ASRUtils::create_intrinsic_function create_function =
432+
ASRUtils::IntrinsicFunctionRegistry::get_create_function("expand");
433+
return create_function(al, loc, args_with_list, [&](const std::string &msg, const Location &loc)
434+
{ throw SemanticError(msg, loc); });
435+
}
436+
391437
}; // AttributeHandler
392438

393439
} // namespace LCompilers::LPython

0 commit comments

Comments
 (0)