Skip to content

Commit 16d2d63

Browse files
authored
Merge pull request #1038 from iasonkrom/fix-apply-to-fileset
fix: make `apply_to_fileset` be able to handle tuple outputs of `data_manipulation`.
2 parents a5674f2 + 2ec485a commit 16d2d63

File tree

2 files changed

+86
-3
lines changed

2 files changed

+86
-3
lines changed

src/coffea/dataset_tools/apply_processor.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def apply_to_dataset(
8686

8787
if report is not None:
8888
return out, report
89-
return out
89+
return (out,)
9090

9191

9292
def apply_to_fileset(
@@ -125,10 +125,10 @@ def apply_to_fileset(
125125
dataset_out = apply_to_dataset(
126126
data_manipulation, dataset, schemaclass, metadata, uproot_options
127127
)
128-
if isinstance(dataset_out, tuple):
128+
if isinstance(dataset_out, tuple) and len(dataset_out) > 1:
129129
out[name], report[name] = dataset_out
130130
else:
131-
out[name] = dataset_out
131+
out[name] = dataset_out[0]
132132
if len(report) > 0:
133133
return out, report
134134
return out

tests/test_dataset_tools.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,89 @@
198198
}
199199

200200

201+
def _my_analysis_output_2(events):
202+
return events.Electron.pt, events.Muon.pt
203+
204+
205+
def _my_analysis_output_3(events):
206+
return events.Electron.pt, events.Muon.pt, events.Tau.pt
207+
208+
209+
@pytest.mark.parametrize("allow_read_errors_with_report", [True, False])
210+
def test_tuple_data_manipulation_output(allow_read_errors_with_report):
211+
import dask_awkward
212+
213+
out = apply_to_fileset(
214+
_my_analysis_output_2,
215+
_runnable_result,
216+
uproot_options={"allow_read_errors_with_report": allow_read_errors_with_report},
217+
)
218+
219+
if allow_read_errors_with_report:
220+
assert isinstance(out, tuple)
221+
assert len(out) == 2
222+
out, report = out
223+
assert isinstance(out, dict)
224+
assert isinstance(report, dict)
225+
assert out.keys() == {"ZJets", "Data"}
226+
assert report.keys() == {"ZJets", "Data"}
227+
assert isinstance(out["ZJets"], tuple)
228+
assert isinstance(out["Data"], tuple)
229+
assert len(out["ZJets"]) == 2
230+
assert len(out["Data"]) == 2
231+
for i, j in zip(out["ZJets"], out["Data"]):
232+
assert isinstance(i, dask_awkward.Array)
233+
assert isinstance(j, dask_awkward.Array)
234+
assert isinstance(report["ZJets"], dask_awkward.Array)
235+
assert isinstance(report["Data"], dask_awkward.Array)
236+
else:
237+
assert isinstance(out, dict)
238+
assert len(out) == 2
239+
assert out.keys() == {"ZJets", "Data"}
240+
assert isinstance(out["ZJets"], tuple)
241+
assert isinstance(out["Data"], tuple)
242+
assert len(out["ZJets"]) == 2
243+
assert len(out["Data"]) == 2
244+
for i, j in zip(out["ZJets"], out["Data"]):
245+
assert isinstance(i, dask_awkward.Array)
246+
assert isinstance(j, dask_awkward.Array)
247+
248+
out = apply_to_fileset(
249+
_my_analysis_output_3,
250+
_runnable_result,
251+
uproot_options={"allow_read_errors_with_report": allow_read_errors_with_report},
252+
)
253+
254+
if allow_read_errors_with_report:
255+
assert isinstance(out, tuple)
256+
assert len(out) == 2
257+
out, report = out
258+
assert isinstance(out, dict)
259+
assert isinstance(report, dict)
260+
assert out.keys() == {"ZJets", "Data"}
261+
assert report.keys() == {"ZJets", "Data"}
262+
assert isinstance(out["ZJets"], tuple)
263+
assert isinstance(out["Data"], tuple)
264+
assert len(out["ZJets"]) == 3
265+
assert len(out["Data"]) == 3
266+
for i, j in zip(out["ZJets"], out["Data"]):
267+
assert isinstance(i, dask_awkward.Array)
268+
assert isinstance(j, dask_awkward.Array)
269+
assert isinstance(report["ZJets"], dask_awkward.Array)
270+
assert isinstance(report["Data"], dask_awkward.Array)
271+
else:
272+
assert isinstance(out, dict)
273+
assert len(out) == 2
274+
assert out.keys() == {"ZJets", "Data"}
275+
assert isinstance(out["ZJets"], tuple)
276+
assert isinstance(out["Data"], tuple)
277+
assert len(out["ZJets"]) == 3
278+
assert len(out["Data"]) == 3
279+
for i, j in zip(out["ZJets"], out["Data"]):
280+
assert isinstance(i, dask_awkward.Array)
281+
assert isinstance(j, dask_awkward.Array)
282+
283+
201284
@pytest.mark.parametrize(
202285
"proc_and_schema",
203286
[(NanoTestProcessor, BaseSchema), (NanoEventsProcessor, NanoAODSchema)],

0 commit comments

Comments
 (0)