Closed
Description
import torch
from torchao.quantization.quant_api import quantize_, int8_weight_only
model = torch.nn.Linear(2048, 2048)
quantize_(model, int8_weight_only())
# model.to("cuda") # this works
model.cuda() # this doesn't
File "/home/xx/code/ao/torchao/dtypes/affine_quantized_tensor.py", line 975, in _
args[0].to(*args[1:], **kwargs)._apply_fn_to_data(torch.clone),
File "/home/xx/code/ao/torchao/dtypes/affine_quantized_tensor.py", line 291, in to
kwargs = self._get_to_kwargs(*args, **kwargs)
File "/home/xx/code/ao/torchao/dtypes/affine_quantized_tensor.py", line 277, in _get_to_kwargs
device, dtype, _, memory_format = torch._C._nn._parse_to(*args, **kwargs)
TypeError: to() received an invalid combination of arguments - got (device=torch.device, layout=torch.layout, dtype=torch.dtype, ), but expected one of:
* (torch.device device = None, torch.dtype dtype = None, bool non_blocking = False, bool copy = False, *, torch.memory_format memory_format = None)
* (torch.dtype dtype, bool non_blocking = False, bool copy = False, *, torch.memory_format memory_format = None)
* (Tensor tensor, bool non_blocking = False, bool copy = False, *, torch.memory_format memory_format = None)
I faced similar problems with other subclasses that I implemented before. Basically need to remove the layout
kwarg.
cc: @jerryzh168
Metadata
Metadata
Assignees
Labels
No labels