Commit a10aa6a
authored
fix bug when loading 4bit checkpoint quantized in INC
```
SRAM_SLICER_SHARED_MME_INPUT_EXPANSION_ENABLED=false ENABLE_EXPERIMENTAL_FLAGS=1 \
python run_generation.py \
-o acc_load_uint4_model.txt \
--model_name_or_path meta-llama/Llama-2-7b-hf \
--use_hpu_graphs \
--use_kv_cache \
--trim_logits \
--batch_size 1 \
--bf16 \
--attn_softmax_bf16 \
--quantized_inc_model_path <local_model_path_from_inc>
```
upper command will got error due to dtype mismatch.
```
File "/home/xinhe3/optimum-habana-fork/examples/text-generation/run_generation.py", line 656, in generate
output_tokens = model.generate(
File "/home/xinhe3/qnpu/venv-118/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
File "/home/xinhe3/optimum-habana-fork/optimum/habana/transformers/generation/utils.py", line 1292, in generate
result = self._sample(
File "/home/xinhe3/optimum-habana-fork/optimum/habana/transformers/generation/utils.py", line 2259, in _sample
outputs = self(
File "/home/xinhe3/qnpu/venv-118/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1556, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/xinhe3/qnpu/venv-118/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1565, in _call_impl
return forward_call(*args, **kwargs)
File "/home/xinhe3/qnpu/venv-118/lib/python3.10/site-packages/habana_frameworks/torch/hpu/graphs.py", line 724, in forward
return wrapped_hpugraph_forward(
File "/home/xinhe3/qnpu/venv-118/lib/python3.10/site-packages/habana_frameworks/torch/hpu/graphs.py", line 597, in wrapped_hpugraph_forward
outputs = orig_fwd(*args, **kwargs)
File "/home/xinhe3/optimum-habana-fork/optimum/habana/transformers/models/llama/modeling_llama.py", line 1345, in forward
outputs = self.model(
File "/home/xinhe3/qnpu/venv-118/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1556, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/xinhe3/qnpu/venv-118/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1606, in _call_impl
result = forward_call(*args, **kwargs)
File "/home/xinhe3/optimum-habana-fork/optimum/habana/transformers/models/llama/modeling_llama.py", line 1236, in forward
layer_outputs = decoder_layer(
File "/home/xinhe3/qnpu/venv-118/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1556, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/xinhe3/qnpu/venv-118/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1606, in _call_impl
result = forward_call(*args, **kwargs)
File "/home/xinhe3/optimum-habana-fork/optimum/habana/transformers/models/llama/modeling_llama.py", line 925, in forward
hidden_states, self_attn_weights, present_key_value = self.pre_attn(
File "/home/xinhe3/optimum-habana-fork/optimum/habana/transformers/models/llama/modeling_llama.py", line 982, in pre_attn
hidden_states, attn_weights, present_key_value = self.self_attn.pre_attn_forward(
File "/home/xinhe3/optimum-habana-fork/optimum/habana/transformers/models/llama/modeling_llama.py", line 647, in pre_attn_forward
key_states = self.k_cache.update(past_key_value[0], key_states, 2, token_idx, self.inp_seq_len)
File "/home/xinhe3/optimum-habana-fork/optimum/habana/transformers/models/llama/modeling_llama.py", line 424, in update
prev.index_copy_(dim, idx - 1, cur)
RuntimeError: cpu fallback is not supported during hpu graph capturing
```1 parent c625680 commit a10aa6a
1 file changed
+3
-0
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
269 | 269 | | |
270 | 270 | | |
271 | 271 | | |
| 272 | + | |
| 273 | + | |
| 274 | + | |
272 | 275 | | |
273 | 276 | | |
274 | 277 | | |
| |||
0 commit comments