Skip to content

Commit 90d730b

Browse files
committed
remove memory_format
1 parent b076d5a commit 90d730b

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

torchao/utils.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -292,16 +292,16 @@ def _get_to_kwargs(self, *args, **kwargs):
292292
args.remove(arg)
293293
if "layout" in kwargs:
294294
kwargs.pop("layout")
295-
device, dtype, _, memory_format = torch._C._nn._parse_to(*args, **kwargs)
295+
# ignoring `non_blocking` and `memory_format` args since these are not
296+
# very useful for most of the tensor subclasses
297+
# if in the future there are use cases that need these, we'd recommend
298+
# to override `_get_to_kwargs` and return these args
299+
device, dtype, _, _ = torch._C._nn._parse_to(*args, **kwargs)
296300
device = self.device if device is None else device
297301
dtype = self.dtype if dtype is None else dtype
298-
memory_format = (
299-
memory_format if memory_format is not None else torch.preserve_format
300-
)
301302
kwargs = {
302303
"device": device,
303304
"dtype": dtype,
304-
"memory_format": memory_format,
305305
}
306306
return kwargs
307307

0 commit comments

Comments
 (0)