Skip to content

Fixed the llama model #769

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Sep 3, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 10 additions & 8 deletions torchao/_models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,15 +197,17 @@ def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:

if input_pos is None:
mask = None
input_pos = torch.arange(0, idx.shape[1], device=idx.device)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your reply. It makes sense to separate the logic for creating mask, input_pos, freq_cis like you did here. However, creating input_pos here (for training case) is now not needed?

Copy link
Contributor Author

@yiliu30 yiliu30 Sep 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've also tested the inference case where input_pos is None. Is input_pos a required argument for model.forward in inference mode?

Copy link
Collaborator

@gau-nernst gau-nernst Sep 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

inference case where input_pos is None

This is what I mentioned earlier. I think right now there is no code that uses inference w/o input_pos, and your auto-round PR doesn't seem to need it also. I think it's fine to support this case (though we will have a tiny inefficiency -> create an unneeded input_pos during training. probably insignificant). Perhaps others can have other comments.

Can you update the PR description to describe the problem this PR fixes clearer? i.e. when input_pos is None (during training), freq_cis is overridden by line xxx. I will approve the PR once you added tests for training mode (model.setup_caches(training=True)) and add a short docstring/comment in model.forward().

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it, thanks for the clarification! I updated the PR description, UTs and docstring. Please review it again.

freqs_cis = self.freqs_cis[:idx.shape[1]]
elif not self.linear_causal_mask:
mask = self.causal_mask[None, None, input_pos]
elif len(input_pos)>1 and self.linear_causal_mask: # prefill for linear causal mask
mask = torch.tril(torch.ones(len(input_pos), self.max_seq_length, dtype=torch.bool, device=input_pos.device)).unsqueeze(0).unsqueeze(0)
else: # decode_one_token for linear causal mask
self.causal_mask[0,0,0,input_pos] = 1
mask = self.causal_mask
freqs_cis = self.freqs_cis[input_pos]
else:
if not self.linear_causal_mask:
mask = self.causal_mask[None, None, input_pos]
elif len(input_pos)>1 and self.linear_causal_mask: # prefill for linear causal mask
mask = torch.tril(torch.ones(len(input_pos), self.max_seq_length, dtype=torch.bool, device=input_pos.device)).unsqueeze(0).unsqueeze(0)
else: # decode_one_token for linear causal mask
self.causal_mask[0,0,0,input_pos] = 1
mask = self.causal_mask
freqs_cis = self.freqs_cis[input_pos]

x = self.tok_embeddings(idx)

Expand Down