-
Notifications
You must be signed in to change notification settings - Fork 31.7k
Integrate colqwen2.5 using colqwen2 modelling code #40600
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 7 commits
d4a79ce
6b8f487
d0697ce
26b8cbf
f5299d2
14d96d3
3ac06f9
8bc5d73
0656367
73b029b
3aa8aa8
d4be146
f591764
9577aae
e9ea6b6
6ae49f6
9297f9e
6a62d82
2032bd5
a0a6245
db2df86
272a7dc
5ca07ce
961fb9f
b6b454e
76238d3
0582b59
30dc9d9
26fe35c
673289b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -69,7 +69,7 @@ def load_original_state_dict(model_id: str, revision: Optional[str] = None) -> d | |
| original_state_dict[key] = f.get_tensor(key) | ||
|
|
||
| # Some weights are tied, so `lm.head`` is not saved. Let's clone to load state dict. | ||
| if "lm_head.weight" not in original_state_dict: | ||
| if "lm_head.weight" not in original_state_dict and "model.embed_tokens.weight" in original_state_dict: | ||
| original_state_dict["lm_head.weight"] = original_state_dict["model.embed_tokens.weight"].clone() | ||
|
|
||
| return original_state_dict | ||
|
|
@@ -99,10 +99,11 @@ def convert_colqwen2_weights_to_hf( | |
| push_to_hub: bool, | ||
| revision: Optional[str] = None, | ||
| original_vlm_name_or_path: Optional[str] = None, | ||
| use_qwen2_5=False, | ||
| ): | ||
| # Load the original model data | ||
| original_config = AutoConfig.from_pretrained( | ||
| model_id, | ||
| model_id, | ||
| revision=revision, | ||
| ) | ||
| if original_vlm_name_or_path is not None: | ||
|
|
@@ -119,6 +120,7 @@ def convert_colqwen2_weights_to_hf( | |
| config = ColQwen2Config( | ||
| vlm_config=original_config, | ||
| embedding_dim=128, # hardcoded in the original model | ||
| use_qwen2_5=use_qwen2_5, | ||
|
||
| ) | ||
| config.model_type = "colqwen2" | ||
| config.is_composition = False | ||
|
|
@@ -201,6 +203,12 @@ def convert_colqwen2_weights_to_hf( | |
| help="Name or path of the original VLM backbone model", | ||
| default=None, | ||
| ) | ||
| parser.add_argument( | ||
| "--use_qwen2_5", | ||
| help="Whether the original VLM backbone is Qwen2.5", | ||
| action="store_true", | ||
| default=False, | ||
| ) | ||
| args = parser.parse_args() | ||
|
|
||
| convert_colqwen2_weights_to_hf( | ||
|
|
@@ -209,4 +217,5 @@ def convert_colqwen2_weights_to_hf( | |
| push_to_hub=args.push_to_hub, | ||
| revision=args.revision, | ||
| original_vlm_name_or_path=args.original_vlm_name_or_path, | ||
| use_qwen2_5=args.use_qwen2_5, | ||
| ) | ||
| Original file line number | Diff line number | Diff line change | ||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -368,7 +368,8 @@ def forward( | |||||||||||||||
| inputs_embeds = self.vlm.language_model.embed_tokens(input_ids) | ||||||||||||||||
|
|
||||||||||||||||
| if pixel_values is not None: | ||||||||||||||||
| pixel_values = pixel_values.type(self.vlm.visual.get_dtype()) | ||||||||||||||||
| dtype, device = self._get_dtype_device() | ||||||||||||||||
| pixel_values = pixel_values.to(dtype=dtype, device=device) | ||||||||||||||||
|
||||||||||||||||
| dtype, device = self._get_dtype_device() | |
| pixel_values = pixel_values.to(dtype=dtype, device=device) | |
| pixel_values = pixel_values.to(inputs_embeds.device, inputs_embeds.dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BTW, in qwen-2 vision tower, we cast pixels to correct dtype manually so it is not needed. Also, LM and vision might be loaded with different dtypes and devices in specific cases :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@zucchini-nlp Did I understand correctly that this line could be removed entirely and it should work anyway?
@sahil-kabir Maybe worth a quick try. 😉
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yep, correct
sahil-kabir marked this conversation as resolved.
Show resolved
Hide resolved
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need for that then
| def _get_dtype_device(self) -> tuple[str, str]: | |
| if self.config.use_qwen2_5: | |
| parameters = next(self.vlm.visual.parameters()) | |
| else: | |
| parameters = next(self.parameters()) | |
| dtype, device = parameters.dtype, parameters.device | |
| return dtype, device |
Uh oh!
There was an error while loading. Please reload this page.