Skip to content

Commit 07c3f5a

Browse files
authored
Update correlation.py
fix CUDA_ERROR_ILLEGAL_ADDRESS
1 parent ed3af4f commit 07c3f5a

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

correlation/correlation.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -381,7 +381,8 @@ def backward(self, gradOutput):
381381
# end
382382

383383
def FunctionCorrelation(tenOne, tenTwo):
384-
return _FunctionCorrelation.apply(tenOne, tenTwo)
384+
with cupy.cuda.Device(tenOne.get_device()):
385+
return _FunctionCorrelation.apply(tenOne, tenTwo)
385386
# end
386387

387388
class ModuleCorrelation(torch.nn.Module):

0 commit comments

Comments
 (0)