Skip to content

Commit 821feab

Browse files
authored
Bugs found while working on slchallenge_demo notebook. (#392)
* WIP - Addressing bugs found while working on slchallenge_demo notebook. * Adding `ids()` method to the dataset class fixed the duplicated ids problem. * Adding some classifier evaluation metrics. * Updated slchallenge notebook with confusion matrix plot. * Trival commit * Removing slchallenge_demo.ipynb from this PR. It is tracked in lincc-frameworks/notebooks_lf.
1 parent 37c9936 commit 821feab

1 file changed

Lines changed: 4 additions & 1 deletion

File tree

src/hyrax/verbs/infer.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,10 @@ def _save_batch(batch: Union[Tensor, list, tuple, dict], batch_results: Tensor):
112112
object_ids[id] for id in range(write_index, write_index + len(batch_results))
113113
]
114114
elif batch_has_ids:
115-
batch_object_ids = batch["object_id"].tolist()
115+
if isinstance(batch["object_id"], list):
116+
batch_object_ids = batch["object_id"]
117+
else:
118+
batch_object_ids = batch["object_id"].tolist()
116119
elif isinstance(batch, dict):
117120
msg = "Dataset dictionary should be returning object_ids to avoid ordering errors. "
118121
msg += "Modify the __getitem__ or __iter__ function of your dataset to include 'object_id' "

0 commit comments

Comments
 (0)