-
Notifications
You must be signed in to change notification settings - Fork 129
MLX backend POC #1365
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
MLX backend POC #1365
Conversation
Codecov ReportAttention: Patch coverage is
❌ Your patch check has failed because the patch coverage (80.23%) is below the target coverage (100.00%). You can increase the patch coverage or adjust the target coverage. Additional details and impacted files@@ Coverage Diff @@
## main #1365 +/- ##
==========================================
+ Coverage 82.02% 82.04% +0.01%
==========================================
Files 203 208 +5
Lines 48845 48949 +104
Branches 8691 8701 +10
==========================================
+ Hits 40067 40162 +95
- Misses 6627 6632 +5
- Partials 2151 2155 +4
🚀 New features to boost your workflow:
|
I suggest basing yourself on the numba linker, torch has a lot of hacks we hopefully don't need here |
Thanks for the pointer. I simplified the one method. Do you think that |
Yeah you shouldn't need that you just need a call to tipify on the runtime inputs as well |
A shout out for the fathers of the day! Co-Authored-By: Ricardo Vieira <[email protected]> Co-Authored-By: Jesse Grabowski <[email protected]>
Still need to get this to run: import pytensor
pytensor.config.mode = "MLX" |
Hey big thanks to @jessegrabowski and @ricardoV94 to help with this PR! I feel the PR is huge enough. Should we make a first merge and start to iterate on next versions? Cleaning and making all more consistent with other backends. Thanks to @williambdean to open the PR! |
|
||
|
||
@mlx_typify.register(np.ndarray) | ||
@mlx_typify.register(mx.array) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mxarray
should be registered in mlx_typify_no_conversion_needed
@mlx_funcify.register(Assert) | ||
@mlx_funcify.register(CheckAndRaise) | ||
def mlx_funcify_CheckAndRaise(op, **kwargs): | ||
warnings.warn( | ||
f"""Skipping `CheckAndRaise` Op (assertion: {op.msg}) as MLX tracing would remove it.""", | ||
stacklevel=2, | ||
) | ||
|
||
def assert_fn(x, *inputs): | ||
return x | ||
|
||
return assert_fn |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this true, or just copy/pasta from JAX?
from pytensor.tensor.signal.conv import Conv1d | ||
|
||
|
||
def blockwise_conv1d(op, node, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not needed anymore since they fixed upstream right?
# ------------------------------------------------------------------ | ||
# Join | ||
# ------------------------------------------------------------------ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't introduce comments like this, the code is readable and won't get stale if we move things around
# ------------------------------------------------------------------ | ||
# Join | ||
# ------------------------------------------------------------------ | ||
@mlx_funcify.register(Join) # MLX |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The multiple #MLX comments are useless
# Convert scalar to array if needed | ||
if isinstance(x, int | float) or ( | ||
isinstance(x, np.number) and not isinstance(x, np.ndarray) | ||
): | ||
x = mx.array(x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should not be needed
@mlx_funcify.register(CAReduce) | ||
def mlx_funcify_CAReduce(op, **kwargs): | ||
if isinstance(op.scalar_op, Add): | ||
|
||
def sum(x): | ||
return mx.sum(x, axis=op.axis) | ||
|
||
return sum | ||
elif isinstance(op.scalar_op, Mul): | ||
|
||
def prod(x): | ||
return mx.prod(x, axis=op.axis) | ||
|
||
return prod | ||
elif isinstance(op.scalar_op, AND): | ||
|
||
def all(x): | ||
return x.all(axis=op.axis) | ||
|
||
return all | ||
elif isinstance(op.scalar_op, OR): | ||
|
||
def any(x): | ||
return mx.any(x, axis=op.axis) | ||
|
||
return any | ||
else: | ||
raise NotImplementedError(f"MLX does not support Elemwise {op.scalar_op}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should do a secon-lever dispatch on the core_op. Something like MLX_funcify_CAREduce
that is called on op.scalar_op
|
||
|
||
@mlx_funcify.register(Elemwise) | ||
def mlx_funcify_Elemwise(op, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Like CAReduce it should have a second level dispatch. Also we need to enforce the runtime_broadcastable
checks (same in Alloc). And we shoud have a default implementation for that second level dispatch that tries to use getattr(MLX, "func_name")
similar to how JAX does it already.
if not op.inplace: | ||
x = deepcopy(x) | ||
x[indices] = y | ||
return x |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need tests for all these including inplace variants
for n in self.fgraph.inputs: | ||
sinput = storage_map[n] | ||
# Handle random number generators specially | ||
if isinstance(sinput[0], RandomState | Generator): | ||
new_value = mlx_typify( | ||
sinput[0], dtype=getattr(sinput[0], "dtype", None) | ||
) | ||
sinput[0] = new_value | ||
thunk_inputs.append(sinput) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since we don't have Random stuff yet we shouldn't include the code
Description
Getting ball rolling started with #1350
Related Issue
Checklist
Type of change
📚 Documentation preview 📚: https://pytensor--1365.org.readthedocs.build/en/1365/