@@ -434,7 +434,7 @@ def cuda(self):
434
434
435
435
436
436
class Translate (Transform3d ):
437
- def __init__ (self , x , y = None , z = None , dtype = torch .float32 , device = "cpu" ):
437
+ def __init__ (self , x , y = None , z = None , dtype = torch .float32 , device = None ):
438
438
"""
439
439
Create a new Transform3d representing 3D translations.
440
440
@@ -448,11 +448,11 @@ def __init__(self, x, y=None, z=None, dtype=torch.float32, device="cpu"):
448
448
- A torch scalar
449
449
- A 1D torch tensor
450
450
"""
451
- super ().__init__ (device = device )
452
451
xyz = _handle_input (x , y , z , dtype , device , "Translate" )
452
+ super ().__init__ (device = xyz .device )
453
453
N = xyz .shape [0 ]
454
454
455
- mat = torch .eye (4 , dtype = dtype , device = device )
455
+ mat = torch .eye (4 , dtype = dtype , device = self . device )
456
456
mat = mat .view (1 , 4 , 4 ).repeat (N , 1 , 1 )
457
457
mat [:, 3 , :3 ] = xyz
458
458
self ._matrix = mat
@@ -468,7 +468,7 @@ def _get_matrix_inverse(self):
468
468
469
469
470
470
class Scale (Transform3d ):
471
- def __init__ (self , x , y = None , z = None , dtype = torch .float32 , device = "cpu" ):
471
+ def __init__ (self , x , y = None , z = None , dtype = torch .float32 , device = None ):
472
472
"""
473
473
A Transform3d representing a scaling operation, with different scale
474
474
factors along each coordinate axis.
@@ -485,12 +485,12 @@ def __init__(self, x, y=None, z=None, dtype=torch.float32, device="cpu"):
485
485
- torch scalar
486
486
- 1D torch tensor
487
487
"""
488
- super ().__init__ (device = device )
489
488
xyz = _handle_input (x , y , z , dtype , device , "scale" , allow_singleton = True )
489
+ super ().__init__ (device = xyz .device )
490
490
N = xyz .shape [0 ]
491
491
492
492
# TODO: Can we do this all in one go somehow?
493
- mat = torch .eye (4 , dtype = dtype , device = device )
493
+ mat = torch .eye (4 , dtype = dtype , device = self . device )
494
494
mat = mat .view (1 , 4 , 4 ).repeat (N , 1 , 1 )
495
495
mat [:, 0 , 0 ] = xyz [:, 0 ]
496
496
mat [:, 1 , 1 ] = xyz [:, 1 ]
@@ -509,7 +509,7 @@ def _get_matrix_inverse(self):
509
509
510
510
class Rotate (Transform3d ):
511
511
def __init__ (
512
- self , R , dtype = torch .float32 , device = "cpu" , orthogonal_tol : float = 1e-5
512
+ self , R , dtype = torch .float32 , device = None , orthogonal_tol : float = 1e-5
513
513
):
514
514
"""
515
515
Create a new Transform3d representing 3D rotation using a rotation
@@ -520,6 +520,7 @@ def __init__(
520
520
orthogonal_tol: tolerance for the test of the orthogonality of R
521
521
522
522
"""
523
+ device = _get_device (R , device )
523
524
super ().__init__ (device = device )
524
525
if R .dim () == 2 :
525
526
R = R [None ]
@@ -548,7 +549,7 @@ def __init__(
548
549
axis : str = "X" ,
549
550
degrees : bool = True ,
550
551
dtype = torch .float64 ,
551
- device = "cpu" ,
552
+ device = None ,
552
553
):
553
554
"""
554
555
Create a new Transform3d representing 3D rotation about an axis
@@ -578,7 +579,7 @@ def __init__(
578
579
# is for transforming column vectors. Therefore we transpose this matrix.
579
580
# R will always be of shape (N, 3, 3)
580
581
R = _axis_angle_rotation (axis , angle ).transpose (1 , 2 )
581
- super ().__init__ (device = device , R = R )
582
+ super ().__init__ (device = angle . device , R = R )
582
583
583
584
584
585
def _handle_coord (c , dtype , device ):
@@ -595,9 +596,24 @@ def _handle_coord(c, dtype, device):
595
596
c = torch .tensor (c , dtype = dtype , device = device )
596
597
if c .dim () == 0 :
597
598
c = c .view (1 )
599
+ if c .device != device :
600
+ c = c .to (device = device )
598
601
return c
599
602
600
603
604
+ def _get_device (x , device = None ):
605
+ if device is not None :
606
+ # User overriding device, leave
607
+ device = device
608
+ elif torch .is_tensor (x ):
609
+ # Set device based on input tensor
610
+ device = x .device
611
+ else :
612
+ # Default device is cpu
613
+ device = "cpu"
614
+ return device
615
+
616
+
601
617
def _handle_input (x , y , z , dtype , device , name : str , allow_singleton : bool = False ):
602
618
"""
603
619
Helper function to handle parsing logic for building transforms. The output
@@ -626,6 +642,7 @@ def _handle_input(x, y, z, dtype, device, name: str, allow_singleton: bool = Fal
626
642
Returns:
627
643
xyz: Tensor of shape (N, 3)
628
644
"""
645
+ device = _get_device (x , device )
629
646
# If x is actually a tensor of shape (N, 3) then just return it
630
647
if torch .is_tensor (x ) and x .dim () == 2 :
631
648
if x .shape [1 ] != 3 :
@@ -634,7 +651,7 @@ def _handle_input(x, y, z, dtype, device, name: str, allow_singleton: bool = Fal
634
651
if y is not None or z is not None :
635
652
msg = "Expected y and z to be None (in %s)" % name
636
653
raise ValueError (msg )
637
- return x
654
+ return x . to ( device = device )
638
655
639
656
if allow_singleton and y is None and z is None :
640
657
y = x
@@ -665,6 +682,7 @@ def _handle_angle_input(x, dtype, device, name: str):
665
682
- Python scalar
666
683
- Torch scalar
667
684
"""
685
+ device = _get_device (x , device )
668
686
if torch .is_tensor (x ) and x .dim () > 1 :
669
687
msg = "Expected tensor of shape (N,); got %r (in %s)"
670
688
raise ValueError (msg % (x .shape , name ))
0 commit comments