Skip to content

Commit 1029df3

Browse files
authored
Add a test for map_location="cpu" (#497)
Summary: torchtune is using torch.load(file_name, map_location="cpu", mmap=True), so we add a test to make sure this works with tensor subclass API Test Plan: python test/quantization/test_quant_api.py -k test_quantized_tensor_subclass_save_load_map_location Reviewers: Subscribers: Tasks: Tags:
1 parent 2ed010a commit 1029df3

File tree

1 file changed

+22
-0
lines changed

1 file changed

+22
-0
lines changed

test/quantization/test_quant_api.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -635,6 +635,28 @@ def test_quantized_model_to_device(self):
635635
cuda_res = m(*example_inputs_cuda)
636636
self.assertEqual(cuda_res.cpu(), ref)
637637

638+
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+")
639+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
640+
def test_quantized_tensor_subclass_save_load_map_location(self):
641+
m = ToyLinearModel().eval().to(dtype=torch.bfloat16, device="cuda")
642+
example_inputs = m.example_inputs(dtype=torch.bfloat16, device="cuda")
643+
644+
quantize_(m, int8_weight_only())
645+
ref = m(*example_inputs)
646+
with tempfile.NamedTemporaryFile() as f:
647+
torch.save(m.state_dict(), f)
648+
f.seek(0)
649+
state_dict = torch.load(f.name, map_location="cpu", mmap=True)
650+
651+
with torch.device('meta'):
652+
m_copy = ToyLinearModel().eval()
653+
654+
m_copy.load_state_dict(state_dict, assign=True)
655+
m_copy.to(dtype=torch.bfloat16, device="cuda")
656+
657+
res = m_copy(*example_inputs)
658+
self.assertEqual(res, ref)
659+
638660

639661
if __name__ == "__main__":
640662
unittest.main()

0 commit comments

Comments
 (0)