From d5ec4a3007aaff003870ac2f7879989f49c4fe20 Mon Sep 17 00:00:00 2001 From: Denis Vieriu Date: Mon, 22 Apr 2024 18:04:37 -0700 Subject: [PATCH] Fix executor_runner_mps and mpsdelegate linking with pybind (#3222) Summary: Summary of changes: - fixes mps_executor_runner build - previously it would fail to build previously due to incorrect linking with portable ops - fixes `mpsdelegate` linking with `pybind` lib - added tests to check correctness directly through pybind - added a helper file (`bench_utils.py`) to help measure models forward pass between PyTorch MPS and ExecuTorch MPS Testing (will run both AOT and runtime if MPS was built with pybind): - `./install_requirements.sh --pybind mps` - invoke a single unit test: `python3 -m unittest backends.apple.mps.test.test_mps_indexing_ops -v -k test_mps_indexing_get_1`. - invoke all tests from a file: `python3 -m unittest backends.apple.mps.test.test_mps_indexing_ops -v` cc cccclai , shoumikhin Pull Request resolved: https://github.com/pytorch/executorch/pull/3222 Reviewed By: shoumikhin Differential Revision: D56447888 Pulled By: cccclai fbshipit-source-id: 5cbbcbf8df34f29e23a1854df72f764337a9df76 (cherry picked from commit 6c30eea98c7c843b719d4dda669e2ba440f18df9) --- backends/apple/mps/CMakeLists.txt | 3 + backends/apple/mps/test/test_mps.py | 371 +++++++----------- .../apple/mps/test/test_mps_binary_ops.py | 296 ++++++++++++++ .../apple/mps/test/test_mps_indexing_ops.py | 225 +++++++++++ backends/apple/mps/test/test_mps_unary_ops.py | 26 ++ backends/apple/mps/test/test_mps_utils.py | 137 ++++--- examples/apple/mps/CMakeLists.txt | 12 +- examples/apple/mps/scripts/bench_utils.py | 117 ++++++ examples/apple/mps/scripts/mps_example.py | 132 +++++-- 9 files changed, 1011 insertions(+), 308 deletions(-) create mode 100644 backends/apple/mps/test/test_mps_binary_ops.py create mode 100644 backends/apple/mps/test/test_mps_indexing_ops.py create mode 100644 backends/apple/mps/test/test_mps_unary_ops.py create mode 100644 examples/apple/mps/scripts/bench_utils.py diff --git a/backends/apple/mps/CMakeLists.txt b/backends/apple/mps/CMakeLists.txt index 50d91fe20fe..a3b0bdab670 100644 --- a/backends/apple/mps/CMakeLists.txt +++ b/backends/apple/mps/CMakeLists.txt @@ -77,6 +77,9 @@ target_link_libraries(mpsdelegate ${MPS_GRAPG_FRAMEWORK} ) +target_link_options_shared_lib(mpsdelegate) +target_compile_options(mpsdelegate PUBLIC ${_common_compile_options}) + install( TARGETS mpsdelegate DESTINATION lib diff --git a/backends/apple/mps/test/test_mps.py b/backends/apple/mps/test/test_mps.py index 691081d35de..5ca9d0175e9 100644 --- a/backends/apple/mps/test/test_mps.py +++ b/backends/apple/mps/test/test_mps.py @@ -677,188 +677,6 @@ def forward(self, x): const_module, model_inputs, func_name=inspect.stack()[0].function[5:] ) - def test_mps_constant_add(self): - class Module(torch.nn.Module): - def __init__(self): - super().__init__() - self._constant = torch.ones(4, 4, 4) - - def forward(self, x): - out1 = x + self._constant - out2 = x + self._constant + self._constant - return out1, out2 - - const_module = Module() - model_inputs = (torch.randn(4, 4, 4),) - - self.lower_and_test_with_partitioner( - const_module, model_inputs, func_name=inspect.stack()[0].function[5:] - ) - - def test_mps_mul_scalar_float(self): - class MulScalarModule(torch.nn.Module): - def __init__(self): - super().__init__() - self._scalar = 3.14 - - def forward(self, x): - out1 = torch.ops.aten.mul.Scalar(x, self._scalar) - return out1 - - mul_scalar_module = MulScalarModule() - model_inputs = (torch.randn(4, 4, 4),) - - self.lower_and_test_with_partitioner( - mul_scalar_module, model_inputs, func_name=inspect.stack()[0].function[5:] - ) - - def test_mps_mul_scalar_int(self): - class MulScalarModule(torch.nn.Module): - def __init__(self): - super().__init__() - self._scalar = 3 - - def forward(self, x): - out1 = torch.ops.aten.mul.Scalar(x, self._scalar) - return out1 - - mul_scalar_module = MulScalarModule() - model_inputs = (torch.randint(11, (4, 4, 4)),) - - self.lower_and_test_with_partitioner( - mul_scalar_module, model_inputs, func_name=inspect.stack()[0].function[5:] - ) - - def test_mps_backend_add_1(self): - class AddModule(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x, y): - z = torch.add(x, y, alpha=0.1) - return z - - add_module = AddModule() - model_inputs = (torch.randn(1), torch.randn(1)) - - self.lower_and_test_with_partitioner( - add_module, model_inputs, func_name=inspect.stack()[0].function[5:] - ) - - def test_mps_backend_add_2(self): - class AddModule(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x): - z = torch.ops.aten.add.Scalar(x, 2.0) - return z - - add_module = AddModule() - model_inputs = (torch.randn(2, 5),) - - self.lower_and_test_with_partitioner( - add_module, model_inputs, func_name=inspect.stack()[0].function[5:] - ) - - def test_mps_backend_add_3(self): - class AddModule(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x, y): - z = torch.add(x, y) - return z - - add_module = AddModule() - model_inputs = (torch.randn(1), torch.randn(1)) - - self.lower_and_test_with_partitioner( - add_module, model_inputs, func_name=inspect.stack()[0].function[5:] - ) - - def test_mps_backend_sub_1(self): - class SubModule(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x, y): - z = torch.sub(x, y, alpha=0.1) - return z - - sub_module = SubModule() - model_inputs = (torch.randn(1), torch.randn(1)) - - self.lower_and_test_with_partitioner( - sub_module, model_inputs, func_name=inspect.stack()[0].function[5:] - ) - - def test_mps_backend_sub_2(self): - class SubModule(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x): - z = torch.ops.aten.sub.Scalar(x, 2.0) - return z - - sub_module = SubModule() - model_inputs = (torch.randn(2, 5),) - - self.lower_and_test_with_partitioner( - sub_module, model_inputs, func_name=inspect.stack()[0].function[5:] - ) - - def test_mps_backend_sub_3(self): - class SubModule(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x, y): - z = torch.sub(x, y) - return z - - sub_module = SubModule() - model_inputs = (torch.randn(1), torch.randn(1)) - - self.lower_and_test_with_partitioner( - sub_module, model_inputs, func_name=inspect.stack()[0].function[5:] - ) - - def test_mps_backend_add_scalar_float(self): - class AddScalarModule(torch.nn.Module): - def __init__(self): - super().__init__() - self._scalar_float = 3.14 - - def forward(self, x): - out = torch.ops.aten.add.Scalar(x, self._scalar_float) - return out - - add_scalar_module = AddScalarModule() - model_inputs = (torch.randn(4, 4, 4),) - - self.lower_and_test_with_partitioner( - add_scalar_module, model_inputs, func_name=inspect.stack()[0].function[5:] - ) - - def test_mps_backend_add_scalar_int(self): - class AddScalarModule(torch.nn.Module): - def __init__(self): - super().__init__() - self._scalar_int = 3 - - def forward(self, x): - out1 = torch.ops.aten.add.Scalar(x, self._scalar_int) - return out1 - - add_scalar_module = AddScalarModule() - model_inputs = (torch.randint(11, (4, 4, 4), dtype=torch.int32),) - - self.lower_and_test_with_partitioner( - add_scalar_module, model_inputs, func_name=inspect.stack()[0].function[5:] - ) - def test_mps_backend_logit_1(self): class LogitModule(torch.nn.Module): def __init__(self): @@ -891,22 +709,6 @@ def forward(self, x): logit_module, model_inputs, func_name=inspect.stack()[0].function[5:] ) - def test_mps_backend_div(self): - class DivModule(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x, y): - z = x / y - return z - - div_module = DivModule() - model_inputs = (torch.ones(1), torch.ones(1)) - - self.lower_and_test_with_partitioner( - div_module, model_inputs, func_name=inspect.stack()[0].function[5:] - ) - def test_mps_backend_round(self): class RoundModule(torch.nn.Module): def __init__(self): @@ -923,36 +725,6 @@ def forward(self, x): module, model_inputs, func_name=inspect.stack()[0].function[5:] ) - def test_mps_backend_fmod(self): - class FModModule(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x, y): - return torch.fmod(x, y) - - module = FModModule() - model_inputs = (torch.randn(2, 3, 4), torch.randn(2, 3, 4)) - - self.lower_and_test_with_partitioner( - module, model_inputs, func_name=inspect.stack()[0].function[5:] - ) - - def test_mps_backend_floor_divide(self): - class FloorDivideModule(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x, y): - return torch.floor_divide(x, y) - - module = FloorDivideModule() - model_inputs = (torch.randn(2, 3, 4), torch.randn(2, 3, 4)) - - self.lower_and_test_with_partitioner( - module, model_inputs, func_name=inspect.stack()[0].function[5:] - ) - def test_mps_backend_amax(self): class AmaxModule(torch.nn.Module): def __init__(self): @@ -1331,6 +1103,149 @@ def forward(self, x): module, model_inputs, func_name=inspect.stack()[0].function[5:] ) + def test_mps_indexing_get_1(self): + class IndexGet(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return x[[0, 1, 2], [0, 1, 0]] + + module = IndexGet() + model_inputs = (torch.tensor([[1, 2], [3, 4], [5, 6]]),) + + self.lower_and_test_with_partitioner( + module, model_inputs, func_name=inspect.stack()[0].function[5:] + ) + + def test_mps_indexing_get_2(self): + class IndexGet(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return x[:, [0, 4, 2]] + + module = IndexGet() + model_inputs = (torch.randn(5, 7, 3),) + + self.lower_and_test_with_partitioner( + module, model_inputs, func_name=inspect.stack()[0].function[5:] + ) + + def test_mps_indexing_get_3(self): + class IndexGet(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return x[:, [[0, 1], [4, 3]]] + + module = IndexGet() + model_inputs = (torch.randn(5, 7, 3),) + + self.lower_and_test_with_partitioner( + module, model_inputs, func_name=inspect.stack()[0].function[5:] + ) + + def test_mps_indexing_get_4(self): + class IndexGet(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return x[[0, 4, 2]] + + module = IndexGet() + model_inputs = (torch.randn(5, 7, 3),) + + self.lower_and_test_with_partitioner( + module, model_inputs, func_name=inspect.stack()[0].function[5:] + ) + + def test_mps_indexing_get_5(self): + class IndexGet(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return x[[0, 2, 1], :, 0] + + module = IndexGet() + model_inputs = (torch.ones(3, 2, 4),) + + self.lower_and_test_with_partitioner( + module, model_inputs, func_name=inspect.stack()[0].function[5:] + ) + + def test_mps_indices2d(self): + class IndexGet(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, rows, columns): + return x[rows, columns] + + module = IndexGet() + x = torch.arange(0, 12).resize(4, 3) + rows = torch.tensor([[0, 0], [3, 3]]) + columns = torch.tensor([[0, 2], [0, 2]]) + model_inputs = ( + x, + rows, + columns, + ) + + self.lower_and_test_with_partitioner( + module, model_inputs, func_name=inspect.stack()[0].function[5:] + ) + + def test_mps_slicing_using_advanced_index_for_column_0(self): + class IndexGet(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return x[1:4] + + module = IndexGet() + model_inputs = (torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]]),) + + self.lower_and_test_with_partitioner( + module, model_inputs, func_name=inspect.stack()[0].function[5:] + ) + + def test_mps_slicing_using_advanced_index_for_column_1(self): + class IndexGet(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + # using advanced index for column + return x[1:4, [1, 2]] + + module = IndexGet() + model_inputs = (torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]]),) + + self.lower_and_test_with_partitioner( + module, model_inputs, func_name=inspect.stack()[0].function[5:] + ) + + def test_boolean_array_indexing(self): + class IndexGet(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return x[x > 5] + + module = IndexGet() + model_inputs = (torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]]),) + + self.lower_and_test_with_partitioner( + module, model_inputs, func_name=inspect.stack()[0].function[5:] + ) + def test_mps_backend_isinf(self): class IsInfModule(torch.nn.Module): def __init__(self): diff --git a/backends/apple/mps/test/test_mps_binary_ops.py b/backends/apple/mps/test/test_mps_binary_ops.py new file mode 100644 index 00000000000..fdf2d1fbb94 --- /dev/null +++ b/backends/apple/mps/test/test_mps_binary_ops.py @@ -0,0 +1,296 @@ +# +# Copyright (c) 2024 Apple Inc. All rights reserved. +# Provided subject to the LICENSE file in the top level directory. +# + +import inspect + +import torch +from executorch.backends.apple.mps.test.test_mps_utils import TestMPS + + +class TestMPSAdd(TestMPS): + class Add(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + z = x + y + z = z + x + z = z + x + z = z + z + return z + + class Add2(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + z = x + x + return z + + class AddConstant(torch.nn.Module): + def __init__(self, constant): + super().__init__() + self._constant1 = constant + self.register_buffer("_constant2", constant, persistent=False) + self.register_parameter("_constant3", torch.nn.Parameter(constant)) + + def forward(self, x): + out1 = x + self._constant1 + torch.ones(1, 1, 1) + out2 = x + self._constant2 + self._constant3 + return out1, out2 + + def test_fp16_add(self): + inputs = (torch.ones(1).to(torch.float16), torch.ones(1).to(torch.float16)) + self.lower_and_test_with_partitioner( + self.Add(), inputs, func_name=inspect.stack()[0].function[5:] + ) + + def test_fp32_add(self): + inputs = (torch.ones(1), torch.ones(1)) + self.lower_and_test_with_partitioner( + self.Add(), inputs, func_name=inspect.stack()[0].function[5:] + ) + + def test_fp32_add_constant(self): + inputs = (torch.randn(4, 4, 4),) + self.lower_and_test_with_partitioner( + self.AddConstant(torch.ones(4, 4, 4)), + inputs, + func_name=inspect.stack()[0].function[5:], + ) + + def test_add_w_alpha(self): + class AddModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + z = torch.add(x, y, alpha=0.1) + return z + + add_module = AddModule() + model_inputs = (torch.randn(1), torch.randn(1)) + + self.lower_and_test_with_partitioner( + add_module, model_inputs, func_name=inspect.stack()[0].function[5:] + ) + + def test_add_scalar(self): + class AddModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + z = torch.ops.aten.add.Scalar(x, 2.0) + return z + + add_module = AddModule() + model_inputs = (torch.randn(2, 5),) + + self.lower_and_test_with_partitioner( + add_module, model_inputs, func_name=inspect.stack()[0].function[5:] + ) + + def test_add_scalar_int(self): + class AddScalarModule(torch.nn.Module): + def __init__(self): + super().__init__() + self._scalar_int = 3 + + def forward(self, x): + out1 = torch.ops.aten.add.Scalar(x, self._scalar_int) + return out1 + + add_scalar_module = AddScalarModule() + model_inputs = (torch.randint(11, (4, 4, 4), dtype=torch.int32),) + + self.lower_and_test_with_partitioner( + add_scalar_module, model_inputs, func_name=inspect.stack()[0].function[5:] + ) + + def test_add_without_alpha(self): + class AddModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + z = torch.add(x, y) + return z + + add_module = AddModule() + model_inputs = (torch.randn(1), torch.randn(1)) + + self.lower_and_test_with_partitioner( + add_module, model_inputs, func_name=inspect.stack()[0].function[5:] + ) + + def test_add_scalar_float(self): + class AddScalarModule(torch.nn.Module): + def __init__(self): + super().__init__() + self._scalar_float = 3.14 + + def forward(self, x): + out = torch.ops.aten.add.Scalar(x, self._scalar_float) + return out + + add_scalar_module = AddScalarModule() + model_inputs = (torch.randn(4, 4, 4),) + + self.lower_and_test_with_partitioner( + add_scalar_module, model_inputs, func_name=inspect.stack()[0].function[5:] + ) + + def test_constant_add(self): + class Module(torch.nn.Module): + def __init__(self): + super().__init__() + self._constant = torch.ones(4, 4, 4) + + def forward(self, x): + out1 = x + self._constant + out2 = x + self._constant + self._constant + return out1, out2 + + const_module = Module() + model_inputs = (torch.randn(4, 4, 4),) + + self.lower_and_test_with_partitioner( + const_module, model_inputs, func_name=inspect.stack()[0].function[5:] + ) + + +class TestMPSSub(TestMPS): + def test_mps_backend_sub_1(self): + class SubModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + z = torch.sub(x, y, alpha=0.1) + return z + + sub_module = SubModule() + model_inputs = (torch.randn(1), torch.randn(1)) + + self.lower_and_test_with_partitioner( + sub_module, model_inputs, func_name=inspect.stack()[0].function[5:] + ) + + def test_mps_backend_sub_2(self): + class SubModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + z = torch.ops.aten.sub.Scalar(x, 2.0) + return z + + sub_module = SubModule() + model_inputs = (torch.randn(2, 5),) + + self.lower_and_test_with_partitioner( + sub_module, model_inputs, func_name=inspect.stack()[0].function[5:] + ) + + def test_mps_backend_sub_3(self): + class SubModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + z = torch.sub(x, y) + return z + + sub_module = SubModule() + model_inputs = (torch.randn(1), torch.randn(1)) + + self.lower_and_test_with_partitioner( + sub_module, model_inputs, func_name=inspect.stack()[0].function[5:] + ) + + +class TestMPSMul(TestMPS): + def test_mps_mul_scalar_float(self): + class MulScalarModule(torch.nn.Module): + def __init__(self): + super().__init__() + self._scalar = 3.14 + + def forward(self, x): + out1 = torch.ops.aten.mul.Scalar(x, self._scalar) + return out1 + + mul_scalar_module = MulScalarModule() + model_inputs = (torch.randn(4, 4, 4),) + + self.lower_and_test_with_partitioner( + mul_scalar_module, model_inputs, func_name=inspect.stack()[0].function[5:] + ) + + def test_mps_mul_scalar_int(self): + class MulScalarModule(torch.nn.Module): + def __init__(self): + super().__init__() + self._scalar = 3 + + def forward(self, x): + out1 = torch.ops.aten.mul.Scalar(x, self._scalar) + return out1 + + mul_scalar_module = MulScalarModule() + model_inputs = (torch.randint(11, (4, 4, 4)),) + + self.lower_and_test_with_partitioner( + mul_scalar_module, model_inputs, func_name=inspect.stack()[0].function[5:] + ) + + +class TestMPSDiv(TestMPS): + def test_mps_backend_div(self): + class DivModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + z = x / y + return z + + div_module = DivModule() + model_inputs = (torch.ones(1), torch.ones(1)) + + self.lower_and_test_with_partitioner( + div_module, model_inputs, func_name=inspect.stack()[0].function[5:] + ) + + def test_mps_backend_fmod(self): + class FModModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + return torch.fmod(x, y) + + module = FModModule() + model_inputs = (torch.randn(2, 3, 4), torch.randn(2, 3, 4)) + + self.lower_and_test_with_partitioner( + module, model_inputs, func_name=inspect.stack()[0].function[5:] + ) + + def test_mps_backend_floor_divide(self): + class FloorDivideModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + return torch.floor_divide(x, y) + + module = FloorDivideModule() + model_inputs = (torch.randn(2, 3, 4), torch.randn(2, 3, 4)) + + self.lower_and_test_with_partitioner( + module, model_inputs, func_name=inspect.stack()[0].function[5:] + ) diff --git a/backends/apple/mps/test/test_mps_indexing_ops.py b/backends/apple/mps/test/test_mps_indexing_ops.py new file mode 100644 index 00000000000..7991f1a165a --- /dev/null +++ b/backends/apple/mps/test/test_mps_indexing_ops.py @@ -0,0 +1,225 @@ +# +# Copyright (c) 2024 Apple Inc. All rights reserved. +# Provided subject to the LICENSE file in the top level directory. +# + +import inspect + +import torch +from executorch.backends.apple.mps.test.test_mps_utils import TestMPS + + +class TestMPSIndexingOps(TestMPS): + def test_mps_indexing_get_1(self): + class IndexGet(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return x[[0, 1, 2], [0, 1, 0]] + + module = IndexGet() + model_inputs = (torch.tensor([[1, 2], [3, 4], [5, 6]]),) + + self.lower_and_test_with_partitioner( + module, model_inputs, func_name=inspect.stack()[0].function[5:] + ) + + def test_mps_indexing_get_2(self): + class IndexGet(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return x[:, [0, 1, 0]] + + module = IndexGet() + model_inputs = (torch.tensor([[1, 2], [3, 4], [5, 6]]),) + + self.lower_and_test_with_partitioner( + module, model_inputs, func_name=inspect.stack()[0].function[5:] + ) + + def test_mps_indexing_get_3(self): + class IndexGet(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return x[:, [0, 1, 0], [0, 1, 0]] + + module = IndexGet() + model_inputs = (torch.tensor([[[1, 2], [3, 4], [5, 6]]]),) + + self.lower_and_test_with_partitioner( + module, model_inputs, func_name=inspect.stack()[0].function[5:] + ) + + def test_mps_indexing_get_4(self): + class IndexGet(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return x[:, [0, 1, 0], [0, 1, 0]] + + module = IndexGet() + model_inputs = ( + torch.tensor([[[1, 2], [3, 4], [5, 6]], [[7, 8], [9, 10], [11, 12]]]), + ) + + self.lower_and_test_with_partitioner( + module, model_inputs, func_name=inspect.stack()[0].function[5:] + ) + + def test_mps_indexing_get_5(self): + class IndexGet(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return x[:, [0, 4, 2]] + + module = IndexGet() + model_inputs = (torch.randn(5, 7, 3),) + + self.lower_and_test_with_partitioner( + module, model_inputs, func_name=inspect.stack()[0].function[5:] + ) + + def test_mps_indexing_get_6(self): + class IndexGet(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return x[:, [[0, 1], [4, 3]]] + + module = IndexGet() + model_inputs = (torch.randn(5, 7, 3),) + + self.lower_and_test_with_partitioner( + module, model_inputs, func_name=inspect.stack()[0].function[5:] + ) + + def test_mps_indexing_get_7(self): + class IndexGet(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return x[[0, 4, 2]] + + module = IndexGet() + model_inputs = (torch.randn(5, 7, 3),) + + self.lower_and_test_with_partitioner( + module, model_inputs, func_name=inspect.stack()[0].function[5:] + ) + + def test_mps_indexing_get_8(self): + class IndexGet(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return x[[0, 2, 1], :, 0] + + module = IndexGet() + model_inputs = (torch.ones(3, 2, 4),) + + self.lower_and_test_with_partitioner( + module, model_inputs, func_name=inspect.stack()[0].function[5:] + ) + + def test_mps_indices2d(self): + class IndexGet(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, rows, columns): + return x[rows, columns] + + module = IndexGet() + x = torch.arange(0, 12).resize(4, 3) + rows = torch.tensor([[0, 0], [3, 3]]) + columns = torch.tensor([[0, 2], [0, 2]]) + model_inputs = ( + x, + rows, + columns, + ) + + self.lower_and_test_with_partitioner( + module, model_inputs, func_name=inspect.stack()[0].function[5:] + ) + + def test_mps_slicing_using_advanced_index_for_column_0(self): + class IndexGet(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return x[1:4] + + module = IndexGet() + model_inputs = (torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]]),) + + self.lower_and_test_with_partitioner( + module, model_inputs, func_name=inspect.stack()[0].function[5:] + ) + + def test_mps_slicing_using_advanced_index_for_column_1(self): + class IndexGet(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + # using advanced index for column + return x[1:4, [1, 2]] + + module = IndexGet() + model_inputs = (torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]]),) + + self.lower_and_test_with_partitioner( + module, model_inputs, func_name=inspect.stack()[0].function[5:] + ) + + # def test_boolean_array_indexing(self): + # class IndexGet(torch.nn.Module): + # def __init__(self): + # super().__init__() + + # def forward(self, x): + # return x[x > 5] + + # module = IndexGet() + # model_inputs = (torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]]),) + + # self.lower_and_test_with_partitioner( + # module, model_inputs, func_name=inspect.stack()[0].function[5:] + # ) + + def test_mps_indexing_put_1(self): + + class IndexPut(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y, z): + x[:, :, y] = z + return x + + module = IndexPut() + input = torch.ones(1, 8, 128, 8) + indices = torch.tensor([1]) + values = torch.randn(8, 1, 8) + model_inputs = ( + input, + indices, + values, + ) + + self.lower_and_test_with_partitioner( + module, model_inputs, func_name=inspect.stack()[0].function[5:] + ) diff --git a/backends/apple/mps/test/test_mps_unary_ops.py b/backends/apple/mps/test/test_mps_unary_ops.py new file mode 100644 index 00000000000..69c1f5ba5c6 --- /dev/null +++ b/backends/apple/mps/test/test_mps_unary_ops.py @@ -0,0 +1,26 @@ +# +# Copyright (c) 2024 Apple Inc. All rights reserved. +# Provided subject to the LICENSE file in the top level directory. +# + +import inspect + +import torch +from executorch.backends.apple.mps.test.test_mps_utils import TestMPS + + +class TestMPSLoigcal(TestMPS): + def test_mps_logical_not(self): + class LogicalNot(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return x.logical_not() + + module = LogicalNot() + model_inputs = (torch.tensor([1, 1, 0, 0], dtype=torch.bool),) + + self.lower_and_test_with_partitioner( + module, model_inputs, func_name=inspect.stack()[0].function[5:] + ) diff --git a/backends/apple/mps/test/test_mps_utils.py b/backends/apple/mps/test/test_mps_utils.py index 0e4a7424cc2..6e569dedb50 100644 --- a/backends/apple/mps/test/test_mps_utils.py +++ b/backends/apple/mps/test/test_mps_utils.py @@ -15,7 +15,6 @@ from executorch.exir import ( EdgeCompileConfig, EdgeProgramManager, - ExecutorchProgram, ExirExportedProgram, to_edge, ) @@ -28,7 +27,6 @@ from executorch.sdk.bundled_program.serialize import ( serialize_from_bundled_program_to_flatbuffer, ) -from torch._export import capture_pre_autograd_graph from torch.export import export, ExportedProgram # Config for Capturing the weights, will be moved in the future @@ -141,7 +139,59 @@ def randomize_bn(num_features: int, dimensionality: int = 2) -> torch.nn.Module: return bn +def dump_bundled_program(sample_inputs, expected_output, executorch_program, func_name): + method_test_suites = [ + MethodTestSuite( + method_name="forward", + test_cases=[ + MethodTestCase(inputs=sample_inputs, expected_outputs=expected_output) + ], + ) + ] + + logging.info(f"Expected output: {expected_output}") + logging.info(" -> Test suites generated successfully") + + bundled_program = BundledProgram(executorch_program, method_test_suites) + bundled_program_buffer = serialize_from_bundled_program_to_flatbuffer( + bundled_program + ) + + filename = f"{func_name}.pte" + logging.info(f"Step 4: Saving bundled program to {filename}") + with open(filename, "wb") as file: + file.write(bundled_program_buffer) + + class TestMPS(unittest.TestCase): + def assert_outputs_equal(self, model_output, ref_output): + """ + Helper testing function that asserts that the model output and the reference output + are equal with some tolerance. Due to numerical differences between eager mode and + the MPS's backend, we relax the detal such that absolute tolerance is 1e-3. and + relative tolerance is 1e-3. + """ + + # Compare the result from executor and eager mode direclty + if isinstance(ref_output, tuple) or isinstance(ref_output, list): + # Multiple outputs executor always returns tuple, even if there is one output + self.assertTrue( + len(ref_output) == len(model_output), + msg="Length of outputs is not matching!", + ) + for i in range(len(ref_output)): + self.assertTrue( + torch.allclose( + model_output[i], ref_output[i], atol=1e-03, rtol=1e-03 + ) + ) + else: + # If one output, eager returns tensor while executor tuple of size 1 + self.assertTrue( + torch.allclose(model_output[0], ref_output, atol=1e-03, rtol=1e-03), + msg="Outputs are not matching!", + ) + def lower_module_and_test_output( self, module: Any, @@ -149,26 +199,24 @@ def lower_module_and_test_output( func_name: str, use_partitioner: bool = True, use_fp16: bool = False, + bundled_program=True, ) -> ExirExportedProgram: """ Helper testing function that takes a torch.nn.Module and lowers it to MPS with the given sample inputs. It then runs the lowered module and compares its outputs with the outputs of the eager module. """ - logging.info("Step 1: EXIR capturing of original module") - class WrappedModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.one_module = module + model = module.eval() + original_inputs = [] + for t in sample_inputs: + original_inputs.append(t.detach().clone()) + original_inputs = tuple(original_inputs) - def forward(self, *args): - return self.one_module(*args) + expected_output = model(*sample_inputs) - model = WrappedModule() - model = model.eval() - model = capture_pre_autograd_graph(model, sample_inputs) + model = torch._export.capture_pre_autograd_graph(model, sample_inputs) edge_program = export_to_edge( model, @@ -183,10 +231,15 @@ def forward(self, *args): if use_partitioner: logging.info(f"Edge IR graph:\n{edge_program.exported_program().graph}") - edge = edge_program.to_backend(MPSPartitioner(compile_specs=compile_specs)) - logging.info(f"Lowered graph:\n{edge.exported_program().graph}") + delegated_program = edge_program + delegated_program = edge_program.to_backend( + MPSPartitioner(compile_specs=compile_specs) + ) + logging.info( + f"Lowered graph:\n{delegated_program.exported_program().graph}" + ) - executorch_program = edge.to_executorch( + executorch_program = delegated_program.to_executorch( config=ExecutorchBackendConfig(extract_constant_segment=False) ) else: @@ -206,42 +259,35 @@ def forward(self, *args): ) ) - exported_program: ExirExportedProgram = exir.capture( - WrappedModule(), sample_inputs, _CAPTURE_CONFIG - ).to_edge(_EDGE_COMPILE_CONFIG) - - executorch_program: ExecutorchProgram = exported_program.to_executorch() - - logging.info("Step 3: Generating bundled program") - logging.info( - " -> Number of execution plans: {len(executorch_program.program.execution_plan)}" - ) + if bundled_program: + dump_bundled_program( + sample_inputs, expected_output, executorch_program, func_name + ) + try: + from executorch.extension.pybindings.portable_lib import ( # @manual + _load_for_executorch_from_buffer, + ) - expected_output = module(*sample_inputs) + logging.info("Testing delegated program using pybind") - method_test_suites = [ - MethodTestSuite( - method_name="forward", - test_cases=[ - MethodTestCase( - inputs=sample_inputs, expected_outputs=module(*sample_inputs) - ) - ], + # Test the model with executor + logging.debug("Initializing MPSGraph") + executorch_module = _load_for_executorch_from_buffer( + executorch_program.buffer ) - ] - logging.info(f"Expected output: {expected_output}") - logging.info(" -> Test suites generated successfully") + model_output = executorch_module.forward(original_inputs) - bundled_program = BundledProgram(executorch_program, method_test_suites) - bundled_program_buffer = serialize_from_bundled_program_to_flatbuffer( - bundled_program - ) + logging.info(f"Expected output: {expected_output}") + logging.info(f"MPS delegate output: {model_output}") + self.assert_outputs_equal(model_output, expected_output) + logging.info("Delegated program matches PyTorch Eager mode result!") - filename = f"{func_name}.pte" - logging.info(f"Step 4: Saving bundled program to {filename}") - with open(filename, "wb") as file: - file.write(bundled_program_buffer) + return delegated_program + except ImportError: + logging.info( + "ExecuTorch MPS delegate was built without pybind support. Exiting..." + ) def lower_and_test_with_partitioner( self, @@ -251,7 +297,6 @@ def lower_and_test_with_partitioner( use_fp16: bool = False, ): logging.info(func_name) - # MPS TODO: partitioner support self.lower_module_and_test_output( graph_module, example_inputs, diff --git a/examples/apple/mps/CMakeLists.txt b/examples/apple/mps/CMakeLists.txt index 89c2b141b01..976ecebc979 100644 --- a/examples/apple/mps/CMakeLists.txt +++ b/examples/apple/mps/CMakeLists.txt @@ -42,7 +42,7 @@ add_compile_options("-Wall" "-Werror") include(${EXECUTORCH_ROOT}/build/Utils.cmake) -set(_common_compile_options -Wno-deprecated-declarations -fPIC) +set(_common_compile_options -Wno-deprecated-declarations -fPIC -DET_EVENT_TRACER_ENABLED) # Let files say "include ". set(_common_include_directories ${EXECUTORCH_ROOT}/..) @@ -51,7 +51,7 @@ set(_common_include_directories ${EXECUTORCH_ROOT}/..) # portable_ops_lib, etdump, bundled_program. find_package(executorch CONFIG REQUIRED) target_include_directories(executorch INTERFACE ${_common_include_directories}) -target_compile_options(executorch INTERFACE -DET_EVENT_TRACER_ENABLED) +target_compile_options(executorch INTERFACE ${_common_compile_options}) find_package( gflags REQUIRED PATHS ${CMAKE_CURRENT_BINARY_DIR}/../../../third-party @@ -73,7 +73,7 @@ generate_bindings_for_kernels( FUNCTIONS_YAML ${EXECUTORCH_ROOT}/kernels/portable/functions.yaml ) gen_operators_lib( - "portable_ops_lib" + "mps_portable_ops_lib" KERNEL_LIBS portable_kernels DEPS executorch) @@ -107,9 +107,9 @@ list(TRANSFORM _mps_executor_runner__srcs PREPEND "${EXECUTORCH_ROOT}/") add_executable(mps_executor_runner ${_mps_executor_runner__srcs}) if(CMAKE_BUILD_TYPE MATCHES "Debug") - set(FLATCC_LIB flatcc_d) + set(FLATCC_LIB flatccrt_d) else() - set(FLATCC_LIB flatcc) + set(FLATCC_LIB flatccrt) endif() target_link_libraries(mps_executor_runner bundled_program @@ -117,7 +117,7 @@ target_link_libraries(mps_executor_runner bundled_program etdump ${FLATCC_LIB} mpsdelegate - portable_ops_lib + mps_portable_ops_lib ${mps_executor_runner_libs}) target_compile_options(mps_executor_runner PUBLIC ${_common_compile_options}) endif() diff --git a/examples/apple/mps/scripts/bench_utils.py b/examples/apple/mps/scripts/bench_utils.py new file mode 100644 index 00000000000..c00738987ab --- /dev/null +++ b/examples/apple/mps/scripts/bench_utils.py @@ -0,0 +1,117 @@ +# +# Copyright (c) 2024 Apple Inc. All rights reserved. +# Provided subject to the LICENSE file in the top level directory. +# + +import logging +import time + +import torch +from torch._export.exported_program import ExportedProgram + + +def assert_outputs_equal(model_output, ref_output): + """ + Helper testing function that asserts that the model output and the reference output + are equal with some tolerance. Due to numerical differences between eager mode and + the MPS's backend, we relax the detal such that absolute tolerance is 1e-3. and + relative tolerance is 1e-3. + """ + + # Compare the result from executor and eager mode direclty + if isinstance(ref_output, tuple) or isinstance(ref_output, list): + # Multiple outputs executor always returns tuple, even if there is one output + assert len(ref_output) == len( + model_output + ), "Length of outputs is not matching!" + for i in range(len(ref_output)): + assert torch.allclose( + model_output[i], ref_output[i], atol=1e-03, rtol=1e-03 + ) + else: + # If one output, eager returns tensor while executor tuple of size 1 + assert torch.allclose( + model_output[0], ref_output, atol=1e-03, rtol=1e-03 + ), "Outputs are not matching!" + + +def bench_forward(func, *args): + # warmup + for _ in range(10): + func(*args) + + start = time.time() + for _ in range(100): + func(*args) + end = time.time() + return end - start + + +def executorch_forward_pass(model, inputs): + for _ in range(10): + model.forward(inputs) + + +def synchronize(): + torch.mps.synchronize() + + +def pytorch_forward_pass(model, inputs): + for _ in range(10): + model(*inputs) + synchronize() + + +def get_mps_inputs(inputs): + inputs_mps = [] + for tensor in inputs: + inputs_mps.append(tensor.to("mps")) + inputs_mps = tuple(inputs_mps) + return inputs_mps + + +def get_executorch_model(executorch_program: ExportedProgram): + try: + from executorch.extension.pybindings.portable_lib import ( # @manual + _load_for_executorch_from_buffer, + ) + + return _load_for_executorch_from_buffer(executorch_program.buffer) + except ImportError: + logging.info( + "ExecuTorch MPS delegate was built without pybind support (not possible to run forward pass within python)" + ) + return None + + +def bench_torch(executorch_program: ExportedProgram, model, inputs, model_name): + model = model.to("mps") + inputs_mps = get_mps_inputs(inputs) + + executorch_model = get_executorch_model(executorch_program) + if executorch_model is not None: + t_pytorch = bench_forward(pytorch_forward_pass, model, inputs_mps) + t_executorch = bench_forward(executorch_forward_pass, executorch_model, inputs) + + logging.info(f"Model name: {model_name}") + logging.info(f"Pytorch MPS forward pass: {t_pytorch} seconds") + logging.info(f"ExecuTorch MPS forward pass: {t_executorch} seconds") + logging.info( + f"ExecuTorch speedup: {((t_pytorch - t_executorch) / t_pytorch) * 100}%" + ) + + +def compare_outputs(executorch_program: ExportedProgram, model, inputs, model_name): + inputs_copy = [] + for t in inputs: + inputs_copy.append(t.detach().clone()) + inputs_copy = tuple(inputs_copy) + + pytorch_results = model(*inputs) + executorch_model = get_executorch_model(executorch_program) + if executorch_model is not None: + executorch_results = executorch_model.forward(inputs_copy) + assert_outputs_equal(executorch_results, pytorch_results) + logging.info( + f"Results between ExecuTorch forward pass with MPS backend and PyTorch forward pass for {model_name} are matching!" + ) diff --git a/examples/apple/mps/scripts/mps_example.py b/examples/apple/mps/scripts/mps_example.py index a86a54c4d5c..0bfef7bf4ce 100644 --- a/examples/apple/mps/scripts/mps_example.py +++ b/examples/apple/mps/scripts/mps_example.py @@ -10,6 +10,7 @@ import logging import torch +from examples.apple.mps.scripts.bench_utils import bench_torch, compare_outputs from executorch import exir from executorch.backends.apple.mps.mps_preprocess import MPSBackend from executorch.backends.apple.mps.partition.mps_partitioner import MPSPartitioner @@ -36,7 +37,28 @@ FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" logging.basicConfig(level=logging.INFO, format=FORMAT) -if __name__ == "__main__": + +def get_bundled_program(executorch_program, example_inputs, expected_output): + method_test_suites = [ + MethodTestSuite( + method_name="forward", + test_cases=[ + MethodTestCase( + inputs=example_inputs, expected_outputs=[expected_output] + ) + ], + ) + ] + logging.info(f"Expected output: {expected_output}") + + bundled_program = BundledProgram(executorch_program, method_test_suites) + bundled_program_buffer = serialize_from_bundled_program_to_flatbuffer( + bundled_program + ) + return bundled_program_buffer + + +def parse_args(): parser = argparse.ArgumentParser() parser.add_argument( "-m", @@ -54,11 +76,18 @@ parser.add_argument( "--use_partitioner", - default=False, + default=True, action=argparse.BooleanOptionalAction, help="Use MPS partitioner to run the model instead of using whole graph lowering.", ) + parser.add_argument( + "--bench_pytorch", + default=False, + action=argparse.BooleanOptionalAction, + help="Bench ExecuTorch MPS foward pass with PyTorch MPS forward pass.", + ) + parser.add_argument( "-b", "--bundled", @@ -68,6 +97,15 @@ help="Flag for bundling inputs and outputs in the final flatbuffer program", ) + parser.add_argument( + "-c", + "--check_correctness", + action="store_true", + required=False, + default=False, + help="Whether to compare the ExecuTorch MPS results with the PyTorch forward pass", + ) + parser.add_argument( "--generate_etrecord", action="store_true", @@ -76,25 +114,64 @@ help="Generate ETRecord metadata to link with runtime results (used for profiling)", ) + parser.add_argument( + "--checkpoint", + required=False, + default=None, + help="checkpoing for llama model", + ) + + parser.add_argument( + "--params", + required=False, + default=None, + help="params for llama model", + ) + args = parser.parse_args() + return args + + +def get_model_config(args): + model_config = {} + model_config["module_name"] = MODEL_NAME_TO_MODEL[args.model_name][0] + model_config["model_class_name"] = MODEL_NAME_TO_MODEL[args.model_name][1] + + if args.model_name == "llama2": + if args.checkpoint: + model_config["checkpoint"] = args.checkpoint + if args.params: + model_config["params"] = args.params + model_config["use_kv_cache"] = True + return model_config + + +if __name__ == "__main__": + args = parse_args() if args.model_name not in MODEL_NAME_TO_MODEL: raise RuntimeError(f"Available models are {list(MODEL_NAME_TO_MODEL.keys())}.") - model, example_inputs, _ = EagerModelFactory.create_model( - *MODEL_NAME_TO_MODEL[args.model_name] - ) + model_config = get_model_config(args) + model, example_inputs, _ = EagerModelFactory.create_model(**model_config) model = model.eval() + if args.check_correctness or args.bench_pytorch: + model_copy = copy.deepcopy(model) + inputs_copy = [] + for t in example_inputs: + inputs_copy.append(t.detach().clone()) + inputs_copy = tuple(inputs_copy) # pre-autograd export. eventually this will become torch.export - model = torch._export.capture_pre_autograd_graph(model, example_inputs) + with torch.no_grad(): + model = torch._export.capture_pre_autograd_graph(model, example_inputs) + edge: EdgeProgramManager = export_to_edge( + model, + example_inputs, + edge_compile_config=EdgeCompileConfig(_check_ir_validity=False), + ) - edge: EdgeProgramManager = export_to_edge( - model, - example_inputs, - edge_compile_config=EdgeCompileConfig(_check_ir_validity=False), - ) edge_program_manager_copy = copy.deepcopy(edge) compile_specs = [CompileSpec("use_fp16", bytes([args.use_fp16]))] @@ -120,31 +197,30 @@ model_name = f"{args.model_name}_mps" if args.bundled: - method_test_suites = [ - MethodTestSuite( - method_name="forward", - test_cases=[ - MethodTestCase( - inputs=example_inputs, expected_outputs=[model(*example_inputs)] - ) - ], - ) - ] - logging.info(f"Expected output: {model(*example_inputs)}") - - bundled_program = BundledProgram(executorch_program, method_test_suites) - bundled_program_buffer = serialize_from_bundled_program_to_flatbuffer( - bundled_program + expected_output = model(*example_inputs) + bundled_program_buffer = get_bundled_program( + executorch_program, example_inputs, expected_output ) model_name = f"{model_name}_bundled" extension = "fp16" if not args.use_fp16: extension = "fp32" - model_name = f"{model_name}_{extension}" + model_name = f"{model_name}_{extension}.pte" if args.generate_etrecord: etrecord_path = "etrecord.bin" logging.info("generating etrecord.bin") generate_etrecord(etrecord_path, edge_program_manager_copy, executorch_program) - save_pte_program(executorch_program, model_name) + if args.bundled: + with open(model_name, "wb") as file: + file.write(bundled_program_buffer) + logging.info(f"Saved bundled program to {model_name}") + else: + save_pte_program(executorch_program, model_name) + + if args.bench_pytorch: + bench_torch(executorch_program, model_copy, example_inputs, model_name) + + if args.check_correctness: + compare_outputs(executorch_program, model_copy, inputs_copy, model_name)