Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 57 additions & 16 deletions flow360/component/simulation/blueprint/core/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

# pylint: disable=too-many-return-statements

from typing import Any, Callable
from typing import Any, Callable, Optional

from flow360.component.simulation.blueprint.core.expressions import (
BinOpNode,
Expand Down Expand Up @@ -38,7 +38,7 @@ def _indent(code: str, level: int = 1) -> str:
return "\n".join(spaces + line if line else line for line in code.split("\n"))


def _empty(syntax):
def _empty(syntax: TargetSyntax) -> str:
if syntax == TargetSyntax.PYTHON:
return "None"
if syntax == TargetSyntax.CPP:
Expand All @@ -49,27 +49,31 @@ def _empty(syntax):
)


def _name(expr, name_translator):
def _name(expr: NameNode, name_translator: Optional[Callable[[str], str]]) -> str:
if name_translator:
return name_translator(expr.id)
return expr.id


def _constant(expr):
def _constant(expr: ConstantNode) -> str:
if isinstance(expr.value, str):
return f"'{expr.value}'"
return str(expr.value)


def _unary_op(expr, syntax, name_translator):
def _unary_op(
expr: UnaryOpNode, syntax: TargetSyntax, name_translator: Optional[Callable[[str], str]]
) -> str:
op_info = UNARY_OPERATORS[expr.op]

arg = expr_to_code(expr.operand, syntax, name_translator)

return f"{op_info.symbol}{arg}"


def _binary_op(expr, syntax, name_translator):
def _binary_op(
expr: BinOpNode, syntax: TargetSyntax, name_translator: Optional[Callable[[str], str]]
) -> str:
left = expr_to_code(expr.left, syntax, name_translator)
right = expr_to_code(expr.right, syntax, name_translator)

Expand All @@ -86,15 +90,19 @@ def _binary_op(expr, syntax, name_translator):
return f"({left} {op_info.symbol} {right})"


def _range_call(expr, syntax, name_translator):
def _range_call(
expr: RangeCallNode, syntax: TargetSyntax, name_translator: Optional[Callable[[str], str]]
) -> str:
if syntax == TargetSyntax.PYTHON:
arg = expr_to_code(expr.arg, syntax, name_translator)
return f"range({arg})"

raise ValueError("Range calls are only supported for Python target syntax")


def _call_model(expr, syntax, name_translator):
def _call_model(
expr: CallModelNode, syntax: TargetSyntax, name_translator: Optional[Callable[[str], str]]
) -> str:
if syntax == TargetSyntax.PYTHON:
args = []
for arg in expr.args:
Expand Down Expand Up @@ -128,7 +136,9 @@ def _call_model(expr, syntax, name_translator):
)


def _tuple(expr, syntax, name_translator):
def _tuple(
expr: TupleNode, syntax: TargetSyntax, name_translator: Optional[Callable[[str], str]]
) -> str:
elements = [expr_to_code(e, syntax, name_translator) for e in expr.elements]

if syntax == TargetSyntax.PYTHON:
Expand All @@ -147,7 +157,9 @@ def _tuple(expr, syntax, name_translator):
)


def _list(expr, syntax, name_translator):
def _list(
expr: ListNode, syntax: TargetSyntax, name_translator: Optional[Callable[[str], str]]
) -> str:
elements = [expr_to_code(e, syntax, name_translator) for e in expr.elements]

if syntax == TargetSyntax.PYTHON:
Expand All @@ -166,7 +178,9 @@ def _list(expr, syntax, name_translator):
)


def _list_comp(expr, syntax, name_translator):
def _list_comp(
expr: ListCompNode, syntax: TargetSyntax, name_translator: Optional[Callable[[str], str]]
) -> str:
if syntax == TargetSyntax.PYTHON:
element = expr_to_code(expr.element, syntax, name_translator)
target = expr_to_code(expr.target, syntax, name_translator)
Expand All @@ -177,14 +191,39 @@ def _list_comp(expr, syntax, name_translator):
raise ValueError("List comprehensions are only supported for Python target syntax")


def _subscript(expr, syntax, name_translator): # pylint:disable=unused-argument
return f"{name_translator(expr.value.id)}[{expr.slice.value}]"
def _subscript(
expr: SubscriptNode, syntax: TargetSyntax, name_translator: Optional[Callable[[str], str]]
) -> str: # pylint:disable=unused-argument
# Generate code for the value and the index recursively.
base = expr_to_code(expr.value, syntax, name_translator)
index = expr_to_code(expr.slice, syntax, name_translator)

# Push the subscript into the right-hand side of a multiplication
# to generate valid scalar*vector component access like: ((scalar * vector[index])).
if syntax == TargetSyntax.CPP and isinstance(expr.value, BinOpNode):
op = expr.value.op
if op == "Mult":
left = expr_to_code(expr.value.left, syntax, name_translator)
right = expr_to_code(expr.value.right, syntax, name_translator)
symbol = BINARY_OPERATORS[op].symbol
# Avoid redundant parentheses for simple names
if isinstance(expr.value.right, NameNode):
right_indexed = f"{right}[{index}]"
else:
right_indexed = f"({right})[{index}]"
return f"(({left} {symbol} {right_indexed}))"

