Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 19 additions & 2 deletions intake_esm/_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,17 @@
import pandas as pd


def unpack_iterable_column(df: pd.DataFrame, column: str) -> pd.DataFrame:
"""Return a DataFrame where elements of a given iterable column have been unpacked into multiple lines."""
rows = []
for _, row in df.iterrows():
for val in row[column]:
new_row = row.copy()
new_row[column] = val
rows.append(new_row)
return pd.DataFrame(rows)


def is_pattern(value):
if isinstance(value, typing.Pattern):
return True
Expand Down Expand Up @@ -50,6 +61,7 @@ def search_apply_require_all_on(
df: pd.DataFrame,
query: typing.Dict[str, typing.Any],
require_all_on: typing.Union[str, typing.List[typing.Any]],
columns_with_iterables: set = None,
) -> pd.DataFrame:
_query = query.copy()
# Make sure to remove columns that were already
Expand All @@ -66,12 +78,17 @@ def search_apply_require_all_on(
condition = set(itertools.product(*values))
query_results = []
for _, group in grouped:
index = group.set_index(keys).index
group_for_index = group
# Unpack iterables to get testable index.
for column in (columns_with_iterables or set()).intersection(keys):
group_for_index = unpack_iterable_column(group_for_index, column)

index = group_for_index.set_index(keys).index
if not isinstance(index, pd.MultiIndex):
index = {(element,) for element in index.to_list()}
else:
index = set(index.to_list())
if index == condition:
if condition.issubset(index): # with iterables we could have more then requested
query_results.append(group)

if query_results:
Expand Down
5 changes: 4 additions & 1 deletion intake_esm/cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,10 @@ def search(
)
if _query.require_all_on is not None and not results.empty:
results = search_apply_require_all_on(
df=results, query=_query.query, require_all_on=_query.require_all_on
df=results,
query=_query.query,
require_all_on=_query.require_all_on,
columns_with_iterables=self.columns_with_iterables,
)
return results

Expand Down
39 changes: 39 additions & 0 deletions tests/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,3 +171,42 @@ def test_search_columns_with_iterables(query, expected):
df=df, query=query_model.query, columns_with_iterables={'variable', 'random'}
).to_dict(orient='records')
assert results == expected


@pytest.mark.parametrize(
'query,expected',
[
(
dict(variable=['A', 'B'], random='bx'),
[
{'path': 'file1', 'variable': ['A', 'B'], 'attr': 1, 'random': {'bx', 'by'}},
{'path': 'file3', 'variable': ['A'], 'attr': 2, 'random': {'bx', 'bz'}},
{'path': 'file4', 'variable': ['B', 'C'], 'attr': 2, 'random': {'bx', 'bz'}},
],
),
],
)
def test_search_require_all_on_columns_with_iterables(query, expected):
df = pd.DataFrame(
{
'path': ['file1', 'file2', 'file3', 'file4', 'file5'],
'variable': [['A', 'B'], ['C', 'D'], ['A'], ['B', 'C'], ['C', 'D', 'A']],
'attr': [1, 1, 2, 2, 3],
'random': [
{'bx', 'by'},
{'bx', 'by'},
{'bx', 'bz'},
{'bx', 'bz'},
{'bx', 'by'},
],
}
)
query_model = QueryModel(query=query, columns=df.columns.tolist(), require_all_on=['attr'])
results = search(df=df, query=query_model.query, columns_with_iterables={'variable', 'random'})
results = search_apply_require_all_on(
df=results,
query=query_model.query,
require_all_on=query_model.require_all_on,
columns_with_iterables={'variable', 'random'},
).to_dict(orient='records')
assert results == expected