Skip to content

Commit a10aa6a

Browse files
authored
fix bug when loading 4bit checkpoint quantized in INC
``` SRAM_SLICER_SHARED_MME_INPUT_EXPANSION_ENABLED=false ENABLE_EXPERIMENTAL_FLAGS=1 \ python run_generation.py \ -o acc_load_uint4_model.txt \ --model_name_or_path meta-llama/Llama-2-7b-hf \ --use_hpu_graphs \ --use_kv_cache \ --trim_logits \ --batch_size 1 \ --bf16 \ --attn_softmax_bf16 \ --quantized_inc_model_path <local_model_path_from_inc> ``` upper command will got error due to dtype mismatch. ``` File "/home/xinhe3/optimum-habana-fork/examples/text-generation/run_generation.py", line 656, in generate output_tokens = model.generate( File "/home/xinhe3/qnpu/venv-118/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context return func(*args, **kwargs) File "/home/xinhe3/optimum-habana-fork/optimum/habana/transformers/generation/utils.py", line 1292, in generate result = self._sample( File "/home/xinhe3/optimum-habana-fork/optimum/habana/transformers/generation/utils.py", line 2259, in _sample outputs = self( File "/home/xinhe3/qnpu/venv-118/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1556, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/home/xinhe3/qnpu/venv-118/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1565, in _call_impl return forward_call(*args, **kwargs) File "/home/xinhe3/qnpu/venv-118/lib/python3.10/site-packages/habana_frameworks/torch/hpu/graphs.py", line 724, in forward return wrapped_hpugraph_forward( File "/home/xinhe3/qnpu/venv-118/lib/python3.10/site-packages/habana_frameworks/torch/hpu/graphs.py", line 597, in wrapped_hpugraph_forward outputs = orig_fwd(*args, **kwargs) File "/home/xinhe3/optimum-habana-fork/optimum/habana/transformers/models/llama/modeling_llama.py", line 1345, in forward outputs = self.model( File "/home/xinhe3/qnpu/venv-118/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1556, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/home/xinhe3/qnpu/venv-118/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1606, in _call_impl result = forward_call(*args, **kwargs) File "/home/xinhe3/optimum-habana-fork/optimum/habana/transformers/models/llama/modeling_llama.py", line 1236, in forward layer_outputs = decoder_layer( File "/home/xinhe3/qnpu/venv-118/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1556, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/home/xinhe3/qnpu/venv-118/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1606, in _call_impl result = forward_call(*args, **kwargs) File "/home/xinhe3/optimum-habana-fork/optimum/habana/transformers/models/llama/modeling_llama.py", line 925, in forward hidden_states, self_attn_weights, present_key_value = self.pre_attn( File "/home/xinhe3/optimum-habana-fork/optimum/habana/transformers/models/llama/modeling_llama.py", line 982, in pre_attn hidden_states, attn_weights, present_key_value = self.self_attn.pre_attn_forward( File "/home/xinhe3/optimum-habana-fork/optimum/habana/transformers/models/llama/modeling_llama.py", line 647, in pre_attn_forward key_states = self.k_cache.update(past_key_value[0], key_states, 2, token_idx, self.inp_seq_len) File "/home/xinhe3/optimum-habana-fork/optimum/habana/transformers/models/llama/modeling_llama.py", line 424, in update prev.index_copy_(dim, idx - 1, cur) RuntimeError: cpu fallback is not supported during hpu graph capturing ```
1 parent c625680 commit a10aa6a

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

examples/text-generation/utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,9 @@ def setup_model(args, model_dtype, model_kwargs, logger):
269269
original_model=org_model,
270270
**model_kwargs,
271271
)
272+
# TODO: This will be removed in v1.19 Synapse release
273+
# the loaded model should have the same dtype as original_model
274+
model = model.to(model_kwargs["torch_dtype"])
272275
else:
273276
if args.assistant_model is not None:
274277
assistant_model = AutoModelForCausalLM.from_pretrained(

0 commit comments

Comments
 (0)