# Parenthesize non-trivial bases to preserve precedence, e.g., (a * b)[0].
if not isinstance(expr.value, NameNode):
base = f"({base})"

return f"{base}[{index}]"


def expr_to_code(
expr: Any,
syntax: TargetSyntax = TargetSyntax.PYTHON,
name_translator: Callable[[str], str] = None,
name_translator: Optional[Callable[[str], str]] = None,
) -> str:
"""Convert an expression model back to source code."""
if expr is None:
Expand Down Expand Up @@ -225,7 +264,9 @@ def expr_to_code(


def stmt_to_code(
stmt: Any, syntax: TargetSyntax = TargetSyntax.PYTHON, remap: dict[str, str] = None
stmt: Any,
syntax: TargetSyntax = TargetSyntax.PYTHON,
remap: Optional[dict[str, str]] = None,
) -> str:
"""Convert a statement model back to source code."""
if syntax == TargetSyntax.PYTHON:
Expand Down Expand Up @@ -277,7 +318,7 @@ def stmt_to_code(
def model_to_function(
func: FunctionNode,
syntax: TargetSyntax = TargetSyntax.PYTHON,
remap: dict[str, str] = None,
remap: Optional[dict[str, str]] = None,
) -> str:
"""Convert a Function model back to source code."""
if syntax == TargetSyntax.PYTHON:
Expand Down
68 changes: 68 additions & 0 deletions tests/simulation/test_expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,74 @@ class ScalarModel(Flow360BaseModel):
assert model.scalar.evaluate() == 10


def test_subscript_on_binary_expression_codegen_cpp():
from flow360.component.simulation.blueprint.core.generator import expr_to_code
from flow360.component.simulation.blueprint.core.parser import expr_to_model
from flow360.component.simulation.blueprint.core.types import TargetSyntax
from flow360.component.simulation.user_code.core.context import default_context

# Ensure codegen supports subscript on a BinOp value, e.g., (a * b)[0]
expression_str = "(solution.pressure * solution.node_area_vector)[0]"
expr_model = expr_to_model(expression_str, default_context)
code = expr_to_code(expr_model, TargetSyntax.CPP)

# In C++ we prefer pushing the subscript into the vector operand:
# ((solution.pressure * solution.node_area_vector[0]))
assert code == "((solution.pressure * solution.node_area_vector[0]))"


def test_subscript_on_binary_expression_velocity_cpp():
from flow360.component.simulation.blueprint.core.generator import expr_to_code
from flow360.component.simulation.blueprint.core.parser import expr_to_model
from flow360.component.simulation.blueprint.core.types import TargetSyntax
from flow360.component.simulation.user_code.core.context import default_context

expression_str = "(solution.pressure * solution.velocity)[1]"
expr_model = expr_to_model(expression_str, default_context)
code = expr_to_code(expr_model, TargetSyntax.CPP)

assert code == "((solution.pressure * solution.velocity[1]))"


def test_subscript_on_binary_expression_constant_left_cpp():
from flow360.component.simulation.blueprint.core.generator import expr_to_code
from flow360.component.simulation.blueprint.core.parser import expr_to_model
from flow360.component.simulation.blueprint.core.types import TargetSyntax
from flow360.component.simulation.user_code.core.context import default_context

expression_str = "(2.0 * solution.node_area_vector)[2]"
expr_model = expr_to_model(expression_str, default_context)
code = expr_to_code(expr_model, TargetSyntax.CPP)

assert code == "((2.0 * solution.node_area_vector[2]))"


def test_subscript_on_binary_expression_dynamic_index_cpp():
from flow360.component.simulation.blueprint.core.generator import expr_to_code
from flow360.component.simulation.blueprint.core.parser import expr_to_model
from flow360.component.simulation.blueprint.core.types import TargetSyntax
from flow360.component.simulation.user_code.core.context import default_context

expression_str = "(solution.pressure * solution.node_area_vector)[control.physicalStep]"
expr_model = expr_to_model(expression_str, default_context)
code = expr_to_code(expr_model, TargetSyntax.CPP)

assert code == "((solution.pressure * solution.node_area_vector[control.physicalStep]))"


def test_subscript_on_binary_expression_with_left_parens_cpp():
from flow360.component.simulation.blueprint.core.generator import expr_to_code
from flow360.component.simulation.blueprint.core.parser import expr_to_model
from flow360.component.simulation.blueprint.core.types import TargetSyntax
from flow360.component.simulation.user_code.core.context import default_context

expression_str = "((solution.pressure + 1) * solution.node_area_vector)[2]"
expr_model = expr_to_model(expression_str, default_context)
code = expr_to_code(expr_model, TargetSyntax.CPP)

assert code == "(((solution.pressure + 1) * solution.node_area_vector[2]))"


def test_error_message():
class TestModel(Flow360BaseModel):
field: ValueOrExpression[VelocityType] = pd.Field()
Expand Down
Loading