|
15 | 15 | """ |
16 | 16 |
|
17 | 17 | import math |
| 18 | +from typing import Optional |
18 | 19 |
|
19 | 20 | import paddle |
20 | 21 | import paddle.nn as nn |
@@ -57,8 +58,10 @@ def __init__(self, text_config, vision_config, prefix=""): |
57 | 58 |
|
58 | 59 | self.pre_norm = nn.LayerNorm(self.vision_config.hidden_size, epsilon=1e-05) |
59 | 60 | self.linear_1 = nn.Linear(self.hidden_size, self.hidden_size) |
| 61 | + self.linear_1.weight.weight_loader = self.weight_loader |
60 | 62 | self.act = GELUActivation() |
61 | 63 | self.linear_2 = nn.Linear(self.hidden_size, self.text_config.hidden_size) |
| 64 | + self.linear_2.weight.weight_loader = self.weight_loader |
62 | 65 |
|
63 | 66 | def forward(self, image_features, image_grid_thw): |
64 | 67 | m1, m2 = self.merge_kernel_size |
@@ -94,6 +97,20 @@ def forward(self, image_features, image_grid_thw): |
94 | 97 | hidden_states = self.linear_2(hidden_states) |
95 | 98 | return hidden_states |
96 | 99 |
|
| 100 | + def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None): |
| 101 | + loaded_weight = get_tensor(loaded_weight) |
| 102 | + loaded_weight = loaded_weight.transpose([1, 0]) |
| 103 | + assert param.shape == loaded_weight.shape, ( |
| 104 | + f" Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({param.shape})" |
| 105 | + ) |
| 106 | + # Ensure loaded weight dtype matches model param dtype |
| 107 | + if loaded_weight.dtype != param.dtype: |
| 108 | + if loaded_weight.dtype == paddle.int8 and param.dtype == paddle.float8_e4m3fn: |
| 109 | + loaded_weight = loaded_weight.view(param.dtype) |
| 110 | + else: |
| 111 | + loaded_weight = loaded_weight.cast(param.dtype) |
| 112 | + param.copy_(loaded_weight, False) |
| 113 | + |
97 | 114 | def load_state_dict(self, state_dict): |
98 | 115 | params_dict = dict(self.named_parameters()) |
99 | 116 | for param_name, param in params_dict.items(): |
|
0 commit comments