|
| 1 | +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. |
| 2 | +# |
| 3 | +# This source code is licensed under the BSD license found in the |
| 4 | +# LICENSE file in the root directory of this source tree. |
| 5 | +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. |
| 6 | +# |
| 7 | +# This source code is licensed under the BSD license found in the |
| 8 | +# LICENSE file in the root directory of this source tree. |
| 9 | +import argparse |
| 10 | +import itertools |
| 11 | +import gc |
| 12 | + |
| 13 | +from typing import Tuple |
| 14 | + |
| 15 | +import torch |
| 16 | +import torch.nn.functional as F |
| 17 | +from torch.utils import benchmark |
| 18 | + |
| 19 | +from torchao.sparsity.training import SemiSparseLinear, swap_linear_with_semi_sparse_linear |
| 20 | +from torchao.sparsity.training.autograd import semi_structured_sparsify |
| 21 | + |
| 22 | +from segment_anything_fast import sam_model_registry |
| 23 | +import pandas as pd |
| 24 | + |
| 25 | +def product_dict(**kwargs): |
| 26 | + keys = kwargs.keys() |
| 27 | + vals = kwargs.values() |
| 28 | + for instance in itertools.product(*vals): |
| 29 | + yield dict(zip(keys, instance)) |
| 30 | + |
| 31 | +def benchmark_helper( |
| 32 | + functions, |
| 33 | + cases, |
| 34 | + fw: bool = False, |
| 35 | + bw: bool = False, |
| 36 | + cuda_graph: bool = False, |
| 37 | + compile: bool = False, |
| 38 | + blocked_autorange = False, |
| 39 | +): |
| 40 | + assert fw or bw |
| 41 | + assert not (cuda_graph and compile) |
| 42 | + print(f"Running benchmarks with: fw={fw}, bw={bw}, cuda_graph={cuda_graph}, compile={compile}: ") |
| 43 | + |
| 44 | + results = [] |
| 45 | + def handle_case(**case): |
| 46 | + for sparsity_config, benchmark_cls in functions.items(): |
| 47 | + result = { |
| 48 | + "sparsity_config": sparsity_config, |
| 49 | + } |
| 50 | + result.update(**case) |
| 51 | + try: |
| 52 | + benchmark_object = benchmark_cls(**case) |
| 53 | + |
| 54 | + def run_one(): |
| 55 | + if fw: |
| 56 | + benchmark_object.fw() |
| 57 | + if bw: |
| 58 | + benchmark_object.bw() |
| 59 | + |
| 60 | + if cuda_graph: |
| 61 | + run_one() |
| 62 | + benchmark_object = benchmark_cls(**case) |
| 63 | + g = torch.cuda.CUDAGraph() |
| 64 | + with torch.cuda.graph(g): |
| 65 | + run_one() |
| 66 | + |
| 67 | + def run_one(): |
| 68 | + g.replay() |
| 69 | + |
| 70 | + if compile: |
| 71 | + benchmark_object.model = torch.compile(benchmark_object.model, mode="max-autotune") |
| 72 | + |
| 73 | + #benchmark |
| 74 | + torch.cuda.reset_peak_memory_stats() |
| 75 | + t0 = benchmark.Timer( |
| 76 | + stmt="fn()", |
| 77 | + globals={ |
| 78 | + "fn": run_one, |
| 79 | + }, |
| 80 | + label="benchmark", |
| 81 | + ) |
| 82 | + if blocked_autorange: |
| 83 | + res = t0.blocked_autorange() |
| 84 | + else: |
| 85 | + res = t0.adaptive_autorange(.03, min_run_time=.2, max_run_time=20) |
| 86 | + result.update({'time':res.median * 1e3, 'memory': torch.cuda.max_memory_allocated()/1e9}) |
| 87 | + except Exception as e: |
| 88 | + if "CUDA out of memory" not in str(e): |
| 89 | + raise |
| 90 | + else: |
| 91 | + result.update({'time': 'OOM', 'memory': 'OOM'}) |
| 92 | + finally: |
| 93 | + # clean up |
| 94 | + if 'benchmark_object' in locals(): |
| 95 | + del benchmark_object |
| 96 | + if 'g' in locals(): |
| 97 | + del g |
| 98 | + gc.collect() |
| 99 | + torch.cuda.empty_cache() |
| 100 | + results.append(result) |
| 101 | + |
| 102 | + for case in cases: |
| 103 | + handle_case(**case) |
| 104 | + return pd.DataFrame(results) |
| 105 | + |
| 106 | +# test classes for Linear |
| 107 | +class LinearTest(torch.nn.Module): |
| 108 | + def __init__(self, mkn): |
| 109 | + super().__init__() |
| 110 | + m, k, n = mkn |
| 111 | + self.model = torch.nn.Linear(k, n).cuda().half() |
| 112 | + self.input = torch.randn([m, k], device='cuda', dtype=torch.half, requires_grad=True) |
| 113 | + self.grad = torch.randn([m, n], device="cuda", dtype=torch.half) |
| 114 | + |
| 115 | + def fw(self): |
| 116 | + self.out = self.model(self.input) |
| 117 | + |
| 118 | + def bw(self): |
| 119 | + self.out.backward(self.grad, retain_graph=True) |
| 120 | + |
| 121 | +class SemiSparseLinearTest(LinearTest): |
| 122 | + def __init__(self, mkn): |
| 123 | + super().__init__(mkn) |
| 124 | + self.model = SemiSparseLinear.from_dense(self.model) |
| 125 | + |
| 126 | +class SemiSparseKernelTest(LinearTest): |
| 127 | + def __init__(self, mkn): |
| 128 | + super().__init__(mkn) |
| 129 | + |
| 130 | + def fw(self): |
| 131 | + self.out = semi_structured_sparsify(self.input) |
| 132 | + |
| 133 | + def bw(self): |
| 134 | + pass |
| 135 | + |
| 136 | +# test class for ViT (SAM image encoder) |
| 137 | +class SAMTest(torch.nn.Module): |
| 138 | + |
| 139 | + def __init__(self, model_type, batch_size): |
| 140 | + super().__init__() |
| 141 | + self.model = sam_model_registry[model_type]().image_encoder.cuda().half().train() |
| 142 | + self.input = torch.randn(batch_size, 3, 1024, 1024, device='cuda', dtype=torch.half, requires_grad=True) |
| 143 | + self.grad = torch.randn([batch_size, 256, 64, 64], device="cuda", dtype=torch.half) |
| 144 | + |
| 145 | + def fw(self): |
| 146 | + self.out = self.model(self.input) |
| 147 | + |
| 148 | + def bw(self): |
| 149 | + self.out.backward(self.grad, retain_graph=True) |
| 150 | + |
| 151 | +class SAM_W24_MLP_ONLY(SAMTest): |
| 152 | + def __init__(self, model_type, batch_size): |
| 153 | + super().__init__(model_type, batch_size) |
| 154 | + # Apply to just MLP linear layers of SAM image encoder (ViT) |
| 155 | + sparse_config = {} |
| 156 | + for name, mod in self.model.named_modules(): |
| 157 | + if isinstance(mod, torch.nn.Linear) and 'mlp' in name: |
| 158 | + sparse_config[name] = SemiSparseLinear |
| 159 | + swap_linear_with_semi_sparse_linear(self.model, sparse_config) |
| 160 | + |
| 161 | +class SAM_W24_ALL(SAMTest): |
| 162 | + def __init__(self, model_type, batch_size): |
| 163 | + super().__init__(model_type, batch_size) |
| 164 | + # Apply to all linear layers of SAM image encoder (ViT) |
| 165 | + sparse_config = {} |
| 166 | + for name, mod in self.model.named_modules(): |
| 167 | + if isinstance(mod, torch.nn.Linear): |
| 168 | + sparse_config[name] = SemiSparseLinear |
| 169 | + swap_linear_with_semi_sparse_linear(self.model, sparse_config) |
| 170 | + |
| 171 | +if __name__ == "__main__": |
| 172 | + print("BENCHMARKING") |
| 173 | + parser = argparse.ArgumentParser(description='run semi-structured spares training benchmarks') |
| 174 | + parser.add_argument('--mode', type=str, choices=["linear", "vit"], help='nn.Linear/ViT-e2e benchmarking', default="vit") |
| 175 | + parser.add_argument('--save', action="store_true", help="save benchmarking results") |
| 176 | + args = parser.parse_args() |
| 177 | + if args.mode == "linear": |
| 178 | + functions = { |
| 179 | + "dense_linear": LinearTest, |
| 180 | + "semi_sparse_linear": SemiSparseLinearTest, |
| 181 | + "semi_sparse_prune+compress_time_only": SemiSparseKernelTest, |
| 182 | + } |
| 183 | + cases = list( |
| 184 | + product_dict( |
| 185 | + mkn=[ |
| 186 | + # DINO ViT-L mlp.lin1 |
| 187 | + (13008, 1024, 4096), |
| 188 | + # DINO ViT-L mlp.lin2 |
| 189 | + (13008, 4096, 1024), |
| 190 | + ], |
| 191 | + ) |
| 192 | + ) |
| 193 | + |
| 194 | + df = benchmark_helper( |
| 195 | + functions, |
| 196 | + cases, |
| 197 | + fw=True, |
| 198 | + bw=True, |
| 199 | + cuda_graph=True, |
| 200 | + blocked_autorange=True) |
| 201 | + |
| 202 | + elif args.mode == "vit": |
| 203 | + functions = { |
| 204 | + "ViT dense (baseline)": SAMTest, |
| 205 | + "ViT MLP weight 2:4 sparse": SAM_W24_MLP_ONLY, |
| 206 | + # "ViT all(MLP+ATN) Linear weight 2:4 sparse": SAM_W24_ALL |
| 207 | + } |
| 208 | + cases = list( |
| 209 | + product_dict( |
| 210 | + model_type=['vit_l'], |
| 211 | + batch_size=[8] |
| 212 | + ) |
| 213 | + ) |
| 214 | + |
| 215 | + df = benchmark_helper( |
| 216 | + functions, |
| 217 | + cases, |
| 218 | + fw=True, |
| 219 | + bw=True, |
| 220 | + compile=True) |
| 221 | + |
| 222 | + print(df) |
| 223 | + if args.save: |
| 224 | + df.to_csv(f"{args.mode}_semi_structured_training_benchmarks.csv") |
0 commit comments