diff --git a/llama/generation.py b/llama/generation.py index 5f8faf9f3..014f4c5e7 100755 --- a/llama/generation.py +++ b/llama/generation.py @@ -168,12 +168,14 @@ def generate( tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device="cuda") for k, t in enumerate(prompt_tokens): tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device="cuda") + if logprobs: token_logprobs = torch.zeros_like(tokens, dtype=torch.float) prev_pos = 0 eos_reached = torch.tensor([False] * bsz, device="cuda") input_text_mask = tokens != pad_id + if min_prompt_len == total_len: logits = self.model.forward(tokens, prev_pos) token_logprobs = -F.cross_entropy( @@ -184,7 +186,7 @@ def generate( ) for cur_pos in range(min_prompt_len, total_len): - logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos) + logits = self.model.forward(tokens[:, :cur_pos], prev_pos) if temperature > 0: probs = torch.softmax(logits[:, -1] / temperature, dim=-1) next_token = sample_top_p(probs, top_p) @@ -197,6 +199,7 @@ def generate( input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token ) tokens[:, cur_pos] = next_token + if logprobs: token_logprobs[:, prev_pos + 1 : cur_pos + 1] = -F.cross_entropy( input=logits.transpose(1, 2), @@ -204,6 +207,7 @@ def generate( reduction="none", ignore_index=pad_id, ) + eos_reached |= (~input_text_mask[:, cur_pos]) & ( next_token == self.tokenizer.eos_id ) @@ -213,23 +217,27 @@ def generate( if logprobs: token_logprobs = token_logprobs.tolist() + out_tokens, out_logprobs = [], [] for i, toks in enumerate(tokens.tolist()): - # cut to max gen len start = 0 if echo else len(prompt_tokens[i]) toks = toks[start : len(prompt_tokens[i]) + max_gen_len] probs = None + if logprobs: probs = token_logprobs[i][start : len(prompt_tokens[i]) + max_gen_len] - # cut to eos tok if any + if self.tokenizer.eos_id in toks: eos_idx = toks.index(self.tokenizer.eos_id) toks = toks[:eos_idx] probs = probs[:eos_idx] if logprobs else None + out_tokens.append(toks) out_logprobs.append(probs) + return (out_tokens, out_logprobs if logprobs else None) + def text_completion( self, prompts: List[str], diff --git a/llama/model.py b/llama/model.py index 562fcad1b..0576a9fa2 100755 --- a/llama/model.py +++ b/llama/model.py @@ -8,12 +8,14 @@ import fairscale.nn.model_parallel.initialize as fs_init import torch import torch.nn.functional as F + from fairscale.nn.model_parallel.layers import ( ColumnParallelLinear, ParallelEmbedding, RowParallelLinear, ) from torch import nn +import torch.nn as nn @dataclass @@ -24,13 +26,57 @@ class ModelArgs: n_kv_heads: Optional[int] = None vocab_size: int = -1 # defined later by tokenizer multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2 - ffn_dim_multiplier: Optional[float] = None + ffn_dim_multiplier: float norm_eps: float = 1e-5 max_batch_size: int = 32 max_seq_len: int = 2048 + + query_groups: int = 32 # New parameter for GQA + + + + +class GroupedQueryAttention(nn.Module): + def __init__(self, embed_dim, num_heads, query_groups): + super(GroupedQueryAttention, self).__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.query_groups = query_groups + self.head_dim = embed_dim // num_heads + assert self.head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads" + + self.qkv_proj = nn.Linear(embed_dim, embed_dim * 3) + self.o_proj = nn.Linear(embed_dim, embed_dim) + self.scale = self.head_dim ** -0.5 + + def forward(self, x): + B, T, C = x.shape + qkv = self.qkv_proj(x) + qkv = qkv.view(B, T, self.num_heads, 3 * self.head_dim) + q, k, v = qkv.chunk(3, dim=-1) + + q_groups = q.split(self.query_groups, dim=1) + k_groups = k.split(self.query_groups, dim=1) + v_groups = v.split(self.query_groups, dim=1) + + attn_outputs = [] + for q_group, k_group, v_group in zip(q_groups, k_groups, v_groups): + scores = torch.einsum('bthd,bThd->bhtT', q_group, k_group) * self.scale + attn_weights = torch.nn.functional.softmax(scores, dim=-1) + attn_output = torch.einsum('bhtT,bThd->bthd', attn_weights, v_group) + attn_outputs.append(attn_output) + + attn_output = torch.cat(attn_outputs, dim=1).contiguous() + attn_output = attn_output.view(B, T, C) + output = self.o_proj(attn_output) + return output + + + + class RMSNorm(torch.nn.Module): def __init__(self, dim: int, eps: float = 1e-6): """ @@ -173,14 +219,6 @@ def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: ) -class Attention(nn.Module): - """Multi-head attention module.""" - def __init__(self, args: ModelArgs): - """ - Initialize the Attention module. - - Args: - args (ModelArgs): Model configuration parameters. Attributes: n_kv_heads (int): Number of key and value heads. @@ -195,7 +233,9 @@ def __init__(self, args: ModelArgs): cache_k (torch.Tensor): Cached keys for attention. cache_v (torch.Tensor): Cached values for attention. - """ +class Attention(nn.Module): + """Multi-head attention module with Grouped Query Attention.""" + def __init__(self, args: ModelArgs): super().__init__() self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads model_parallel_size = fs_init.get_model_parallel_world_size() @@ -203,6 +243,7 @@ def __init__(self, args: ModelArgs): self.n_local_kv_heads = self.n_kv_heads // model_parallel_size self.n_rep = self.n_local_heads // self.n_local_kv_heads self.head_dim = args.dim // args.n_heads + self.query_groups = args.query_groups # Add query_groups parameter in ModelArgs self.wq = ColumnParallelLinear( args.dim, @@ -257,19 +298,6 @@ def forward( freqs_cis: torch.Tensor, mask: Optional[torch.Tensor], ): - """ - Forward pass of the attention module. - - Args: - x (torch.Tensor): Input tensor. - start_pos (int): Starting position for caching. - freqs_cis (torch.Tensor): Precomputed frequency tensor. - mask (torch.Tensor, optional): Attention mask tensor. - - Returns: - torch.Tensor: Output tensor after attention. - - """ bsz, seqlen, _ = x.shape xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) @@ -295,13 +323,26 @@ def forward( xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) keys = keys.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim) values = values.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim) - scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim) - if mask is not None: - scores = scores + mask # (bs, n_local_heads, seqlen, cache_len + seqlen) - scores = F.softmax(scores.float(), dim=-1).type_as(xq) - output = torch.matmul(scores, values) # (bs, n_local_heads, seqlen, head_dim) - output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1) - return self.wo(output) + + # Split queries, keys, values into groups for GQA + q_groups = xq.split(self.query_groups, dim=1) + k_groups = keys.split(self.query_groups, dim=1) + v_groups = values.split(self.query_groups, dim=1) + + attn_outputs = [] + for q_group, k_group, v_group in zip(q_groups, k_groups, v_groups): + scores = torch.matmul(q_group, k_group.transpose(2, 3)) / math.sqrt(self.head_dim) + if mask is not None: + scores = scores + mask # (bs, n_local_heads, seqlen, cache_len + seqlen) + scores = F.softmax(scores.float(), dim=-1).type_as(q_group) + attn_output = torch.matmul(scores, v_group) # (bs, n_local_heads, seqlen, head_dim) + attn_outputs.append(attn_output) + + # Concatenate attention outputs from all groups + attn_output = torch.cat(attn_outputs, dim=1).contiguous() + attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, seqlen, -1) + return self.wo(attn_output) + class FeedForward(nn.Module): @@ -348,6 +389,7 @@ def forward(self, x): return self.w2(F.silu(self.w1(x)) * self.w3(x)) + class TransformerBlock(nn.Module): def __init__(self, layer_id: int, args: ModelArgs): """ @@ -372,7 +414,7 @@ def __init__(self, layer_id: int, args: ModelArgs): self.n_heads = args.n_heads self.dim = args.dim self.head_dim = args.dim // args.n_heads - self.attention = Attention(args) + self.attention = Attention(args) # Use the updated Attention class self.feed_forward = FeedForward( dim=args.dim, hidden_dim=4 * args.dim, @@ -410,6 +452,10 @@ def forward( return out + + + + class Transformer(nn.Module): def __init__(self, params: ModelArgs): """ @@ -427,7 +473,6 @@ def __init__(self, params: ModelArgs): norm (RMSNorm): Layer normalization for the model output. output (ColumnParallelLinear): Linear layer for final output. freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies. - """ super().__init__() self.params = params @@ -448,8 +493,6 @@ def __init__(self, params: ModelArgs): ) self.freqs_cis = precompute_freqs_cis( - # Note that self.params.max_seq_len is multiplied by 2 because the token limit for the Llama 2 generation of models is 4096. - # Adding this multiplier instead of using 4096 directly allows for dynamism of token lengths while training or fine-tuning. self.params.dim // self.params.n_heads, self.params.max_seq_len * 2 ) @@ -464,7 +507,6 @@ def forward(self, tokens: torch.Tensor, start_pos: int): Returns: torch.Tensor: Output logits after applying the Transformer model. - """ _bsz, seqlen = tokens.shape h = self.tok_embeddings(tokens) @@ -476,13 +518,7 @@ def forward(self, tokens: torch.Tensor, start_pos: int): mask = torch.full( (seqlen, seqlen), float("-inf"), device=tokens.device ) - mask = torch.triu(mask, diagonal=1) - - # When performing key-value caching, we compute the attention scores - # only for the new sequence. Thus, the matrix of scores is of size - # (seqlen, cache_len + seqlen), and the only masked entries are (i, j) for - # j > cache_len + i, since row i corresponds to token cache_len + i. mask = torch.hstack([ torch.zeros((seqlen, start_pos), device=tokens.device), mask diff --git a/llama/tokenizer.py b/llama/tokenizer.py index 3eda89a06..4506e7d4e 100755 --- a/llama/tokenizer.py +++ b/llama/tokenizer.py @@ -11,8 +11,11 @@ logger = getLogger() + + class Tokenizer: - """tokenizing and encoding/decoding text using SentencePiece.""" + """Tokenizing and encoding/decoding text using SentencePiece.""" + def __init__(self, model_path: str): """ Initializes the Tokenizer with a SentencePiece model. @@ -20,7 +23,6 @@ def __init__(self, model_path: str): Args: model_path (str): The path to the SentencePiece model file. """ - # reload tokenizer assert os.path.isfile(model_path), model_path self.sp_model = SentencePieceProcessor(model_file=model_path) logger.info(f"Reloaded SentencePiece model from {model_path}") @@ -35,7 +37,7 @@ def __init__(self, model_path: str): ) assert self.sp_model.vocab_size() == self.sp_model.get_piece_size() - def encode(self, s: str, bos: bool, eos: bool) -> List[int]: + def encode(self, s: str, bos: bool = True, eos: bool = True) -> List[int]: """ Encodes a string into a list of token IDs. @@ -47,13 +49,13 @@ def encode(self, s: str, bos: bool, eos: bool) -> List[int]: Returns: List[int]: A list of token IDs. """ - assert type(s) is str - t = self.sp_model.encode(s) + assert isinstance(s, str) + tokens = self.sp_model.encode(s) if bos: - t = [self.bos_id] + t + tokens = [self.bos_id] + tokens if eos: - t = t + [self.eos_id] - return t + tokens = tokens + [self.eos_id] + return tokens def decode(self, t: List[int]) -> str: """ @@ -66,3 +68,28 @@ def decode(self, t: List[int]) -> str: str: The decoded string. """ return self.sp_model.decode(t) + + def tokenize(self, s: str) -> List[str]: + """ + Tokenizes a string into subword tokens. + + Args: + s (str): The input string to be tokenized. + + Returns: + List[str]: A list of subword tokens. + """ + return self.sp_model.encode_as_pieces(s) + + def detokenize(self, tokens: List[str]) -> str: + """ + Detokenizes a list of subword tokens into a string. + + Args: + tokens (List[str]): The list of subword tokens to be detokenized. + + Returns: + str: The detokenized string. + """ + return self.sp_model.decode_pieces(tokens) +