Skip to content
Draft
22 changes: 21 additions & 1 deletion torchao/_models/llama/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ def main(
save: bool = False,
compile: bool = True,
compile_prefill: bool = False,
superblock: bool = False,
profile: Optional[Path] = None,
memory_profile: Optional[Path] = None,
device=default_device,
Expand Down Expand Up @@ -273,6 +274,24 @@ def main(
filename = str(checkpoint_path.name).split(".")[0]
torch.save(model.state_dict(), os.path.join(output_dir, filename + f"-{quantization}.pt"))

if superblock:
from torchao.sparsity.prototype.superblock.utils import (
accelerate_with_sparsity,
get_args_parser,
simulate_sparsity,
)

superblock_args = get_args_parser(benchmark=True).parse_args([])
superblock_args.sparsity = "bsr"
superblock_args.sparsity_linear = 0.9
superblock_args.bsr = 64

sparsifier_or_none = simulate_sparsity(model, superblock_args)
if sparsifier_or_none is not None:
sparsifier_or_none.squash_mask()

accelerate_with_sparsity(model, superblock_args)

if compile:
print("Compiling Model")
global decode_one_token, prefill
Expand Down Expand Up @@ -426,6 +445,7 @@ def callback(x):
parser.add_argument('--save', action='store_true', help='Whether to save the quantized model.')
parser.add_argument('--compile', action='store_true', help='Whether to compile the model.')
parser.add_argument('--compile_prefill', action='store_true', help='Whether to compile the prefill (improves prefill perf, but higher compile times)')
parser.add_argument('--superblock', action='store_true', help='Apply Superblock BSR sparsity')
parser.add_argument('--profile', type=Path, default=None, help='Profile path.')
parser.add_argument('--memory_profile', type=Path, default=None, help='filename for memory profile.')
parser.add_argument('--device', type=str, default=default_device, help='Device to use')
Expand All @@ -435,5 +455,5 @@ def callback(x):
args = parser.parse_args()
main(
args.prompt, args.interactive, args.num_samples, args.max_new_tokens, args.top_k,
args.temperature, args.checkpoint_path, args.quantization, args.kv_cache_quantization, args.cache_size, args.linear_causal_mask, args.save, args.compile, args.compile_prefill, args.profile, args.memory_profile, args.device, args.precision, args.write_result
args.temperature, args.checkpoint_path, args.quantization, args.kv_cache_quantization, args.cache_size, args.linear_causal_mask, args.save, args.compile, args.compile_prefill, args.superblock, args.profile, args.memory_profile, args.device, args.precision, args.write_result
)
28 changes: 0 additions & 28 deletions torchao/sparsity/prototype/superblock/README.md
Copy link
Contributor

@jerryzh168 jerryzh168 Nov 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be torchao/prototype/sparsity ? #1013

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is because this PR forked from an old commit that had superblock in torchao/sparsity/prototype.
When finalizing the PR we can rebase on top of main and change the path of the directory.

Original file line number Diff line number Diff line change
Expand Up @@ -15,34 +15,6 @@ The BSR format is efficient for sparse matrices with a block structure, where no

Currently, the BSR format is optimized for Nvidia A100 GPU(s) only.

## Setup
To use SuperBlock, you will need
* [PyTorch](https://pytorch.org/get-started/locally/)

To train the model or evaluate accuracy, you will need:
* ImageNet2012-blurred dataset

At least one GPU:
* A100 or H100

## Installation
* Clone this repo
```
git clone https://github.com/pytorch-labs/superblock.git
cd superblock
```
* Create a new conda environment
```
conda create -n superblock
conda activate superblock
```
* Install PyTorch. For best performance, we recommend the pytorch nightlies
```
pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121
```
We ran our experiments with torch==2.6.0.dev20240924+cu121


# Results

### Benchmarking
Expand Down
190 changes: 116 additions & 74 deletions torchao/sparsity/prototype/superblock/supermask.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import torch.nn.functional as F
import numpy as np

from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter

# original supermask
scores_min=None
scores_max=9e9
Expand Down Expand Up @@ -35,6 +37,21 @@ def backward(ctx, g):
return g, None, None, None


class ApplyMask(torch.autograd.Function):
"""Supermask STE function"""
@staticmethod
def forward(ctx, weight, scores):
return weight * scores
@staticmethod
def backward(ctx, grad_output):
grad_weight = grad_scores = None
if ctx.needs_input_grad[0]:
grad_weight = grad_output
if ctx.needs_input_grad[1]:
grad_scores = grad_output
return grad_weight, grad_scores


class SupermaskLinear(nn.Linear):
"""Supermask class for Linear layer"""
def __init__(self, sparsity, fixed_mask, fixed_weight, bitwidth, transform, fixed_transform, *args, **kwargs):
Expand Down Expand Up @@ -109,7 +126,8 @@ def sparsify_offline(self):
def forward(self, x):
if not self.sparsify_weights:
subnet = self.get_mask()
w = (self.weight*self.scale+self.shift) * subnet
w = (self.weight*self.scale+self.shift)
w = ApplyMask.apply(w, subnet)
else:
w = self.weight
return F.linear(x, w, self.bias)
Expand Down Expand Up @@ -179,7 +197,8 @@ def forward(self, x):
subnet = subnet.repeat_interleave(self.tile_size, dim=i)
subnet = torch.narrow(subnet, i, 0, k)

w = (self.weight*self.scale+self.shift) * subnet
w = (self.weight*self.scale+self.shift)
w = ApplyMask.apply(w, subnet)
return F.conv2d(x, w, self.bias, self.stride, self.padding, self.dilation, self.groups)

def apply_supermask(
Expand All @@ -195,81 +214,104 @@ def apply_supermask(
device="cuda",
verbose=False,
):
sparsified_modules = {}
# create filter function
# TODO: it might be better to move the filtering function to the script calling this function
is_last_layer = lambda module, name: name == "heads.head"
is_first_transformer_layer = lambda module, name: name == "encoder.layers.encoder_layer_0"
# TODO: create condition for ffn, k,v,q,o projections
reject_fn = lambda module, name : (skip_last_layer_sparsity and is_last_layer(module, name)) or (skip_first_transformer_sparsity and is_first_transformer_layer(module, name))
filter_fn = lambda module, name : not reject_fn(module, name) and isinstance(module, (torch.nn.Linear, torch.nn.Conv2d))

for n, m in model.named_modules():
# check conditions for skipping sparsity
if skip_last_layer_sparsity and n == "heads.head":
continue
if skip_first_transformer_sparsity and "encoder.layers.encoder_layer_0" in n:
continue

# convert 1x1 convolutions
if conv1x1_sparsity != 0.0 and isinstance(m, torch.nn.Conv2d) and m.kernel_size == (1, 1):
new_m = SupermaskConv2d(
conv1x1_sparsity, False, False, None, None, None,
m.in_channels,
m.out_channels,
m.kernel_size,
stride=m.stride,
padding=m.padding,
dilation=m.dilation,
groups=m.groups,
bias=m.bias is not None,
padding_mode=m.padding_mode,
device=device,
tile_size=conv1x1_sp_tilesize,
)
new_m.weight.data.copy_(m.weight.data)
if m.bias is not None:
new_m.bias.data.copy_(m.bias.data)
sparsified_modules[n] = new_m
continue
_replace_with_custom_fn_if_matches_filter(
model,
SuperMaskReplacementClass(
linear_sparsity=linear_sparsity,
linear_sp_tilesize=linear_sp_tilesize,
conv1x1_sparsity=conv1x1_sparsity,
conv1x1_sp_tilesize=conv1x1_sp_tilesize,
conv_sparsity=conv_sparsity,
conv_sp_tilesize=conv_sp_tilesize,
device=device,
verbose=verbose,
),
filter_fn,
)

# convert all other convolutions (not tested!)
if conv_sparsity != 0.0 and isinstance(m, torch.nn.Conv2d):
new_m = SupermaskConv2d(
conv_sparsity, False, False, None, None, None,
m.in_channels,
m.out_channels,
m.kernel_size,
stride=m.stride,
padding=m.padding,
dilation=m.dilation,
groups=m.groups,
bias=m.bias is not None,
padding_mode=m.padding_mode,
device=device,
tile_size=conv_sp_tilesize,
)
new_m.weight.data.copy_(m.weight.data)
if m.bias is not None:
new_m.bias.data.copy_(m.bias.data)
sparsified_modules[n] = new_m
continue
class SuperMaskReplacementClass:
def __init__(
self,
linear_sparsity=0.0,
linear_sp_tilesize=1,
conv1x1_sparsity=0.0,
conv1x1_sp_tilesize=1,
conv_sparsity=0.0,
conv_sp_tilesize=1,
device="cuda",
verbose=False,
):
self.linear_sparsity = linear_sparsity
self.linear_sp_tilesize = linear_sp_tilesize
self.conv1x1_sparsity = conv1x1_sparsity
self.conv1x1_sp_tilesize = conv1x1_sp_tilesize
self.conv_sparsity = conv_sparsity
self.conv_sp_tilesize = conv_sp_tilesize
self.device = device
self.verbose = verbose

if linear_sparsity != 0.0 and isinstance(m, torch.nn.Linear):
new_m = SupermaskLinear(
linear_sparsity, False, False, None, None, None,
m.in_features,
m.out_features,
bias=m.bias is not None,
device=device,
tile_size=linear_sp_tilesize,
)
new_m.weight.data.copy_(m.weight.data)
if m.bias is not None:
new_m.bias.data.copy_(m.bias.data)
sparsified_modules[n] = new_m
continue
def __call__(self, module):
module_new = None

# add modules to model
for k, v in sparsified_modules.items():
sm_name, ch_name = k.rsplit(".", 1)
sm = model.get_submodule(sm_name)
sm.add_module(ch_name, v)
if self.conv1x1_sparsity != 0.0 and isinstance(module, torch.nn.Conv2d) and module.kernel_size == (1, 1):
# convert 1x1 convolutions
module_new = SupermaskConv2d(
self.conv1x1_sparsity, False, False, None, None, None,
module.in_channels,
module.out_channels,
module.kernel_size,
stride=module.stride,
padding=module.padding,
dilation=module.dilation,
groups=module.groups,
bias=module.bias is not None,
padding_mode=module.padding_mode,
tile_size=self.conv1x1_sp_tilesize,
).to(device=self.device, dtype=module.weight.dtype)
module_new.weight.data.copy_(module.weight.data)
if module.bias is not None:
module_new.bias.data.copy_(module.bias.data)
elif self.conv_sparsity != 0.0 and isinstance(module, torch.nn.Conv2d):
# convert all other convolutions (not tested!)
module_new = SupermaskConv2d(
self.conv_sparsity, False, False, None, None, None,
module.in_channels,
module.out_channels,
module.kernel_size,
stride=module.stride,
padding=module.padding,
dilation=module.dilation,
groups=module.groups,
bias=module.bias is not None,
padding_mode=module.padding_mode,
tile_size=self.conv_sp_tilesize,
).to(device=self.device, dtype=module.weight.dtype)
module_new.weight.data.copy_(module.weight.data)
if module.bias is not None:
module_new.bias.data.copy_(module.bias.data)
elif self.linear_sparsity != 0.0 and isinstance(module, torch.nn.Linear):
module_new = SupermaskLinear(
self.linear_sparsity, False, False, None, None, None,
module.in_features,
module.out_features,
bias=module.bias is not None,
tile_size=self.linear_sp_tilesize,
).to(device=self.device, dtype=module.weight.dtype)
module_new.weight.data.copy_(module.weight.data)
if module.bias is not None:
module_new.bias.data.copy_(module.bias.data)
else:
return module

if verbose:
print(f'sparsified module "{k}" with sparsity={v.sparsity}, tile size={v.tile_size}')
if self.verbose:
print(f'sparsified module "{module}" with sparsity={module_new.sparsity}, tile size={module_new.tile_size}')

return model
return module_new
4 changes: 2 additions & 2 deletions torchao/sparsity/prototype/superblock/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def mlp_only(mod, name):


def superblock_only(mod, name):
return isinstance(mod, SupermaskLinear) and "mlp" in name
return isinstance(mod, SupermaskLinear)# and "mlp" in name
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mostafaelhoushi Should this be changed to SupermaskReplacementClass?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmmm... SupermaskReplacementClass constructor requires a lot of arguments like linear_sparsity, and linear_sp_tilesize, etc. How will we pass them here?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I need to do some more refactoring:

  • the ViT benchmark code assumes that there is a model checkpoint trained with SuperBlock, and hence has SupermaskLinear layers and parameters
  • the GPT-Fast code I wrote did a hack in which it converted Linear layers to SupermaskLinear layers then applied BSR sparsification.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I think I read your changes wrong, I was assuming you created a SupermaskReplacementClass to combine SupermaskLinear, SupermaskConv, but I still you're still using those under the hood. I think this should be fine actually.



def mlp_only_with_args(
Expand All @@ -138,7 +138,7 @@ def mlp_only_with_args(
### Custom sparsification utils
def apply_sparsity(model):
for name, module in model.named_modules():
if isinstance(module, SupermaskLinear) and "mlp" in name:
if isinstance(module, SupermaskLinear):# and "mlp" in name: # TODO: add option in another function for "mlp" in name
module.sparsify_offline()


Expand Down