diff --git a/pytorch3d/structures/utils.py b/pytorch3d/structures/utils.py index c6b75d93..833b69f2 100644 --- a/pytorch3d/structures/utils.py +++ b/pytorch3d/structures/utils.py @@ -135,22 +135,16 @@ def list_to_packed(x: List[torch.Tensor]): - **item_packed_to_list_idx**: tensor of shape sum(Mi) containing the index of the element in the list the item belongs to. """ - N = len(x) - num_items = torch.zeros(N, dtype=torch.int64, device=x[0].device) - item_packed_first_idx = torch.zeros(N, dtype=torch.int64, device=x[0].device) - item_packed_to_list_idx = [] - cur = 0 - for i, y in enumerate(x): - num = len(y) - num_items[i] = num - item_packed_first_idx[i] = cur - item_packed_to_list_idx.append( - torch.full((num,), i, dtype=torch.int64, device=y.device) - ) - cur += num - + if not x: + raise ValueError("Input list is empty") + device = x[0].device + sizes = [xi.shape[0] for xi in x] + num_items = torch.tensor(sizes, dtype=torch.int64).to(device) + item_packed_first_idx = torch.zeros_like(num_items) + item_packed_first_idx[1:] = torch.cumsum(num_items[:-1], dim=0) + item_packed_to_list_idx = torch.arange(torch.sum(num_items), dtype=torch.int64).to(device) + item_packed_to_list_idx = torch.bucketize(item_packed_to_list_idx, item_packed_first_idx, right=True) - 1 x_packed = torch.cat(x, dim=0) - item_packed_to_list_idx = torch.cat(item_packed_to_list_idx, dim=0) return x_packed, num_items, item_packed_first_idx, item_packed_to_list_idx