Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 35 additions & 11 deletions flow360/component/simulation/framework/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@
import hashlib
import json
from itertools import chain
from typing import Any, List, Literal, get_origin
from typing import Any, List, Literal, get_origin, Set, Optional

import pydantic as pd
import rich
import yaml
from pydantic import ConfigDict
from pydantic._internal._decorators import Decorator, FieldValidatorDecoratorInfo
from pydantic_core import InitErrorDetails
from unyt import unit_registry

from flow360.component.simulation.conversion import need_conversion, unit_converter
from flow360.component.simulation.validation import validation_context
Expand Down Expand Up @@ -66,6 +67,15 @@ class Conflicts(pd.BaseModel):
field2: str


class RegistryLookup:
"""
Helper object to cache the conversion unit system registry
"""

converted_fields: Set[str] = set()
registry: Optional[unit_registry] = None


class Flow360BaseModel(pd.BaseModel):
"""Base pydantic (V2) model that all Flow360 components inherit from.
Defines configuration for handling data structures
Expand Down Expand Up @@ -552,6 +562,7 @@ def _nondimensionalization(
params,
exclude: List[str] = None,
required_by: List[str] = None,
registry_lookup: RegistryLookup = None,
) -> dict:
solver_values = {}
self_dict = self.__dict__
Expand All @@ -570,13 +581,18 @@ def _nondimensionalization(
if field is not None and field.alias is not None:
loc_name = field.alias
if need_conversion(value) and property_name not in exclude:
flow360_conv_system = unit_converter(
value.units.dimensions,
params=params,
required_by=[*required_by, loc_name],
)
# pylint: disable=no-member
value.units.registry = flow360_conv_system.registry
dimension = value.units.dimensions
if dimension not in registry_lookup.converted_fields:
flow360_conv_system = unit_converter(
value.units.dimensions,
params=params,
required_by=[*required_by, loc_name],
)
# Calling unit_converter is always additive on the global conversion system
# so we can only keep track of the most recent registry and use it
registry_lookup.registry = flow360_conv_system.registry
registry_lookup.converted_fields.add(dimension)
value.units.registry = registry_lookup.registry
solver_values[property_name] = value.in_base(unit_system="flow360_v2")
else:
solver_values[property_name] = value
Expand All @@ -589,6 +605,7 @@ def preprocess(
params=None,
exclude: List[str] = None,
required_by: List[str] = None,
registry_lookup: RegistryLookup = None,
) -> Flow360BaseModel:
"""
Loops through all fields, for Flow360BaseModel runs .preprocess() recusrively. For dimensioned value performs
Expand All @@ -609,22 +626,27 @@ def preprocess(
required_by: List[str] (optional)
Path to property which requires conversion.

registry_lookup: RegistryLookup (optional)
Lookup object that allows us to quickly perform conversions by
reducing redundant calls to the conversion system getter

Returns
-------
caller class
returns caller class with units all in flow360 base unit system
"""

if registry_lookup is None:
registry_lookup = RegistryLookup()

if exclude is None:
exclude = []

if required_by is None:
required_by = []

solver_values = self._nondimensionalization(
params=params,
exclude=exclude,
required_by=required_by,
params=params, exclude=exclude, required_by=required_by, registry_lookup=registry_lookup
)
for property_name, value in self.__dict__.items():
if property_name in exclude:
Expand All @@ -638,6 +660,7 @@ def preprocess(
params=params,
required_by=[*required_by, loc_name],
exclude=exclude,
registry_lookup=registry_lookup,
)
elif isinstance(value, list):
for i, item in enumerate(value):
Expand All @@ -646,6 +669,7 @@ def preprocess(
params=params,
required_by=[*required_by, loc_name, f"{i}"],
exclude=exclude,
registry_lookup=registry_lookup,
)

return self.__class__(**solver_values)
67 changes: 1 addition & 66 deletions flow360/component/simulation/framework/entity_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import pydantic as pd
import unyt

from flow360.component.simulation.conversion import need_conversion, unit_converter
from flow360.component.simulation.framework.base_model import Flow360BaseModel
from flow360.component.simulation.utils import is_exact_instance

Expand Down Expand Up @@ -481,70 +480,6 @@ def _get_expanded_entities(
return copy.deepcopy(expanded_entities)
return expanded_entities

# pylint: disable=too-many-locals
def _batch_preprocess(self, **kwargs):
"""
Batch preprocesses properties for all child entities that need processing.

Inspects each attribute of every stored entity. For attributes that need conversion
(as determined by conversion.need_conversion), it groups values by attribute name.

- If the value's underlying array is not already 2D (i.e. not a true batched array),
the value is grouped for batch processing.
- If the value is already a 2D array, it is marked for direct (traditional) conversion.

For batch groups, the underlying data of each unyt_array is converted to a common unit,
stacked into a single unyt_array, and then processed in one vectorized call.

For directly converted values, the conversion is applied individually.
"""
stored_entities = self.stored_entities
groups = {}
direct = {}

for idx, entity in enumerate(stored_entities):
for attr, value in entity.__dict__.items():
if need_conversion(value):
if getattr(value, "ndim", 1) == 2:
direct.setdefault(attr, []).append(idx)
else:
groups.setdefault(attr, {"indices": [], "values": []})
groups[attr]["indices"].append(idx)
groups[attr]["values"].append(value)

for attr, data in groups.items():
group_values = data["values"]
ref_unit = group_values[0].units
converted = np.empty((len(group_values), group_values[0].size))
for i, val in enumerate(group_values):
converted[i] = val.to(ref_unit).v
data["values"] = unyt.unyt_array(converted, ref_unit)

params = kwargs.get("params")
required_by = kwargs.get("required_by", [])

new_entities = [entity.__dict__.copy() for entity in stored_entities]

for attr, data in groups.items():
flow360_conv_system = unit_converter(
data["values"].units.dimensions,
params=params,
required_by=[*required_by, attr],
)
# pylint: disable=no-member
data["values"].units.registry = flow360_conv_system.registry
processed_array = data["values"].in_base(unit_system="flow360_v2")

for idx, processed_val in zip(data["indices"], processed_array):
new_entities[idx][attr] = processed_val

for attr, indices in direct.items():
for idx in indices:
new_entities[idx] = stored_entities[idx].preprocess(**kwargs)

solver_values = {"stored_entities": new_entities}
return solver_values

# pylint: disable=arguments-differ
def preprocess(self, **kwargs):
"""
Expand All @@ -553,4 +488,4 @@ def preprocess(self, **kwargs):
"""
# WARNING: this is very expensive all for long lists as it is quadratic
self.stored_entities = self._get_expanded_entities(create_hard_copy=False)
return self._batch_preprocess(**kwargs)
return super().preprocess(**kwargs)