Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 36 additions & 30 deletions src/anemoi/datasets/create/sources/grib.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@
from earthkit.data import from_source
from earthkit.data.utils.patterns import Pattern

from anemoi.datasets.create.arguments import ValidDates

from ..source import Source
from . import source_registry
from .legacy import LegacySource

LOG = logging.getLogger(__name__)

Expand Down Expand Up @@ -83,26 +85,23 @@ def _expand(paths: list[str]) -> Any:


@source_registry.register("grib")
class GribSource(LegacySource):
class GribSource(Source):

@staticmethod
def _execute(
def __init__(
self,
context: Any,
dates: list[Any],
path: str | list[str],
flavour: str | dict[str, Any] | None = None,
grid_definition: dict[str, Any] | None = None,
*args: Any,
**kwargs: Any,
) -> ekd.FieldList:
"""Executes the function to load data from GRIB files.
) -> None:
"""Initialise the GRIB source.

Parameters
----------
context : Any
The context in which the function is executed.
dates : list of Any
List of dates.
The context in which the source is created.
path : str or list of str
Path or list of paths to the GRIB files.
flavour : str or dict of str to Any, optional
Expand All @@ -112,22 +111,29 @@ def _execute(
*args : Any
Additional positional arguments.
**kwargs : Any
Additional keyword arguments.
Additional keyword arguments forwarded to ``.sel()``.
"""
super().__init__(context)
self.path = path
self.flavour = RuleBasedFlavour(flavour) if flavour is not None else None
self.grid = grid_registry.from_config(grid_definition) if grid_definition is not None else None
self.args = args
self.kwargs = kwargs

def execute_valid_dates(self, dates: ValidDates) -> ekd.FieldList:
"""Load data from the GRIB files for the given dates.

Parameters
----------
dates : ValidDates
The validity-time argument from the pipeline.

Returns
-------
Any
ekd.FieldList
The loaded dataset.
"""
given_paths = path if isinstance(path, list) else [path]

if flavour is not None:
flavour = RuleBasedFlavour(flavour)

if grid_definition is not None:
grid = grid_registry.from_config(grid_definition)
else:
grid = None
given_paths = self.path if isinstance(self.path, list) else [self.path]

ds = from_source("empty")
dates = [d.isoformat() for d in dates]
Expand All @@ -138,24 +144,24 @@ def _execute(
if "{" not in path:
paths = [path]
else:
paths = Pattern(path).substitute(*args, date=dates, allow_extra=True, **kwargs)
paths = Pattern(path).substitute(*self.args, date=dates, allow_extra=True, **self.kwargs)

for name in ("grid", "area", "rotation", "frame", "resol", "bitmap"):
if name in kwargs:
if name in self.kwargs:
raise ValueError(f"MARS interpolation parameter '{name}' not supported")

for path in _expand(paths):
context.trace("📁", "PATH", path)
self.context.trace("📁", "PATH", path)

if isinstance(path, str) and (path.startswith("ec:") or path.startswith("ectmp:")):
from anemoi.datasets.create.ecfs import get_ecfs_file

path = get_ecfs_file(path)

s = from_source("file", path)
if flavour is not None:
s = flavour.map(s)
sel_kwargs = kwargs.copy()
if self.flavour is not None:
s = self.flavour.map(s)
sel_kwargs = self.kwargs.copy()
if dates != []:
sel_kwargs["valid_datetime"] = dates
s = s.sel(**sel_kwargs)
Expand All @@ -164,10 +170,10 @@ def _execute(
# if kwargs and not context.partial_ok:
# BACK check(ds, given_paths, valid_datetime=dates, **kwargs)

if grid is not None:
ds = new_fieldlist_from_list([new_field_from_grid(f, grid) for f in ds])
if self.grid is not None:
ds = new_fieldlist_from_list([new_field_from_grid(f, self.grid) for f in ds])

if len(ds) == 0:
LOG.warning(f"No fields found for {dates} in {given_paths} (kwargs={kwargs})")
LOG.warning(f"No fields found for {dates} in {given_paths} (kwargs={self.kwargs})")

return ds
Loading