41
41
42
42
def backwarp (tenInput , tenFlow ):
43
43
if str (tenFlow .shape ) not in backwarp_tenGrid :
44
- tenHor = torch .linspace (- 1.0 + ( 1.0 / tenFlow . shape [ 3 ]) , 1.0 - ( 1.0 / tenFlow . shape [ 3 ]) , tenFlow .shape [3 ]).view (1 , 1 , 1 , - 1 ).repeat (1 , 1 , tenFlow .shape [2 ], 1 )
45
- tenVer = torch .linspace (- 1.0 + ( 1.0 / tenFlow . shape [ 2 ]) , 1.0 - ( 1.0 / tenFlow . shape [ 2 ]) , tenFlow .shape [2 ]).view (1 , 1 , - 1 , 1 ).repeat (1 , 1 , 1 , tenFlow .shape [3 ])
44
+ tenHor = torch .linspace (- 1.0 , 1.0 , tenFlow .shape [3 ]).view (1 , 1 , 1 , - 1 ).repeat (1 , 1 , tenFlow .shape [2 ], 1 )
45
+ tenVer = torch .linspace (- 1.0 , 1.0 , tenFlow .shape [2 ]).view (1 , 1 , - 1 , 1 ).repeat (1 , 1 , 1 , tenFlow .shape [3 ])
46
46
47
47
backwarp_tenGrid [str (tenFlow .shape )] = torch .cat ([ tenHor , tenVer ], 1 ).cuda ()
48
48
# end
@@ -51,10 +51,10 @@ def backwarp(tenInput, tenFlow):
51
51
backwarp_tenPartial [str (tenFlow .shape )] = tenFlow .new_ones ([ tenFlow .shape [0 ], 1 , tenFlow .shape [2 ], tenFlow .shape [3 ] ])
52
52
# end
53
53
54
- tenFlow = torch .cat ([ tenFlow [:, 0 :1 , :, :] / (( tenInput .shape [3 ] - 1.0 ) / 2.0 ), tenFlow [:, 1 :2 , :, :] / (( tenInput .shape [2 ] - 1.0 ) / 2.0 ) ], 1 )
54
+ tenFlow = torch .cat ([ tenFlow [:, 0 :1 , :, :] * ( 2.0 / (tenInput .shape [3 ] - 1.0 )), tenFlow [:, 1 :2 , :, :] * ( 2.0 / (tenInput .shape [2 ] - 1.0 )) ], 1 )
55
55
tenInput = torch .cat ([ tenInput , backwarp_tenPartial [str (tenFlow .shape )] ], 1 )
56
56
57
- tenOutput = torch .nn .functional .grid_sample (input = tenInput , grid = (backwarp_tenGrid [str (tenFlow .shape )] + tenFlow ).permute (0 , 2 , 3 , 1 ), mode = 'bilinear' , padding_mode = 'zeros' , align_corners = False )
57
+ tenOutput = torch .nn .functional .grid_sample (input = tenInput , grid = (backwarp_tenGrid [str (tenFlow .shape )] + tenFlow ).permute (0 , 2 , 3 , 1 ), mode = 'bilinear' , padding_mode = 'zeros' , align_corners = True )
58
58
59
59
tenMask = tenOutput [:, - 1 :, :, :]; tenMask [tenMask > 0.999 ] = 1.0 ; tenMask [tenMask < 1.0 ] = 0.0
60
60
0 commit comments