Skip to content

Commit bc3a365

Browse files
authored
Add missing Block size + Update Configs to not hardcode rope_scaling (#1128)
* Update Configs to not hardcode rope_scaling fields * Adding explcit error
1 parent 5152988 commit bc3a365

File tree

5 files changed

+21
-18
lines changed

5 files changed

+21
-18
lines changed

torchchat/model.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010
from dataclasses import dataclass
1111
from enum import Enum
1212
from pathlib import Path
13-
from typing import Callable, Dict, Optional, Union
13+
14+
from typing import Any, Callable, Dict, Optional, Union
1415
from abc import ABC, abstractmethod
1516

1617
import torch
@@ -132,7 +133,7 @@ class TransformerArgs:
132133
ffn_dim_multiplier: Optional[int] = None
133134
use_tiktoken: bool = False
134135
max_seq_length: int = 8192
135-
use_scaled_rope: bool = False
136+
rope_scaling: Optional[Dict[str, Any]] = None
136137
# For pipeline parallel
137138
n_stages: int = 1
138139
stage_idx: int = 0
@@ -418,8 +419,6 @@ def __init__(self, config: TransformerArgs) -> None:
418419
self.norm = None
419420
self.output = None
420421

421-
# self.freqs_cis: Optional[Tensor] = None
422-
# self.mask_cache: Optional[Tensor] = None
423422
self.max_batch_size = -1
424423
self.max_seq_length = -1
425424
# For supporting sequence parallel (default is off, thus value of 1)
@@ -444,7 +443,7 @@ def setup_caches(self, max_batch_size, max_seq_length):
444443
self.config.dim // self.config.n_heads,
445444
self.config.block_size * 2,
446445
self.config.rope_base,
447-
use_scaled=self.config.use_scaled_rope,
446+
rope_scaling=self.config.rope_scaling,
448447
)
449448
self.register_buffer("freqs_cis", freqs_cis, persistent=True)
450449
causal_mask = torch.tril(
@@ -681,12 +680,16 @@ def forward(self, x: Tensor) -> Tensor:
681680
return output * self.weight
682681

683682

684-
def apply_scaling(freqs: torch.Tensor):
685-
# Values obtained from grid search
686-
scale_factor = 8
687-
low_freq_factor = 1
688-
high_freq_factor = 4
689-
old_context_len = 8192 # original llama3 length
683+
def apply_scaling(freqs: torch.Tensor, rope_scaling: Dict[str, Any]):
684+
# Check for the presence of the required keys
685+
required_keys = {"factor", "low_freq_factor", "high_freq_factor", "original_max_position_embeddings"}
686+
if not required_keys.issubset(rope_scaling.keys()):
687+
raise ValueError(f"Missing required keys in apply_scaling. Expected: {required_keys}")
688+
689+
scale_factor = rope_scaling["factor"]
690+
low_freq_factor = rope_scaling["low_freq_factor"]
691+
high_freq_factor = rope_scaling["high_freq_factor"]
692+
old_context_len = rope_scaling["original_max_position_embeddings"]
690693

691694
low_freq_wavelen = old_context_len / low_freq_factor
692695
high_freq_wavelen = old_context_len / high_freq_factor
@@ -707,16 +710,16 @@ def apply_scaling(freqs: torch.Tensor):
707710

708711

709712
def precompute_freqs_cis(
710-
n_elem: int, seq_len: int, base: int = 10000, dtype=None, use_scaled: bool = False
713+
n_elem: int, seq_len: int, base: int = 10000, dtype=None, rope_scaling: Optional[Dict[str, Any]] = None
711714
) -> Tensor:
712715
if not dtype:
713716
dtype = get_precision()
714717
freqs = 1.0 / (
715718
base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)
716719
)
717720
t = torch.arange(seq_len, device=freqs.device)
718-
if use_scaled:
719-
freqs = apply_scaling(freqs)
721+
if rope_scaling is not None:
722+
freqs = apply_scaling(freqs, rope_scaling)
720723
freqs = torch.outer(t, freqs)
721724
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
722725
cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
{"dim": 8192, "ffn_dim_multiplier": 1.3, "multiple_of": 4096, "n_heads": 64, "n_local_heads": 8, "n_layers": 80, "rope_base": 500000.0, "vocab_size": 128256, "use_tiktoken": true}
1+
{"block_size": 8192, "dim": 8192, "ffn_dim_multiplier": 1.3, "multiple_of": 4096, "n_heads": 64, "n_local_heads": 8, "n_layers": 80, "rope_base": 500000.0, "vocab_size": 128256, "use_tiktoken": true}
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
{"dim": 4096, "ffn_dim_multiplier": 1.3, "multiple_of": 1024, "n_heads": 32, "n_local_heads": 8, "n_layers": 32, "rope_base": 500000.0, "vocab_size": 128256, "use_tiktoken": true}
1+
{"block_size": 8192, "dim": 4096, "ffn_dim_multiplier": 1.3, "multiple_of": 1024, "n_heads": 32, "n_local_heads": 8, "n_layers": 32, "rope_base": 500000.0, "vocab_size": 128256, "use_tiktoken": true}
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
{"dim": 8192, "ffn_dim_multiplier": 1.3, "multiple_of": 4096, "n_heads": 64, "n_local_heads": 8, "n_layers": 80, "rope_base": 500000.0, "vocab_size": 128256, "use_tiktoken": true, "norm_eps": 1e-05, "use_scaled_rope": true}
1+
{"block_size": 131072, "dim": 8192, "ffn_dim_multiplier": 1.3, "multiple_of": 4096, "n_heads": 64, "n_local_heads": 8, "n_layers": 80, "rope_base": 500000.0, "vocab_size": 128256, "use_tiktoken": true, "norm_eps": 1e-05, "rope_scaling": {"factor": 8.0, "low_freq_factor": 1.0, "high_freq_factor": 4.0, "original_max_position_embeddings": 8192}}
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
{"dim": 4096, "ffn_dim_multiplier": 1.3, "multiple_of": 1024, "n_heads": 32, "n_local_heads": 8, "n_layers": 32, "rope_base": 500000.0, "vocab_size": 128256, "use_tiktoken": true, "norm_eps": 1e-05, "use_scaled_rope": true}
1+
{"block_size": 131072, "dim": 4096, "ffn_dim_multiplier": 1.3, "multiple_of": 1024, "n_heads": 32, "n_local_heads": 8, "n_layers": 32, "rope_base": 500000.0, "vocab_size": 128256, "use_tiktoken": true, "norm_eps": 1e-05, "rope_scaling": {"factor": 8.0, "low_freq_factor": 1.0, "high_freq_factor": 4.0, "original_max_position_embeddings": 8192}}

0 commit comments

Comments
 (0)