Skip to content

Commit 5559405

Browse files
authored
Benchmarking updates for semi-structured sparse training (#398)
* Benchmarking updates for semi-structured sparse training Summary: This PR does the following: - adds e2e ViT benchmarks for semi-structured sparse training - adds nn.Linear microbenchmarks - removes extra xformers benchmarking utils I copied over - removes MLP block benchmarks - updated README.md with new benchmarks + accuracy benchmarks Given we have nn.Linear microbenchmarks and e2e benchmarks, I felt that the MLP block benchmarks were unnecessary As a sanity check, I ran the MLP benchmarks with the new benchmarking suite and the old one, and got the same results: Test Plan: Reviewers: Subscribers: Tasks: Tags: * update * add units
1 parent e5ee771 commit 5559405

File tree

4 files changed

+262
-873
lines changed

4 files changed

+262
-873
lines changed

benchmarks/benchmark_semi_sparse.py

Lines changed: 0 additions & 129 deletions
This file was deleted.
Lines changed: 224 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,224 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2+
#
3+
# This source code is licensed under the BSD license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
6+
#
7+
# This source code is licensed under the BSD license found in the
8+
# LICENSE file in the root directory of this source tree.
9+
import argparse
10+
import itertools
11+
import gc
12+
13+
from typing import Tuple
14+
15+
import torch
16+
import torch.nn.functional as F
17+
from torch.utils import benchmark
18+
19+
from torchao.sparsity.training import SemiSparseLinear, swap_linear_with_semi_sparse_linear
20+
from torchao.sparsity.training.autograd import semi_structured_sparsify
21+
22+
from segment_anything_fast import sam_model_registry
23+
import pandas as pd
24+
25+
def product_dict(**kwargs):
26+
keys = kwargs.keys()
27+
vals = kwargs.values()
28+
for instance in itertools.product(*vals):
29+
yield dict(zip(keys, instance))
30+
31+
def benchmark_helper(
32+
functions,
33+
cases,
34+
fw: bool = False,
35+
bw: bool = False,
36+
cuda_graph: bool = False,
37+
compile: bool = False,
38+
blocked_autorange = False,
39+
):
40+
assert fw or bw
41+
assert not (cuda_graph and compile)
42+
print(f"Running benchmarks with: fw={fw}, bw={bw}, cuda_graph={cuda_graph}, compile={compile}: ")
43+
44+
results = []
45+
def handle_case(**case):
46+
for sparsity_config, benchmark_cls in functions.items():
47+
result = {
48+
"sparsity_config": sparsity_config,
49+
}
50+
result.update(**case)
51+
try:
52+
benchmark_object = benchmark_cls(**case)
53+
54+
def run_one():
55+
if fw:
56+
benchmark_object.fw()
57+
if bw:
58+
benchmark_object.bw()
59+
60+
if cuda_graph:
61+
run_one()
62+
benchmark_object = benchmark_cls(**case)
63+
g = torch.cuda.CUDAGraph()
64+
with torch.cuda.graph(g):
65+
run_one()
66+
67+
def run_one():
68+
g.replay()
69+
70+
if compile:
71+
benchmark_object.model = torch.compile(benchmark_object.model, mode="max-autotune")
72+
73+
#benchmark
74+
torch.cuda.reset_peak_memory_stats()
75+
t0 = benchmark.Timer(
76+
stmt="fn()",
77+
globals={
78+
"fn": run_one,
79+
},
80+
label="benchmark",
81+
)
82+
if blocked_autorange:
83+
res = t0.blocked_autorange()
84+
else:
85+
res = t0.adaptive_autorange(.03, min_run_time=.2, max_run_time=20)
86+
result.update({'time':res.median * 1e3, 'memory': torch.cuda.max_memory_allocated()/1e9})
87+
except Exception as e:
88+
if "CUDA out of memory" not in str(e):
89+
raise
90+
else:
91+
result.update({'time': 'OOM', 'memory': 'OOM'})
92+
finally:
93+
# clean up
94+
if 'benchmark_object' in locals():
95+
del benchmark_object
96+
if 'g' in locals():
97+
del g
98+
gc.collect()
99+
torch.cuda.empty_cache()
100+
results.append(result)
101+
102+
for case in cases:
103+
handle_case(**case)
104+
return pd.DataFrame(results)
105+
106+
# test classes for Linear
107+
class LinearTest(torch.nn.Module):
108+
def __init__(self, mkn):
109+
super().__init__()
110+
m, k, n = mkn
111+
self.model = torch.nn.Linear(k, n).cuda().half()
112+
self.input = torch.randn([m, k], device='cuda', dtype=torch.half, requires_grad=True)
113+
self.grad = torch.randn([m, n], device="cuda", dtype=torch.half)
114+
115+
def fw(self):
116+
self.out = self.model(self.input)
117+
118+
def bw(self):
119+
self.out.backward(self.grad, retain_graph=True)
120+
121+
class SemiSparseLinearTest(LinearTest):
122+
def __init__(self, mkn):
123+
super().__init__(mkn)
124+
self.model = SemiSparseLinear.from_dense(self.model)
125+
126+
class SemiSparseKernelTest(LinearTest):
127+
def __init__(self, mkn):
128+
super().__init__(mkn)
129+
130+
def fw(self):
131+
self.out = semi_structured_sparsify(self.input)
132+
133+
def bw(self):
134+
pass
135+
136+
# test class for ViT (SAM image encoder)
137+
class SAMTest(torch.nn.Module):
138+
139+
def __init__(self, model_type, batch_size):
140+
super().__init__()
141+
self.model = sam_model_registry[model_type]().image_encoder.cuda().half().train()
142+
self.input = torch.randn(batch_size, 3, 1024, 1024, device='cuda', dtype=torch.half, requires_grad=True)
143+
self.grad = torch.randn([batch_size, 256, 64, 64], device="cuda", dtype=torch.half)
144+
145+
def fw(self):
146+
self.out = self.model(self.input)
147+
148+
def bw(self):
149+
self.out.backward(self.grad, retain_graph=True)
150+
151+
class SAM_W24_MLP_ONLY(SAMTest):
152+
def __init__(self, model_type, batch_size):
153+
super().__init__(model_type, batch_size)
154+
# Apply to just MLP linear layers of SAM image encoder (ViT)
155+
sparse_config = {}
156+
for name, mod in self.model.named_modules():
157+
if isinstance(mod, torch.nn.Linear) and 'mlp' in name:
158+
sparse_config[name] = SemiSparseLinear
159+
swap_linear_with_semi_sparse_linear(self.model, sparse_config)
160+
161+
class SAM_W24_ALL(SAMTest):
162+
def __init__(self, model_type, batch_size):
163+
super().__init__(model_type, batch_size)
164+
# Apply to all linear layers of SAM image encoder (ViT)
165+
sparse_config = {}
166+
for name, mod in self.model.named_modules():
167+
if isinstance(mod, torch.nn.Linear):
168+
sparse_config[name] = SemiSparseLinear
169+
swap_linear_with_semi_sparse_linear(self.model, sparse_config)
170+
171+
if __name__ == "__main__":
172+
print("BENCHMARKING")
173+
parser = argparse.ArgumentParser(description='run semi-structured spares training benchmarks')
174+
parser.add_argument('--mode', type=str, choices=["linear", "vit"], help='nn.Linear/ViT-e2e benchmarking', default="vit")
175+
parser.add_argument('--save', action="store_true", help="save benchmarking results")
176+
args = parser.parse_args()
177+
if args.mode == "linear":
178+
functions = {
179+
"dense_linear": LinearTest,
180+
"semi_sparse_linear": SemiSparseLinearTest,
181+
"semi_sparse_prune+compress_time_only": SemiSparseKernelTest,
182+
}
183+
cases = list(
184+
product_dict(
185+
mkn=[
186+
# DINO ViT-L mlp.lin1
187+
(13008, 1024, 4096),
188+
# DINO ViT-L mlp.lin2
189+
(13008, 4096, 1024),
190+
],
191+
)
192+
)
193+
194+
df = benchmark_helper(
195+
functions,
196+
cases,
197+
fw=True,
198+
bw=True,
199+
cuda_graph=True,
200+
blocked_autorange=True)
201+
202+
elif args.mode == "vit":
203+
functions = {
204+
"ViT dense (baseline)": SAMTest,
205+
"ViT MLP weight 2:4 sparse": SAM_W24_MLP_ONLY,
206+
# "ViT all(MLP+ATN) Linear weight 2:4 sparse": SAM_W24_ALL
207+
}
208+
cases = list(
209+
product_dict(
210+
model_type=['vit_l'],
211+
batch_size=[8]
212+
)
213+
)
214+
215+
df = benchmark_helper(
216+
functions,
217+
cases,
218+
fw=True,
219+
bw=True,
220+
compile=True)
221+
222+
print(df)
223+
if args.save:
224+
df.to_csv(f"{args.mode}_semi_structured_training_benchmarks.csv")

0 commit comments

Comments
 (0)