diff --git a/intake_esm/_search.py b/intake_esm/_search.py index 47978b4c..a97c49c4 100644 --- a/intake_esm/_search.py +++ b/intake_esm/_search.py @@ -5,6 +5,17 @@ import pandas as pd +def unpack_iterable_column(df: pd.DataFrame, column: str) -> pd.DataFrame: + """Return a DataFrame where elements of a given iterable column have been unpacked into multiple lines.""" + rows = [] + for _, row in df.iterrows(): + for val in row[column]: + new_row = row.copy() + new_row[column] = val + rows.append(new_row) + return pd.DataFrame(rows) + + def is_pattern(value): if isinstance(value, typing.Pattern): return True @@ -50,6 +61,7 @@ def search_apply_require_all_on( df: pd.DataFrame, query: typing.Dict[str, typing.Any], require_all_on: typing.Union[str, typing.List[typing.Any]], + columns_with_iterables: set = None, ) -> pd.DataFrame: _query = query.copy() # Make sure to remove columns that were already @@ -66,12 +78,17 @@ def search_apply_require_all_on( condition = set(itertools.product(*values)) query_results = [] for _, group in grouped: - index = group.set_index(keys).index + group_for_index = group + # Unpack iterables to get testable index. + for column in (columns_with_iterables or set()).intersection(keys): + group_for_index = unpack_iterable_column(group_for_index, column) + + index = group_for_index.set_index(keys).index if not isinstance(index, pd.MultiIndex): index = {(element,) for element in index.to_list()} else: index = set(index.to_list()) - if index == condition: + if condition.issubset(index): # with iterables we could have more then requested query_results.append(group) if query_results: diff --git a/intake_esm/cat.py b/intake_esm/cat.py index 42f4d11d..9b2110c7 100644 --- a/intake_esm/cat.py +++ b/intake_esm/cat.py @@ -356,7 +356,10 @@ def search( ) if _query.require_all_on is not None and not results.empty: results = search_apply_require_all_on( - df=results, query=_query.query, require_all_on=_query.require_all_on + df=results, + query=_query.query, + require_all_on=_query.require_all_on, + columns_with_iterables=self.columns_with_iterables, ) return results diff --git a/tests/test_search.py b/tests/test_search.py index 90fb3f20..bc6f6c84 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -171,3 +171,42 @@ def test_search_columns_with_iterables(query, expected): df=df, query=query_model.query, columns_with_iterables={'variable', 'random'} ).to_dict(orient='records') assert results == expected + + +@pytest.mark.parametrize( + 'query,expected', + [ + ( + dict(variable=['A', 'B'], random='bx'), + [ + {'path': 'file1', 'variable': ['A', 'B'], 'attr': 1, 'random': {'bx', 'by'}}, + {'path': 'file3', 'variable': ['A'], 'attr': 2, 'random': {'bx', 'bz'}}, + {'path': 'file4', 'variable': ['B', 'C'], 'attr': 2, 'random': {'bx', 'bz'}}, + ], + ), + ], +) +def test_search_require_all_on_columns_with_iterables(query, expected): + df = pd.DataFrame( + { + 'path': ['file1', 'file2', 'file3', 'file4', 'file5'], + 'variable': [['A', 'B'], ['C', 'D'], ['A'], ['B', 'C'], ['C', 'D', 'A']], + 'attr': [1, 1, 2, 2, 3], + 'random': [ + {'bx', 'by'}, + {'bx', 'by'}, + {'bx', 'bz'}, + {'bx', 'bz'}, + {'bx', 'by'}, + ], + } + ) + query_model = QueryModel(query=query, columns=df.columns.tolist(), require_all_on=['attr']) + results = search(df=df, query=query_model.query, columns_with_iterables={'variable', 'random'}) + results = search_apply_require_all_on( + df=results, + query=query_model.query, + require_all_on=query_model.require_all_on, + columns_with_iterables={'variable', 'random'}, + ).to_dict(orient='records') + assert results == expected