Skip to content

Int8DynActInt4WeightQATQuantizer doesn't support qwen series #1080

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
elfisworking opened this issue Oct 15, 2024 · 4 comments
Open

Int8DynActInt4WeightQATQuantizer doesn't support qwen series #1080

elfisworking opened this issue Oct 15, 2024 · 4 comments
Assignees

Comments

@elfisworking
Copy link
Contributor

elfisworking commented Oct 15, 2024

i use Int8DynActInt4WeightQATQuantizer to quantize qwen2-1.5B model. But after prepare function, i find that bias is set to False.
This is my Code

from torchtune.models.qwen2 import qwen2_1_5b
model = qwen2_1_5b()
from torchao.quantization.prototype.qat.linear import Int8DynActInt4WeightQATQuantizer
qat_quantizer = Int8DynActInt4WeightQATQuantizer()
print("before prepare: ", model)
model = qat_quantizer.prepare(model)
print("after prepare: ", model)

The output is

before prepare:  TransformerDecoder(
  (tok_embeddings): Embedding(151936, 1536)
  (layers): ModuleList(
    (0-27): 28 x TransformerSelfAttentionLayer(
      (attn): MultiHeadAttention(
        (q_proj): Linear(in_features=1536, out_features=1536, bias=True)
        (k_proj): Linear(in_features=1536, out_features=256, bias=True)
        (v_proj): Linear(in_features=1536, out_features=256, bias=True)
        (output_proj): Linear(in_features=1536, out_features=1536, bias=False)
        (pos_embeddings): Qwen2RotaryPositionalEmbeddings()
      )
      (mlp): FeedForward(
        (w1): Linear(in_features=1536, out_features=8960, bias=False)
        (w2): Linear(in_features=8960, out_features=1536, bias=False)
        (w3): Linear(in_features=1536, out_features=8960, bias=False)
        (activation): SiLU()
      )
      (sa_norm): RMSNorm()
      (mlp_norm): RMSNorm()
      (sa_scale): Identity()
      (mlp_scale): Identity()
    )
  )
  (norm): RMSNorm()
)
after prepare:  TransformerDecoder(
  (tok_embeddings): Embedding(151936, 1536)
  (layers): ModuleList(
    (0-27): 28 x TransformerSelfAttentionLayer(
      (attn): MultiHeadAttention(
        (q_proj): Int8DynActInt4WeightQATLinear(in_features=1536, out_features=1536, bias=False)
        (k_proj): Int8DynActInt4WeightQATLinear(in_features=1536, out_features=256, bias=False)
        (v_proj): Int8DynActInt4WeightQATLinear(in_features=1536, out_features=256, bias=False)
        (output_proj): Int8DynActInt4WeightQATLinear(in_features=1536, out_features=1536, bias=False)
        (pos_embeddings): Qwen2RotaryPositionalEmbeddings()
      )
      (mlp): FeedForward(
        (w1): Int8DynActInt4WeightQATLinear(in_features=1536, out_features=8960, bias=False)
        (w2): Int8DynActInt4WeightQATLinear(in_features=8960, out_features=1536, bias=False)
        (w3): Int8DynActInt4WeightQATLinear(in_features=1536, out_features=8960, bias=False)
        (activation): SiLU()
      )
      (sa_norm): RMSNorm()
      (mlp_norm): RMSNorm()
      (sa_scale): Identity()
      (mlp_scale): Identity()
    )
  )
  (norm): RMSNorm()
)

we can see that after prepare function, (q_proj): Linear(in_features=1536, out_features=1536, bias=True) has been (q_proj): Int8DynActInt4WeightQATLinear(in_features=1536, out_features=1536, bias=False)
From torchao code, we can see In function

def replacement_fn(child: torch.nn.Module) -> torch.nn.Module:
        new_linear = linear_class(
                    child.in_features,
                    child.out_features,
                    bias=False,
                    device=child.weight.device,
                    groupsize=groupsize,
                    precision=precision,
                    scales_precision=scales_precision,
                )

