Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
through `output.splitting` section in the config file, and support for
optionally compute statistics for a given split (with
`output.splitting.splits.{split_name}.compute_statistics`).
![\#28](https://github.com/mllam/mllam-data-prep/pull/10)
![\#28](https://github.com/mllam/mllam-data-prep/pull/10).

- include `units` and `long_name` attributes for all stacked variables as
`{output_variable}_units` and `{output_variable}_long_name`
![\#11](https://github.com/mllam/mllam-data-prep/pull/11).

### Changed

Expand Down
53 changes: 48 additions & 5 deletions mllam_data_prep/ops/stacking.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,24 @@ def stack_variables_as_coord_values(ds, name_format, combined_dim_name):
combine all variables in an xr.Dataset into a single xr.DataArray
by stacking the variables along a new coordinate with the name given
by `name_format` (which should include the variable name, `var_name`)


Parameters
----------
ds : xr.Dataset
source dataset with variables to stack
name_format : str
format string to construct the new coordinate values for the
stacked variables, e.g. "{var_name}_level"
combined_dim_name : str
name of the new dimension to create for the stacked variables, for
example "forcing_feature"

Returns
-------
da_combined : xr.DataArray
The combined dataset with all variables stacked along the new
coordinate
"""
if "{var_name}" not in name_format:
raise ValueError(
Expand All @@ -14,10 +32,21 @@ def stack_variables_as_coord_values(ds, name_format, combined_dim_name):
)
dataarrays = []
for var_name in list(ds.data_vars):
da = ds[var_name]
da.coords[combined_dim_name] = name_format.format(var_name=var_name)
da = ds[var_name].expand_dims(combined_dim_name)
da.coords[combined_dim_name] = [name_format.format(var_name=var_name)]

# add extra coordinates (spanning along `combined_dim_name`) for
# keeping track of `units` and `long_name` attributes
for attr in ["units", "long_name"]:
da_attr = xr.DataArray(
[ds[var_name].attrs.get(attr, "")],
dims=[combined_dim_name],
coords={combined_dim_name: da.coords[combined_dim_name]},
)
da.coords[f"{combined_dim_name}_{attr}"] = da_attr
dataarrays.append(da)
da_combined = xr.concat(dataarrays, dim=combined_dim_name)

return da_combined


Expand All @@ -40,6 +69,11 @@ def stack_variables_by_coord_values(ds, coord, name_format, combined_dim_name):
3. stack all the variables along the `combined_dim_name` dimension to
produce a single xr.DataArray

In addition to the stacked variables, we also add extra coordinates for
keeping track of `units` and `long_name` attributes for each variable in
`{combined_dim_name}_units` and `{combined_dim_name}_long_name`
respectively.

Parameters
----------
ds : xr.Dataset
Expand Down Expand Up @@ -73,15 +107,24 @@ def stack_variables_by_coord_values(ds, coord, name_format, combined_dim_name):
)

datasets = []
for var in list(ds.data_vars):
da = ds[var]
for var_name in list(ds.data_vars):
da = ds[var_name]
coord_values = da.coords[coord].values
new_coord_values = [
name_format.format(var_name=var, **{coord: val}) for val in coord_values
name_format.format(var_name=var_name, **{coord: val})
for val in coord_values
]
da = da.assign_coords({coord: new_coord_values}).rename(
{coord: combined_dim_name}
)

# add extra coordinates for keeping track of `units` and `long_name` attributes
for attr in ["units", "long_name"]:
da_attr = xr.DataArray(
[ds[var_name].attrs.get(attr, "")] * len(coord_values),
dims=[combined_dim_name],
)
da.coords[f"{combined_dim_name}_{attr}"] = da_attr
datasets.append(da)

da_combined = xr.concat(datasets, dim=combined_dim_name)
Expand Down