Add feature dim attributes to BitLinear for easier PEFT integration#34946
Conversation
|
LGTM @agostinv, thanks for the feature ! |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
Great little addition, thanks. Is this change sufficient to enable PEFT LoRA with bitlinear? Do you have a snippet to show its usage? I could imagine that training and inference work out of the box with this change, but some features like merging don't work or need special handling in PEFT. Edit: As |
|
@BenjaminBossan You're exactly right! Based on my experience, training functions but merging is non-trivial (also should clarify I forked
The attributes allow us to get caught by the following code in As far as a quick example goes, I have the following snippet that's pretty ad-hoc but is generally based on the BitsAndBytes implementations for import warnings
from typing import Any, Optional
import torch
from peft.tuners.tuners_utils import BaseTunerLayer, check_adapters_to_merge
from peft.utils.other import transpose
from peft.tuners.lora.layer import LoraLayer
class BitNetLinearLora(torch.nn.Module, LoraLayer):
# Lora implemented in a dense layer
def __init__(
self,
base_layer: torch.nn.Module,
adapter_name: str,
r: int = 0,
lora_alpha: int = 1,
lora_dropout: float = 0.0,
init_lora_weights: bool = True,
use_rslora: bool = False,
use_dora: bool = False,
**kwargs,
) -> None:
super().__init__()
LoraLayer.__init__(self, base_layer)
self.fan_in_fan_out = False
self._active_adapter = adapter_name
self.update_layer(
adapter_name,
r,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
init_lora_weights=init_lora_weights,
use_rslora=use_rslora,
use_dora=use_dora,
)
def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None:
raise NotImplementedError
def unmerge(self) -> None:
raise NotImplementedError
def get_delta_weight(self, adapter):
return (
transpose(
self.lora_B[adapter].weight @ self.lora_A[adapter].weight,
False,
)
* self.scaling[adapter]
)
def _mixed_batch_forward(
self, x: torch.Tensor, *args: Any, adapter_names: list[str], **kwargs: Any
) -> torch.Tensor:
# This is a special method that handles the case when users pass the argument `adapter_names`. This is an
# extra argument that allows mixing different adapters in the same batch at inference time.
result = self.base_layer(x, *args, **kwargs)
unique_adapters = set(adapter_names)
sub_batch_indices_list = []
for adapter in unique_adapters:
sub_batch_indices_list.append([index for index, item in enumerate(adapter_names) if item == adapter])
for i, active_adapter in enumerate(unique_adapters):
if active_adapter == "__base__":
continue
if active_adapter not in self.lora_A.keys():
continue
lora_A = self.lora_A[active_adapter]
lora_B = self.lora_B[active_adapter]
dropout = self.lora_dropout[active_adapter]
scaling = self.scaling[active_adapter]
requires_conversion = not torch.is_autocast_enabled()
if requires_conversion:
expected_dtype = result.dtype
x = x.to(lora_A.weight.dtype)
# getting the sub-batch, passing it to LoRA layers and updating the corresponding indices of the linear
# layer output
sub_batch = x[sub_batch_indices_list[i]]
output = lora_B(lora_A(dropout(sub_batch))) * scaling
if requires_conversion:
output = output.to(expected_dtype)
result[sub_batch_indices_list[i]] += output
return result
def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
self._check_forward_args(x, *args, **kwargs)
adapter_names = kwargs.pop("adapter_names", None)
if self.disable_adapters:
if self.merged:
self.unmerge()
result = self.base_layer(x, *args, **kwargs)
elif adapter_names is not None:
result = self._mixed_batch_forward(x, *args, adapter_names=adapter_names, **kwargs)
elif self.merged:
result = self.base_layer(x, *args, **kwargs)
else:
result = self.base_layer(x, *args, **kwargs)
# As per Tim Dettmers, for 4bit, we need to defensively clone here.
# The reason is that in some cases, an error can occur that backprop
# does not work on a manipulated view. This issue may be solved with
# newer PyTorch versions but this would need extensive testing to be
# sure.
result = result.clone()
for active_adapter in self.active_adapters:
if active_adapter not in self.lora_A.keys():
continue
lora_A = self.lora_A[active_adapter]
lora_B = self.lora_B[active_adapter]
dropout = self.lora_dropout[active_adapter]
scaling = self.scaling[active_adapter]
requires_conversion = not torch.is_autocast_enabled()
if requires_conversion:
expected_dtype = result.dtype
x = x.to(lora_A.weight.dtype)
if not self.use_dora[active_adapter]:
output = lora_B(lora_A(dropout(x))) * scaling
else:
x = dropout(x)
output = self.lora_magnitude_vector[active_adapter](
x,
lora_A=lora_A,
lora_B=lora_B,
scaling=scaling,
base_layer=self.get_base_layer(),
)
if requires_conversion:
output = output.to(expected_dtype)
result = result + output
return result
def __repr__(self) -> str:
rep = super().__repr__()
return "lora." + rep
def dispatch_bitnet(target: torch.nn.Module, adapter_name: str, **kwargs):
new_module = None
if isinstance(target, BaseTunerLayer):
target_base_layer = target.get_base_layer()
else:
target_base_layer = target
bitnet_kwargs = kwargs.copy()
new_module = BitNetLinearLora(target, adapter_name, **bitnet_kwargs)
return new_module |
ArthurZucker
left a comment
There was a problem hiding this comment.
Makes sense! Can't we just pass them to the super().__init__() call?
super().__init__(in_features, out_features) should be equivalent no?
|
@agostinv Pretty cool, thanks for providing more details and the code sample. If you're interested, we can look into adding bitnet support to PEFT directly, your example already looks quite good as is and merging support is not mandatory. |
I didn't think so, and after trying on a minimum snippet it didn't seem to work (unless I've misunderstood what you meant). Since
Not opposed at all to adding direct PEFT support, especially if it is in addition to this PR. Not including these attributes in While I currently have disallowed merging in that code snippet (mostly because I doubt it would result in a usable adapter), it feels like an official implementation should have some support for users that want to explore it. Since merging isn't super complicated, I can just go ahead and implement the most straightforward version if you'd like (i.e. dequantizing the BitLinear weights, adding the adapter, then requantizing and storing the new scales). |
|
@MekkCyber Sorry to ping you again, but do you know if any other steps are required before merging to main, here? Assuming the state of this PR is fine. |
|
@agostinv sorry forgot about it, merged ! |
What does this PR do?
This PR is an extremely simple two-liner (adding
in_featuresandout_featuresas attributes toBitLinear) whose only purpose is to improve accessibility forBitLinearto users that want to employpeft. Currently,BitLinearis not usable with LoRAs inpeftout-of-the-box.The typical flow for enabling LoRAs for custom layers in
peftis to construct a custom class that describes the LoRAs behavior and then registers it with a private API. The problem is thatpeftstill needs additional information on input and output dimensionality viain_featuresandout_features, whichBitLinearcurrently lacks. The current solution for this problem is to wrapBitLinearwith another module that adds these attributes during initialization and then replace all instances ofBitLinearwith that new module. Alternatively, the LoRA source code would have to be revised to supportBitLinearand derive the feature dimensions from its weight matrix. From the perspective of potential users, adding the aforementioned attributes improves accessibility and avoids requiring some hacky looking fixes from their end.Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Other checkmarks are left untouched, as they don't look relevant.
Who can review?