You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: docs/source/serialization.rst
+75-79Lines changed: 75 additions & 79 deletions
Original file line number
Diff line number
Diff line change
@@ -3,101 +3,97 @@ Serialization
3
3
4
4
Serialization and deserialization is an important question that people care about especially when we integrate torchao with other libraries. Here we want to describe how serialization and deserialization works for torchao optimized (quantized or sparsified) models.
print(f"quantized model size: {get_model_size_in_bytes(m) / 1024 / 1024} MB")
39
-
40
-
ref = m(*example_inputs)
41
-
with tempfile.NamedTemporaryFile() as f:
42
-
torch.save(m.state_dict(), f)
43
-
f.seek(0)
44
-
state_dict = torch.load(f)
58
+
What happens when serializing an optimized model?
59
+
=================================================
60
+
To serialize an optimized model, we just need to call ``torch.save(m.state_dict(), f)``, because in torchao, we use tensor subclass to represent different dtypes or support different optimization techniques like quantization and sparsity. So after optimization, the only thing change is the weight Tensor is changed to an optimized weight Tensor, and the model structure is not changed at all. For example:
To serialize an optimized model, we just need to call `torch.save(m.state_dict(), f)`, because in torchao, we use tensor subclass to represent different dtypes or support different optimization techniques like quantization and sparsity. So after optimization, the only thing change is the weight Tensor is changed to an optimized weight Tensor, and the model structure is not changed at all. For example:
71
+
The size of the quantized model is typically going to be smaller to the original floating point model, but it also depends on the specific techinque and implementation you are using. You can print the model size with ``torchao.utils.get_model_size_in_bytes`` utility function, specifically for the above example using int4_weight_only quantization, we can see the size reduction is around 4x::
To deserialize an optimized model, we can initialize the floating point model in `meta <https://pytorch.org/docs/stable/meta.html>`__ device and then load the optimized ``state_dict`` with ``assign=True`` using `model.load_state_dict <https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.load_state_dict>`__::
72
80
73
-
The size of the quantized model is typically going to be smaller to the original floating point model, but it also depends on the specific techinque and implementation you are using. You can print the model size with `torchao.utils.get_model_size_in_bytes` utility function, specifically for the above example using int4_weight_only quantization, we can see the size reduction is around 4x:
To deserialize an optimized model, we can initialize the floating point model in `meta <https://pytorch.org/docs/stable/meta.html>`__ device and then load the optimized `state_dict` with `assign=True` using `model.load_state_dict <https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.load_state_dict>`__:
85
+
print(f"type of weight before loading: {type(m_loaded.linear1.weight), type(m_loaded.linear2.weight)}")
86
+
m_loaded.load_state_dict(state_dict, assign=True)
87
+
print(f"type of weight after loading: {type(m_loaded.linear1.weight), type(m_loaded.linear2.weight)}")
print(f"type of weight before loading: {type(m_loaded.linear1.weight), type(m_loaded.linear2.weight)}")
89
-
m_loaded.load_state_dict(state_dict, assign=True)
90
-
print(f"type of weight after loading: {type(m_loaded.linear1.weight), type(m_loaded.linear2.weight)}")
91
-
```
90
+
The reason we initialize the model in ``meta`` device is to avoid initializing the original floating point model since original floating point model may not fit into the device that we want to use for inference.
92
91
93
-
The reason we initialize the model in `meta` device is to avoid initializing the original floating point model since original floating point model may not fit into the device that we want to use for inference.
92
+
What happens in ``m_loaded.load_state_dict(state_dict, assign=True)`` is that the corresponding weights (e.g. m_loaded.linear1.weight) are updated with the Tensors in ``state_dict``, which is an optimized tensor subclass instance (e.g. int4 ``AffineQuantizedTensor``). No dependency on torchao is needed for this to work.
94
93
95
-
What happens in `m_loaded.load_state_dict(state_dict, assign=True)` is that the corresponding weights (e.g. m_loaded.linear1.weight) are updated with the Tensors in `state_dict`, which is an optimized tensor subclass instance (e.g. int4 `AffineQuantizedTensor`). No dependency on torchao is needed for this to work.
94
+
We can also verify that the weightis properly loaded by checking the type of weight tensor::
96
95
97
-
We can also verify that the weight is properly loaded by checking the type of weight tensor:
98
-
```
99
-
type of weight before loading: (<class 'torch.Tensor'>, <class 'torch.Tensor'>)
100
-
type of weight after loading: (<class 'torchao.dtypes.affine_quantized_tensor.AffineQuantizedTensor'>, <class 'torchao.dtypes.affine_quantized_tensor.AffineQuantizedTensor'>)
96
+
type of weight before loading: (<class 'torch.Tensor'>, <class 'torch.Tensor'>)
97
+
type of weight after loading: (<class 'torchao.dtypes.affine_quantized_tensor.AffineQuantizedTensor'>, <class 'torchao.dtypes.affine_quantized_tensor.AffineQuantizedTensor'>)
0 commit comments