|
206 | 206 | }, |
207 | 207 | "outputs": [], |
208 | 208 | "source": [ |
209 | | - "def compute_rope_params(head_dim, theta_base=10_000, context_length=4096, attention_factor=1.0, rope_type=\"default\", rope_factor=1.0, rope_orig_max=8192, dtype=torch.float32):\n", |
| 209 | + "import math\n", |
| 210 | + "\n", |
| 211 | + "\n", |
| 212 | + "def compute_rope_params(head_dim, theta_base=10_000, context_length=4096, attention_factor=1.0, rope_type=\"default\", rope_factor=1.0, rope_orig_max=8192, beta_fast=32.0, beta_slow=1.0, dtype=torch.float32):\n", |
210 | 213 | " assert head_dim % 2 == 0, \"Embedding dimension must be even\"\n", |
211 | 214 | "\n", |
212 | | - " # Compute the inverse frequencies\n", |
213 | | - " inv_freq = 1.0 / (\n", |
214 | | - " theta_base ** (\n", |
215 | | - " torch.arange(0, head_dim, 2, dtype=dtype)[: head_dim // 2].float()\n", |
216 | | - " / head_dim\n", |
| 215 | + " if rope_type == \"yarn\":\n", |
| 216 | + " # Compute YaRN-style frequency scaling (as per https://huggingface.co/papers/2309.00071)\n", |
| 217 | + "\n", |
| 218 | + " def find_correction_dim(num_rotations, dim, base, max_position_embeddings):\n", |
| 219 | + " \"\"\"Inverse dimension formula to find the dimension based on the number of rotations\"\"\"\n", |
| 220 | + " return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base))\n", |
| 221 | + "\n", |
| 222 | + " def find_correction_range(low_rot, high_rot, dim, base, max_position_embeddings):\n", |
| 223 | + " \"\"\"Find dimension range bounds based on rotations\"\"\"\n", |
| 224 | + " low = find_correction_dim(low_rot, dim, base, max_position_embeddings)\n", |
| 225 | + " high = find_correction_dim(high_rot, dim, base, max_position_embeddings)\n", |
| 226 | + " low = math.floor(low)\n", |
| 227 | + " high = math.ceil(high)\n", |
| 228 | + " return max(low, 0), min(high, dim - 1)\n", |
| 229 | + "\n", |
| 230 | + " def linear_ramp_factor(min_val, max_val, dim):\n", |
| 231 | + " if min_val == max_val:\n", |
| 232 | + " max_val += 0.001 # Prevent singularity\n", |
| 233 | + " linear_func = (torch.arange(dim, dtype=torch.float32) - min_val) / (max_val - min_val)\n", |
| 234 | + " ramp_func = torch.clamp(linear_func, 0, 1)\n", |
| 235 | + " return ramp_func\n", |
| 236 | + "\n", |
| 237 | + " # Base frequencies\n", |
| 238 | + " pos_freqs = theta_base ** (torch.arange(0, head_dim, 2, dtype=dtype) / head_dim)\n", |
| 239 | + " inv_freq_extrapolation = 1.0 / pos_freqs # No scaling (extrapolation)\n", |
| 240 | + " inv_freq_interpolation = 1.0 / (rope_factor * pos_freqs) # With scaling (interpolation)\n", |
| 241 | + "\n", |
| 242 | + " # Find the range where we blend between interpolation and extrapolation\n", |
| 243 | + " low, high = find_correction_range(beta_fast, beta_slow, head_dim, theta_base, rope_orig_max)\n", |
| 244 | + "\n", |
| 245 | + " # Get n-dimensional rotational scaling corrected for extrapolation\n", |
| 246 | + " inv_freq_extrapolation_factor = 1 - linear_ramp_factor(low, high, head_dim // 2).to(dtype=dtype)\n", |
| 247 | + " inv_freq = (\n", |
| 248 | + " inv_freq_interpolation * (1 - inv_freq_extrapolation_factor)\n", |
| 249 | + " + inv_freq_extrapolation * inv_freq_extrapolation_factor\n", |
| 250 | + " )\n", |
| 251 | + " else:\n", |
| 252 | + " # Default RoPE\n", |
| 253 | + " inv_freq = 1.0 / (\n", |
| 254 | + " theta_base ** (\n", |
| 255 | + " torch.arange(0, head_dim, 2, dtype=dtype)[: head_dim // 2].float()\n", |
| 256 | + " / head_dim\n", |
| 257 | + " )\n", |
217 | 258 | " )\n", |
218 | | - " )\n", |
219 | 259 | "\n", |
220 | 260 | " # Generate position indices\n", |
221 | 261 | " positions = torch.arange(context_length, dtype=dtype)\n", |
222 | 262 | "\n", |
223 | | - " # Optional YaRN scaling\n", |
224 | | - " if rope_type == \"yarn\":\n", |
225 | | - " positions = positions / rope_factor\n", |
226 | | - " positions = torch.clamp(positions, max=rope_orig_max - 1)\n", |
227 | | - "\n", |
228 | 263 | " # Compute the base angles (shape: [context_length, head_dim // 2])\n", |
229 | 264 | " angles = positions.unsqueeze(1) * inv_freq.unsqueeze(0)\n", |
230 | 265 | "\n", |
|
642 | 677 | " \"rope_type\": \"yarn\",\n", |
643 | 678 | " \"rope_factor\": 8.0,\n", |
644 | 679 | " \"rope_orig_max\": 8_192,\n", |
| 680 | + " \"beta_fast\": 32.0,\n", |
| 681 | + " \"beta_slow\": 1.0,\n", |
645 | 682 | " \"rms_norm_eps\": 1e-6,\n", |
646 | 683 | " \"dtype\": torch.bfloat16,\n", |
647 | 684 | " \"eos_token_id\": 100_257,\n", |
|
727 | 764 | " \"rope_type\": \"yarn\",\n", |
728 | 765 | " \"rope_factor\": 8.0,\n", |
729 | 766 | " \"rope_orig_max\": 8_192,\n", |
| 767 | + " \"beta_fast\": 32.0,\n", |
| 768 | + " \"beta_slow\": 1.0,\n", |
730 | 769 | " \"rms_norm_eps\": 1e-6,\n", |
731 | 770 | " \"dtype\": torch.bfloat16,\n", |
732 | 771 | " \"eos_token_id\": 100_257,\n", |
|
810 | 849 | { |
811 | 850 | "data": { |
812 | 851 | "text/plain": [ |
813 | | - "tensor([[[ 0.3594, -0.6289, -0.2754, ..., 1.1016, 0.4219, 0.0381],\n", |
814 | | - " [ 1.1719, 0.0283, 0.6055, ..., 0.4863, -0.1953, 0.2246],\n", |
815 | | - " [ 0.4902, -0.0425, 0.6758, ..., 0.3730, -0.5781, -0.1670]]],\n", |
| 852 | + "tensor([[[ 0.3867, -0.6328, -0.2734, ..., 1.1484, 0.4258, 0.0400],\n", |
| 853 | + " [ 1.2734, 0.0040, 0.5000, ..., 0.5625, -0.2383, 0.1855],\n", |
| 854 | + " [ 0.5859, -0.0540, 0.7930, ..., 0.3262, -0.5430, -0.1494]]],\n", |
816 | 855 | " dtype=torch.bfloat16, grad_fn=<UnsafeViewBackward0>)" |
817 | 856 | ] |
818 | 857 | }, |
|
1202 | 1241 | "name": "stdout", |
1203 | 1242 | "output_type": "stream", |
1204 | 1243 | "text": [ |
1205 | | - "Sure! Here’s a brief introduction to large language models: \n", |
1206 | | - "Large models are advanced AI systems trained to process vast neural networks capable of understanding and generating text, learning from vast amounts of data, learning language, performing diverse tasks, assisting in many applications, and adapting various tasks.\n", |
| 1244 | + "Large language models are advanced AI systems trained on vast amounts of text to understand and generate human-like language. They can perform a wide range of tasks, from answering questions to writing essays or code. These models have transformed natural language processing and are now foundational in many modern AI applications.\n", |
1207 | 1245 | "\n", |
1208 | 1246 | "GPU memory used: 13.71 GB\n" |
1209 | 1247 | ] |
|
0 commit comments