Skip to content

Commit 26066b7

Browse files
mlazossvekarsAlannaBurke
authored
Added Torch Function modes x torch.compile tutorial (#3320)
--------- Co-authored-by: Svetlana Karslioglu <[email protected]> Co-authored-by: Alanna Burke <[email protected]>
1 parent a9ca64e commit 26066b7

File tree

2 files changed

+86
-0
lines changed

2 files changed

+86
-0
lines changed

recipes_source/recipes_index.rst

+9
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,15 @@ Recipes are bite-sized, actionable examples of how to use specific PyTorch featu
317317
:link: ../recipes/amx.html
318318
:tags: Model-Optimization
319319

320+
.. (beta) Utilizing Torch Function modes with torch.compile
321+
322+
.. customcarditem::
323+
:header: (beta) Utilizing Torch Function modes with torch.compile
324+
:card_description: Override torch operators with Torch Function modes and torch.compile
325+
:image: ../_static/img/thumbnails/cropped/generic-pytorch-logo.png
326+
:link: ../recipes/torch_compile_torch_function_modes.html
327+
:tags: Model-Optimization
328+
320329
.. (beta) Compiling the Optimizer with torch.compile
321330
322331
.. customcarditem::
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
"""
2+
(beta) Utilizing Torch Function modes with torch.compile
3+
============================================================
4+
5+
**Author:** `Michael Lazos <https://github.com/mlazos>`_
6+
"""
7+
8+
#########################################################
9+
# This recipe covers how to use a key torch extensibility point,
10+
# torch function modes, in tandem with ``torch.compile`` to override
11+
# the behavior of torch operators, also know as **ops**, at trace time, with no runtime overhead.
12+
#
13+
# .. note::
14+
#
15+
# This recipe requires PyTorch 2.7.0 or later.
16+
17+
18+
#####################################################################
19+
# Rewriting a torch op (torch.add -> torch.mul)
20+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
21+
# For this example, we'll use torch function modes to rewrite occurences
22+
# of addition with multiply instead. This type of override can be common
23+
# if a certain backend has a custom implementation that should be dispatched
24+
# for a given op.
25+
import torch
26+
27+
# exit cleanly if we are on a device that doesn't support ``torch.compile``
28+
if torch.cuda.get_device_capability() < (7, 0):
29+
print("Exiting because torch.compile is not supported on this device.")
30+
import sys
31+
sys.exit(0)
32+
33+
from torch.overrides import BaseTorchFunctionMode
34+
35+
# Define our mode, Note: ``BaseTorchFunctionMode``
36+
# implements the actual invocation of func(..)
37+
class AddToMultiplyMode(BaseTorchFunctionMode):
38+
def __torch_function__(self, func, types, args=(), kwargs=None):
39+
if func == torch.Tensor.add:
40+
func = torch.mul
41+
42+
return super().__torch_function__(func, types, args, kwargs)
43+
44+
@torch.compile()
45+
def test_fn(x, y):
46+
return x + y * x # Note: infix operators map to torch.Tensor.* methods
47+
48+
x = torch.rand(2, 2)
49+
y = torch.rand_like(x)
50+
51+
with AddToMultiplyMode():
52+
z = test_fn(x, y)
53+
54+
assert torch.allclose(z, x * y * x)
55+
56+
# The mode can also be used within the compiled region as well like this:
57+
58+
@torch.compile()
59+
def test_fn(x, y):
60+
with AddToMultiplyMode():
61+
return x + y * x # Note: infix operators map to torch.Tensor.* methods
62+
63+
x = torch.rand(2, 2)
64+
y = torch.rand_like(x)
65+
z = test_fn(x, y)
66+
67+
assert torch.allclose(z, x * y * x)
68+
69+
######################################################################
70+
# Conclusion
71+
# ~~~~~~~~~~
72+
# In this recipe we demonstrated how to override the behavior of ``torch.*`` operators
73+
# using torch function modes from within ``torch.compile``. This enables users to utilize
74+
# the extensibility benefits of torch function modes without the runtime overhead
75+
# of calling torch function on every op invocation.
76+
#
77+
# * See `Extending Torch API with Modes <https://pytorch.org/docs/stable/notes/extending.html#extending-all-torch-api-with-modes>`__ for other examples and background on Torch Function modes.

0 commit comments

Comments
 (0)