@@ -266,30 +266,8 @@ def _parse_auxiliary_input(
266
266
aux_input_C = None
267
267
268
268
if isinstance (aux_input , list ):
269
- if len (aux_input ) != self ._N :
270
- raise ValueError ("Points and auxiliary input must be the same length." )
271
- for p , d in zip (self ._num_points_per_cloud , aux_input ):
272
- if p != d .shape [0 ]:
273
- raise ValueError (
274
- "A cloud has mismatched numbers of points and inputs"
275
- )
276
- if d .device != self .device :
277
- raise ValueError (
278
- "All auxiliary inputs must be on the same device as the points."
279
- )
280
- if p > 0 :
281
- if d .dim () != 2 :
282
- raise ValueError (
283
- "A cloud auxiliary input must be of shape PxC or empty"
284
- )
285
- if aux_input_C is None :
286
- aux_input_C = d .shape [1 ]
287
- if aux_input_C != d .shape [1 ]:
288
- raise ValueError (
289
- "The clouds must have the same number of channels"
290
- )
291
- return aux_input , None , aux_input_C
292
- elif torch .is_tensor (aux_input ):
269
+ return self ._parse_auxiliary_input_list (aux_input )
270
+ if torch .is_tensor (aux_input ):
293
271
if aux_input .dim () != 3 :
294
272
raise ValueError ("Auxiliary input tensor has incorrect dimensions." )
295
273
if self ._N != aux_input .shape [0 ]:
@@ -312,6 +290,72 @@ def _parse_auxiliary_input(
312
290
points in a cloud."
313
291
)
314
292
293
+ def _parse_auxiliary_input_list (
294
+ self , aux_input : list
295
+ ) -> Tuple [Optional [List [torch .Tensor ]], None , Optional [int ]]:
296
+ """
297
+ Interpret the auxiliary inputs (normals, features) given to __init__,
298
+ if a list.
299
+
300
+ Args:
301
+ aux_input:
302
+ - List where each element is a tensor of shape (num_points, C)
303
+ containing the features for the points in the cloud.
304
+ For normals, C = 3
305
+
306
+ Returns:
307
+ 3-element tuple of list, padded=None, num_channels.
308
+ If aux_input is list, then padded is None. If aux_input is a tensor,
309
+ then list is None.
310
+ """
311
+ aux_input_C = None
312
+ good_empty = None
313
+ needs_fixing = False
314
+
315
+ if len (aux_input ) != self ._N :
316
+ raise ValueError ("Points and auxiliary input must be the same length." )
317
+ for p , d in zip (self ._num_points_per_cloud , aux_input ):
318
+ valid_but_empty = p == 0 and d is not None and d .ndim == 2
319
+ if p > 0 or valid_but_empty :
320
+ if p != d .shape [0 ]:
321
+ raise ValueError (
322
+ "A cloud has mismatched numbers of points and inputs"
323
+ )
324
+ if d .dim () != 2 :
325
+ raise ValueError (
326
+ "A cloud auxiliary input must be of shape PxC or empty"
327
+ )
328
+ if aux_input_C is None :
329
+ aux_input_C = d .shape [1 ]
330
+ elif aux_input_C != d .shape [1 ]:
331
+ raise ValueError ("The clouds must have the same number of channels" )
332
+ if d .device != self .device :
333
+ raise ValueError (
334
+ "All auxiliary inputs must be on the same device as the points."
335
+ )
336
+ else :
337
+ needs_fixing = True
338
+
339
+ if aux_input_C is None :
340
+ # We found nothing useful
341
+ return None , None , None
342
+
343
+ # If we have empty but "wrong" inputs we want to store "fixed" versions.
344
+ if needs_fixing :
345
+ if good_empty is None :
346
+ good_empty = torch .zeros ((0 , aux_input_C ), device = self .device )
347
+ aux_input_out = []
348
+ for p , d in zip (self ._num_points_per_cloud , aux_input ):
349
+ valid_but_empty = p == 0 and d is not None and d .ndim == 2
350
+ if p > 0 or valid_but_empty :
351
+ aux_input_out .append (d )
352
+ else :
353
+ aux_input_out .append (good_empty )
354
+ else :
355
+ aux_input_out = aux_input
356
+
357
+ return aux_input_out , None , aux_input_C
358
+
315
359
def __len__ (self ) -> int :
316
360
return self ._N
317
361
0 commit comments