Skip to content

Commit bd5869e

Browse files
Copilotmtauraso
andauthored
Simplify NestedPandasDataset: use nested-pandas bracket operator, add docstrings
Agent-Logs-Url: https://github.com/lincc-frameworks/hyrax/sessions/637f35ea-20ee-4019-ba3e-8d6e2c3627c8 Co-authored-by: mtauraso <31012+mtauraso@users.noreply.github.com>
1 parent f0f8a05 commit bd5869e

File tree

2 files changed

+5
-16
lines changed

2 files changed

+5
-16
lines changed

src/hyrax/datasets/nested_pandas_dataset.py

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from pathlib import Path
22
from types import MethodType
3-
from typing import Any
43

54
from hyrax.datasets.dataset_registry import HyraxDataset
65

@@ -21,7 +20,7 @@ def __init__(self, config: dict, data_location: Path | str | None = None):
2120
self._register_getters()
2221
super().__init__(config)
2322

24-
def _load_nested_frame(self, read_kwargs: dict[str, Any]):
23+
def _load_nested_frame(self, read_kwargs: dict):
2524
try:
2625
import nested_pandas as npd
2726
except ImportError as err:
@@ -38,23 +37,10 @@ def _all_available_fields(self) -> list[str]:
3837
fields.extend(self.nested_frame.get_subcolumns())
3938
return fields
4039

41-
def _value_for_field(self, idx: int, field: str):
42-
if "." not in field:
43-
return self.nested_frame.iloc[idx][field]
44-
45-
nested_col, sub_col = field.split(".", maxsplit=1)
46-
if nested_col not in self.nested_frame.columns:
47-
return self.nested_frame.iloc[idx][field]
48-
49-
nested_value = self.nested_frame.iloc[idx][nested_col]
50-
if nested_value is None:
51-
return None
52-
return nested_value[sub_col]
53-
5440
def _register_getters(self) -> None:
5541
def _make_getter(field_name: str):
5642
def getter(self, idx, _field_name=field_name):
57-
return self._value_for_field(idx, _field_name)
43+
return self.nested_frame[_field_name].loc[self.nested_frame.index[idx]]
5844

5945
return getter
6046

tests/hyrax/test_nested_pandas_dataset.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ def _write_nested_lightcurve_parquet(path: Path) -> None:
1919

2020

2121
def test_nested_pandas_dataset_reads_parquet_and_exposes_nested_fields(tmp_path):
22+
"""Test that NestedPandasDataset reads a parquet file and exposes nested field getters."""
2223
data_path = tmp_path / "random_lightcurve.parquet"
2324
_write_nested_lightcurve_parquet(data_path)
2425

@@ -34,6 +35,7 @@ def test_nested_pandas_dataset_reads_parquet_and_exposes_nested_fields(tmp_path)
3435

3536

3637
def test_nested_pandas_dataset_passes_read_kwargs(tmp_path):
38+
"""Test that read_kwargs are forwarded to nested_pandas.read_parquet."""
3739
data_path = tmp_path / "random_lightcurve.parquet"
3840
_write_nested_lightcurve_parquet(data_path)
3941

@@ -54,5 +56,6 @@ def test_nested_pandas_dataset_passes_read_kwargs(tmp_path):
5456

5557

5658
def test_nested_pandas_dataset_requires_data_location():
59+
"""Test that omitting data_location raises a ValueError."""
5760
with pytest.raises(ValueError):
5861
NestedPandasDataset(config={"data_set": {}})

0 commit comments

Comments
 (0)