@@ -884,7 +884,8 @@ def test_camera_class_init(self):
884884 self .assertTrue (new_cam .device == device )
885885
886886 def test_getitem (self ):
887- R_matrix = torch .randn ((6 , 3 , 3 ))
887+ N_CAMERAS = 6
888+ R_matrix = torch .randn ((N_CAMERAS , 3 , 3 ))
888889 cam = FoVPerspectiveCameras (znear = 10.0 , zfar = 100.0 , R = R_matrix )
889890
890891 # Check get item returns an instance of the same class
@@ -908,22 +909,39 @@ def test_getitem(self):
908909 self .assertClose (c012 .R , R_matrix [0 :3 , ...])
909910
910911 # Check torch.LongTensor index
911- index = torch .tensor ([1 , 3 , 5 ], dtype = torch .int64 )
912+ SLICE = [1 , 3 , 5 ]
913+ index = torch .tensor (SLICE , dtype = torch .int64 )
912914 c135 = cam [index ]
913915 self .assertEqual (len (c135 ), 3 )
914916 self .assertClose (c135 .zfar , torch .tensor ([100.0 ] * 3 ))
915917 self .assertClose (c135 .znear , torch .tensor ([10.0 ] * 3 ))
916- self .assertClose (c135 .R , R_matrix [[1 , 3 , 5 ], ...])
918+ self .assertClose (c135 .R , R_matrix [SLICE , ...])
919+
920+ # Check torch.BoolTensor index
921+ bool_slice = [i in SLICE for i in range (N_CAMERAS )]
922+ index = torch .tensor (bool_slice , dtype = torch .bool )
923+ c135 = cam [index ]
924+ self .assertEqual (len (c135 ), 3 )
925+ self .assertClose (c135 .zfar , torch .tensor ([100.0 ] * 3 ))
926+ self .assertClose (c135 .znear , torch .tensor ([10.0 ] * 3 ))
927+ self .assertClose (c135 .R , R_matrix [SLICE , ...])
917928
918929 # Check errors with get item
919930 with self .assertRaisesRegex (ValueError , "out of bounds" ):
920- cam [6 ]
931+ cam [N_CAMERAS ]
932+
933+ with self .assertRaisesRegex (ValueError , "does not match cameras" ):
934+ index = torch .tensor ([1 , 0 , 1 ], dtype = torch .bool )
935+ cam [index ]
921936
922937 with self .assertRaisesRegex (ValueError , "Invalid index type" ):
923938 cam [slice (0 , 1 )]
924939
925940 with self .assertRaisesRegex (ValueError , "Invalid index type" ):
926- index = torch .tensor ([1 , 3 , 5 ], dtype = torch .float32 )
941+ cam [[True , False ]]
942+
943+ with self .assertRaisesRegex (ValueError , "Invalid index type" ):
944+ index = torch .tensor (SLICE , dtype = torch .float32 )
927945 cam [index ]
928946
929947 def test_get_full_transform (self ):
0 commit comments