-
Notifications
You must be signed in to change notification settings - Fork 129
MLX backend POC #1365
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
williambdean
wants to merge
56
commits into
pymc-devs:main
Choose a base branch
from
williambdean:mlx-poc
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
MLX backend POC #1365
Changes from all commits
Commits
Show all changes
56 commits
Select commit
Hold shift + click to select a range
d25f214
mlx poc
williambdean edacc0e
add test for dot
williambdean 052fdc2
restore pytorch
williambdean a9ecad0
wrap in mx.array
williambdean e690bff
modify the pytorch jit
williambdean ad29c17
move file
williambdean ba29b37
dont wrap
williambdean 8716870
attempt to fix github action
williambdean 9bf7edf
change the rtol
williambdean 96ba116
add init file
williambdean e116fa1
skip if not installed
williambdean 5d5f754
remove torch related code / comments
williambdean b8cee3f
simplify the fgraph_convert
williambdean d057453
assert type
williambdean ae202e6
simplify the internal
williambdean f1941fe
remove the language
williambdean 7c8eae7
Adding operations in pytensor
cetagostini 67a74fb
add extension
williambdean fb5eb52
make compare function
williambdean 516b595
rename function
williambdean 67bb8da
correct the function name
williambdean 82bb964
tests for elemwise
williambdean 877d79f
Changes
cetagostini fafedd6
Toma tu tomate William
cetagostini 60acb8d
Pushing changes with the core shit.
cetagostini 242aba7
add more tests
williambdean 6cb47fc
additional tests
williambdean bc98e09
test for switch with mlx
williambdean 4d5b34b
Pushing code
cetagostini 5abd32d
Changes
cetagostini 12daeac
A lot of new code
cetagostini ac93949
almost there baby william
cetagostini a19cbc8
Another push small
cetagostini 5c97bc8
fix for all
williambdean 2fc81bc
fix for carlos
williambdean e6437cc
just return the compiled func
williambdean c3a3e1a
A change for willy may!
cetagostini e7cf10e
FINALLY BABY LETS PARTY! (IF YOU ARE READING THIS MAKE MORE PRs)
cetagostini 880dd5c
refactor to use getattr
williambdean 1e6addd
bring argmax test
williambdean aabbb78
use deepcopy
williambdean 0812c55
move some tests
williambdean 294c271
THE SUPER BLOCKWISEE YA YA YA YA JUUUUU
cetagostini 9d3eca8
Merge branch 'mlx-poc' of https://github.com/williambdean/pytensor in…
cetagostini 9f31ab1
Guys, I'm getting sad. We need help yisus!!!!!
cetagostini 37440ff
WILLIAM YOU NEED TO GO ANOTHER MILE! GO ON MY MATEEEEEEE, GO PHILLIES!
cetagostini 4e4923f
RETURN, WHAT A SHAME! Sad times are coming.
cetagostini 6b27dc4
AI COULD BE COOL? OR WE ARE JUST FUCKING AROUND?
cetagostini e308f83
AI RULES BABY MY MATE
cetagostini 3744a18
test conv1d case
williambdean b41cab0
I'm going for pizzas, it was an incredible day!
cetagostini 323fa9d
Merge branch 'mlx-poc' of https://github.com/williambdean/pytensor in…
cetagostini 9766975
SUUUUUUUUU!!!!!! LIFE IS GOING WELL. MLX FOR MEDIA MIX MODELS BAY
cetagostini 5ffc5ef
pre-commit
cetagostini 597f84e
Almost working
cetagostini fb8fd2f
Last PR sampling working
cetagostini File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -27,7 +27,6 @@ __pycache__ | |
\#*\# | ||
build | ||
compiled/*.cpp | ||
core.* | ||
cutils_ext.cpp | ||
dist | ||
doc/.build/ | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from pytensor.link.mlx.linker import MLXLinker |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
# isort: off | ||
from pytensor.link.mlx.dispatch.basic import mlx_funcify, mlx_typify | ||
|
||
import pytensor.link.mlx.dispatch.math | ||
import pytensor.link.mlx.dispatch.basic | ||
import pytensor.link.mlx.dispatch.elemwise | ||
import pytensor.link.mlx.dispatch.shape | ||
import pytensor.link.mlx.dispatch.subtensor | ||
import pytensor.link.mlx.dispatch.core | ||
import pytensor.link.mlx.dispatch.signal | ||
import pytensor.link.mlx.dispatch.signal.conv | ||
import pytensor.link.mlx.dispatch.blockwise | ||
# isort: on |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
import warnings | ||
from copy import deepcopy | ||
from functools import singledispatch | ||
from types import NoneType | ||
|
||
import mlx.core as mx | ||
import numpy as np | ||
|
||
from pytensor.compile.ops import DeepCopyOp | ||
from pytensor.graph.fg import FunctionGraph | ||
from pytensor.link.utils import fgraph_to_python | ||
from pytensor.raise_op import Assert, CheckAndRaise | ||
|
||
|
||
@singledispatch | ||
def mlx_typify(data, **kwargs): | ||
raise NotImplementedError(f"mlx_typify is not implemented for {type(data)}") | ||
|
||
|
||
@mlx_typify.register(np.ndarray) | ||
@mlx_typify.register(mx.array) | ||
def mlx_typify_tensor(data, dtype=None, **kwargs): | ||
return mx.array(data, dtype=dtype) | ||
|
||
|
||
@mlx_typify.register(slice) | ||
@mlx_typify.register(NoneType) | ||
@mlx_typify.register(np.number) | ||
def mlx_typify_no_conversion_needed(data, **kwargs): | ||
return data | ||
|
||
|
||
@singledispatch | ||
def mlx_funcify(op, node=None, storage_map=None, **kwargs): | ||
"""Create a MLX compatible function from an PyTensor `Op`.""" | ||
raise NotImplementedError( | ||
f"No MLX conversion for the given `Op`: {op}.\nCheck out `https://github.com/pymc-devs/pytensor/issues/1350` for progress or to request we prioritize this operation" | ||
) | ||
|
||
|
||
@mlx_funcify.register(FunctionGraph) | ||
def mlx_funcify_FunctionGraph( | ||
fgraph, | ||
node=None, | ||
fgraph_name="mlx_funcified_fgraph", | ||
conversion_func=mlx_funcify, | ||
**kwargs, | ||
): | ||
built_kwargs = {"conversion_func": conversion_func, **kwargs} | ||
return fgraph_to_python( | ||
fgraph, | ||
conversion_func, | ||
type_conversion_fn=mlx_typify, | ||
fgraph_name=fgraph_name, | ||
**built_kwargs, | ||
) | ||
|
||
|
||
@mlx_funcify.register(DeepCopyOp) | ||
def mlx_funcify_DeepCopyOp(op, **kwargs): | ||
def deepcopyop(x): | ||
return deepcopy(x) | ||
|
||
return deepcopyop | ||
|
||
|
||
@mlx_funcify.register(Assert) | ||
@mlx_funcify.register(CheckAndRaise) | ||
def mlx_funcify_CheckAndRaise(op, **kwargs): | ||
warnings.warn( | ||
f"""Skipping `CheckAndRaise` Op (assertion: {op.msg}) as MLX tracing would remove it.""", | ||
stacklevel=2, | ||
) | ||
|
||
def assert_fn(x, *inputs): | ||
return x | ||
|
||
return assert_fn | ||
Comment on lines
+67
to
+78
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this true, or just copy/pasta from JAX? |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
import mlx.core as mx | ||
|
||
from pytensor.link.mlx.dispatch import mlx_funcify | ||
from pytensor.tensor.blockwise import Blockwise | ||
from pytensor.tensor.signal.conv import Conv1d | ||
|
||
|
||
def blockwise_conv1d(op, node, **kwargs): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not needed anymore since they fixed upstream right? |
||
""" | ||
Custom implementation of Blockwise.conv1d for MLX. | ||
""" | ||
|
||
def batched_conv1d( | ||
x: mx.array, | ||
kernels: mx.array, | ||
mode: str = op.core_op.mode, | ||
stride: int = 1, | ||
dilation: int = 1, | ||
) -> mx.array: | ||
""" | ||
Apply B separate 1D convolutions (full or valid) to B sequences in parallel. | ||
|
||
Parameters | ||
---------- | ||
x : array of shape (B, T) | ||
B sequences of length T. | ||
kernels : array of shape (B, K) | ||
B kernels of length K. | ||
mode : {"valid", "full"} | ||
"valid" → no padding, output length = T - K + 1 | ||
"full" → zero-pad so output length = T + K - 1 | ||
stride : int, convolution stride (default=1) | ||
dilation : int, convolution dilation (default=1) | ||
|
||
Returns | ||
------- | ||
out : array of shape (B, L) | ||
where L = | ||
- T - K + 1 if mode="valid" | ||
- T + K - 1 if mode="full" | ||
""" | ||
# --- 1) shape checks --- | ||
B, T = x.shape | ||
Bk, K = kernels.shape | ||
if B != Bk: | ||
raise ValueError(f"Batch mismatch: x has {B}, kernels has {Bk}") | ||
|
||
# --- 2) flip kernels for convolution --- | ||
kernels_flipped = kernels[:, ::-1] # shape (B, K) | ||
|
||
# --- 3) decide padding --- | ||
if mode == "valid": | ||
pad = 0 | ||
elif mode == "full": | ||
pad = (K - 1) * dilation | ||
else: | ||
raise ValueError(f"Unsupported mode {mode!r}: choose 'valid' or 'full'") | ||
|
||
# --- 4) reshape into MLX conv1d form --- | ||
# input: (N=1, H=T, C_in=B) | ||
x_in = x.T[None, :, :] | ||
|
||
# weight: (C_out=B, H_f=K, C_in=1) | ||
w = kernels_flipped[:, :, None] | ||
|
||
# --- 5) run grouped conv1d --- | ||
y = mx.conv1d(x_in, w, stride=stride, padding=pad, dilation=dilation, groups=B) | ||
# y shape: (1, H_out, B) | ||
|
||
# --- 6) return shape (B, H_out) --- | ||
return y[0].T | ||
|
||
return batched_conv1d | ||
|
||
|
||
@mlx_funcify.register(Blockwise) | ||
def funcify_Blockwise(op: Blockwise, node, **kwargs): | ||
# 1) If it's a Conv1d Blockwise, use the custom implementation | ||
if isinstance(op.core_op, Conv1d): | ||
return blockwise_conv1d(op, node, **kwargs) | ||
|
||
# 2) Otherwise, get the core python function for this Blockwise | ||
core_node = op._create_dummy_core_node(node.inputs) | ||
core_f = mlx_funcify(op.core_op, core_node) | ||
|
||
# 3) Determine how many inputs correspond to batch dimensions | ||
n_batch = op.batch_ndim(node) | ||
|
||
# 4) Build in_axes: map only the first n_batch args, keep the rest static | ||
in_axes = tuple(0 if i < n_batch else None for i in range(len(node.inputs))) | ||
|
||
# 5) Vectorize (vmap) with in_axes | ||
blockwise_f = mx.vmap(core_f, in_axes=in_axes) | ||
|
||
# 6) Return the mapped function | ||
def blockwise_fun(*inputs): | ||
return blockwise_f(*inputs) | ||
|
||
return blockwise_fun |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mxarray
should be registered inmlx_typify_no_conversion_needed