Skip to content

Quantize vit_b_16 tutorial - Part 1 #60

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Mar 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion torchao/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from . import dtypes
from .quantization.quant_api import apply_dynamic_quant
from .quantization.quant_api import apply_weight_only_int8_quant

__all__ = [
"dtypes"
"dtypes",
"apply_dynamic_quant",
]
10 changes: 3 additions & 7 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,14 +126,10 @@ def apply_weight_only_int8_quant(model, filter_fn=None):
def apply_dynamic_quant(model, filter_fn=None):
"""
Applies dynamic symmetric per-token activation and per-channel weight
quantization to all linear layers in the given model using
module swaps.
quantization to all linear layers by converting all linear weight
tensors to the `Int8DynamicallyQuantizedLinearWeight` Tensor subclass.
"""
_replace_with_custom_fn_if_matches_filter(
model,
lambda mod: DynamicallyPerAxisQuantizedLinear.from_float(mod),
_is_linear if filter_fn is None else filter_fn,
)
change_linear_weights_to_int8_dqtensors(model, filter_fn)


def _get_subclass_inserter(cls, **kwargs):
Expand Down
Binary file added tutorials/quantize_vit/bfloat16.json.gz
Binary file not shown.
1,682 changes: 1,682 additions & 0 deletions tutorials/quantize_vit/bfloat16_code.py

Large diffs are not rendered by default.

Binary file added tutorials/quantize_vit/quant.json.gz
Binary file not shown.
2,413 changes: 2,413 additions & 0 deletions tutorials/quantize_vit/quant_code.py

Large diffs are not rendered by default.

13 changes: 13 additions & 0 deletions tutorials/quantize_vit/run.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
#!/bin/bash

# Run bfloat16 version
TORCH_LOGS='graph_breaks,recompiles' python run_vit_b.py

# Run dynamic quantized version
TORCH_LOGS='graph_breaks,recompiles' python run_vit_b_quant.py

# Store the output code for further inspection
echo "bfloat16 generated code lives in:"
TORCH_LOGS='output_code' python run_vit_b.py 2>&1 | grep "Output code written to: " | awk -F" " '{print $NF}'
echo "quantization generated code lives in:"
TORCH_LOGS='output_code' python run_vit_b_quant.py 2>&1 | grep "Output code written to: " | awk -F" " '{print $NF}'
46 changes: 46 additions & 0 deletions tutorials/quantize_vit/run_vit_b.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import torch
import torchvision.models.vision_transformer as models

# Load Vision Transformer model
model = models.vit_b_16(pretrained=True)

# Set the model to evaluation mode
model.eval().cuda().to(torch.bfloat16)

# Input tensor (batch_size, channels, height, width)
input_tensor = torch.randn(1, 3, 224, 224, dtype=torch.bfloat16, device='cuda')

model = torch.compile(model, mode='max-autotune')

def benchmark_model(model, num_runs, input_tensor):
torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()

# benchmark
for _ in range(num_runs):
with torch.autograd.profiler.record_function("timed region"):
model(input_tensor)

end_event.record()
torch.cuda.synchronize()
return start_event.elapsed_time(end_event) / num_runs

def profiler_runner(path, fn, *args, **kwargs):
with torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA],
record_shapes=True) as prof:
result = fn(*args, **kwargs)
prof.export_chrome_trace(path)
return result

# Must run with no_grad when optimizing for inference
with torch.no_grad():
# warmup
benchmark_model(model, 5, input_tensor)
# benchmark
print("elapsed_time: ", benchmark_model(model, 100, input_tensor), " milliseconds")
# Create a trace
profiler_runner("bfloat16.json.gz", benchmark_model, model, 5, input_tensor)
53 changes: 53 additions & 0 deletions tutorials/quantize_vit/run_vit_b_quant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import torch
import torchao
import torchvision.models.vision_transformer as models

# Load Vision Transformer model
model = models.vit_b_16(pretrained=True)

# Set the model to evaluation mode
model.eval().cuda().to(torch.bfloat16)

# Input tensor (batch_size, channels, height, width)
input_tensor = torch.randn(1, 3, 224, 224, dtype=torch.bfloat16, device='cuda')

## Quantization code - start
torchao.apply_dynamic_quant(model)
from torch._inductor import config as inductorconfig
inductorconfig.force_fuse_int_mm_with_mul = True
## Quantization code - end

model = torch.compile(model, mode='max-autotune')

def benchmark_model(model, num_runs, input_tensor):
torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()

# benchmark
for _ in range(num_runs):
with torch.autograd.profiler.record_function("timed region"):
model(input_tensor)

end_event.record()
torch.cuda.synchronize()
return start_event.elapsed_time(end_event) / num_runs

def profiler_runner(path, fn, *args, **kwargs):
with torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA],
record_shapes=True) as prof:
result = fn(*args, **kwargs)
prof.export_chrome_trace(path)
return result

# Must run with no_grad when optimizing for inference
with torch.no_grad():
# warmup
benchmark_model(model, 5, input_tensor)
# benchmark
print("elapsed_time: ", benchmark_model(model, 100, input_tensor), " milliseconds")
# Create a trace
profiler_runner("quant.json.gz", benchmark_model, model, 5, input_tensor)