10
10
from dataclasses import dataclass
11
11
from enum import Enum
12
12
from pathlib import Path
13
- from typing import Callable , Dict , Optional , Union
13
+
14
+ from typing import Any , Callable , Dict , Optional , Union
14
15
from abc import ABC , abstractmethod
15
16
16
17
import torch
@@ -132,7 +133,7 @@ class TransformerArgs:
132
133
ffn_dim_multiplier : Optional [int ] = None
133
134
use_tiktoken : bool = False
134
135
max_seq_length : int = 8192
135
- use_scaled_rope : bool = False
136
+ rope_scaling : Optional [ Dict [ str , Any ]] = None
136
137
# For pipeline parallel
137
138
n_stages : int = 1
138
139
stage_idx : int = 0
@@ -418,8 +419,6 @@ def __init__(self, config: TransformerArgs) -> None:
418
419
self .norm = None
419
420
self .output = None
420
421
421
- # self.freqs_cis: Optional[Tensor] = None
422
- # self.mask_cache: Optional[Tensor] = None
423
422
self .max_batch_size = - 1
424
423
self .max_seq_length = - 1
425
424
# For supporting sequence parallel (default is off, thus value of 1)
@@ -444,7 +443,7 @@ def setup_caches(self, max_batch_size, max_seq_length):
444
443
self .config .dim // self .config .n_heads ,
445
444
self .config .block_size * 2 ,
446
445
self .config .rope_base ,
447
- use_scaled = self .config .use_scaled_rope ,
446
+ rope_scaling = self .config .rope_scaling ,
448
447
)
449
448
self .register_buffer ("freqs_cis" , freqs_cis , persistent = True )
450
449
causal_mask = torch .tril (
@@ -681,12 +680,16 @@ def forward(self, x: Tensor) -> Tensor:
681
680
return output * self .weight
682
681
683
682
684
- def apply_scaling (freqs : torch .Tensor ):
685
- # Values obtained from grid search
686
- scale_factor = 8
687
- low_freq_factor = 1
688
- high_freq_factor = 4
689
- old_context_len = 8192 # original llama3 length
683
+ def apply_scaling (freqs : torch .Tensor , rope_scaling : Dict [str , Any ]):
684
+ # Check for the presence of the required keys
685
+ required_keys = {"factor" , "low_freq_factor" , "high_freq_factor" , "original_max_position_embeddings" }
686
+ if not required_keys .issubset (rope_scaling .keys ()):
687
+ raise ValueError (f"Missing required keys in apply_scaling. Expected: { required_keys } " )
688
+
689
+ scale_factor = rope_scaling ["factor" ]
690
+ low_freq_factor = rope_scaling ["low_freq_factor" ]
691
+ high_freq_factor = rope_scaling ["high_freq_factor" ]
692
+ old_context_len = rope_scaling ["original_max_position_embeddings" ]
690
693
691
694
low_freq_wavelen = old_context_len / low_freq_factor
692
695
high_freq_wavelen = old_context_len / high_freq_factor
@@ -707,16 +710,16 @@ def apply_scaling(freqs: torch.Tensor):
707
710
708
711
709
712
def precompute_freqs_cis (
710
- n_elem : int , seq_len : int , base : int = 10000 , dtype = None , use_scaled : bool = False
713
+ n_elem : int , seq_len : int , base : int = 10000 , dtype = None , rope_scaling : Optional [ Dict [ str , Any ]] = None
711
714
) -> Tensor :
712
715
if not dtype :
713
716
dtype = get_precision ()
714
717
freqs = 1.0 / (
715
718
base ** (torch .arange (0 , n_elem , 2 )[: (n_elem // 2 )].float () / n_elem )
716
719
)
717
720
t = torch .arange (seq_len , device = freqs .device )
718
- if use_scaled :
719
- freqs = apply_scaling (freqs )
721
+ if rope_scaling is not None :
722
+ freqs = apply_scaling (freqs , rope_scaling )
720
723
freqs = torch .outer (t , freqs )
721
724
freqs_cis = torch .polar (torch .ones_like (freqs ), freqs )
722
725
cache = torch .stack ([freqs_cis .real , freqs_cis .imag ], dim = - 1 )
0 commit comments