-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Speedup model loading by 4-5x in Diffusers ⚡ #3674
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 3 commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -222,6 +222,8 @@ def set_module_tensor_to_device( | |
| dtype: Optional[Union[str, torch.dtype]] = None, | ||
| fp16_statistics: Optional[torch.HalfTensor] = None, | ||
| tied_params_map: Optional[dict[int, dict[torch.device, torch.Tensor]]] = None, | ||
| non_blocking: bool = False, | ||
| _empty_cache: bool = True, | ||
|
||
| ): | ||
| """ | ||
| A helper function to set a given tensor (parameter of buffer) of a module on a specific device (note that doing | ||
|
|
@@ -295,9 +297,9 @@ def set_module_tensor_to_device( | |
|
|
||
| if dtype is None: | ||
| # For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model | ||
| value = value.to(old_value.dtype) | ||
| value = value.to(old_value.dtype, non_blocking=non_blocking) | ||
| elif not str(value.dtype).startswith(("torch.uint", "torch.int", "torch.bool")): | ||
| value = value.to(dtype) | ||
| value = value.to(dtype, non_blocking=non_blocking) | ||
|
|
||
| device_quantization = None | ||
| with torch.no_grad(): | ||
|
|
@@ -326,15 +328,15 @@ def set_module_tensor_to_device( | |
| if "xpu" in str(device) and not is_xpu_available(): | ||
| raise ValueError(f'{device} is not available, you should use device="cpu" instead') | ||
| if value is None: | ||
| new_value = old_value.to(device) | ||
| new_value = old_value.to(device, non_blocking=non_blocking) | ||
| if dtype is not None and device in ["meta", torch.device("meta")]: | ||
| if not str(old_value.dtype).startswith(("torch.uint", "torch.int", "torch.bool")): | ||
| new_value = new_value.to(dtype) | ||
| new_value = new_value.to(dtype, non_blocking=non_blocking) | ||
|
|
||
| if not is_buffer: | ||
| module._parameters[tensor_name] = param_cls(new_value, requires_grad=old_value.requires_grad) | ||
| elif isinstance(value, torch.Tensor): | ||
| new_value = value.to(device) | ||
| new_value = value.to(device, non_blocking=non_blocking) | ||
| else: | ||
| new_value = torch.tensor(value, device=device) | ||
| if device_quantization is not None: | ||
|
|
@@ -347,24 +349,30 @@ def set_module_tensor_to_device( | |
| if param_cls.__name__ in ["Int8Params", "FP4Params", "Params4bit"]: | ||
| if param_cls.__name__ == "Int8Params" and new_value.dtype == torch.float32: | ||
| # downcast to fp16 if any - needed for 8bit serialization | ||
| new_value = new_value.to(torch.float16) | ||
| new_value = new_value.to(torch.float16, non_blocking=non_blocking) | ||
| # quantize module that are going to stay on the cpu so that we offload quantized weights | ||
| if device == "cpu" and param_cls.__name__ == "Int8Params": | ||
| new_value = param_cls(new_value, requires_grad=old_value.requires_grad, **kwargs).to(0).to("cpu") | ||
| new_value.CB = new_value.CB.to("cpu") | ||
| new_value.SCB = new_value.SCB.to("cpu") | ||
| else: | ||
| new_value = param_cls(new_value, requires_grad=old_value.requires_grad, **kwargs).to(device) | ||
| new_value = param_cls(new_value, requires_grad=old_value.requires_grad, **kwargs).to( | ||
| device, non_blocking=non_blocking | ||
| ) | ||
| elif param_cls.__name__ in ["QTensor", "QBitsTensor"]: | ||
| new_value = torch.nn.Parameter(new_value, requires_grad=old_value.requires_grad).to(device) | ||
| new_value = torch.nn.Parameter(new_value, requires_grad=old_value.requires_grad).to( | ||
| device, non_blocking=non_blocking | ||
| ) | ||
| elif param_cls.__name__ in ["AffineQuantizedTensor"]: | ||
| new_value = new_value.to(device) | ||
| new_value = new_value.to(device, non_blocking=non_blocking) | ||
| else: | ||
| new_value = param_cls(new_value, requires_grad=old_value.requires_grad).to(device) | ||
| new_value = param_cls(new_value, requires_grad=old_value.requires_grad).to( | ||
| device, non_blocking=non_blocking | ||
| ) | ||
|
|
||
| module._parameters[tensor_name] = new_value | ||
| if fp16_statistics is not None: | ||
| module._parameters[tensor_name].SCB = fp16_statistics.to(device) | ||
| module._parameters[tensor_name].SCB = fp16_statistics.to(device, non_blocking=non_blocking) | ||
| del fp16_statistics | ||
| # as we put the weight to meta, it doesn't have SCB attr anymore. make sure that it is not a meta weight | ||
| if ( | ||
|
|
@@ -390,9 +398,11 @@ def set_module_tensor_to_device( | |
| device_index = torch.device(device).index if torch.device(device).type == "cuda" else None | ||
| if not getattr(module.weight, "quant_state", None) and device_index is not None: | ||
| module.weight = module.weight.cuda(device_index) | ||
|
|
||
| # clean pre and post forward hook | ||
| if device != "cpu": | ||
| clear_device_cache() | ||
| if _empty_cache: | ||
| if device != "cpu": | ||
| clear_device_cache() | ||
|
|
||
| # When handling tied weights, we update tied_params_map to keep track of the tied weights that have already been allocated on the device in | ||
| # order to avoid duplicating memory, see above. | ||
|
|
||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Change is unrelated. Seems to come from ruff, and I'm not sure why. I'm using the ruff version from the setup.py