diff --git a/tests/test_components/test_IO.py b/tests/test_components/test_IO.py index 21d2ff8481..f5f58fe5ee 100644 --- a/tests/test_components/test_IO.py +++ b/tests/test_components/test_IO.py @@ -13,7 +13,7 @@ import tidy3d as td from tidy3d import __version__ -from tidy3d.components.data.data_array import DATA_ARRAY_MAP +from tidy3d.components.data.data_array import is_data_array_name from tidy3d.components.data.sim_data import DATA_TYPE_MAP from ..test_data.test_monitor_data import make_flux_data @@ -242,7 +242,7 @@ def test_to_json_data(): # type saved in the combined json file? data = make_flux_data() json_dict = json.loads(data._json_string) - assert json_dict["flux"] in DATA_ARRAY_MAP + assert is_data_array_name(json_dict["flux"]) def test_to_hdf5_group_path_sim_data(tmp_path): diff --git a/tests/test_components/test_custom.py b/tests/test_components/test_custom.py index 31f4df7077..4d5f421f3e 100644 --- a/tests/test_components/test_custom.py +++ b/tests/test_components/test_custom.py @@ -9,6 +9,7 @@ from pydantic import ValidationError import tidy3d as td +from tidy3d.components.data.data_array import _isinstance from tidy3d.components.data.dataset import PermittivityDataset from tidy3d.components.data.utils import UnstructuredGridDataset, _get_numpy_array from tidy3d.components.medium import ( @@ -554,7 +555,9 @@ def verify_custom_medium_methods(mat, reduced_fields): # data fields in medium classes could be SpatialArrays or 2d tuples of spatial arrays # lets convert everything into 2d tuples of spatial arrays for uniform handling - if isinstance(original, (td.SpatialDataArray, UnstructuredGridDataset)): + if _isinstance(original, td.SpatialDataArray) or isinstance( + original, UnstructuredGridDataset + ): original = [[original]] reduced = [[reduced]] @@ -562,7 +565,7 @@ def verify_custom_medium_methods(mat, reduced_fields): assert len(or_set) == len(re_set) for ind in range(len(or_set)): - if isinstance(or_set[ind], td.SpatialDataArray): + if _isinstance(or_set[ind], td.SpatialDataArray): diff = (or_set[ind] - re_set[ind]).abs assert diff.does_cover(subsection.bounds) assert np.allclose(diff, 0) diff --git a/tests/test_components/test_microwave.py b/tests/test_components/test_microwave.py index 5ea56e0349..7eaa27c1b7 100644 --- a/tests/test_components/test_microwave.py +++ b/tests/test_components/test_microwave.py @@ -13,7 +13,7 @@ from shapely import LineString import tidy3d as td -from tidy3d.components.data.data_array import FreqModeDataArray +from tidy3d.components.data.data_array import FreqModeDataArray, _isinstance from tidy3d.components.data.monitor_data import FreqDataArray from tidy3d.components.microwave.formulas.circuit_parameters import ( capacitance_colinear_cylindrical_wire_segments, @@ -440,8 +440,8 @@ def test_antenna_parameters(): ) # Test that all essential parameters exist and are correct type - assert isinstance(antenna_params.radiation_efficiency, FreqDataArray) - assert isinstance(antenna_params.reflection_efficiency, FreqDataArray) + assert _isinstance(antenna_params.radiation_efficiency, FreqDataArray) + assert _isinstance(antenna_params.reflection_efficiency, FreqDataArray) assert np.allclose(antenna_params.reflection_efficiency, 0.75) assert isinstance(antenna_params.gain, xr.DataArray) assert isinstance(antenna_params.realized_gain, xr.DataArray) diff --git a/tests/test_data/test_data_arrays.py b/tests/test_data/test_data_arrays.py index 5c47d2462a..4fba5925a8 100644 --- a/tests/test_data/test_data_arrays.py +++ b/tests/test_data/test_data_arrays.py @@ -8,10 +8,17 @@ import autograd.numpy as np import numpy import pytest +import xarray as xr import xarray.testing as xrt from autograd.test_util import check_grads +from pydantic import BaseModel, ValidationError import tidy3d as td +from tidy3d._common.components.data.data_array import DataArray +from tidy3d.components.data.data_array import ( + data_array_annotated_type, +) +from tidy3d.components.data.dataset import TimeDataset from tidy3d.exceptions import DataError np.random.seed(4) @@ -343,13 +350,93 @@ def test_abs(): def test_angle(): - # Make sure works on real data and the type is correct + # Make sure works on real data and preserves DataArray structure data = make_scalar_field_time_data_array("Ex") angle_data = data.angle - assert type(data) is type(angle_data) + assert isinstance(angle_data, xr.DataArray) + assert angle_data.dims == data.dims + assert angle_data.coords.equals(data.coords) data = make_mode_amps_data_array() angle_data = data.angle - assert type(data) is type(angle_data) + assert isinstance(angle_data, xr.DataArray) + assert angle_data.dims == data.dims + assert angle_data.coords.equals(data.coords) + + +def test_annotated_data_array_spec(): + ScalarFieldSpec = data_array_annotated_type(td.ScalarFieldDataArray) + + class Model(BaseModel): + field: ScalarFieldSpec + + data = make_scalar_field_data_array("Ex") + data_plain = xr.DataArray(data.data, coords=data.coords, dims=data.dims) + model = Model(field=data_plain) + assert model.field.dims == data.dims + assert "tidy3d.data.scalar_field" in model.model_dump_json() + + with pytest.raises(ValidationError): + Model(field=xr.DataArray(np.zeros((2, 2)), dims=("x", "y"))) + + +def test_annotated_accepts_legacy_class(): + ScalarFieldSpec = data_array_annotated_type(td.ScalarFieldDataArray) + + class Model(BaseModel): + field: ScalarFieldSpec + + data = make_scalar_field_data_array("Ex") + model = Model(field=data) + assert model.field.dims == data.dims + + +def test_legacy_data_array_shims(): + arr = xr.DataArray( + np.random.random((3, 4, 5)), + coords={ + "x": np.linspace(0, 1, 3), + "y": np.linspace(1, 2, 4), + "z": np.linspace(2, 3, 5), + }, + ) + bounds = ((0.2, 1.1, 2.1), (0.9, 1.9, 2.9)) + selected = arr.sel_inside(bounds) + assert selected.dims == arr.dims + reflected = arr.reflect(axis=0, center=-0.5, reflection_only=True) + assert reflected.dims == arr.dims + updated = arr._with_updated_data(data=np.zeros((1, 1, 1)), coords={"x": 0, "y": 1, "z": 2}) + assert updated.dims == arr.dims + + +def test_annotated_dataset_hdf5_roundtrip(tmp_path): + times = np.linspace(0, 1e-12, 4) + values = np.random.random(len(times)) + data = xr.DataArray(values, coords={"t": times}, dims=("t",)) + dataset = TimeDataset(values=data) + + path = tmp_path / "time_dataset.hdf5" + dataset.to_hdf5(path) + loaded = TimeDataset.from_hdf5(path) + + assert type(loaded.values) is DataArray + assert loaded.values.dims == data.dims + assert loaded.values.coords["t"].equals(data.coords["t"]) + + +def test_legacy_class_spec_validation(): + class Model(BaseModel): + field: data_array_annotated_type(td.ScalarFieldDataArray) + + data = xr.DataArray( + np.random.random((len(FS), 2, 3, 4)), + coords={"f": FS, "x": [0, 1], "y": [0, 1, 2], "z": [0, 1, 2, 3]}, + dims=("f", "x", "y", "z"), + ) + model = Model(field=data) + assert model.field.dims == ("x", "y", "z", "f") + + with pytest.raises(ValidationError): + Model(field=xr.DataArray(np.zeros((2, 2)), dims=("x", "y"))) def test_heat_data_array(): diff --git a/tests/test_plugins/smatrix/test_terminal_component_modeler.py b/tests/test_plugins/smatrix/test_terminal_component_modeler.py index 95bf5d92f1..60fdd90a5a 100644 --- a/tests/test_plugins/smatrix/test_terminal_component_modeler.py +++ b/tests/test_plugins/smatrix/test_terminal_component_modeler.py @@ -13,7 +13,7 @@ import tidy3d.plugins.smatrix.utils from tidy3d import SimulationDataMap from tidy3d.components.boundary import BroadbandModeABCSpec -from tidy3d.components.data.data_array import FreqDataArray +from tidy3d.components.data.data_array import FreqDataArray, _isinstance from tidy3d.exceptions import SetupError, Tidy3dError, Tidy3dKeyError from tidy3d.plugins.smatrix import ( CoaxialLumpedPort, @@ -1236,8 +1236,8 @@ def test_antenna_helpers(monkeypatch, tmp_path): # Test power wave amplitude computation a, b = modeler_data.compute_power_wave_amplitudes_at_each_port(sim_data=sim_data) - assert isinstance(a, PortDataArray) - assert isinstance(b, PortDataArray) + assert _isinstance(a, PortDataArray) + assert _isinstance(b, PortDataArray) @pytest.mark.parametrize("port_type", ["lumped", "wave"]) @@ -1288,8 +1288,8 @@ def test_antenna_parameters(monkeypatch, port_type): antenna_params = modeler_data.get_antenna_metrics_data() # Test that all essential parameters exist and are correct type - assert isinstance(antenna_params.radiation_efficiency, FreqDataArray) - assert isinstance(antenna_params.reflection_efficiency, FreqDataArray) + assert _isinstance(antenna_params.radiation_efficiency, FreqDataArray) + assert _isinstance(antenna_params.reflection_efficiency, FreqDataArray) assert isinstance(antenna_params.gain, xr.DataArray) assert isinstance(antenna_params.realized_gain, xr.DataArray) @@ -1345,8 +1345,8 @@ def test_get_combined_antenna_parameters_data(monkeypatch, tmp_path): ) # Check that essential properties exist and are correct type - assert isinstance(antenna_params.radiation_efficiency, FreqDataArray) - assert isinstance(antenna_params.reflection_efficiency, FreqDataArray) + assert _isinstance(antenna_params.radiation_efficiency, FreqDataArray) + assert _isinstance(antenna_params.reflection_efficiency, FreqDataArray) assert isinstance(antenna_params.partial_gain(), xr.Dataset) assert isinstance(antenna_params.gain, xr.DataArray) assert isinstance(antenna_params.partial_realized_gain(), xr.Dataset) diff --git a/tidy3d/_common/components/autograd/derivative_utils.py b/tidy3d/_common/components/autograd/derivative_utils.py index 5c4104d8f3..ce3aea1995 100644 --- a/tidy3d/_common/components/autograd/derivative_utils.py +++ b/tidy3d/_common/components/autograd/derivative_utils.py @@ -9,7 +9,12 @@ import xarray as xr from numpy.typing import NDArray -from tidy3d._common.components.data.data_array import FreqDataArray, ScalarFieldDataArray +from tidy3d._common.components.data.data_array import ( + FreqDataArray, + ScalarFieldDataArray, + _isinstance, + data_array_annotated_type, +) from tidy3d._common.components.types.base import ArrayLike, Bound, Complex from tidy3d._common.config import config from tidy3d._common.constants import C_0, EPSILON_0, LARGE_NUMBER, MU_0 @@ -27,7 +32,7 @@ FieldData = dict[str, ScalarFieldDataArray] PermittivityData = dict[str, ScalarFieldDataArray] -EpsType = Union[Complex, FreqDataArray] +EpsType = Union[Complex, data_array_annotated_type(FreqDataArray)] ArrayFloat = NDArray[np.floating] ArrayComplex = NDArray[np.complexfloating] @@ -706,7 +711,7 @@ def _prepare_epsilon(eps: EpsType) -> np.ndarray: For FreqDataArray, extracts values and broadcasts to shape (1, n_freqs). For scalar values, broadcasts to shape (1, 1) for consistency with multi-frequency. """ - if isinstance(eps, FreqDataArray): + if _isinstance(eps, FreqDataArray): # data is already sliced, just extract values eps_values = eps.values # shape: (n_freqs,) - need to broadcast to (1, n_freqs) @@ -812,7 +817,7 @@ def adaptive_vjp_spacing( min_allowed_spacing_fraction = config.adjoint.minimum_spacing_fraction # handle FreqDataArray or scalar eps_in - if isinstance(self.eps_in, FreqDataArray): + if _isinstance(self.eps_in, FreqDataArray): eps_real = np.asarray(self.eps_in.values, dtype=np.complex128).real else: eps_real = np.asarray(self.eps_in, dtype=np.complex128).real diff --git a/tidy3d/_common/components/base.py b/tidy3d/_common/components/base.py index c0ad4f769d..47e6a87f82 100644 --- a/tidy3d/_common/components/base.py +++ b/tidy3d/_common/components/base.py @@ -29,7 +29,12 @@ from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, field_validator, model_validator from tidy3d._common.components.autograd.utils import get_static -from tidy3d._common.components.data.data_array import DATA_ARRAY_MAP +from tidy3d._common.components.data.data_array import ( + data_array_spec_from_name, + is_data_array_name, + iter_data_array_names, + write_data_array_to_hdf5, +) from tidy3d._common.components.file_util import compress_file_to_gzip, extract_gzip_file from tidy3d._common.components.types.base import TYPE_TAG_STR, Undefined from tidy3d._common.exceptions import FileError @@ -1053,7 +1058,7 @@ def to_yaml(self, fname: PathLike) -> None: @staticmethod def _warn_if_contains_data(json_str: str) -> None: """Log a warning if the json string contains data, used in '.json' and '.yaml' file.""" - if any((key in json_str for key, _ in DATA_ARRAY_MAP.items())): + if any(name in json_str for name in iter_data_array_names()): log.warning( "Data contents found in the model to be written to file. " "Note that this data will not be included in '.json' or '.yaml' formats. " @@ -1155,7 +1160,7 @@ def dict_from_hdf5( def is_data_array(value: Any) -> bool: """Whether a value is supposed to be a data array based on the contents.""" - return isinstance(value, str) and value in DATA_ARRAY_MAP + return is_data_array_name(value) fname_path = Path(fname) @@ -1178,10 +1183,10 @@ def load_data_from_file(model_dict: dict, group_path: str = "") -> None: # write the path to the element of the json dict where the data_array should be if is_data_array(value): - data_array_type = DATA_ARRAY_MAP[value] - model_dict[key] = data_array_type.from_hdf5( - fname=fname_path, group_path=subpath - ) + spec = data_array_spec_from_name(value) + if spec is None: + raise FileError(f"Unrecognized DataArray schema '{value}'.") + model_dict[key] = spec.from_hdf5(fname=fname_path, group_path=subpath) continue # if a list, assign each element a unique key, recurse @@ -1291,7 +1296,7 @@ def add_data_to_file(data_dict: dict, group_path: str = "") -> None: # write the path to the element of the json dict where the data_array should be if isinstance(value, xr.DataArray): - value.to_hdf5(fname=f_handle, group_path=subpath) + write_data_array_to_hdf5(value, f_handle=f_handle, group_path=subpath) # if a tuple, assign each element a unique key if isinstance(value, (list, tuple)): @@ -1444,7 +1449,11 @@ def _fields_equal(a: Any, b: Any) -> bool: if a is b: return True if type(a) is not type(b): - if not (isinstance(a, (list, tuple)) and isinstance(b, (list, tuple))): + if isinstance(a, (xr.DataArray, xr.Dataset)) and isinstance( + b, (xr.DataArray, xr.Dataset) + ): + pass + elif not (isinstance(a, (list, tuple)) and isinstance(b, (list, tuple))): return False if isinstance(a, np.ndarray): return np.array_equal(a, b) diff --git a/tidy3d/_common/components/data/data_array.py b/tidy3d/_common/components/data/data_array.py index 0f14b5e984..1fc7ceb66e 100644 --- a/tidy3d/_common/components/data/data_array.py +++ b/tidy3d/_common/components/data/data_array.py @@ -2,9 +2,14 @@ from __future__ import annotations +import functools import pathlib +import re +import warnings from abc import ABC -from typing import TYPE_CHECKING, Any +from collections.abc import Mapping +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Annotated, Any, ParamSpec, TypeVar import autograd.numpy as anp import h5py @@ -27,12 +32,11 @@ RADIAN, SECOND, ) -from tidy3d._common.exceptions import DataError, FileError +from tidy3d._common.exceptions import DataError if TYPE_CHECKING: - from collections.abc import Mapping from os import PathLike - from typing import Optional, Union + from typing import Callable, Literal, Optional, Union from numpy.typing import NDArray from pydantic.annotated_handlers import GetCoreSchemaHandler @@ -71,8 +75,246 @@ # name of the DataArray.values in the hdf5 file (xarray's default name too) DATA_ARRAY_VALUE_NAME = "__xarray_dataarray_variable__" -DATA_ARRAY_MAP: dict[str, type[DataArray]] = {} -DATA_ARRAY_TYPES: list[type[DataArray]] = [] +# Toggle for emitting deprecation warnings from legacy xarray shims. +LEGACY_SHIM_WARNINGS = True +# Toggle for installing legacy xarray shims on import. +LEGACY_SHIM_ENABLED = True + + +@dataclass(frozen=True) +class DataArraySpec: + """Declarative schema for an ``xarray.DataArray`` field.""" + + id: str + dims: tuple[str, ...] + data_attrs: Mapping[str, Any] = field(default_factory=dict) + coord_attrs: Mapping[str, Mapping[str, Any]] = field(default_factory=dict) + require_unique_coords: bool = False + + def __get_pydantic_core_schema__( + self, source_type: Any, handler: GetCoreSchemaHandler + ) -> core_schema.CoreSchema: + return core_schema.with_info_after_validator_function( + self._validate, + core_schema.any_schema(), + serialization=core_schema.plain_serializer_function_ser_schema( + self._serialize, info_arg=True, when_used="json" + ), + ) + + def __get_pydantic_json_schema__( + self, schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler + ) -> JsonSchemaValue: + json_schema = handler(schema) + json_schema.update( + { + "title": "xarray.DataArray", + "type": "object", + "td_schema": self.id, + "td_dims": list(self.dims), + } + ) + return json_schema + + def _validate(self, value: Any, info: core_schema.ValidationInfo) -> xr.DataArray: + data_array = self._coerce_to_dataarray(value, info) + return self.validate_data_array(data_array) + + def _serialize(self, value: xr.DataArray, info: core_schema.SerializationInfo) -> str: + if isinstance(value, xr.DataArray): + schema = value.attrs.get("_td_schema") + if isinstance(schema, str): + return schema + # Preserve existing JSON placeholder behavior by default. + return self.id + + def _coerce_to_dataarray(self, value: Any, info: core_schema.ValidationInfo) -> xr.DataArray: + if isinstance(value, DataArray): + return value + + if isinstance(value, xr.DataArray): + return DataArray( + value.data, + coords=value.coords, + dims=value.dims, + name=value.name, + attrs=dict(value.attrs), + ) + + if isinstance(value, str) and is_data_array_name(value): + raise DataError( + "Trying to load a DataArray from a string placeholder but the data is missing. " + "DataArrays are not typically stored in JSON. Load from HDF5 or ensure the " + "DataArray object is provided." + ) + + if isinstance(value, Mapping) and "__td_dataarray__" in value: + payload = value.get("__td_dataarray__", {}) + schema = payload.get("schema") + if schema != self.id: + raise ValueError(f"schema mismatch: expected {self.id!r}, got {schema!r}") + inline = payload.get("inline") + if isinstance(inline, Mapping): + return self._from_inline(inline) + + raise ValueError("unsupported DataArray payload; missing inline data") + + raise ValueError("expected an xarray.DataArray or serialized DataArray payload") + + def _from_inline(self, inline: Mapping[str, Any]) -> xr.DataArray: + dims = inline.get("dims", self.dims) + if isinstance(dims, str): + dims = (dims,) + dims = tuple(dims) + coords = dict(inline.get("coords", {})) + data = np.asarray(inline.get("data")) + return DataArray(data, coords=coords, dims=dims) + + def from_hdf5(self, fname: PathLike, group_path: str) -> xr.DataArray: + """Load a DataArray from an hdf5 file using this spec's dimensions.""" + path = pathlib.Path(fname) + with h5py.File(path, "r") as f: + sub_group = f[group_path] + values = np.array(sub_group[DATA_ARRAY_VALUE_NAME]) + coords = {dim: np.array(sub_group[dim]) for dim in self.dims if dim in sub_group} + for key, val in coords.items(): + if val.dtype == "O": + coords[key] = [byte_string.decode() for byte_string in val.tolist()] + data_array = DataArray(values, coords=coords, dims=self.dims) + return self.validate_data_array(data_array) + + def validate_data_array(self, data_array: xr.DataArray) -> xr.DataArray: + expected = tuple(self.dims) + given = tuple(str(d) for d in data_array.dims) + if set(given) != set(expected): + raise ValueError(f"wrong dims: expected {expected}, got {given}") + + if given != expected: + data_array = data_array.transpose(*expected) + + data_array = data_array.copy(deep=False) + + if self.data_attrs: + data_array.attrs.update(self.data_attrs) + for dim, attrs in self.coord_attrs.items(): + if dim in data_array.coords: + data_array.coords[dim].attrs.update(attrs) + + data_array.attrs["_td_schema"] = self.id + + if self.require_unique_coords: + for dim in expected: + if data_array.coords[dim].to_index().duplicated().any(): + raise ValueError(f"duplicate coordinates in dimension {dim!r}") + + if type(data_array) is not DataArray: + data_array = DataArray( + data_array.data, + coords=data_array.coords, + dims=data_array.dims, + name=data_array.name, + attrs=dict(data_array.attrs), + ) + + return data_array + + def matches(self, data_array: xr.DataArray) -> bool: + try: + self.validate_data_array(data_array) + except ValueError: + return False + return True + + +DATA_ARRAY_SPEC_MAP: dict[str, DataArraySpec] = {} +DATA_ARRAY_SCHEMA_MAP: dict[str, type[DataArray]] = {} + + +def _camel_to_snake(name: str) -> str: + step1 = re.sub(r"(.)([A-Z][a-z]+)", r"\1_\2", name) + step2 = re.sub(r"([a-z0-9])([A-Z])", r"\1_\2", step1) + return step2.lower() + + +def _default_schema_id(data_array_type: type[DataArray]) -> str: + override = getattr(data_array_type, "_schema_id", None) + if override: + return override + base = data_array_type.__name__ + if base == "DataArray": + return "tidy3d.data.data_array" + if base.endswith("DataArray"): + base = base[: -len("DataArray")] + if not base: + return "tidy3d.data.data_array" + return f"tidy3d.data.{_camel_to_snake(base)}" + + +def _default_spec_for_type(data_array_type: type[DataArray]) -> DataArraySpec: + dims = tuple(getattr(data_array_type, "_dims", ())) + data_attrs = dict(getattr(data_array_type, "_data_attrs", {})) + coord_attrs = {dim: DIM_ATTRS[dim] for dim in dims if dim in DIM_ATTRS} + return DataArraySpec( + id=_default_schema_id(data_array_type), + dims=dims, + data_attrs=data_attrs, + coord_attrs=coord_attrs, + ) + + +def data_array_spec_for_type(data_array_type: type[DataArray]) -> DataArraySpec: + spec = data_array_type.__spec__ + if not isinstance(spec, DataArraySpec): + raise TypeError(f"{data_array_type.__name__} is missing a DataArraySpec.") + return spec + + +def data_array_annotated_type(data_array_type: type[DataArray]) -> Any: + """Return an ``Annotated[DataArray, DataArraySpec]`` alias for a DataArray class.""" + return Annotated[DataArray, data_array_spec_for_type(data_array_type)] + + +def _isinstance(value: Any, data_array_type: type[DataArray]) -> bool: + """Spec-based check that replaces subclass ``isinstance`` usage.""" + if not isinstance(value, xr.DataArray): + return False + spec = data_array_spec_for_type(data_array_type) + return spec.matches(value) + + +def register_data_array_spec(spec: DataArraySpec, data_array_type: type[DataArray]) -> None: + """Register a DataArraySpec for schema lookup and legacy compatibility.""" + DATA_ARRAY_SPEC_MAP[spec.id] = spec + DATA_ARRAY_SCHEMA_MAP[spec.id] = data_array_type + + +def data_array_spec_from_name(name: str) -> DataArraySpec | None: + spec = DATA_ARRAY_SPEC_MAP.get(name) + if spec is not None: + return spec + if name.endswith("DataArray"): + base = name[: -len("DataArray")] + if base: + legacy_id = f"tidy3d.data.{_camel_to_snake(base)}" + return DATA_ARRAY_SPEC_MAP.get(legacy_id) + return None + + +def data_array_type_from_name(name: str) -> type[DataArray] | None: + spec = data_array_spec_from_name(name) + if spec is None: + return None + return DataArray + + +def iter_data_array_names() -> tuple[str, ...]: + names = list(DATA_ARRAY_SPEC_MAP.keys()) + names.extend(da_type.__name__ for da_type in DATA_ARRAY_SCHEMA_MAP.values()) + return tuple(dict.fromkeys(names)) + + +def is_data_array_name(value: Any) -> bool: + return isinstance(value, str) and data_array_spec_from_name(value) is not None class DataArray(xr.DataArray): @@ -84,13 +326,22 @@ class DataArray(xr.DataArray): _dims = () # stores a dictionary of attributes corresponding to the data values _data_attrs: dict[str, str] = {} + # optional stable schema id (defaults to class name if not set) + _schema_id: str | None = None + # schema metadata for spec-based validation + __spec__: DataArraySpec = DataArraySpec(id="tidy3d.data.data_array", dims=()) def __init_subclass__(cls, **kwargs: Any) -> None: super().__init_subclass__(**kwargs) if cls is DataArray: return - DATA_ARRAY_MAP[cls.__name__] = cls - DATA_ARRAY_TYPES.append(cls) + spec = cls.__dict__.get("__spec__") + if spec is None: + spec = _default_spec_for_type(cls) + cls.__spec__ = spec + elif not isinstance(spec, DataArraySpec): + raise TypeError(f"{cls.__name__}.__spec__ must be a DataArraySpec.") + register_data_array_spec(spec, cls) def __init__(self, data: Any, *args: Any, **kwargs: Any) -> None: # if data is a vanilla autograd box, convert to our box @@ -102,111 +353,6 @@ def __init__(self, data: Any, *args: Any, **kwargs: Any) -> None: data.data = TidyArrayBox.from_arraybox(data.data) super().__init__(data, *args, **kwargs) - @classmethod - def __get_pydantic_core_schema__( - cls, source_type: Any, handler: GetCoreSchemaHandler - ) -> core_schema.CoreSchema: - """Core schema definition for validation & serialization.""" - - def _initial_parser(value: Any) -> Self: - if isinstance(value, cls): - return value - - if isinstance(value, str) and value == cls.__name__: - raise DataError( - f"Trying to load '{cls.__name__}' from string placeholder '{value}' " - "but the actual data is missing. DataArrays are not typically stored " - "in JSON. Load from HDF5 or ensure the DataArray object is provided." - ) - - try: - instance = cls(value) - if not isinstance(instance, cls): - raise TypeError( - f"Constructor for {cls.__name__} returned unexpected type {type(instance)}" - ) - return instance - except Exception as e: - raise ValueError( - f"Could not construct '{cls.__name__}' from input of type '{type(value)}'. " - f"Ensure input is compatible with xarray.DataArray constructor. Original error: {e}" - ) from e - - validation_schema = core_schema.no_info_plain_validator_function(_initial_parser) - validation_schema = core_schema.no_info_after_validator_function( - cls._validate_dims, validation_schema - ) - validation_schema = core_schema.no_info_after_validator_function( - cls._assign_data_attrs, validation_schema - ) - validation_schema = core_schema.no_info_after_validator_function( - cls._assign_coord_attrs, validation_schema - ) - - def _serialize_to_name(instance: Self) -> str: - return type(instance).__name__ - - # serialization behavior: - # - for JSON ('json' mode), use the _serialize_to_name function. - # - for Python ('python' mode), use Pydantic's default for the object type - serialization_schema = core_schema.plain_serializer_function_ser_schema( - _serialize_to_name, - return_schema=core_schema.str_schema(), - when_used="json", - ) - - return core_schema.json_or_python_schema( - python_schema=validation_schema, - json_schema=validation_schema, # Use same validation rules for JSON input - serialization=serialization_schema, - ) - - @classmethod - def __get_pydantic_json_schema__( - cls, core_schema_obj: core_schema.CoreSchema, handler: GetJsonSchemaHandler - ) -> JsonSchemaValue: - """JSON schema definition (defines how it LOOKS in a schema, not the data).""" - return { - "type": "string", - "title": cls.__name__, - "description": ( - f"Placeholder for a '{cls.__name__}' object. Actual data is typically " - "serialized separately (e.g., via HDF5) and not embedded in JSON." - ), - } - - @classmethod - def _validate_dims(cls, val: Self) -> Self: - """Make sure the dims are the same as ``_dims``, then put them in the correct order.""" - if set(val.dims) != set(cls._dims): - raise ValueError( - f"Wrong dims for {cls.__name__}, expected '{cls._dims}', got '{val.dims}'" - ) - if val.dims != cls._dims: - val = val.transpose(*cls._dims) - return val - - @classmethod - def _assign_data_attrs(cls, val: Self) -> Self: - """Assign the correct data attributes to the :class:`.DataArray`.""" - for attr_name, attr_val in cls._data_attrs.items(): - val.attrs[attr_name] = attr_val - return val - - @classmethod - def _assign_coord_attrs(cls, val: Self) -> Self: - """Assign the correct coordinate attributes to the :class:`.DataArray`.""" - target_dims = set(val.dims) & set(cls._dims) & set(val.coords) - for dim in target_dims: - template = DIM_ATTRS.get(dim) - if not template: - continue - - coord_attrs = val.coords[dim].attrs - missing = {k: v for k, v in template.items() if coord_attrs.get(k) != v} - coord_attrs.update(missing) - return val - def _interp_validator(self, field_name: Optional[str] = None) -> None: """Ensure the data can be interpolated or selected by checking for duplicate coordinates. @@ -288,36 +434,7 @@ def to_hdf5(self, fname: Union[PathLike, h5py.File], group_path: str) -> None: def to_hdf5_handle(self, f_handle: h5py.File, group_path: str) -> None: """Save an ``xr.DataArray`` to the hdf5 file handle with a given path to the group.""" - sub_group = f_handle.create_group(group_path) - sub_group[DATA_ARRAY_VALUE_NAME] = get_static(self.data) - for key, val in self.coords.items(): - if val.dtype == " Self: - """Load a DataArray from an hdf5 file with a given path to the group.""" - path = pathlib.Path(fname) - with h5py.File(path, "r") as f: - sub_group = f[group_path] - values = np.array(sub_group[DATA_ARRAY_VALUE_NAME]) - coords = {dim: np.array(sub_group[dim]) for dim in cls._dims if dim in sub_group} - for key, val in coords.items(): - if val.dtype == "O": - coords[key] = [byte_string.decode() for byte_string in val.tolist()] - return cls(values, coords=coords, dims=cls._dims) - - @classmethod - def from_file(cls, fname: PathLike, group_path: str) -> Self: - """Load a DataArray from an hdf5 file with a given path to the group.""" - path = pathlib.Path(fname) - if not any(suffix.lower() == ".hdf5" for suffix in path.suffixes): - raise FileError( - f"'DataArray' objects must be written to '.hdf5' format. Given filename of {path}." - ) - return cls.from_hdf5(fname=path, group_path=group_path) + write_data_array_to_hdf5(self, f_handle=f_handle, group_path=group_path) def __hash__(self) -> int: """Generate hash value for a :class:`.DataArray` instance, needed for custom components.""" @@ -553,39 +670,343 @@ def _ag_interp_func( result = result.transpose(*out_dims) return result - def _with_updated_data(self, data: np.ndarray, coords: dict[str, Any]) -> DataArray: - """Make copy of ``DataArray`` with ``data`` at specified ``coords``, autograd compatible - Constraints / Edge cases: - - `coords` must map to a specific value eg {x: '1'}, does not broadcast to arrays - - `data` will be reshaped to try to match `self.shape` except where `coords` present - """ +def write_data_array_to_hdf5( + data_array: xr.DataArray, f_handle: h5py.File, group_path: str +) -> None: + """Save an ``xr.DataArray`` to an hdf5 file handle at the given group path.""" + sub_group = f_handle.create_group(group_path) + sub_group[DATA_ARRAY_VALUE_NAME] = get_static(data_array.data) + for key, val in data_array.coords.items(): + if val.dtype == " xr.DataArray: + needs_sorting = [] + for axis in "xyz": + if axis not in data_array.coords: + raise DataError( + "Spatial DataArray methods require coordinates for 'x', 'y', and 'z' dimensions." + ) + axis_coords = data_array.coords[axis].values + if len(axis_coords) > 1 and np.any(axis_coords[1:] < axis_coords[:-1]): + needs_sorting.append(axis) - # broadcast data to repeat data along the selected dimensions to match mask - new_data = new_data + np.zeros_like(old_data) + if needs_sorting: + result = data_array.sortby(needs_sorting) + return _cast_data_array(result, data_array) - new_data = np.where(mask, new_data, old_data) + return data_array + + +def _sel_inside_data_array(data_array: xr.DataArray, bounds: Bound) -> xr.DataArray: + if any(bmin > bmax for bmin, bmax in zip(*bounds)): + raise DataError( + "Min and max bounds must be packaged as '(minx, miny, minz), (maxx, maxy, maxz)'." + ) - return self.copy(deep=True, data=new_data) + sorted_data = _spatially_sorted_data_array(data_array) + inds_list = [] + + coords = (sorted_data.coords["x"], sorted_data.coords["y"], sorted_data.coords["z"]) + + for coord, smin, smax in zip(coords, bounds[0], bounds[1]): + length = len(coord) + + # one point along direction, assume invariance + if length == 1: + comp_inds = [0] + else: + # if data does not cover structure at all take the closest index + if smax < coord[0]: + comp_inds = np.arange(0, max(2, length)) + elif smin > coord[-1]: + comp_inds = np.arange(min(0, length - 2), length) + else: + if smin < coord[0]: + ind_min = 0 + else: + ind_min = max(0, (coord >= smin).argmax().data - 1) + + if smax > coord[-1]: + ind_max = length - 1 + else: + ind_max = (coord >= smax).argmax().data + + comp_inds = np.arange(ind_min, ind_max + 1) + + inds_list.append(comp_inds) + + result = sorted_data.isel(x=inds_list[0], y=inds_list[1], z=inds_list[2]) + return _cast_data_array(result, data_array) + + +def _does_cover_data_array( + data_array: xr.DataArray, bounds: Bound, rtol: float = 0.0, atol: float = 0.0 +) -> bool: + if any(bmin > bmax for bmin, bmax in zip(*bounds)): + raise DataError( + "Min and max bounds must be packaged as '(minx, miny, minz), (maxx, maxy, maxz)'." + ) + + for axis in "xyz": + if axis not in data_array.coords: + raise DataError( + "Spatial DataArray methods require coordinates for 'x', 'y', and 'z' dimensions." + ) + + xyz = [data_array.coords["x"], data_array.coords["y"], data_array.coords["z"]] + data_min = [0.0, 0.0, 0.0] + data_max = [0.0, 0.0, 0.0] + for dim in range(3): + coords = xyz[dim] + if len(coords) == 1: + data_min[dim] = bounds[0][dim] + data_max[dim] = bounds[1][dim] + else: + data_min[dim] = np.min(coords) + data_max[dim] = np.max(coords) + data_bounds = (tuple(data_min), tuple(data_max)) + return bounds_contains(data_bounds, bounds, rtol=rtol, atol=atol) + + +def _is_uniform_data_array(data_array: xr.DataArray) -> bool: + raw_data = np.asarray(data_array.data).ravel() + if raw_data.size == 0: + return True + return np.allclose(raw_data, raw_data[0]) + + +def _angle_data_array(data_array: xr.DataArray) -> xr.DataArray: + values = np.angle(np.asarray(data_array.data)) + result = xr.DataArray(values, coords=data_array.coords, dims=data_array.dims) + return _cast_data_array(result, data_array) + + +def _with_updated_data( + data_array: xr.DataArray, data: np.ndarray, coords: dict[str, Any] +) -> xr.DataArray: + mask = xr.zeros_like(data_array, dtype=bool) + mask.loc[coords] = True + + old_data = np.asarray(data_array.data) + new_shape = list(old_data.shape) + for i, dim in enumerate(data_array.dims): + if dim in coords: + new_shape[i] = 1 + try: + new_data = data.reshape(new_shape) + except ValueError as e: + raise ValueError( + "Couldn't reshape the supplied 'data' to update 'DataArray'. The provided data was " + f"of shape {data.shape} and tried to reshape to {new_shape}." + ) from e + + new_data = new_data + np.zeros_like(old_data) + updated = np.where(mask, new_data, old_data) + result = data_array.copy(deep=True, data=updated) + return _cast_data_array(result, data_array) + + +def _reflect_data_array( + data_array: xr.DataArray, axis: int, center: float, reflection_only: bool = False +) -> xr.DataArray: + sorted_data = _spatially_sorted_data_array(data_array) + + coords = [ + sorted_data.coords["x"].values, + sorted_data.coords["y"].values, + sorted_data.coords["z"].values, + ] + data = np.array(sorted_data.data) + + data_left_bound = coords[axis][0] + + if np.isclose(center, data_left_bound): + num_duplicates = 1 + elif center > data_left_bound: + raise DataError("Reflection center must be outside and to the left of the data region.") + else: + num_duplicates = 0 + + if reflection_only: + coords[axis] = 2 * center - coords[axis] + coords_dict = dict(zip("xyz", coords)) + reflected = type(sorted_data)(data, coords=coords_dict, dims=sorted_data.dims) + return _cast_data_array(reflected.sortby("xyz"[axis]), data_array) + + shape = np.array(np.shape(data)) + old_len = shape[axis] + shape[axis] = 2 * old_len - num_duplicates + + ind_left = [slice(shape[0]), slice(shape[1]), slice(shape[2])] + ind_right = [slice(shape[0]), slice(shape[1]), slice(shape[2])] + + ind_left[axis] = slice(old_len - 1, None, -1) + ind_right[axis] = slice(old_len - num_duplicates, None) + + new_data = np.zeros(shape) + + new_data[ind_left[0], ind_left[1], ind_left[2]] = data + new_data[ind_right[0], ind_right[1], ind_right[2]] = data + + new_coords = np.zeros(shape[axis]) + new_coords[old_len - num_duplicates :] = coords[axis] + new_coords[old_len - 1 :: -1] = 2 * center - coords[axis] + + coords[axis] = new_coords + coords_dict = dict(zip("xyz", coords)) + + reflected = type(sorted_data)(new_data, coords=coords_dict, dims=sorted_data.dims) + return _cast_data_array(reflected, data_array) + + +def _cast_data_array(result: xr.DataArray, reference: xr.DataArray) -> xr.DataArray: + """Preserve subclass type when possible.""" + ref_type = type(reference) + if ref_type is xr.DataArray or isinstance(result, ref_type): + return result + return ref_type( + result.data, + coords=result.coords, + dims=result.dims, + name=result.name, + attrs=dict(result.attrs), + ) + + +@xr.register_dataarray_accessor("td") +class Tidy3DAccessor: + def __init__(self, xarray_obj: xr.DataArray) -> None: + self._obj = xarray_obj + + def sel_inside(self, bounds: Bound) -> xr.DataArray: + return self._obj.sel_inside(bounds) + + def does_cover(self, bounds: Bound, rtol: float = 0.0, atol: float = 0.0) -> bool: + return self._obj.does_cover(bounds, rtol=rtol, atol=atol) + + @property + def is_uniform(self) -> bool: + return self._obj.is_uniform + + @property + def angle(self) -> xr.DataArray: + return self._obj.angle + + @property + def abs(self) -> xr.DataArray: + return self._obj.abs + + def reflect(self, axis: int, center: float, reflection_only: bool = False) -> xr.DataArray: + return self._obj.reflect(axis, center, reflection_only=reflection_only) + + def _with_updated_data(self, data: np.ndarray, coords: dict[str, Any]) -> xr.DataArray: + return _with_updated_data(self._obj, data=data, coords=coords) + + +P = ParamSpec("P") +R = TypeVar("R") + + +def legacy_da_shim( + *, + kind: Literal["method", "property"] = "method", + name: str | None = None, + new_name: str | None = None, + message: str | None = None, +) -> Callable[[Callable[..., Any]], Any]: + """ + Install a deprecated xr.DataArray shim if it doesn't already exist. + + - `name`: legacy attribute name on xr.DataArray (defaults to function name) + - `new_name`: target name under da.td.* (defaults to name with one leading '_' stripped) + - `message`: override full warning text (otherwise inferred) + """ + + def deco(func: Callable[..., Any]) -> Any: + legacy_name = name or func.__name__ + target_name = new_name or legacy_name + + if message is None: + if kind == "method": + msg = ( + f"xr.DataArray.{legacy_name}(...) is deprecated; " + f"use `da.td.{target_name}(...)` instead." + ) + else: + msg = ( + f"xr.DataArray.{legacy_name} is deprecated; use `da.td.{target_name}` instead." + ) + else: + msg = message + + if hasattr(xr.DataArray, legacy_name): + return func # don't override existing attributes + + if kind == "property": + + @functools.wraps(func) + def fget(self: xr.DataArray) -> Any: + if LEGACY_SHIM_WARNINGS: + warnings.warn(msg, DeprecationWarning, stacklevel=2) + return func(self) + + fget.__name__ = legacy_name + setattr(xr.DataArray, legacy_name, property(fget)) + return fget + + @functools.wraps(func) + def wrapper(self: xr.DataArray, *args: Any, **kwargs: Any) -> Any: + if LEGACY_SHIM_WARNINGS: + warnings.warn(msg, DeprecationWarning, stacklevel=2) + return func(self, *args, **kwargs) + + wrapper.__name__ = legacy_name + setattr(xr.DataArray, legacy_name, wrapper) + return wrapper + + return deco + + +def install_legacy_shims() -> None: + @legacy_da_shim() + def reflect( + self: xr.DataArray, axis: int, center: float, reflection_only: bool = False + ) -> xr.DataArray: + return _reflect_data_array(self, axis, center, reflection_only=reflection_only) + + @legacy_da_shim() + def sel_inside(self: xr.DataArray, bounds: Bound) -> xr.DataArray: + return _sel_inside_data_array(self, bounds) + + @legacy_da_shim() + def does_cover(self: xr.DataArray, bounds: Bound, rtol: float = 0.0, atol: float = 0.0) -> bool: + return _does_cover_data_array(self, bounds, rtol=rtol, atol=atol) + + @legacy_da_shim(kind="property") + def angle(self: xr.DataArray) -> xr.DataArray: + return _angle_data_array(self) + + @legacy_da_shim(kind="property") + def abs(self: xr.DataArray) -> xr.DataArray: + return abs(self) + + @legacy_da_shim() + def _with_updated_data( + self: xr.DataArray, data: np.ndarray, coords: dict[str, Any] + ) -> xr.DataArray: + return self.td._with_updated_data(data=data, coords=coords) + + +register_data_array_spec(DataArray.__spec__, DataArray) + + +if LEGACY_SHIM_ENABLED: + install_legacy_shims() class FreqDataArray(DataArray): @@ -611,16 +1032,7 @@ class AbstractSpatialDataArray(DataArray, ABC): @property def _spatially_sorted(self) -> Self: """Check whether sorted and sort if not.""" - needs_sorting = [] - for axis in "xyz": - axis_coords = self.coords[axis].values - if len(axis_coords) > 1 and np.any(axis_coords[1:] < axis_coords[:-1]): - needs_sorting.append(axis) - - if len(needs_sorting) > 0: - return self.sortby(needs_sorting) - - return self + return _spatially_sorted_data_array(self) def sel_inside(self, bounds: Bound) -> Self: """Return a new SpatialDataArray that contains the minimal amount data necessary to cover @@ -638,50 +1050,7 @@ def sel_inside(self, bounds: Bound) -> Self: SpatialDataArray Extracted spatial data array. """ - if any(bmin > bmax for bmin, bmax in zip(*bounds)): - raise DataError( - "Min and max bounds must be packaged as '(minx, miny, minz), (maxx, maxy, maxz)'." - ) - - # make sure data is sorted with respect to coordinates - sorted_self = self._spatially_sorted - - inds_list = [] - - coords = (sorted_self.x, sorted_self.y, sorted_self.z) - - for coord, smin, smax in zip(coords, bounds[0], bounds[1]): - length = len(coord) - - # one point along direction, assume invariance - if length == 1: - comp_inds = [0] - else: - # if data does not cover structure at all take the closest index - if smax < coord[0]: # structure is completely on the left side - # take 2 if possible, so that linear iterpolation is possible - comp_inds = np.arange(0, max(2, length)) - - elif smin > coord[-1]: # structure is completely on the right side - # take 2 if possible, so that linear iterpolation is possible - comp_inds = np.arange(min(0, length - 2), length) - - else: - if smin < coord[0]: - ind_min = 0 - else: - ind_min = max(0, (coord >= smin).argmax().data - 1) - - if smax > coord[-1]: - ind_max = length - 1 - else: - ind_max = (coord >= smax).argmax().data - - comp_inds = np.arange(ind_min, ind_max + 1) - - inds_list.append(comp_inds) - - return sorted_self.isel(x=inds_list[0], y=inds_list[1], z=inds_list[2]) + return _sel_inside_data_array(self, bounds) def does_cover(self, bounds: Bound, rtol: float = 0.0, atol: float = 0.0) -> bool: """Check whether data fully covers specified by ``bounds`` spatial region. If data contains @@ -703,23 +1072,7 @@ def does_cover(self, bounds: Bound, rtol: float = 0.0, atol: float = 0.0) -> boo bool Full cover check outcome. """ - if any(bmin > bmax for bmin, bmax in zip(*bounds)): - raise DataError( - "Min and max bounds must be packaged as '(minx, miny, minz), (maxx, maxy, maxz)'." - ) - xyz = [self.x, self.y, self.z] - self_min = [0] * 3 - self_max = [0] * 3 - for dim in range(3): - coords = xyz[dim] - if len(coords) == 1: - self_min[dim] = bounds[0][dim] - self_max[dim] = bounds[1][dim] - else: - self_min[dim] = np.min(coords) - self_max[dim] = np.max(coords) - self_bounds = (tuple(self_min), tuple(self_max)) - return bounds_contains(self_bounds, bounds, rtol=rtol, atol=atol) + return _does_cover_data_array(self, bounds, rtol=rtol, atol=atol) class ScalarFieldDataArray(AbstractSpatialDataArray): diff --git a/tidy3d/_common/components/data/dataset.py b/tidy3d/_common/components/data/dataset.py index 68a57735bb..3f821bedc8 100644 --- a/tidy3d/_common/components/data/dataset.py +++ b/tidy3d/_common/components/data/dataset.py @@ -11,9 +11,9 @@ from tidy3d._common.components.base import Tidy3dBaseModel from tidy3d._common.components.data.data_array import ( - DataArray, TimeDataArray, TriangleMeshDataArray, + data_array_annotated_type, ) from tidy3d._common.exceptions import DataError from tidy3d._common.log import log @@ -21,7 +21,7 @@ if TYPE_CHECKING: from typing import Callable - from tidy3d._common.components.data.data_array import ScalarFieldDataArray + from tidy3d._common.components.data.data_array import DataArray, ScalarFieldDataArray from tidy3d._common.components.types.base import ArrayLike, Axis DEFAULT_MAX_SAMPLES_PER_STEP = 10_000 @@ -38,7 +38,7 @@ def data_arrs(self) -> dict: data_arrs = {} for key in self.__class__.model_fields.keys(): data = getattr(self, key) - if isinstance(data, DataArray): + if isinstance(data, xr.DataArray): data_arrs[key] = data return data_arrs @@ -46,7 +46,7 @@ def data_arrs(self) -> dict: class TriangleMeshDataset(Dataset): """Dataset for storing triangular surface data.""" - surface_mesh: TriangleMeshDataArray = Field( + surface_mesh: data_array_annotated_type(TriangleMeshDataArray) = Field( title="Surface mesh data", description="Dataset containing the surface triangles and corresponding face indices " "for a surface mesh.", @@ -154,7 +154,7 @@ def colocate(self, x: ArrayLike = None, y: ArrayLike = None, z: ArrayLike = None class TimeDataset(Dataset): """Dataset for storing a function of time.""" - values: TimeDataArray = Field( + values: data_array_annotated_type(TimeDataArray) = Field( title="Values", description="Values as a function of time.", ) diff --git a/tidy3d/_common/components/data/validators.py b/tidy3d/_common/components/data/validators.py index fd7ae3d2bf..d64f6d63d3 100644 --- a/tidy3d/_common/components/data/validators.py +++ b/tidy3d/_common/components/data/validators.py @@ -6,7 +6,7 @@ import numpy as np from pydantic import field_validator -from tidy3d._common.components.data.data_array import DataArray, ScalarFieldDataArray +from tidy3d._common.components.data.data_array import DataArray, ScalarFieldDataArray, _isinstance from tidy3d._common.components.data.dataset import AbstractFieldDataset from tidy3d._common.exceptions import ValidationError @@ -78,7 +78,7 @@ def validate_can_interpolate( def check_fields_interpolate(val: AbstractFieldDataset) -> AbstractFieldDataset: if isinstance(val, AbstractFieldDataset): for name, data in val.field_components.items(): - if isinstance(data, ScalarFieldDataArray): + if _isinstance(data, ScalarFieldDataArray): data._interp_validator(name) return val diff --git a/tidy3d/_common/components/geometry/mesh.py b/tidy3d/_common/components/geometry/mesh.py index 416b9eaaf1..25b19ddefe 100644 --- a/tidy3d/_common/components/geometry/mesh.py +++ b/tidy3d/_common/components/geometry/mesh.py @@ -12,7 +12,7 @@ from tidy3d._common.components.autograd import get_static from tidy3d._common.components.base import cached_property -from tidy3d._common.components.data.data_array import DATA_ARRAY_MAP, TriangleMeshDataArray +from tidy3d._common.components.data.data_array import TriangleMeshDataArray, is_data_array_name from tidy3d._common.components.data.dataset import TriangleMeshDataset from tidy3d._common.components.data.validators import validate_no_nans from tidy3d._common.components.geometry import base @@ -67,7 +67,7 @@ def _validate_trimesh_library(cls, data: dict[str, Any]) -> dict[str, Any]: def _warn_if_none(cls, val: TriangleMeshDataset) -> TriangleMeshDataset: """Warn if the Dataset fails to load.""" if isinstance(val, dict): - if any((v in DATA_ARRAY_MAP for _, v in val.items() if isinstance(v, str))): + if any(is_data_array_name(v) for _, v in val.items() if isinstance(v, str)): log.warning("Loading 'mesh_dataset' without data.") return None return val diff --git a/tidy3d/_common/components/validators.py b/tidy3d/_common/components/validators.py index ded4659f02..1988a41fe0 100644 --- a/tidy3d/_common/components/validators.py +++ b/tidy3d/_common/components/validators.py @@ -10,7 +10,7 @@ from pydantic import field_validator from tidy3d._common.components.autograd.utils import get_static, hasbox -from tidy3d._common.components.data.data_array import DATA_ARRAY_MAP +from tidy3d._common.components.data.data_array import is_data_array_name from tidy3d._common.exceptions import ValidationError from tidy3d._common.log import log @@ -99,7 +99,7 @@ def warn_if_dataset_none( def _warn_if_none(cls: type, val: Optional[dict[str, Any]]) -> Optional[dict[str, Any]]: """Warn if the DataArrays fail to load.""" if isinstance(val, dict): - if any((v in DATA_ARRAY_MAP for _, v in val.items() if isinstance(v, str))): + if any(is_data_array_name(v) for _, v in val.items() if isinstance(v, str)): log.warning(f"Loading {field_name} without data.", custom_loc=[field_name]) return None return val diff --git a/tidy3d/components/data/data_array.py b/tidy3d/components/data/data_array.py index ab41337f5f..91c22e942d 100644 --- a/tidy3d/components/data/data_array.py +++ b/tidy3d/components/data/data_array.py @@ -10,14 +10,33 @@ import numpy as np from tidy3d._common.components.data.data_array import ( - DATA_ARRAY_MAP, - DATA_ARRAY_TYPES, + DATA_ARRAY_SCHEMA_MAP, + DATA_ARRAY_SPEC_MAP, + LEGACY_SHIM_WARNINGS, AbstractSpatialDataArray, DataArray, + DataArraySpec, FreqDataArray, ScalarFieldDataArray, TimeDataArray, TriangleMeshDataArray, + _isinstance, + _reflect_data_array, + _spatially_sorted_data_array, + data_array_annotated_type, + data_array_spec_for_type, + data_array_spec_from_name, + data_array_type_from_name, + install_legacy_shims, + is_data_array_name, + iter_data_array_names, + # td_abs, + # td_angle, + # td_does_cover, + # td_reflect, + # td_sel_inside, + # td_validate, + # td_with_updated_data, ) from tidy3d._common.constants import ( AMP, @@ -116,51 +135,7 @@ def reflect(self, axis: Axis, center: float, reflection_only: bool = False) -> S Data after reflection is performed. """ - sorted_self = self._spatially_sorted - - coords = [sorted_self.x.values, sorted_self.y.values, sorted_self.z.values] - data = np.array(sorted_self.data) - - data_left_bound = coords[axis][0] - - if np.isclose(center, data_left_bound): - num_duplicates = 1 - elif center > data_left_bound: - raise DataError("Reflection center must be outside and to the left of the data region.") - else: - num_duplicates = 0 - - if reflection_only: - coords[axis] = 2 * center - coords[axis] - coords_dict = dict(zip("xyz", coords)) - - tmp_arr = SpatialDataArray(sorted_self.data, coords=coords_dict) - - return tmp_arr.sortby("xyz"[axis]) - - shape = np.array(np.shape(data)) - old_len = shape[axis] - shape[axis] = 2 * old_len - num_duplicates - - ind_left = [slice(shape[0]), slice(shape[1]), slice(shape[2])] - ind_right = [slice(shape[0]), slice(shape[1]), slice(shape[2])] - - ind_left[axis] = slice(old_len - 1, None, -1) - ind_right[axis] = slice(old_len - num_duplicates, None) - - new_data = np.zeros(shape) - - new_data[ind_left[0], ind_left[1], ind_left[2]] = data - new_data[ind_right[0], ind_right[1], ind_right[2]] = data - - new_coords = np.zeros(shape[axis]) - new_coords[old_len - num_duplicates :] = coords[axis] - new_coords[old_len - 1 :: -1] = 2 * center - coords[axis] - - coords[axis] = new_coords - coords_dict = dict(zip("xyz", coords)) - - return SpatialDataArray(new_data, coords=coords_dict) + return _reflect_data_array(self, axis, center, reflection_only=reflection_only) class ScalarFieldTimeDataArray(AbstractSpatialDataArray): @@ -908,7 +883,10 @@ def _make_base_result_data_array(result: DataArray) -> IntegralResultType: cls = TimeDataArray if "f" in result.coords and "mode_index" in result.coords: cls = FreqModeDataArray - return cls._assign_data_attrs(cls(data=result.data, coords=result.coords)) + spec = data_array_spec_for_type(cls) + return spec.validate_data_array( + DataArray(data=result.data, coords=result.coords, dims=result.dims) + ) def _make_voltage_data_array(result: DataArray) -> VoltageIntegralResultType: @@ -918,7 +896,10 @@ def _make_voltage_data_array(result: DataArray) -> VoltageIntegralResultType: cls = VoltageTimeDataArray if "f" in result.coords and "mode_index" in result.coords: cls = VoltageFreqModeDataArray - return cls._assign_data_attrs(cls(data=result.data, coords=result.coords)) + spec = data_array_spec_for_type(cls) + return spec.validate_data_array( + DataArray(data=result.data, coords=result.coords, dims=result.dims) + ) def _make_current_data_array(result: DataArray) -> CurrentIntegralResultType: @@ -928,7 +909,10 @@ def _make_current_data_array(result: DataArray) -> CurrentIntegralResultType: cls = CurrentTimeDataArray if "f" in result.coords and "mode_index" in result.coords: cls = CurrentFreqModeDataArray - return cls._assign_data_attrs(cls(data=result.data, coords=result.coords)) + spec = data_array_spec_for_type(cls) + return spec.validate_data_array( + DataArray(data=result.data, coords=result.coords, dims=result.dims) + ) def _make_impedance_data_array(result: DataArray) -> ImpedanceResultType: @@ -938,24 +922,40 @@ def _make_impedance_data_array(result: DataArray) -> ImpedanceResultType: cls = ImpedanceTimeDataArray if "f" in result.coords and "mode_index" in result.coords: cls = ImpedanceFreqModeDataArray - return cls._assign_data_attrs(cls(data=result.data, coords=result.coords)) + spec = data_array_spec_for_type(cls) + return spec.validate_data_array( + DataArray(data=result.data, coords=result.coords, dims=result.dims) + ) IndexedDataArrayTypes = Union[ - IndexedDataArray, - IndexedVoltageDataArray, - IndexedTimeDataArray, - IndexedFieldVoltageDataArray, - PointDataArray, + data_array_annotated_type(IndexedDataArray), + data_array_annotated_type(IndexedVoltageDataArray), + data_array_annotated_type(IndexedTimeDataArray), + data_array_annotated_type(IndexedFieldVoltageDataArray), + data_array_annotated_type(PointDataArray), +] + +IntegralResultType = Union[ + data_array_annotated_type(FreqDataArray), + data_array_annotated_type(FreqModeDataArray), + data_array_annotated_type(TimeDataArray), ] -IntegralResultType = Union[FreqDataArray, FreqModeDataArray, TimeDataArray] VoltageIntegralResultType = Union[ - VoltageFreqDataArray, VoltageFreqModeDataArray, VoltageTimeDataArray + data_array_annotated_type(VoltageFreqDataArray), + data_array_annotated_type(VoltageFreqModeDataArray), + data_array_annotated_type(VoltageTimeDataArray), ] + CurrentIntegralResultType = Union[ - CurrentFreqDataArray, CurrentFreqModeDataArray, CurrentTimeDataArray + data_array_annotated_type(CurrentFreqDataArray), + data_array_annotated_type(CurrentFreqModeDataArray), + data_array_annotated_type(CurrentTimeDataArray), ] + ImpedanceResultType = Union[ - ImpedanceFreqDataArray, ImpedanceFreqModeDataArray, ImpedanceTimeDataArray + data_array_annotated_type(ImpedanceFreqDataArray), + data_array_annotated_type(ImpedanceFreqModeDataArray), + data_array_annotated_type(ImpedanceTimeDataArray), ] diff --git a/tidy3d/components/data/dataset.py b/tidy3d/components/data/dataset.py index 9c36c993d9..d6bc9ce9ed 100644 --- a/tidy3d/components/data/dataset.py +++ b/tidy3d/components/data/dataset.py @@ -33,6 +33,7 @@ ScalarModeFieldCylindricalDataArray, ScalarModeFieldDataArray, TimeDataArray, + data_array_annotated_type, ) from tidy3d.components.data.zbf import ZBFData from tidy3d.components.types.base import xyz @@ -188,12 +189,12 @@ def _apply_mode_reorder(self, sort_inds_2d: np.ndarray) -> Self: EMScalarFieldType = Union[ - ScalarFieldDataArray, - ScalarFieldTimeDataArray, - ScalarModeFieldDataArray, - ScalarModeFieldCylindricalDataArray, - EMEScalarModeFieldDataArray, - EMEScalarFieldDataArray, + data_array_annotated_type(ScalarFieldDataArray), + data_array_annotated_type(ScalarFieldTimeDataArray), + data_array_annotated_type(ScalarModeFieldDataArray), + data_array_annotated_type(ScalarModeFieldCylindricalDataArray), + data_array_annotated_type(EMEScalarModeFieldDataArray), + data_array_annotated_type(EMEScalarFieldDataArray), ] @@ -277,32 +278,32 @@ class FieldDataset(ElectromagneticFieldDataset): >>> data = FieldDataset(Ex=scalar_field, Hz=scalar_field) """ - Ex: Optional[ScalarFieldDataArray] = Field( + Ex: Optional[data_array_annotated_type(ScalarFieldDataArray)] = Field( None, title="Ex", description="Spatial distribution of the x-component of the electric field.", ) - Ey: Optional[ScalarFieldDataArray] = Field( + Ey: Optional[data_array_annotated_type(ScalarFieldDataArray)] = Field( None, title="Ey", description="Spatial distribution of the y-component of the electric field.", ) - Ez: Optional[ScalarFieldDataArray] = Field( + Ez: Optional[data_array_annotated_type(ScalarFieldDataArray)] = Field( None, title="Ez", description="Spatial distribution of the z-component of the electric field.", ) - Hx: Optional[ScalarFieldDataArray] = Field( + Hx: Optional[data_array_annotated_type(ScalarFieldDataArray)] = Field( None, title="Hx", description="Spatial distribution of the x-component of the magnetic field.", ) - Hy: Optional[ScalarFieldDataArray] = Field( + Hy: Optional[data_array_annotated_type(ScalarFieldDataArray)] = Field( None, title="Hy", description="Spatial distribution of the y-component of the magnetic field.", ) - Hz: Optional[ScalarFieldDataArray] = Field( + Hz: Optional[data_array_annotated_type(ScalarFieldDataArray)] = Field( None, title="Hz", description="Spatial distribution of the z-component of the magnetic field.", @@ -409,32 +410,32 @@ class FieldTimeDataset(ElectromagneticFieldDataset): >>> data = FieldTimeDataset(Ex=scalar_field, Hz=scalar_field) """ - Ex: Optional[ScalarFieldTimeDataArray] = Field( + Ex: Optional[data_array_annotated_type(ScalarFieldTimeDataArray)] = Field( None, title="Ex", description="Spatial distribution of the x-component of the electric field.", ) - Ey: Optional[ScalarFieldTimeDataArray] = Field( + Ey: Optional[data_array_annotated_type(ScalarFieldTimeDataArray)] = Field( None, title="Ey", description="Spatial distribution of the y-component of the electric field.", ) - Ez: Optional[ScalarFieldTimeDataArray] = Field( + Ez: Optional[data_array_annotated_type(ScalarFieldTimeDataArray)] = Field( None, title="Ez", description="Spatial distribution of the z-component of the electric field.", ) - Hx: Optional[ScalarFieldTimeDataArray] = Field( + Hx: Optional[data_array_annotated_type(ScalarFieldTimeDataArray)] = Field( None, title="Hx", description="Spatial distribution of the x-component of the magnetic field.", ) - Hy: Optional[ScalarFieldTimeDataArray] = Field( + Hy: Optional[data_array_annotated_type(ScalarFieldTimeDataArray)] = Field( None, title="Hy", description="Spatial distribution of the y-component of the magnetic field.", ) - Hz: Optional[ScalarFieldTimeDataArray] = Field( + Hz: Optional[data_array_annotated_type(ScalarFieldTimeDataArray)] = Field( None, title="Hz", description="Spatial distribution of the z-component of the magnetic field.", @@ -511,19 +512,19 @@ class AuxFieldTimeDataset(AuxFieldDataset): >>> data = AuxFieldTimeDataset(Nfx=scalar_field) """ - Nfx: Optional[ScalarFieldTimeDataArray] = Field( + Nfx: Optional[data_array_annotated_type(ScalarFieldTimeDataArray)] = Field( None, title="Nfx", description="Spatial distribution of the free carrier density for polarization " "in the x-direction.", ) - Nfy: Optional[ScalarFieldTimeDataArray] = Field( + Nfy: Optional[data_array_annotated_type(ScalarFieldTimeDataArray)] = Field( None, title="Nfy", description="Spatial distribution of the free carrier density for polarization " "in the y-direction.", ) - Nfz: Optional[ScalarFieldTimeDataArray] = Field( + Nfz: Optional[data_array_annotated_type(ScalarFieldTimeDataArray)] = Field( None, title="Nfz", description="Spatial distribution of the free carrier density for polarization " @@ -557,50 +558,50 @@ class ModeSolverDataset(ElectromagneticFieldDataset, ModeFreqDataset): ... ) """ - Ex: Optional[ScalarModeFieldDataArray] = Field( + Ex: Optional[data_array_annotated_type(ScalarModeFieldDataArray)] = Field( None, title="Ex", description="Spatial distribution of the x-component of the electric field of the mode.", ) - Ey: Optional[ScalarModeFieldDataArray] = Field( + Ey: Optional[data_array_annotated_type(ScalarModeFieldDataArray)] = Field( None, title="Ey", description="Spatial distribution of the y-component of the electric field of the mode.", ) - Ez: Optional[ScalarModeFieldDataArray] = Field( + Ez: Optional[data_array_annotated_type(ScalarModeFieldDataArray)] = Field( None, title="Ez", description="Spatial distribution of the z-component of the electric field of the mode.", ) - Hx: Optional[ScalarModeFieldDataArray] = Field( + Hx: Optional[data_array_annotated_type(ScalarModeFieldDataArray)] = Field( None, title="Hx", description="Spatial distribution of the x-component of the magnetic field of the mode.", ) - Hy: Optional[ScalarModeFieldDataArray] = Field( + Hy: Optional[data_array_annotated_type(ScalarModeFieldDataArray)] = Field( None, title="Hy", description="Spatial distribution of the y-component of the magnetic field of the mode.", ) - Hz: Optional[ScalarModeFieldDataArray] = Field( + Hz: Optional[data_array_annotated_type(ScalarModeFieldDataArray)] = Field( None, title="Hz", description="Spatial distribution of the z-component of the magnetic field of the mode.", ) - n_complex: ModeIndexDataArray = Field( + n_complex: data_array_annotated_type(ModeIndexDataArray) = Field( title="Propagation Index", description="Complex-valued effective propagation constants associated with the mode.", ) - n_group_raw: Optional[GroupIndexDataArray] = Field( + n_group_raw: Optional[data_array_annotated_type(GroupIndexDataArray)] = Field( None, alias="n_group", # This is for backwards compatibility only when loading old data title="Group Index", description="Index associated with group velocity of the mode.", ) - dispersion_raw: Optional[ModeDispersionDataArray] = Field( + dispersion_raw: Optional[data_array_annotated_type(ModeDispersionDataArray)] = Field( None, title="Dispersion", description="Dispersion parameter for the mode.", @@ -669,15 +670,15 @@ def plot_field(self, *args: Any, **kwargs: Any) -> None: class AbstractMediumPropertyDataset(AbstractFieldDataset, ABC): """Dataset storing medium property.""" - eps_xx: ScalarFieldDataArray = Field( + eps_xx: data_array_annotated_type(ScalarFieldDataArray) = Field( title="Epsilon xx", description="Spatial distribution of the xx-component of the relative permittivity.", ) - eps_yy: ScalarFieldDataArray = Field( + eps_yy: data_array_annotated_type(ScalarFieldDataArray) = Field( title="Epsilon yy", description="Spatial distribution of the yy-component of the relative permittivity.", ) - eps_zz: ScalarFieldDataArray = Field( + eps_zz: data_array_annotated_type(ScalarFieldDataArray) = Field( title="Epsilon zz", description="Spatial distribution of the zz-component of the relative permittivity.", ) @@ -727,15 +728,15 @@ class MediumDataset(AbstractMediumPropertyDataset): >>> data = MediumDataset(eps_xx=sclr_fld, eps_yy=sclr_fld, eps_zz=sclr_fld, mu_xx=sclr_fld, mu_yy=sclr_fld, mu_zz=sclr_fld) """ - mu_xx: ScalarFieldDataArray = Field( + mu_xx: data_array_annotated_type(ScalarFieldDataArray) = Field( title="Mu xx", description="Spatial distribution of the xx-component of the relative permeability.", ) - mu_yy: ScalarFieldDataArray = Field( + mu_yy: data_array_annotated_type(ScalarFieldDataArray) = Field( title="Mu yy", description="Spatial distribution of the yy-component of the relative permeability.", ) - mu_zz: ScalarFieldDataArray = Field( + mu_zz: data_array_annotated_type(ScalarFieldDataArray) = Field( title="Mu zz", description="Spatial distribution of the zz-component of the relative permeability.", ) diff --git a/tidy3d/components/data/monitor_data.py b/tidy3d/components/data/monitor_data.py index dd79c719e0..42bf11701c 100644 --- a/tidy3d/components/data/monitor_data.py +++ b/tidy3d/components/data/monitor_data.py @@ -15,6 +15,7 @@ from tidy3d.components.base import cached_property from tidy3d.components.base_sim.data.monitor_data import AbstractMonitorData +from tidy3d.components.data.data_array import data_array_annotated_type from tidy3d.components.grid.grid import Coords, Grid from tidy3d.components.medium import Medium, MediumType from tidy3d.components.monitor import ( @@ -120,10 +121,10 @@ GRID_CORRECTION_TYPE = Union[ float, - FreqDataArray, - TimeDataArray, - FreqModeDataArray, - EMEFreqModeDataArray, + data_array_annotated_type(FreqDataArray), + data_array_annotated_type(TimeDataArray), + data_array_annotated_type(FreqModeDataArray), + data_array_annotated_type(EMEFreqModeDataArray), ] @@ -1697,7 +1698,7 @@ class ModeData(ModeSolverDataset, ElectromagneticFieldData): description="Mode monitor associated with the data.", ) - amps: ModeAmpsDataArray = Field( + amps: data_array_annotated_type(ModeAmpsDataArray) = Field( title="Amplitudes", description="Complex-valued amplitudes associated with the mode.", ) @@ -2492,7 +2493,7 @@ class ModeSolverData(ModeData): description="Mode solver monitor associated with the data.", ) - amps: Optional[ModeAmpsDataArray] = Field( + amps: Optional[data_array_annotated_type(ModeAmpsDataArray)] = Field( None, title="Amplitudes", description="Unused for ModeSolverData.", @@ -2802,7 +2803,7 @@ class FluxData(MonitorData): description="Frequency-domain flux monitor associated with the data.", ) - flux: FluxDataArray = Field( + flux: data_array_annotated_type(FluxDataArray) = Field( title="Flux", description="Flux values in the frequency-domain.", ) @@ -2858,17 +2859,17 @@ class FluxTimeData(MonitorData): description="Time-domain flux monitor associated with the data.", ) - flux: FluxTimeDataArray = Field( + flux: data_array_annotated_type(FluxTimeDataArray) = Field( title="Flux", description="Flux values in the time-domain.", ) ProjFieldType = Union[ - FieldProjectionAngleDataArray, - FieldProjectionCartesianDataArray, - FieldProjectionKSpaceDataArray, - DiffractionDataArray, + data_array_annotated_type(FieldProjectionAngleDataArray), + data_array_annotated_type(FieldProjectionCartesianDataArray), + data_array_annotated_type(FieldProjectionKSpaceDataArray), + data_array_annotated_type(DiffractionDataArray), ] ProjMonitorType = Union[ @@ -3192,27 +3193,27 @@ class FieldProjectionAngleData(AbstractFieldProjectionData): description="Surfaces of the monitor where near fields were recorded for projection", ) - Er: FieldProjectionAngleDataArray = Field( + Er: data_array_annotated_type(FieldProjectionAngleDataArray) = Field( title="Er", description="Spatial distribution of r-component of the electric field.", ) - Etheta: FieldProjectionAngleDataArray = Field( + Etheta: data_array_annotated_type(FieldProjectionAngleDataArray) = Field( title="Etheta", description="Spatial distribution of the theta-component of the electric field.", ) - Ephi: FieldProjectionAngleDataArray = Field( + Ephi: data_array_annotated_type(FieldProjectionAngleDataArray) = Field( title="Ephi", description="Spatial distribution of phi-component of the electric field.", ) - Hr: FieldProjectionAngleDataArray = Field( + Hr: data_array_annotated_type(FieldProjectionAngleDataArray) = Field( title="Hr", description="Spatial distribution of r-component of the magnetic field.", ) - Htheta: FieldProjectionAngleDataArray = Field( + Htheta: data_array_annotated_type(FieldProjectionAngleDataArray) = Field( title="Htheta", description="Spatial distribution of theta-component of the magnetic field.", ) - Hphi: FieldProjectionAngleDataArray = Field( + Hphi: data_array_annotated_type(FieldProjectionAngleDataArray) = Field( title="Hphi", description="Spatial distribution of phi-component of the magnetic field.", ) @@ -3394,27 +3395,27 @@ class FieldProjectionCartesianData(AbstractFieldProjectionData): description="Surfaces of the monitor where near fields were recorded for projection", ) - Er: FieldProjectionCartesianDataArray = Field( + Er: data_array_annotated_type(FieldProjectionCartesianDataArray) = Field( title="Er", description="Spatial distribution of r-component of the electric field.", ) - Etheta: FieldProjectionCartesianDataArray = Field( + Etheta: data_array_annotated_type(FieldProjectionCartesianDataArray) = Field( title="Etheta", description="Spatial distribution of the theta-component of the electric field.", ) - Ephi: FieldProjectionCartesianDataArray = Field( + Ephi: data_array_annotated_type(FieldProjectionCartesianDataArray) = Field( title="Ephi", description="Spatial distribution of phi-component of the electric field.", ) - Hr: FieldProjectionCartesianDataArray = Field( + Hr: data_array_annotated_type(FieldProjectionCartesianDataArray) = Field( title="Hr", description="Spatial distribution of r-component of the magnetic field.", ) - Htheta: FieldProjectionCartesianDataArray = Field( + Htheta: data_array_annotated_type(FieldProjectionCartesianDataArray) = Field( title="Htheta", description="Spatial distribution of theta-component of the magnetic field.", ) - Hphi: FieldProjectionCartesianDataArray = Field( + Hphi: data_array_annotated_type(FieldProjectionCartesianDataArray) = Field( title="Hphi", description="Spatial distribution of phi-component of the magnetic field.", ) @@ -3539,27 +3540,27 @@ class FieldProjectionKSpaceData(AbstractFieldProjectionData): description="Surfaces of the monitor where near fields were recorded for projection", ) - Er: FieldProjectionKSpaceDataArray = Field( + Er: data_array_annotated_type(FieldProjectionKSpaceDataArray) = Field( title="Er", description="Spatial distribution of r-component of the electric field.", ) - Etheta: FieldProjectionKSpaceDataArray = Field( + Etheta: data_array_annotated_type(FieldProjectionKSpaceDataArray) = Field( title="Etheta", description="Spatial distribution of the theta-component of the electric field.", ) - Ephi: FieldProjectionKSpaceDataArray = Field( + Ephi: data_array_annotated_type(FieldProjectionKSpaceDataArray) = Field( title="Ephi", description="Spatial distribution of phi-component of the electric field.", ) - Hr: FieldProjectionKSpaceDataArray = Field( + Hr: data_array_annotated_type(FieldProjectionKSpaceDataArray) = Field( title="Hr", description="Spatial distribution of r-component of the magnetic field.", ) - Htheta: FieldProjectionKSpaceDataArray = Field( + Htheta: data_array_annotated_type(FieldProjectionKSpaceDataArray) = Field( title="Htheta", description="Spatial distribution of theta-component of the magnetic field.", ) - Hphi: FieldProjectionKSpaceDataArray = Field( + Hphi: data_array_annotated_type(FieldProjectionKSpaceDataArray) = Field( title="Hphi", description="Spatial distribution of phi-component of the magnetic field.", ) @@ -3666,27 +3667,27 @@ class DiffractionData(AbstractFieldProjectionData): description="Diffraction monitor associated with the data.", ) - Er: DiffractionDataArray = Field( + Er: data_array_annotated_type(DiffractionDataArray) = Field( title="Er", description="Spatial distribution of r-component of the electric field.", ) - Etheta: DiffractionDataArray = Field( + Etheta: data_array_annotated_type(DiffractionDataArray) = Field( title="Etheta", description="Spatial distribution of the theta-component of the electric field.", ) - Ephi: DiffractionDataArray = Field( + Ephi: data_array_annotated_type(DiffractionDataArray) = Field( title="Ephi", description="Spatial distribution of phi-component of the electric field.", ) - Hr: DiffractionDataArray = Field( + Hr: data_array_annotated_type(DiffractionDataArray) = Field( title="Hr", description="Spatial distribution of r-component of the magnetic field.", ) - Htheta: DiffractionDataArray = Field( + Htheta: data_array_annotated_type(DiffractionDataArray) = Field( title="Htheta", description="Spatial distribution of theta-component of the magnetic field.", ) - Hphi: DiffractionDataArray = Field( + Hphi: data_array_annotated_type(DiffractionDataArray) = Field( title="Hphi", description="Spatial distribution of phi-component of the magnetic field.", ) @@ -4008,7 +4009,7 @@ class DirectivityData(FieldProjectionAngleData): description="Monitor describing the angle-based projection grid on which to measure directivity data.", ) - flux: FluxDataArray = Field( + flux: data_array_annotated_type(FluxDataArray) = Field( title="Flux", description="Flux values that are either computed from fields recorded on the " "projection surfaces or by integrating the projected fields over a spherical surface.", diff --git a/tidy3d/components/data/sim_data.py b/tidy3d/components/data/sim_data.py index de839cfd87..07c531e991 100644 --- a/tidy3d/components/data/sim_data.py +++ b/tidy3d/components/data/sim_data.py @@ -17,6 +17,7 @@ from tidy3d.components.autograd.utils import split_list from tidy3d.components.base import JSON_TAG, Tidy3dBaseModel, cached_property from tidy3d.components.base_sim.data.sim_data import AbstractSimulationData +from tidy3d.components.data.data_array import data_array_annotated_type from tidy3d.components.simulation import Simulation from tidy3d.components.source.current import CustomCurrentSource from tidy3d.components.source.time import GaussianPulse @@ -68,7 +69,7 @@ class AdjointSourceInfo(Tidy3dBaseModel): description="Set of processed sources to include in the adjoint simulation.", ) - post_norm: Union[float, FreqDataArray] = Field( + post_norm: Union[float, data_array_annotated_type(FreqDataArray)] = Field( title="Post Normalization Values", description="Factor to multiply the adjoint fields by after running " "given the adjoint source pipeline used.", diff --git a/tidy3d/components/data/unstructured/base.py b/tidy3d/components/data/unstructured/base.py index b166c87c07..6dfc9ed882 100644 --- a/tidy3d/components/data/unstructured/base.py +++ b/tidy3d/components/data/unstructured/base.py @@ -13,12 +13,13 @@ from tidy3d.components.base import cached_property from tidy3d.components.data.data_array import ( - DATA_ARRAY_MAP, CellDataArray, IndexedDataArray, IndexedDataArrayTypes, PointDataArray, SpatialDataArray, + data_array_annotated_type, + is_data_array_name, ) from tidy3d.components.data.dataset import Dataset from tidy3d.constants import inf @@ -53,7 +54,7 @@ class UnstructuredGridDataset(Dataset, np.lib.mixins.NDArrayOperatorsMixin, ABC): """Abstract base for datasets that store unstructured grid data.""" - points: PointDataArray = Field( + points: data_array_annotated_type(PointDataArray) = Field( title="Grid Points", description="Coordinates of points composing the unstructured grid.", ) @@ -63,7 +64,7 @@ class UnstructuredGridDataset(Dataset, np.lib.mixins.NDArrayOperatorsMixin, ABC) description="Values stored at the grid points.", ) - cells: CellDataArray = Field( + cells: data_array_annotated_type(CellDataArray) = Field( title="Grid Cells", description="Cells composing the unstructured grid specified as connections between grid " "points.", @@ -219,7 +220,7 @@ def _warn_if_none(cls, data: Any) -> Any: no_data_fields = [] for field_name in ["points", "cells", "values"]: field = data.get(field_name) - if isinstance(field, str) and field in DATA_ARRAY_MAP.keys(): + if isinstance(field, str) and is_data_array_name(field): no_data_fields.append(field_name) if len(no_data_fields) > 0: diff --git a/tidy3d/components/data/utils.py b/tidy3d/components/data/utils.py index ada0a4e77a..4ed1cc7ff1 100644 --- a/tidy3d/components/data/utils.py +++ b/tidy3d/components/data/utils.py @@ -9,7 +9,7 @@ from tidy3d.components.types.base import discriminated_union -from .data_array import SpatialDataArray +from .data_array import SpatialDataArray, data_array_annotated_type from .unstructured.base import UnstructuredGridDataset from .unstructured.tetrahedral import TetrahedralGridDataset from .unstructured.triangular import TriangularGridDataset @@ -21,10 +21,13 @@ UnstructuredGridDatasetType = Union[TriangularGridDataset, TetrahedralGridDataset] -CustomSpatialDataType = Union[SpatialDataArray, UnstructuredGridDatasetType] +CustomSpatialDataType = Union[ + data_array_annotated_type(SpatialDataArray), + UnstructuredGridDatasetType, +] CustomSpatialDataTypeAnnotated = Union[ discriminated_union(UnstructuredGridDatasetType), - SpatialDataArray, + data_array_annotated_type(SpatialDataArray), ] diff --git a/tidy3d/components/eme/data/dataset.py b/tidy3d/components/eme/data/dataset.py index db2383ad4c..84aeaddb4c 100644 --- a/tidy3d/components/eme/data/dataset.py +++ b/tidy3d/components/eme/data/dataset.py @@ -16,6 +16,7 @@ EMEScalarFieldDataArray, EMEScalarModeFieldDataArray, EMESMatrixDataArray, + data_array_annotated_type, ) from tidy3d.components.data.dataset import Dataset, ElectromagneticFieldDataset from tidy3d.exceptions import ValidationError @@ -24,19 +25,19 @@ class EMESMatrixDataset(Dataset): """Dataset storing S matrix.""" - S11: EMESMatrixDataArray = Field( + S11: data_array_annotated_type(EMESMatrixDataArray) = Field( title="S11 matrix", description="S matrix relating output modes at port 1 to input modes at port 1.", ) - S12: EMESMatrixDataArray = Field( + S12: data_array_annotated_type(EMESMatrixDataArray) = Field( title="S12 matrix", description="S matrix relating output modes at port 1 to input modes at port 2.", ) - S21: EMESMatrixDataArray = Field( + S21: data_array_annotated_type(EMESMatrixDataArray) = Field( title="S21 matrix", description="S matrix relating output modes at port 2 to input modes at port 1.", ) - S22: EMESMatrixDataArray = Field( + S22: data_array_annotated_type(EMESMatrixDataArray) = Field( title="S22 matrix", description="S matrix relating output modes at port 2 to input modes at port 2.", ) @@ -45,19 +46,19 @@ class EMESMatrixDataset(Dataset): class EMEInterfaceSMatrixDataset(Dataset): """Dataset storing S matrices associated with EME cell interfaces.""" - S11: EMEInterfaceSMatrixDataArray = Field( + S11: data_array_annotated_type(EMEInterfaceSMatrixDataArray) = Field( title="S11 matrix", description="S matrix relating output modes at port 1 to input modes at port 1.", ) - S12: EMEInterfaceSMatrixDataArray = Field( + S12: data_array_annotated_type(EMEInterfaceSMatrixDataArray) = Field( title="S12 matrix", description="S matrix relating output modes at port 1 to input modes at port 2.", ) - S21: EMEInterfaceSMatrixDataArray = Field( + S21: data_array_annotated_type(EMEInterfaceSMatrixDataArray) = Field( title="S21 matrix", description="S matrix relating output modes at port 2 to input modes at port 1.", ) - S22: EMEInterfaceSMatrixDataArray = Field( + S22: data_array_annotated_type(EMEInterfaceSMatrixDataArray) = Field( title="S22 matrix", description="S matrix relating output modes at port 2 to input modes at port 2.", ) @@ -73,15 +74,15 @@ class EMEOverlapDataset(Dataset): in cell ``i``, and ``mode_index_in`` refers to the mode index in cell ``j``. """ - O11: EMEInterfaceSMatrixDataArray = Field( + O11: data_array_annotated_type(EMEInterfaceSMatrixDataArray) = Field( title="O11 matrix", description="Overlap integral between E field and H field in the same cell.", ) - O12: EMEInterfaceSMatrixDataArray = Field( + O12: data_array_annotated_type(EMEInterfaceSMatrixDataArray) = Field( title="O12 matrix", description="Overlap integral between E field on side 1 and H field on side 2.", ) - O21: EMEInterfaceSMatrixDataArray = Field( + O21: data_array_annotated_type(EMEInterfaceSMatrixDataArray) = Field( title="O21 matrix", description="Overlap integral between E field on side 2 and H field on side 1.", ) @@ -102,25 +103,25 @@ class EMECoefficientDataset(Dataset): between EME cells. """ - A: Optional[EMECoefficientDataArray] = Field( + A: Optional[data_array_annotated_type(EMECoefficientDataArray)] = Field( None, title="A coefficient", description="Coefficient for forward mode in this cell.", ) - B: Optional[EMECoefficientDataArray] = Field( + B: Optional[data_array_annotated_type(EMECoefficientDataArray)] = Field( None, title="B coefficient", description="Coefficient for backward mode in this cell.", ) - n_complex: Optional[EMEModeIndexDataArray] = Field( + n_complex: Optional[data_array_annotated_type(EMEModeIndexDataArray)] = Field( None, title="Propagation Index", description="Complex-valued effective propagation indices associated with the EME modes.", ) - flux: Optional[EMEFluxDataArray] = Field( + flux: Optional[data_array_annotated_type(EMEFluxDataArray)] = Field( None, title="Flux", description="Power flux of the EME modes.", @@ -176,32 +177,32 @@ def normalized_copy(self) -> EMECoefficientDataset: class EMEFieldDataset(ElectromagneticFieldDataset): """Dataset storing scalar components of E and H fields as a function of freq, mode_index, and port_index.""" - Ex: Optional[EMEScalarFieldDataArray] = Field( + Ex: Optional[data_array_annotated_type(EMEScalarFieldDataArray)] = Field( None, title="Ex", description="Spatial distribution of the x-component of the electric field of the mode.", ) - Ey: Optional[EMEScalarFieldDataArray] = Field( + Ey: Optional[data_array_annotated_type(EMEScalarFieldDataArray)] = Field( None, title="Ey", description="Spatial distribution of the y-component of the electric field of the mode.", ) - Ez: Optional[EMEScalarFieldDataArray] = Field( + Ez: Optional[data_array_annotated_type(EMEScalarFieldDataArray)] = Field( None, title="Ez", description="Spatial distribution of the z-component of the electric field of the mode.", ) - Hx: Optional[EMEScalarFieldDataArray] = Field( + Hx: Optional[data_array_annotated_type(EMEScalarFieldDataArray)] = Field( None, title="Hx", description="Spatial distribution of the x-component of the magnetic field of the mode.", ) - Hy: Optional[EMEScalarFieldDataArray] = Field( + Hy: Optional[data_array_annotated_type(EMEScalarFieldDataArray)] = Field( None, title="Hy", description="Spatial distribution of the y-component of the magnetic field of the mode.", ) - Hz: Optional[EMEScalarFieldDataArray] = Field( + Hz: Optional[data_array_annotated_type(EMEScalarFieldDataArray)] = Field( None, title="Hz", description="Spatial distribution of the z-component of the magnetic field of the mode.", @@ -211,32 +212,32 @@ class EMEFieldDataset(ElectromagneticFieldDataset): class EMEModeSolverDataset(ElectromagneticFieldDataset): """Dataset storing EME modes as a function of freq, mode_index, and cell_index.""" - n_complex: EMEModeIndexDataArray = Field( + n_complex: data_array_annotated_type(EMEModeIndexDataArray) = Field( title="Propagation Index", description="Complex-valued effective propagation constants associated with the mode.", ) - Ex: EMEScalarModeFieldDataArray = Field( + Ex: data_array_annotated_type(EMEScalarModeFieldDataArray) = Field( title="Ex", description="Spatial distribution of the x-component of the electric field of the mode.", ) - Ey: EMEScalarModeFieldDataArray = Field( + Ey: data_array_annotated_type(EMEScalarModeFieldDataArray) = Field( title="Ey", description="Spatial distribution of the y-component of the electric field of the mode.", ) - Ez: EMEScalarModeFieldDataArray = Field( + Ez: data_array_annotated_type(EMEScalarModeFieldDataArray) = Field( title="Ez", description="Spatial distribution of the z-component of the electric field of the mode.", ) - Hx: EMEScalarModeFieldDataArray = Field( + Hx: data_array_annotated_type(EMEScalarModeFieldDataArray) = Field( title="Hx", description="Spatial distribution of the x-component of the magnetic field of the mode.", ) - Hy: EMEScalarModeFieldDataArray = Field( + Hy: data_array_annotated_type(EMEScalarModeFieldDataArray) = Field( title="Hy", description="Spatial distribution of the y-component of the magnetic field of the mode.", ) - Hz: EMEScalarModeFieldDataArray = Field( + Hz: data_array_annotated_type(EMEScalarModeFieldDataArray) = Field( title="Hz", description="Spatial distribution of the z-component of the magnetic field of the mode.", ) diff --git a/tidy3d/components/grid/grid.py b/tidy3d/components/grid/grid.py index fd7d291614..bb039315ea 100644 --- a/tidy3d/components/grid/grid.py +++ b/tidy3d/components/grid/grid.py @@ -8,7 +8,12 @@ from pydantic import Field from tidy3d.components.base import Tidy3dBaseModel, cached_property -from tidy3d.components.data.data_array import DataArray, ScalarFieldDataArray, SpatialDataArray +from tidy3d.components.data.data_array import ( + DataArray, + ScalarFieldDataArray, + SpatialDataArray, + _isinstance, +) from tidy3d.components.data.utils import UnstructuredGridDataset from tidy3d.components.types import ArrayFloat1D from tidy3d.exceptions import SetupError @@ -261,7 +266,7 @@ def spatial_interp( # Check for empty dimensions result_coords = dict(self.to_dict) if any(len(v) == 0 for v in result_coords.values()): - if isinstance(array, (SpatialDataArray, ScalarFieldDataArray)): + if _isinstance(array, SpatialDataArray) or _isinstance(array, ScalarFieldDataArray): for c in array.coords: if c not in result_coords: result_coords[c] = array.coords[c].values diff --git a/tidy3d/components/medium.py b/tidy3d/components/medium.py index 4ab3e1292d..0a2a223cc0 100644 --- a/tidy3d/components/medium.py +++ b/tidy3d/components/medium.py @@ -44,11 +44,16 @@ from .autograd.derivative_utils import integrate_within_bounds from .autograd.types import TracedFloat, TracedPolesAndResidues, TracedPositiveFloat from .base import Tidy3dBaseModel, cached_property -from .data.data_array import DATA_ARRAY_MAP, ScalarFieldDataArray, SpatialDataArray +from .data.data_array import ( + ScalarFieldDataArray, + SpatialDataArray, + _isinstance, + _spatially_sorted_data_array, + is_data_array_name, +) from .data.dataset import PermittivityDataset from .data.unstructured.base import UnstructuredGridDataset from .data.utils import ( - CustomSpatialDataType, CustomSpatialDataTypeAnnotated, _check_same_coordinates, _get_numpy_array, @@ -94,6 +99,7 @@ from .autograd.derivative_utils import DerivativeInfo from .autograd.types import AutogradFieldMap from .data.dataset import ElectromagneticFieldDataset + from .data.utils import CustomSpatialDataType from .transformation import RotationType from .types import ( ArrayComplex1D, @@ -1103,7 +1109,7 @@ def sel_inside(self, bounds: Bound) -> AbstractCustomMedium: @staticmethod def _not_loaded(field: Any) -> bool: """Check whether data was not loaded.""" - if isinstance(field, str) and field in DATA_ARRAY_MAP: + if isinstance(field, str) and is_data_array_name(field): return True # attempting to construct an UnstructuredGridDataset from a dict if isinstance(field, dict) and field.get("type") in ( @@ -1111,7 +1117,7 @@ def _not_loaded(field: Any) -> bool: "TetrahedralGridDataset", ): return any( - isinstance(subfield, str) and subfield in DATA_ARRAY_MAP + isinstance(subfield, str) and is_data_array_name(subfield) for subfield in [field["points"], field["cells"], field["values"]] ) # attempting to pass an UnstructuredGridDataset with zero points @@ -1713,7 +1719,7 @@ def _warn_if_none(cls, data: dict) -> dict: fail_load = True eps_ds = data.get("eps_dataset") if isinstance(eps_ds, dict): - if any(isinstance(v, str) and v in DATA_ARRAY_MAP for v in eps_ds.values()): + if any(isinstance(v, str) and is_data_array_name(v) for v in eps_ds.values()): log.warning( "Loading 'eps_dataset' without data; constructing a vacuum medium instead." ) @@ -1920,7 +1926,7 @@ def _check_permittivity_conductivity_interpolate( ) -> Optional[CustomSpatialDataType]: """Check that the custom medium 'SpatialDataArrays' can be interpolated.""" - if isinstance(val, SpatialDataArray): + if _isinstance(val, SpatialDataArray): val._interp_validator(info.field_name) return val @@ -1944,14 +1950,14 @@ def _permittivity_sorted(self) -> SpatialDataArray | None: """Cached copy of permittivity sorted along spatial axes.""" if self.permittivity is None: return None - return self.permittivity._spatially_sorted + return _spatially_sorted_data_array(self.permittivity) @cached_property def _conductivity_sorted(self) -> SpatialDataArray | None: """Cached copy of conductivity sorted along spatial axes.""" if self.conductivity is None: return None - return self.conductivity._spatially_sorted + return _spatially_sorted_data_array(self.conductivity) @cached_property def _eps_components_sorted(self) -> dict[str, ScalarFieldDataArray]: @@ -1959,7 +1965,8 @@ def _eps_components_sorted(self) -> dict[str, ScalarFieldDataArray]: if self.eps_dataset is None: return {} return { - key: comp._spatially_sorted for key, comp in self.eps_dataset.field_components.items() + key: _spatially_sorted_data_array(comp) + for key, comp in self.eps_dataset.field_components.items() } @cached_property @@ -2149,7 +2156,7 @@ def from_eps_raw( :class:`.CustomMedium` Medium containing the spatially varying permittivity data. """ - if isinstance(eps, CustomSpatialDataType.__args__): + if _isinstance(eps, SpatialDataArray) or isinstance(eps, UnstructuredGridDataset): # purely real, not need to know `freq` if CustomMedium._validate_isreal_dataarray(eps): return cls(permittivity=eps, interp_method=interp_method, **kwargs) @@ -2229,7 +2236,7 @@ def from_nk( """ # lossless if k is None: - if isinstance(n, ScalarFieldDataArray): + if _isinstance(n, ScalarFieldDataArray): n = SpatialDataArray(n.squeeze(dim="f", drop=True)) freq = 0 # dummy value eps_real, _ = CustomMedium.nk_to_eps_sigma(n, 0 * n, freq) @@ -2240,7 +2247,7 @@ def from_nk( raise SetupError("'n' and 'k' must be of the same type and must have same coordinates.") # k is a SpatialDataArray - if isinstance(k, CustomSpatialDataType.__args__): + if _isinstance(k, SpatialDataArray) or isinstance(k, UnstructuredGridDataset): if freq is None: raise SetupError( "For a lossy medium, must supply 'freq' at which to convert 'n' " @@ -4400,7 +4407,7 @@ def _all_larger( coeff_b: tuple[tuple[CustomSpatialDataType, CustomSpatialDataType], ...], ) -> bool: """``coeff_a`` and ``coeff_b`` can be either float or SpatialDataArray.""" - if isinstance(coeff_a, CustomSpatialDataType.__args__): + if _isinstance(coeff_a, SpatialDataArray) or isinstance(coeff_a, UnstructuredGridDataset): return np.all(_get_numpy_array(coeff_a) > _get_numpy_array(coeff_b)) return coeff_a > coeff_b @@ -7221,6 +7228,8 @@ def perturbed_copy( if self.perturbation_spec is not None: pspec = self.perturbation_spec + delta_eps: Optional[CustomSpatialDataType] = None + delta_sigma: Optional[CustomSpatialDataType] = None if isinstance(pspec, PermittivityPerturbation): delta_eps, delta_sigma = pspec._sample_delta_eps_delta_sigma( temperature, electron_density, hole_density diff --git a/tidy3d/components/microwave/data/dataset.py b/tidy3d/components/microwave/data/dataset.py index b7a59bc500..9dbaac1dbc 100644 --- a/tidy3d/components/microwave/data/dataset.py +++ b/tidy3d/components/microwave/data/dataset.py @@ -8,6 +8,7 @@ CurrentFreqModeDataArray, ImpedanceFreqModeDataArray, VoltageFreqModeDataArray, + data_array_annotated_type, ) from tidy3d.components.data.dataset import ModeFreqDataset @@ -23,19 +24,19 @@ class TransmissionLineDataset(ModeFreqDataset): or :class:`ModeSimulation`. """ - Z0: ImpedanceFreqModeDataArray = Field( + Z0: data_array_annotated_type(ImpedanceFreqModeDataArray) = Field( title="Characteristic Impedance", description="The characteristic impedance of the transmission line.", ) - voltage_coeffs: VoltageFreqModeDataArray = Field( + voltage_coeffs: data_array_annotated_type(VoltageFreqModeDataArray) = Field( title="Mode Voltage Coefficients", description="Quantity calculated for transmission lines, which associates " "a voltage-like quantity with each mode profile that scales linearly with the " "complex-valued mode amplitude.", ) - current_coeffs: CurrentFreqModeDataArray = Field( + current_coeffs: data_array_annotated_type(CurrentFreqModeDataArray) = Field( title="Mode Current Coefficients", description="Quantity calculated for transmission lines, which associates " "a current-like quantity with each mode profile that scales linearly with the " diff --git a/tidy3d/components/microwave/data/monitor_data.py b/tidy3d/components/microwave/data/monitor_data.py index 5a64b4687e..80fdc65cb4 100644 --- a/tidy3d/components/microwave/data/monitor_data.py +++ b/tidy3d/components/microwave/data/monitor_data.py @@ -13,6 +13,7 @@ FreqDataArray, FreqModeDataArray, ImpedanceFreqModeDataArray, + data_array_annotated_type, ) from tidy3d.components.data.monitor_data import DirectivityData, ModeData, ModeSolverData from tidy3d.components.microwave.base import MicrowaveBaseModel @@ -90,12 +91,12 @@ class AntennaMetricsData(DirectivityData, MicrowaveBaseModel): John Wiley & Sons, Chapter 2.9 (2016). """ - power_incident: FreqDataArray = Field( + power_incident: data_array_annotated_type(FreqDataArray) = Field( title="Power incident", description="Array of values representing the incident power to an antenna.", ) - power_reflected: FreqDataArray = Field( + power_reflected: data_array_annotated_type(FreqDataArray) = Field( title="Power reflected", description="Array of values representing power reflected due to an impedance mismatch with the antenna.", ) diff --git a/tidy3d/components/microwave/path_integrals/integrals/base.py b/tidy3d/components/microwave/path_integrals/integrals/base.py index e49469fe01..08781c0629 100644 --- a/tidy3d/components/microwave/path_integrals/integrals/base.py +++ b/tidy3d/components/microwave/path_integrals/integrals/base.py @@ -12,6 +12,7 @@ ScalarFieldTimeDataArray, ScalarModeFieldDataArray, _make_base_result_data_array, + data_array_annotated_type, ) from tidy3d.components.data.monitor_data import FieldData, FieldTimeData, ModeData, ModeSolverData from tidy3d.components.microwave.path_integrals.specs.base import ( @@ -25,7 +26,11 @@ from tidy3d.components.data.data_array import IntegralResultType IntegrableMonitorDataType = Union[FieldData, FieldTimeData, ModeData, ModeSolverData] -EMScalarFieldType = Union[ScalarFieldDataArray, ScalarFieldTimeDataArray, ScalarModeFieldDataArray] +EMScalarFieldType = Union[ + data_array_annotated_type(ScalarFieldDataArray), + data_array_annotated_type(ScalarFieldTimeDataArray), + data_array_annotated_type(ScalarModeFieldDataArray), +] FieldParameter = Literal["E", "H"] diff --git a/tidy3d/components/microwave/path_integrals/integrals/current.py b/tidy3d/components/microwave/path_integrals/integrals/current.py index 1cdea6b457..1782e30ab2 100644 --- a/tidy3d/components/microwave/path_integrals/integrals/current.py +++ b/tidy3d/components/microwave/path_integrals/integrals/current.py @@ -8,7 +8,11 @@ import xarray as xr from tidy3d.components.base import cached_property -from tidy3d.components.data.data_array import FreqModeDataArray, _make_current_data_array +from tidy3d.components.data.data_array import ( + FreqModeDataArray, + _isinstance, + _make_current_data_array, +) from tidy3d.components.data.monitor_data import FieldTimeData from tidy3d.components.microwave.path_integrals.integrals.base import ( AxisAlignedPathIntegral, @@ -253,7 +257,7 @@ def _check_phase_sign_consistency( "Please provide the current path specifications manually." ) - if isinstance(phase_difference, FreqModeDataArray): + if _isinstance(phase_difference, FreqModeDataArray): inconsistent_modes = [] mode_indices = phase_difference.mode_index.values for mode_idx in range(len(mode_indices)): @@ -294,7 +298,7 @@ def _check_phase_amplitude_consistency( "specifications manually." ) - if isinstance(current_in_phase, FreqModeDataArray): + if _isinstance(current_in_phase, FreqModeDataArray): inconsistent_modes = [] mode_indices = current_in_phase.mode_index.values for mode_idx in range(len(mode_indices)): diff --git a/tidy3d/components/parameter_perturbation.py b/tidy3d/components/parameter_perturbation.py index c679bd61bf..addec15f06 100644 --- a/tidy3d/components/parameter_perturbation.py +++ b/tidy3d/components/parameter_perturbation.py @@ -9,6 +9,7 @@ import numpy as np from pydantic import Field, NonNegativeFloat, model_validator +from tidy3d.components.data.data_array import data_array_annotated_type from tidy3d.components.types.base import ArrayComplex, ArrayFloat, discriminated_union from tidy3d.constants import C_0, CMCUBE, EPSILON_0, HERTZ, KELVIN, PERCMCUBE, inf from tidy3d.exceptions import DataError @@ -21,6 +22,7 @@ IndexedDataArray, PerturbationCoefficientDataArray, SpatialDataArray, + _isinstance, ) from .data.unstructured.base import UnstructuredGridDataset from .data.utils import ( @@ -321,7 +323,7 @@ class CustomHeatPerturbation(HeatPerturbation): ... ) """ - perturbation_values: HeatDataArray = Field( + perturbation_values: data_array_annotated_type(HeatDataArray) = Field( title="Perturbation Values", description="Sampled perturbation values.", ) @@ -407,7 +409,7 @@ def sample( sampled = np.reshape(sampled, np.shape(temp_clip)) # preserve input type - if isinstance(temperature, SpatialDataArray): + if _isinstance(temperature, SpatialDataArray): return SpatialDataArray(sampled, coords=temperature.coords) if isinstance(temperature, UnstructuredGridDataset): return temperature.updated_copy( @@ -794,7 +796,7 @@ class CustomChargePerturbation(ChargePerturbation): ... ) """ - perturbation_values: ChargeDataArray = Field( + perturbation_values: data_array_annotated_type(ChargeDataArray) = Field( title="Petrubation Values", description="2D array (vs electron and hole densities) of sampled perturbation values.", ) @@ -956,7 +958,7 @@ def sample( # preserve input type for arr in inputs: - if isinstance(arr, SpatialDataArray): + if _isinstance(arr, SpatialDataArray): return SpatialDataArray(sampled, coords=arr.coords) if isinstance(arr, UnstructuredGridDataset): @@ -1224,7 +1226,7 @@ def _sample_delta_eps_delta_sigma( temperature: Optional[CustomSpatialDataType] = None, electron_density: Optional[CustomSpatialDataType] = None, hole_density: Optional[CustomSpatialDataType] = None, - ) -> CustomSpatialDataType: + ) -> tuple[Optional[CustomSpatialDataType], Optional[CustomSpatialDataType]]: """Compute effictive pertubation to eps and sigma.""" delta_eps_sampled = None @@ -1266,7 +1268,7 @@ class NedeljkovicSorefMashanovich(AbstractDeltaModel): ------- """ - perturb_coeffs: PerturbationCoefficientDataArray = Field( + perturb_coeffs: data_array_annotated_type(PerturbationCoefficientDataArray) = Field( default_factory=lambda: PerturbationCoefficientDataArray( np.column_stack( [ diff --git a/tidy3d/components/scene.py b/tidy3d/components/scene.py index 267a84e9be..14c9ed8041 100644 --- a/tidy3d/components/scene.py +++ b/tidy3d/components/scene.py @@ -54,6 +54,7 @@ from tidy3d.log import log from .base import Tidy3dBaseModel, cached_property +from .data.data_array import _isinstance from .data.utils import ( SpatialDataArray, TetrahedralGridDataset, @@ -2025,7 +2026,7 @@ def doping_bounds(self) -> tuple[list[float], list[float]]: if doping > limits[1]: limits[1] = doping # NOTE: This will be deprecated. - if isinstance(doping, SpatialDataArray): + if _isinstance(doping, SpatialDataArray): min_value = np.min(doping.data.flatten()) max_value = np.max(doping.data.flatten()) if min_value < limits[0]: @@ -2109,7 +2110,7 @@ def _get_absolute_minimum_from_doping( return np.abs(doping) # NOTE: This will be deprecated. - if isinstance(doping, SpatialDataArray): + if _isinstance(doping, SpatialDataArray): return np.min(np.abs(doping.data.flatten())) if isinstance(doping, tuple): @@ -2173,7 +2174,7 @@ def _pcolormesh_shape_doping_box( if isinstance(doping, float): struct_doping[n] = struct_doping[n] + doping # NOTE: This will be deprecated. - if isinstance(doping, SpatialDataArray): + if _isinstance(doping, SpatialDataArray): struct_coords = {"xyz"[d]: coords_2D[i] for i, d in enumerate(plane_axes_inds)} data_2D = doping # check whether the provided doping data is 2 or 3D diff --git a/tidy3d/components/simulation.py b/tidy3d/components/simulation.py index 448a1b66d8..40c4e6558d 100644 --- a/tidy3d/components/simulation.py +++ b/tidy3d/components/simulation.py @@ -19,6 +19,7 @@ model_validator, ) +from tidy3d.components.data.data_array import data_array_annotated_type from tidy3d.components.microwave.mode_spec import MicrowaveModeSpec from tidy3d.components.types.base import discriminated_union from tidy3d.constants import C_0, SECOND, fp_eps, inf @@ -383,7 +384,7 @@ class AbstractYeeGridSimulation(AbstractSimulation, ABC): "``autograd`` gradient processing.", ) - post_norm: Union[float, FreqDataArray] = Field( + post_norm: Union[float, data_array_annotated_type(FreqDataArray)] = Field( 1.0, title="Post Normalization Values", description="Factor to multiply the fields by after running, " diff --git a/tidy3d/components/tcad/data/monitor_data/abstract.py b/tidy3d/components/tcad/data/monitor_data/abstract.py index 7228da5adf..d05850278d 100644 --- a/tidy3d/components/tcad/data/monitor_data/abstract.py +++ b/tidy3d/components/tcad/data/monitor_data/abstract.py @@ -10,7 +10,11 @@ from pydantic import Field from tidy3d.components.base_sim.data.monitor_data import AbstractMonitorData -from tidy3d.components.data.data_array import SpatialDataArray +from tidy3d.components.data.data_array import ( + SpatialDataArray, + _isinstance, + data_array_annotated_type, +) from tidy3d.components.data.utils import TetrahedralGridDataset, TriangularGridDataset from tidy3d.components.tcad.types import HeatChargeMonitorType from tidy3d.components.types import Coordinate, ScalarSymmetry @@ -19,7 +23,8 @@ from tidy3d.log import log FieldDataset = Union[ - SpatialDataArray, discriminated_union(Union[TriangularGridDataset, TetrahedralGridDataset]) + data_array_annotated_type(SpatialDataArray), + discriminated_union(Union[TriangularGridDataset, TetrahedralGridDataset]), ] UnstructuredFieldType = Union[TriangularGridDataset, TetrahedralGridDataset] @@ -81,7 +86,7 @@ def _symmetry_expanded_copy_base(self, property: FieldDataset) -> FieldDataset: mnt_bounds = np.array(self.monitor.bounds) - if isinstance(new_property, SpatialDataArray): + if _isinstance(new_property, SpatialDataArray): data_bounds = [ [np.min(new_property.x), np.min(new_property.y), np.min(new_property.z)], [np.max(new_property.x), np.max(new_property.y), np.max(new_property.z)], @@ -127,7 +132,7 @@ def _symmetry_expanded_copy_base(self, property: FieldDataset) -> FieldDataset: for dim in dims_need_clipping_right: clip_bounds[1][dim] = mnt_bounds[1][dim] - if isinstance(new_property, SpatialDataArray): + if _isinstance(new_property, SpatialDataArray): new_property = new_property.sel_inside(clip_bounds) else: new_property = new_property.box_clip(bounds=clip_bounds) diff --git a/tidy3d/components/tcad/data/monitor_data/charge.py b/tidy3d/components/tcad/data/monitor_data/charge.py index f3dc5b95b4..c439768870 100644 --- a/tidy3d/components/tcad/data/monitor_data/charge.py +++ b/tidy3d/components/tcad/data/monitor_data/charge.py @@ -13,6 +13,8 @@ PointDataArray, SpatialDataArray, SteadyVoltageDataArray, + _isinstance, + data_array_annotated_type, ) from tidy3d.components.data.utils import TetrahedralGridDataset, TriangularGridDataset from tidy3d.components.tcad.data.monitor_data.abstract import HeatChargeMonitorData @@ -102,7 +104,7 @@ def check_correct_data_type(self) -> Self: field_data = {field: getattr(self, field) for field in ["electrons", "holes"]} for field, data in field_data.items(): if isinstance(data, TetrahedralGridDataset) or isinstance(data, TriangularGridDataset): - if not isinstance(data.values, IndexedVoltageDataArray): + if not _isinstance(data.values, IndexedVoltageDataArray): raise ValueError( f"In the data associated with monitor {self.monitor}, the " f"field {field} does not contain data associated to any voltage value." @@ -193,7 +195,7 @@ def check_correct_data_type(self) -> Self: for field, data in field_data.items(): if isinstance(data, TetrahedralGridDataset) or isinstance(data, TriangularGridDataset): - if not isinstance(data.values, IndexedVoltageDataArray): + if not _isinstance(data.values, IndexedVoltageDataArray): raise ValueError( f"In the data associated with monitor {self.monitor}, the " f"field {field} does not contain data associated to any voltage value." @@ -302,14 +304,14 @@ class SteadyCapacitanceData(HeatChargeMonitorData): description="Capacitance data associated with a Charge simulation.", ) - hole_capacitance: Optional[SteadyVoltageDataArray] = Field( + hole_capacitance: Optional[data_array_annotated_type(SteadyVoltageDataArray)] = Field( None, title="Hole capacitance", description="Small signal capacitance :math:`(\\frac{dQ_p}{dV})` associated to the monitor.", ) # C_p = hole_capacitance - electron_capacitance: Optional[SteadyVoltageDataArray] = Field( + electron_capacitance: Optional[data_array_annotated_type(SteadyVoltageDataArray)] = Field( None, title="Electron capacitance", description="Small signal capacitance :math:`(\\frac{dQn}{dV})` associated to the monitor.", @@ -382,7 +384,10 @@ def check_correct_data_type(self) -> Self: """Issue error if incorrect data type is used""" if isinstance(self.E, TetrahedralGridDataset) or isinstance(self.E, TriangularGridDataset): - if not isinstance(self.E.values, (IndexedFieldVoltageDataArray, PointDataArray)): + if not ( + _isinstance(self.E.values, IndexedFieldVoltageDataArray) + or _isinstance(self.E.values, PointDataArray) + ): raise ValueError( f"The data associated with monitor {self.monitor.name} must contain a field. This can be " "defined with 'IndexedFieldVoltageDataArray' or 'PointDataArray'." @@ -423,8 +428,10 @@ def check_correct_data_type(self) -> Self: J = self.J if isinstance(J, TetrahedralGridDataset) or isinstance(J, TriangularGridDataset): - AcceptedTypes = (IndexedFieldVoltageDataArray, PointDataArray) - if not isinstance(J.values, AcceptedTypes): + if not ( + _isinstance(J.values, IndexedFieldVoltageDataArray) + or _isinstance(J.values, PointDataArray) + ): raise ValueError( f"In the data associated with monitor {mnt}, must contain a field. This can be " "defined with IndexedFieldVoltageDataArray or PointDataArray." diff --git a/tidy3d/components/tcad/data/sim_data.py b/tidy3d/components/tcad/data/sim_data.py index 94d427147c..f2e7f0e0e1 100644 --- a/tidy3d/components/tcad/data/sim_data.py +++ b/tidy3d/components/tcad/data/sim_data.py @@ -14,6 +14,8 @@ FreqVoltageDataArray, SpatialDataArray, SteadyVoltageDataArray, + _isinstance, + data_array_annotated_type, ) from tidy3d.components.data.utils import ( TetrahedralGridDataset, @@ -66,35 +68,39 @@ class DeviceCharacteristics(Tidy3dBaseModel): """ - steady_dc_hole_capacitance: Optional[SteadyVoltageDataArray] = Field( + steady_dc_hole_capacitance: Optional[data_array_annotated_type(SteadyVoltageDataArray)] = Field( None, title="Steady DC hole capacitance", description="Device steady DC capacitance data based on holes. If the simulation " "has converged, these result should be close to that of electrons.", ) - steady_dc_electron_capacitance: Optional[SteadyVoltageDataArray] = Field( - None, - title="Steady DC electron capacitance", - description="Device steady DC capacitance data based on electrons. If the simulation " - "has converged, these result should be close to that of holes.", + steady_dc_electron_capacitance: Optional[data_array_annotated_type(SteadyVoltageDataArray)] = ( + Field( + None, + title="Steady DC electron capacitance", + description="Device steady DC capacitance data based on electrons. If the simulation " + "has converged, these result should be close to that of holes.", + ) ) - steady_dc_current_voltage: Optional[SteadyVoltageDataArray] = Field( + steady_dc_current_voltage: Optional[data_array_annotated_type(SteadyVoltageDataArray)] = Field( None, title="Steady DC current-voltage", description="Device steady DC current-voltage relation for the device.", ) - steady_dc_resistance_voltage: Optional[SteadyVoltageDataArray] = Field( - None, - title="Small signal resistance", - description="Steady DC computation of the small signal resistance. This is computed " - "as the derivative of the current-voltage relation :math:`\\frac{\\Delta V}{\\Delta I}`, and the result " - "is given in Ohms. Note that in 2D the resistance is given in :math:`\\Omega \\mu`.", + steady_dc_resistance_voltage: Optional[data_array_annotated_type(SteadyVoltageDataArray)] = ( + Field( + None, + title="Small signal resistance", + description="Steady DC computation of the small signal resistance. This is computed " + "as the derivative of the current-voltage relation :math:`\\frac{\\Delta V}{\\Delta I}`, and the result " + "is given in Ohms. Note that in 2D the resistance is given in :math:`\\Omega \\mu`.", + ) ) - ac_current_voltage: Optional[FreqVoltageDataArray] = Field( + ac_current_voltage: Optional[data_array_annotated_type(FreqVoltageDataArray)] = Field( None, title="Small-signal AC current-voltage", description="Small-signal AC current as a function of DC bias voltage and frequency. " @@ -392,7 +398,7 @@ def plot_field( min_bounds.pop(axis) max_bounds.pop(axis) - if isinstance(field_data, SpatialDataArray): + if _isinstance(field_data, SpatialDataArray): # interp out any monitor.size==0 dimensions monitor = self.simulation.get_monitor_by_name(monitor_name) thin_dims = { diff --git a/tidy3d/components/tcad/doping.py b/tidy3d/components/tcad/doping.py index 83d284bd8f..66f6db66fc 100644 --- a/tidy3d/components/tcad/doping.py +++ b/tidy3d/components/tcad/doping.py @@ -10,7 +10,7 @@ from tidy3d.components.autograd import TracedSize from tidy3d.components.base import cached_property -from tidy3d.components.data.data_array import SpatialDataArray +from tidy3d.components.data.data_array import SpatialDataArray, data_array_annotated_type from tidy3d.components.geometry.base import Box from tidy3d.constants import MICROMETER, PERCMCUBE, inf from tidy3d.exceptions import SetupError @@ -313,7 +313,7 @@ class CustomDoping(AbstractDopingBox): ... ) """ - concentration: SpatialDataArray = Field( + concentration: data_array_annotated_type(SpatialDataArray) = Field( title="Doping concentration data array.", description="Doping concentration data array.", units=PERCMCUBE, diff --git a/tidy3d/components/tcad/generation_recombination.py b/tidy3d/components/tcad/generation_recombination.py index 5e2016f84d..bcb89ef7e1 100644 --- a/tidy3d/components/tcad/generation_recombination.py +++ b/tidy3d/components/tcad/generation_recombination.py @@ -6,7 +6,7 @@ from pydantic import Field, PositiveFloat, model_validator from tidy3d.components.base import Tidy3dBaseModel -from tidy3d.components.data.data_array import SpatialDataArray +from tidy3d.components.data.data_array import SpatialDataArray, data_array_annotated_type from tidy3d.constants import PERCMCUBE, SECOND if TYPE_CHECKING: @@ -218,7 +218,7 @@ class DistributedGeneration(Tidy3dBaseModel): >>> dist_g = td.DistributedGeneration(rate=fd) """ - rate: SpatialDataArray = Field( + rate: data_array_annotated_type(SpatialDataArray) = Field( title="Generation rate", description="Spatially varying generation rate.", units="1/(cm^3 s^1)", diff --git a/tidy3d/components/tcad/source/heat.py b/tidy3d/components/tcad/source/heat.py index 1fbbde2007..5f0b247a32 100644 --- a/tidy3d/components/tcad/source/heat.py +++ b/tidy3d/components/tcad/source/heat.py @@ -6,7 +6,7 @@ from pydantic import Field, model_validator -from tidy3d.components.data.data_array import SpatialDataArray +from tidy3d.components.data.data_array import SpatialDataArray, data_array_annotated_type from tidy3d.components.tcad.source.abstract import StructureBasedHeatChargeSource from tidy3d.constants import VOLUMETRIC_HEAT_RATE from tidy3d.log import log @@ -21,7 +21,7 @@ class HeatSource(StructureBasedHeatChargeSource): >>> heat_source = HeatSource(rate=1, structures=["box"]) """ - rate: Union[float, SpatialDataArray] = Field( + rate: Union[float, data_array_annotated_type(SpatialDataArray)] = Field( title="Volumetric Heat Rate", description="Volumetric rate of heating or cooling (if negative).", units=VOLUMETRIC_HEAT_RATE, diff --git a/tidy3d/components/time_modulation.py b/tidy3d/components/time_modulation.py index 17094c7462..a93223eea9 100644 --- a/tidy3d/components/time_modulation.py +++ b/tidy3d/components/time_modulation.py @@ -9,11 +9,12 @@ import numpy as np from pydantic import Field, PositiveFloat, field_validator, model_validator +from tidy3d.components.data.data_array import data_array_annotated_type from tidy3d.constants import HERTZ, RADIAN from tidy3d.exceptions import ValidationError from .base import Tidy3dBaseModel, cached_property -from .data.data_array import SpatialDataArray +from .data.data_array import SpatialDataArray, _isinstance from .data.validators import validate_no_nans from .time import AbstractTimeDependence from .types import InterpMethod @@ -135,14 +136,14 @@ class SpaceModulation(AbstractSpaceModulation): >>> space = SpaceModulation(amplitude=amp, phase=phase) """ - amplitude: Union[float, SpatialDataArray] = Field( + amplitude: Union[float, data_array_annotated_type(SpatialDataArray)] = Field( 1, title="Amplitude of modulation in space", description="Amplitude of modulation that can vary spatially. " "It takes the unit of whatever is being modulated.", ) - phase: Union[float, SpatialDataArray] = Field( + phase: Union[float, data_array_annotated_type(SpatialDataArray)] = Field( 0, title="Phase of modulation in space", description="Phase of modulation that can vary spatially.", @@ -188,12 +189,12 @@ def sel_inside(self, bounds: Bound) -> Self: SpaceModulation with reduced data. """ - if isinstance(self.amplitude, SpatialDataArray): + if _isinstance(self.amplitude, SpatialDataArray): amp_reduced = self.amplitude.sel_inside(bounds) else: amp_reduced = self.amplitude - if isinstance(self.phase, SpatialDataArray): + if _isinstance(self.phase, SpatialDataArray): phase_reduced = self.phase.sel_inside(bounds) else: phase_reduced = self.phase diff --git a/tidy3d/components/validators.py b/tidy3d/components/validators.py index 2ed5f84fe3..e814dd581e 100644 --- a/tidy3d/components/validators.py +++ b/tidy3d/components/validators.py @@ -20,7 +20,6 @@ validate_name_str, warn_if_dataset_none, ) -from tidy3d.components.data.data_array import DATA_ARRAY_MAP from tidy3d.components.geometry.base import Box from tidy3d.exceptions import SetupError, ValidationError from tidy3d.log import log diff --git a/tidy3d/plugins/smatrix/data/terminal.py b/tidy3d/plugins/smatrix/data/terminal.py index c32b17b40f..cda880303a 100644 --- a/tidy3d/plugins/smatrix/data/terminal.py +++ b/tidy3d/plugins/smatrix/data/terminal.py @@ -8,7 +8,7 @@ from pydantic import Field from tidy3d.components.base import cached_property -from tidy3d.components.data.data_array import FreqDataArray +from tidy3d.components.data.data_array import FreqDataArray, _isinstance, data_array_annotated_type from tidy3d.components.microwave.base import MicrowaveBaseModel from tidy3d.constants import C_0 from tidy3d.plugins.smatrix.component_modelers.terminal import TerminalComponentModeler @@ -39,13 +39,13 @@ class MicrowaveSMatrixData(MicrowaveBaseModel): """Stores the computed S-matrix and reference impedances for the terminal ports.""" - port_reference_impedances: Optional[PortDataArray] = Field( + port_reference_impedances: Optional[data_array_annotated_type(PortDataArray)] = Field( None, title="Port Reference Impedances", description="Reference impedance for each port used in the S-parameter calculation. This is optional and may not be present if not specified or computed.", ) - data: TerminalPortDataArray = Field( + data: data_array_annotated_type(TerminalPortDataArray) = Field( title="S-Matrix Data", description="An array containing the computed S-matrix of the device. The data is organized by terminal ports, representing the scattering parameters between them.", ) @@ -261,7 +261,7 @@ def _monitor_data_at_port_amplitude( port, mode_index = self.modeler.network_dict[port_index] sim_data_port = self.data[self.modeler.get_task_name(port, mode_index)] monitor_data = sim_data_port[monitor_name] - if not isinstance(a_port, FreqDataArray): + if not _isinstance(a_port, FreqDataArray): freqs = list(monitor_data.monitor.freqs) array_vals = a_port * np.ones(len(freqs)) a_port = FreqDataArray(array_vals, coords={"f": freqs}) diff --git a/tidy3d/plugins/smatrix/ports/types.py b/tidy3d/plugins/smatrix/ports/types.py index acee0be42c..48ebccf3bd 100644 --- a/tidy3d/plugins/smatrix/ports/types.py +++ b/tidy3d/plugins/smatrix/ports/types.py @@ -7,6 +7,7 @@ CurrentFreqModeDataArray, VoltageFreqDataArray, VoltageFreqModeDataArray, + data_array_annotated_type, ) from tidy3d.plugins.smatrix.ports.coaxial_lumped import CoaxialLumpedPort from tidy3d.plugins.smatrix.ports.modal import Port @@ -16,5 +17,12 @@ LumpedPortType = Union[LumpedPort, CoaxialLumpedPort] TerminalPortType = Union[LumpedPortType, WavePort] PortType = Union[Port, TerminalPortType] -PortVoltageType = Union[VoltageFreqDataArray, VoltageFreqModeDataArray] -PortCurrentType = Union[CurrentFreqDataArray, CurrentFreqModeDataArray] +PortVoltageType = Union[ + data_array_annotated_type(VoltageFreqDataArray), + data_array_annotated_type(VoltageFreqModeDataArray), +] + +PortCurrentType = Union[ + data_array_annotated_type(CurrentFreqDataArray), + data_array_annotated_type(CurrentFreqModeDataArray), +] diff --git a/tidy3d/plugins/smatrix/utils.py b/tidy3d/plugins/smatrix/utils.py index 02dd588bdc..82925f27b9 100644 --- a/tidy3d/plugins/smatrix/utils.py +++ b/tidy3d/plugins/smatrix/utils.py @@ -11,6 +11,7 @@ import numpy as np +from tidy3d.components.data.data_array import _isinstance from tidy3d.exceptions import Tidy3dError from tidy3d.plugins.smatrix.data.data_array import PortDataArray, TerminalPortDataArray @@ -261,7 +262,7 @@ def s_to_z( shape_left = (len(s_matrix.f), len(s_matrix.port_out), 1) shape_right = (len(s_matrix.f), 1, len(s_matrix.port_in)) # Setup the port reference impedance array (scalar) - if isinstance(reference, PortDataArray): + if _isinstance(reference, PortDataArray): Zport = reference.values.reshape(shape_right) F = compute_F(Zport, s_param_def).reshape(shape_right) Finv = (1.0 / F).reshape(shape_left)