bias is set to False.
So has any Solution about this problem ?

@elfisworking
Copy link
Contributor Author

elfisworking commented Oct 15, 2024

i read the code
in function filter_fn

    def filter_fn(child: torch.nn.Module, cur_fqn:str) -> bool:
        return isinstance(child, nn.Linear) and (_check_linear_int4_k(child.in_features, groupsize) or padding_allowed)

add a judgment condition child.bias is None, maybe a solution?
For example

    def filter_fn(child: torch.nn.Module, cur_fqn:str) -> bool:
        return isinstance(child, nn.Linear) and (_check_linear_int4_k(child.in_features, groupsize) or padding_allowed) and child.bias is None

skip the linear layer where bias is True

@jerryzh168
Copy link
Contributor

cc @andrewor14 can you take a look

@andrewor14 andrewor14 self-assigned this Oct 15, 2024
@andrewor14
Copy link
Contributor

andrewor14 commented Oct 15, 2024

Hi @elfisworking, yes the easy fix would be to skip the replacement when bias is False. Would you like to submit a fix for this? If not I can do it too.

Probably the longer term fix would be to actually support the bias=True case. This is currently not supported because the quantized linear used in the convert path (Int8DynActInt4WeightLinear) does not support bias. If we make the convert path call the tensor subclass path (using quantize_(model, int8_dynamic_activations_int4_weight())) instead, then this problem will be resolved. This is on my TODO list.

@elfisworking
Copy link
Contributor Author

@andrewor14 ok, i will submit a fix.

yanbing-j pushed a commit to yanbing-j/ao that referenced this issue Dec 9, 2024
…at/ folder (pytorch#1076)

* [Hackability Refactor] Move known_model_params under torchchat (pytorch#1073)

* [Hackability Refactor] Migrate CLI call sites to explicitly go through torchchat.py (pytorch#1075)

* [Hackability Refactor] Move model.py underneath torchchat/ (pytorch#1077)

* Move model.py

* Clear out init to avoid package circular import

* [Hackability Refactor] Move select top level docs into folders within torchchat (pytorch#1080)

* [Hackability Refactor] Move the top level util folder into torchchat/utils (pytorch#1079)

* [Hackability Refactor] Move the top level util file into torchchat/utils/

* Cleared out init to avoid packing

* [Hackability Refactor] Collapse gguf_util into gguf_loader (pytorch#1078)

* [Hackability Refactor] Collapse gguf_util into gguf_loader

* Update bad import

* [Hackability Refactor] Move model_config into torchchat/model_config (pytorch#1082)

* [Hackability Refactor] Move cli related files under torchchat/cli (pytorch#1083)

* [Hackability Refactor] Move build/util into torchchat/utils (pytorch#1084)

* [Hackability Refactor] Easy Moves: eval, gguf_loader, quantize, model_dist (pytorch#1085)

* [Hackability Refactor] Easy Cheap Moves: eval, gguf_loader, quantize, model_dist

* Update eval.py call sites that slipped through the initial pass

* [Hackability Refactor] Update missed direct file calls to use torchchat.py (pytorch#1088)

* [Hackability Refactor] Move export and generate under torchchat/ (pytorch#1089)

* [Hackability Refactor] Move scripts under torchchat/utils (pytorch#1090)

* [Hackability Refactor] Move scripts under torchchat/utils

* Fix install script for AOTI

* Update referenced path in build_android

* Adding missing utils path

* Add another layer for torchchat

* Move the source command depending on if TC root is defined

* [Hackability Refactor] Move installation related files into install/ (pytorch#1081)

* [Hackability Refactor] Move installation related files into install/

* Fix install req path

* Test fix with install path for bash

* Debug messages

* Remove changes to install in et_python_libs

* Remove debug echo

* Fix pin path for et

* [Hackability Refactor] Restricted Lint (pytorch#1091)

* [Hackability Refactor] Removing __main__ from export/generate/eval (pytorch#1092)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants