-
Notifications
You must be signed in to change notification settings - Fork 78
Open
Description
run:
USE_LADE=1 LOAD_LADE=1 python minimal.py
result:
Traceback (most recent call last):
File "/home/workspace/LookaheadDecoding/minimal.py", line 32, in <module>
greedy_output = model.generate(**model_inputs, max_new_tokens=1)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/envs/lade/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/envs/lade/lib/python3.12/site-packages/transformers/generation/utils.py", line 1718, in generate
return self.greedy_search(
^^^^^^^^^^^^^^^^^^^
File "/home/workspace/LookaheadDecoding/lade/decoding.py", line 24, in greedy_search_proxy
return jacobi_greedy_search_multilevel(self, chat=False, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/workspace/LookaheadDecoding/lade/decoding.py", line 988, in jacobi_greedy_search_multilevel
outputs = self.jforward_multilevel(
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/workspace/LookaheadDecoding/lade/models/modeling_llama.py", line 1519, in jforward_multilevel
outputs = self.model.LlamaModeljforward(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/workspace/LookaheadDecoding/lade/models/modeling_llama.py", line 1224, in LlamaModeljforward
layer_outputs = decoder_layer(
^^^^^^^^^^^^^^
File "/root/miniconda3/envs/lade/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/envs/lade/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/workspace/LookaheadDecoding/lade/models/modeling_llama.py", line 864, in forward
hidden_states, self_attn_weights, present_key_value = self.self_attn.forward(
^^^^^^^^^^^^^^^^^^^^^^^
TypeError: LlamaSdpaAttention.forward() got an unexpected keyword argument 'lookahead'
how solve this?
Metadata
Metadata
Assignees
Labels
No labels