-
Notifications
You must be signed in to change notification settings - Fork 258
Open
Description
-
There's a bug of duplicate code w/ wrong indentation level when computing
attention_outputinCausalSelfAttention._forward_inference. Currently it's never computed.
nanotron/src/nanotron/models/llama.py
Lines 499 to 666 in c737f00
if self.rope_interleaved: query_states = self.rotary_embedding(query_states, position_ids=position_ids) key_states = self.rotary_embedding(key_states, position_ids=position_ids) else: cos, sin = self.rotary_embedding(value_states, position_ids) query_states, key_states = self.rotary_embedding.apply_rotary_pos_emb(query_states, key_states, cos, sin) # Compute rotary embeddings # Note: keep track of old rotary embedding end to check if we need to enlarge k_cache and v_cache old_rotary_embed_end = self.rotary_embedding.end # interleaved version. if self.rope_interleaved: query_states = self.rotary_embedding(query_states, position_ids=position_ids) key_states = self.rotary_embedding(key_states, position_ids=position_ids) # non interleaved version. else: cos, sin = self.rotary_embedding(value_states, position_ids) query_states, key_states = self.rotary_embedding.apply_rotary_pos_emb( query_states, key_states, cos, sin ) if "key" not in store: # First inference iteration (Prefill) # TODO @nouamane: support custom masking # assert that [ False, False, False, False, True, True, True, True, True, True] is accepted # but [ False, False, False, False, True, True, False, False, True, True] is not (can't mask in the middle of sequence) assert ~( sequence_mask[:, :-1] & (~sequence_mask[:, 1:]) # True is never followed by False ).any(), "Can't mask in the middle of sequence, please make sure that pads are at the left of the sequence if existing" # preallocate k_cache, v_cache to self.prefill_kv_len k_cache = torch.zeros( ( batch_size, self.prefill_kv_len, self.n_local_kv_heads, self.d_qk, ), dtype=query_states.dtype, device=query_states.device, ) v_cache = torch.zeros( (batch_size, self.prefill_kv_len, self.n_local_kv_heads, self.d_v), dtype=query_states.dtype, device=query_states.device, ) # Remove pad tokens from key_states and concatenate samples in key_unpad # cu_seqlens_k is the cumulative sequence lengths of key_states (query_unpad, indices_q, cu_seqlens_q, max_seqlen_q) = bert_padding.unpad_input( query_states, sequence_mask, ) (key_unpad, indices_k, cu_seqlens_k, max_seqlen_k) = bert_padding.unpad_input( key_states, sequence_mask ) (value_unpad, _, _, _) = bert_padding.unpad_input(value_states, sequence_mask) # NOTE: this scale is for µTransfer, # in SP, we use sqrt(1/d_h) softmax_scale = 1 / query_states.shape[-1] if self.is_using_mup else None output_unpad = flash_attn_varlen_func( q=query_unpad, # (total_q, n_local_q_heads, d_qk) k=key_unpad, # (total_kv, n_local_kv_heads, d_qk) v=value_unpad, # (total_kv, n_local_kv_heads, d_v) cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, max_seqlen_q=max_seqlen_q, max_seqlen_k=max_seqlen_k, dropout_p=0.0, softmax_scale=softmax_scale, causal=True, # True in prefill phase, False in subsequent phases return_attn_probs=False, ) # (total_unpadded, n_local_q_heads, d_v) attention_output = bert_padding.pad_input( output_unpad, indices_q, batch_size, q_length ) # (batch_size, q_length, n_local_q_heads, d_v) pad_to_right(key_states, sequence_mask, new_tensor=k_cache) pad_to_right(value_states, sequence_mask, new_tensor=v_cache) else: # Pull pre-computed key/value states # Subsequent inference iterations (q_length=1) k_cache = store["key"] v_cache = store["value"] # NOTE(fmom): According to flash_attn_with_kvcache, "If you pass in k / v, you must make sure that the cache is large enough to hold the new values" # Since rotary embedding has changed (to enable larger context), we need to enlarge k_cache and v_cache if self.rotary_embedding.end > old_rotary_embed_end: k_cache = torch.cat( [ k_cache, torch.zeros( ( batch_size, self.rotary_embedding.end - old_rotary_embed_end, self.n_local_kv_heads, self.d_qk, ), dtype=query_states.dtype, device=query_states.device, ), ], dim=1, ) v_cache = torch.cat( [ v_cache, torch.zeros( ( batch_size, self.rotary_embedding.end - old_rotary_embed_end, self.n_local_kv_heads, self.d_v, ), dtype=query_states.dtype, device=query_states.device, ), ], dim=1, ) assert ( k_cache.shape[1] == self.rotary_embedding.end ), f"Cache size {k_cache.shape[1]} is smaller than rotary embedding end {self.rotary_embedding.end}" assert ( v_cache.shape[1] == self.rotary_embedding.end ), f"Cache size {v_cache.shape[1]} is smaller than rotary embedding end {self.rotary_embedding.end}" # [batch_size, seq_length, num_heads, d_qk] query_states = query_states.view( batch_size, q_length, self.n_local_q_heads, self.d_qk ) # [batch_size, q_length, self.n_heads, d_qk] kv_length = key_states.shape[1] key_states = key_states.view( batch_size, kv_length, self.n_local_kv_heads, self.d_qk ) # [batch_size, kv_length, self.n_heads, d_qk] value_states = value_states.view( batch_size, kv_length, self.n_local_kv_heads, self.d_v ) # [batch_size, kv_length, self.n_heads, d_v] # NOTE: this scale is for µTransfer, # in SP, we use sqrt(1/d_h) softmax_scale = 1 / query_states.shape[-1] if self.is_using_mup else None attention_output = flash_attn_with_kvcache( query_states, k_cache, v_cache, key_states, value_states, rotary_cos=None, rotary_sin=None, # TODO @nouamane: seems like this doesn't help to indicate padding in (for first iteration it's just 0) cache_seqlens=position_offsets.contiguous(), softmax_scale=softmax_scale, causal=True, rotary_interleaved=False, # the value is not used unless rotary_cos/sin is provided. https://github.com/Dao-AILab/flash-attention ) store.update( { "key": k_cache, # flash-attn has updated with new key_states using cache_seqlens "value": v_cache, "position_offsets": position_offsets, } ) -
Wrong argument passed to
parametrizator_clswheninit_model_randomlyfor testing
nanotron/src/nanotron/models/llama.py
Line 1095 in c737f00
parametrizator = parametrizator_cls(config=config.model)
Metadata
Metadata
Assignees
Labels
No labels