-
Notifications
You must be signed in to change notification settings - Fork 9
Add dependency cycle checking and add non-dimensioned array handling to expression validation service #1013
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -7,6 +7,7 @@ | |
| from numbers import Number | ||
| from typing import Any, Collection, Dict, Literal, Optional, Tuple, Union | ||
|
|
||
| import numpy as np | ||
| import pydantic as pd | ||
| from unyt import unyt_array, unyt_quantity | ||
| from unyt.exceptions import UnitParseError | ||
|
|
@@ -787,16 +788,14 @@ def validate_expression(variables: list[dict], expressions: list[str]): | |
| values = [] | ||
| units = [] | ||
|
|
||
| loc = "" | ||
|
|
||
| # Populate variable scope | ||
| for i in range(len(variables)): | ||
| variable = variables[i] | ||
| loc = f"variables/{i}" | ||
| try: | ||
| variable = UserVariable(name=variable["name"], value=variable["value"]) | ||
| if variable and isinstance(variable.value, Expression): | ||
| _ = variable.value.evaluate() | ||
| _ = variable.value.evaluate(strict=False) | ||
| except (ValueError, KeyError, NameError, UnitParseError) as e: | ||
| errors.append({"loc": loc, "msg": str(e)}) | ||
|
|
||
|
|
@@ -807,15 +806,22 @@ def validate_expression(variables: list[dict], expressions: list[str]): | |
| unit = None | ||
| try: | ||
| expression_object = Expression(expression=expression) | ||
| result = expression_object.evaluate() | ||
| if isinstance(result, Number): | ||
| result = expression_object.evaluate(strict=False) | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we pass validation and evaluate to NaN (perhaps because of a solver variable) we should not throw an error, just not write an evaluated value. |
||
| if np.isnan(result): | ||
| pass | ||
| elif isinstance(result, Number): | ||
| value = result | ||
| elif isinstance(result, unyt_array): | ||
| if result.size == 1: | ||
| value = float(result.value) | ||
| else: | ||
| value = tuple(result.value.tolist()) | ||
| unit = str(result.units.expr) | ||
| elif isinstance(result, np.ndarray): | ||
| if result.size == 1: | ||
| value = float(result[0]) | ||
| else: | ||
| value = tuple(result.tolist()) | ||
| except (ValueError, KeyError, NameError, UnitParseError) as e: | ||
| errors.append({"loc": loc, "msg": str(e)}) | ||
| values.append(value) | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -1,10 +1,8 @@ | ||||||
| from __future__ import annotations | ||||||
|
|
||||||
| import re | ||||||
| from numbers import Number | ||||||
| from typing import Generic, Iterable, Optional, TypeVar | ||||||
|
|
||||||
| import pydantic as pd | ||||||
| from pydantic import BeforeValidator | ||||||
| from typing_extensions import Self | ||||||
| from unyt import Unit, unyt_array | ||||||
|
|
@@ -214,6 +212,28 @@ def update_context(cls, value): | |||||
| _user_variables.add(value.name) | ||||||
| return value | ||||||
|
|
||||||
| @pd.model_validator(mode="after") | ||||||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This runs a regular DFS traversal but by keeping track of the path we traversed we can check for cycles. If a cycle is found we throw an error and print the cycle in a readable format, e.g. |
||||||
| @classmethod | ||||||
| def check_dependencies(cls, value): | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
| visited = set() | ||||||
| stack = [(value.name, [value.name])] | ||||||
| while stack: | ||||||
| (current_name, current_path) = stack.pop() | ||||||
| current_value = _global_ctx.get(current_name) | ||||||
| if isinstance(current_value, Expression): | ||||||
| used_names = current_value.user_variable_names() | ||||||
| if [name for name in used_names if name in current_path]: | ||||||
| path_string = " -> ".join(current_path + [current_path[0]]) | ||||||
| details = InitErrorDetails( | ||||||
| type="value_error", | ||||||
| ctx={"error": f"Cyclic dependency between variables {path_string}"}, | ||||||
| ) | ||||||
| raise pd.ValidationError.from_exception_data("Variable value error", [details]) | ||||||
| stack.extend( | ||||||
| [(name, current_path + [name]) for name in used_names if name not in visited] | ||||||
| ) | ||||||
| return value | ||||||
|
|
||||||
|
|
||||||
| class SolverVariable(Variable): | ||||||
| solver_name: Optional[str] = pd.Field(None) | ||||||
|
|
@@ -273,7 +293,7 @@ def _validate_expression(cls, value) -> Self: | |||||
| details = InitErrorDetails( | ||||||
| type="value_error", ctx={"error": f"Invalid type {type(value)}"} | ||||||
| ) | ||||||
| raise pd.ValidationError.from_exception_data("expression type error", [details]) | ||||||
| raise pd.ValidationError.from_exception_data("Expression type error", [details]) | ||||||
| try: | ||||||
| expr_to_model(expression, _global_ctx) | ||||||
| except SyntaxError as s_err: | ||||||
|
|
@@ -286,7 +306,7 @@ def _validate_expression(cls, value) -> Self: | |||||
|
|
||||||
| def evaluate( | ||||||
| self, context: EvaluationContext = None, strict: bool = True | ||||||
| ) -> Union[float, list[float], unyt_array]: | ||||||
| ) -> Union[float, np.ndarray, unyt_array]: | ||||||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. After numpy interop has been implemented using regular list types is not supported to avoid confusion |
||||||
| if context is None: | ||||||
| context = _global_ctx | ||||||
| expr = expr_to_model(self.expression, context) | ||||||
|
|
@@ -296,11 +316,17 @@ def evaluate( | |||||
| def user_variables(self): | ||||||
| expr = expr_to_model(self.expression, _global_ctx) | ||||||
| names = expr.used_names() | ||||||
|
|
||||||
| names = [name for name in names if name in _user_variables] | ||||||
|
|
||||||
| return [UserVariable(name=name, value=_global_ctx.get(name)) for name in names] | ||||||
|
|
||||||
| def user_variable_names(self): | ||||||
| expr = expr_to_model(self.expression, _global_ctx) | ||||||
| names = expr.used_names() | ||||||
| names = [name for name in names if name in _user_variables] | ||||||
|
|
||||||
| return names | ||||||
|
|
||||||
| def to_solver_code(self): | ||||||
| expr = expr_to_model(self.expression, _global_ctx) | ||||||
| source = expr_to_code(expr, TargetSyntax.CPP, _solver_variables) | ||||||
|
|
||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This means that using solver variables in expressions will not throw an error. We are fine with solver variables as long as their value is not being used.