Skip to content

Commit 8fe3ec2

Browse files
authored
[SAM3] Fix MPS race condition in add_point_inputs (#43042)
[SAM2] Fix MPS race condition in add_point_inputs
1 parent e8c51d1 commit 8fe3ec2

File tree

4 files changed

+4
-4
lines changed

4 files changed

+4
-4
lines changed

src/transformers/models/edgetam_video/modeling_edgetam_video.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1001,7 +1001,7 @@ def add_point_inputs(self, obj_idx: int, frame_idx: int, inputs: dict):
10011001
device_inputs = {}
10021002
for key, value in inputs.items():
10031003
if isinstance(value, torch.Tensor):
1004-
device_inputs[key] = value.to(self.inference_device, non_blocking=True)
1004+
device_inputs[key] = value.to(self.inference_device, non_blocking=False)
10051005
else:
10061006
device_inputs[key] = value
10071007
self.point_inputs_per_obj[obj_idx][frame_idx] = device_inputs

src/transformers/models/sam2_video/modeling_sam2_video.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ def add_point_inputs(self, obj_idx: int, frame_idx: int, inputs: dict):
209209
device_inputs = {}
210210
for key, value in inputs.items():
211211
if isinstance(value, torch.Tensor):
212-
device_inputs[key] = value.to(self.inference_device, non_blocking=True)
212+
device_inputs[key] = value.to(self.inference_device, non_blocking=False)
213213
else:
214214
device_inputs[key] = value
215215
self.point_inputs_per_obj[obj_idx][frame_idx] = device_inputs

src/transformers/models/sam2_video/modular_sam2_video.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -478,7 +478,7 @@ def add_point_inputs(self, obj_idx: int, frame_idx: int, inputs: dict):
478478
device_inputs = {}
479479
for key, value in inputs.items():
480480
if isinstance(value, torch.Tensor):
481-
device_inputs[key] = value.to(self.inference_device, non_blocking=True)
481+
device_inputs[key] = value.to(self.inference_device, non_blocking=False)
482482
else:
483483
device_inputs[key] = value
484484
self.point_inputs_per_obj[obj_idx][frame_idx] = device_inputs

src/transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ def add_point_inputs(self, obj_idx: int, frame_idx: int, inputs: dict):
213213
device_inputs = {}
214214
for key, value in inputs.items():
215215
if isinstance(value, torch.Tensor):
216-
device_inputs[key] = value.to(self.inference_device, non_blocking=True)
216+
device_inputs[key] = value.to(self.inference_device, non_blocking=False)
217217
else:
218218
device_inputs[key] = value
219219
self.point_inputs_per_obj[obj_idx][frame_idx] = device_inputs

0 commit comments

Comments
 (0)