diff --git a/integration_tests/symbolics_05.py b/integration_tests/symbolics_05.py index 9214298d93..e915a7afba 100644 --- a/integration_tests/symbolics_05.py +++ b/integration_tests/symbolics_05.py @@ -1,4 +1,4 @@ -from sympy import Symbol, expand, diff +from sympy import Symbol, expand, diff, sin, cos, exp, pi from lpython import S def test_operations(): @@ -21,4 +21,16 @@ def test_operations(): print(a.diff(x)) print(diff(b, x)) + # test diff 2 + c:S = sin(x) + d:S = cos(x) + assert(sin(Symbol("x")).diff(x) == d) + assert(sin(x).diff(Symbol("x")) == d) + assert(sin(x).diff(x) == d) + assert(sin(x).diff(x).diff(x) == S(-1)*c) + assert(sin(x).expand().diff(x).diff(x) == S(-1)*c) + assert((sin(x) + cos(x)).diff(x) == S(-1)*c + d) + assert((sin(x) + cos(x) + exp(x) + pi).diff(x).expand().diff(x) == exp(x) + S(-1)*c + S(-1)*d) + + test_operations() \ No newline at end of file diff --git a/src/lpython/semantics/python_ast_to_asr.cpp b/src/lpython/semantics/python_ast_to_asr.cpp index 7ec0e032fc..36cbea9e1f 100644 --- a/src/lpython/semantics/python_ast_to_asr.cpp +++ b/src/lpython/semantics/python_ast_to_asr.cpp @@ -7166,6 +7166,27 @@ class BodyVisitor : public CommonVisitor { st = current_scope->get_symbol(call_name_store); } else { st = current_scope->resolve_symbol(mod_name); + std::set symbolic_attributes = { + "diff", "expand" + }; + std::set symbolic_constants = { + "pi" + }; + if (symbolic_attributes.find(call_name) != symbolic_attributes.end() && + symbolic_constants.find(mod_name) != symbolic_constants.end()){ + ASRUtils::create_intrinsic_function create_func; + create_func = ASRUtils::IntrinsicScalarFunctionRegistry::get_create_function(mod_name); + Vec eles; eles.reserve(al, args.size()); + Vec args_; args_.reserve(al, 1); + for (size_t i=0; ibase.base.loc, args_, + [&](const std::string &msg, const Location &loc) { + throw SemanticError(msg, loc); }); + handle_symbolic_attribute(ASRUtils::EXPR(tmp), call_name, loc, eles); + return; + } if (!st) { throw SemanticError("NameError: '" + mod_name + "' is not defined", n->base.base.loc); } @@ -7220,6 +7241,32 @@ class BodyVisitor : public CommonVisitor { ASR::expr_t* expr = ASR::down_cast(tmp); handle_builtin_attribute(expr, at->m_attr, loc, eles); return; + } else if (AST::is_a(*at->m_value)) { + AST::BinOp_t* bop = AST::down_cast(at->m_value); + std::set symbolic_attributes = { + "diff", "expand" + }; + if (symbolic_attributes.find(at->m_attr) != symbolic_attributes.end()){ + switch (bop->m_op) { + case (AST::operatorType::Add) : + case (AST::operatorType::Sub) : + case (AST::operatorType::Mult) : + case (AST::operatorType::Div) : + case (AST::operatorType::Pow) : { + visit_BinOp(*bop); + Vec eles; + eles.reserve(al, args.size()); + for (size_t i=0; im_attr, loc, eles); + return; + } + default : { + throw SemanticError("Binary operator type not supported", loc); + } + } + } } else if (AST::is_a(*at->m_value)) { if (std::string(at->m_attr) == std::string("bit_length")) { //bit_length() attribute: @@ -7241,6 +7288,41 @@ class BodyVisitor : public CommonVisitor { std::string res = n->m_value; handle_constant_string_attributes(res, args, at->m_attr, loc); return; + } else if (AST::is_a(*at->m_value)) { + AST::Call_t* call = AST::down_cast(at->m_value); + std::set symbolic_attributes = { + "diff", "expand" + }; + if (symbolic_attributes.find(at->m_attr) != symbolic_attributes.end()){ + std::set symbolic_functions = { + "sin", "cos", "log", "exp", "Abs", "Symbol" + }; + if (AST::is_a(*call->m_func)) { + visit_Call(*call); + Vec eles; + eles.reserve(al, args.size()); + for (size_t i=0; im_attr, loc, eles); + return; + } else if (AST::is_a(*call->m_func)) { + AST::Name_t *n = AST::down_cast(call->m_func); + std::string call_name = n->m_id; + if (symbolic_functions.find(call_name) != symbolic_functions.end()) { + visit_Call(*call); + Vec eles; + eles.reserve(al, args.size()); + for (size_t i=0; im_attr, loc, eles); + return; + } else { + throw SemanticError(std::string(call_name) + " not supported in Call", loc); + } + } + } } else { throw SemanticError("Only Name type and constant integers supported in Call", loc); }