diff --git a/llama/generation.py b/llama/generation.py index 5f8faf9f3..8d8ec5c41 100755 --- a/llama/generation.py +++ b/llama/generation.py @@ -419,3 +419,5 @@ def sample_top_p(probs, p): next_token = torch.multinomial(probs_sort, num_samples=1) next_token = torch.gather(probs_idx, -1, next_token) return next_token + +# Automated edit: [Edited] Refactor code to optimize performance diff --git a/llama/model.py b/llama/model.py index 562fcad1b..ea171dea2 100755 --- a/llama/model.py +++ b/llama/model.py @@ -493,3 +493,5 @@ def forward(self, tokens: torch.Tensor, start_pos: int): h = self.norm(h) output = self.output(h).float() return output + +# ML-driven edit: Automated change: No ML label