Skip to content

[Mutable Buffer] [Core ML Delegate] Let Core ML Handle Mutable Buffer #4209

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
YifanShenSZ opened this issue Jul 11, 2024 · 15 comments
Closed
Assignees
Labels
good first issue Good for newcomers module: coreml Issues related to Apple's Core ML delegation and code under backends/apple/coreml/ triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@YifanShenSZ
Copy link
Collaborator

YifanShenSZ commented Jul 11, 2024

🚀 The feature, motivation and pitch

Starting from iOS 18, Core ML has state, which is the counterpart of mutable buffer. As a result, ExecuTorch can now let Core ML handle buffer mutation

Additional context

The change will probably base on #2876

@lucylq lucylq added the module: coreml Issues related to Apple's Core ML delegation and code under backends/apple/coreml/ label Jul 11, 2024
@lucylq
Copy link
Contributor

lucylq commented Jul 11, 2024

@cccclai on Core ML

@lucylq lucylq added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jul 11, 2024
@cccclai
Copy link
Contributor

cccclai commented Jul 18, 2024

@YifanShenSZ any chance we know eta for this feature?

@YifanShenSZ
Copy link
Collaborator Author

I'm actively working on coremltools side, aiming toward our next release (end of July)

  1. (Done) conversion of executorch exported mutable-buffer model
  2. (Ongoing) make sure the in-place ops have correct pattern

@YifanShenSZ
Copy link
Collaborator Author

YifanShenSZ commented Jul 22, 2024

Coremltools side change is almost done, I can now export stateful executorch model to CoreML. We will need some change in executorch to_backend implementation to let the delegation path work

@cccclai
Copy link
Contributor

cccclai commented Jul 23, 2024

From this error log:

  File "/Volumes/Models/LLM/Framework/CoreMLTools-Dev_ExecuTorch-0.2/envs/llama-py310/lib/python3.10/site-packages/executorch/exir/backend/backend_api.py", line 113, in _
    copied_edge_program = copy.deepcopy(edge_program)

Looks like the program is not copiable. Can you try running copy.deepcopy(exported_program) before running to_backend api?

@YifanShenSZ
Copy link
Collaborator Author

YifanShenSZ commented Jul 23, 2024

Dug a bit and found the issue to be more involved. Please reproduce by

git clone --branch coreml-state https://github.com/YifanShenSZ/executorch.git

cd executorch
git submodule sync
git submodule update --init
./install_requirements.sh

./backends/apple/coreml/scripts/install_requirements.sh
pip install numpy==1.26.4

cd backends/apple/coreml
python test/test_coreml_partitioner.py

@kimishpatel
Copy link
Contributor

Are you asking for state management completely handed over to delegate? If so, would delegate allow access to this state? We have had requests from users who wanted to manipulate kv cache state and not sure how this will line up with that

@YifanShenSZ
Copy link
Collaborator Author

Are you asking for state management completely handed over to delegate?

Yes

If so, would delegate allow access to this state? We have had requests from users who wanted to manipulate kv cache state and not sure how this will line up with that

@cymbalrush does Core ML runtime allow user to access state?

@cymbalrush
Copy link
Contributor

@cccclai
Copy link
Contributor

cccclai commented Jul 28, 2024

I was able to repro, looking

Traceback (most recent call last):
  File "/Users/chenlai/coreml_debug/executorch/backends/apple/coreml/test/test_coreml_partitioner.py", line 118, in <module>
    test_runner.test_buffer()
  File "/Users/chenlai/coreml_debug/executorch/backends/apple/coreml/test/test_coreml_partitioner.py", line 102, in test_buffer
    delegated_program_manager = edge_program_manager.to_backend(CoreMLPartitioner())
  File "/opt/homebrew/anaconda3/envs/coreml_debug/lib/python3.10/site-packages/executorch/exir/program/_program.py", line 1204, in to_backend
    new_edge_programs[name] = to_backend(program, partitioner)
  File "/opt/homebrew/anaconda3/envs/coreml_debug/lib/python3.10/functools.py", line 878, in wrapper
    return dispatch(args[0].__class__)(*args, **kw)
  File "/opt/homebrew/anaconda3/envs/coreml_debug/lib/python3.10/site-packages/executorch/exir/backend/backend_api.py", line 394, in _
    return ExportedProgram(
  File "/opt/homebrew/anaconda3/envs/coreml_debug/lib/python3.10/site-packages/torch/export/exported_program.py", line 682, in __init__
    self._validate()
  File "/opt/homebrew/anaconda3/envs/coreml_debug/lib/python3.10/site-packages/torch/export/exported_program.py", line 1101, in _validate
    v().check(self)
  File "/opt/homebrew/anaconda3/envs/coreml_debug/lib/python3.10/site-packages/torch/_export/verifier.py", line 157, in check
    _verify_exported_program_signature(ep)
  File "/opt/homebrew/anaconda3/envs/coreml_debug/lib/python3.10/site-packages/torch/_export/verifier.py", line 408, in _verify_exported_program_signature
    raise SpecViolationError(
torch._export.verifier.SpecViolationError: Buffer output getitem does not point to a buffer that exists.
Dict of buffers that are mutated, in order: {'getitem': 'state_1'}
Buffer nodes available: []

@cccclai
Copy link
Contributor

cccclai commented Jul 29, 2024

After checking, some changes need to be added to the delegate infra to support consuming the mutable buffer. It is our first case to have the backend to consume the mutable buffer. @angelayi will help add the feature.

@cccclai cccclai added the good first issue Good for newcomers label Jul 29, 2024
@kimishpatel
Copy link
Contributor

@cccclai can you describe what is the nature of the change?

@cccclai
Copy link
Contributor

cccclai commented Jul 30, 2024

@cccclai can you describe what is the nature of the change?

We'd need to add some changes in lowered_backend_module.py and backend_api.py to support backend consuming the in-place ops. Since we didn't have a backend that can consume in place ops before, it's the first time we test this branch.

@cccclai
Copy link
Contributor

cccclai commented Aug 7, 2024

There are some progress from @angelayi on this. #4566 is first pr and there will be one more pr to resolve it.

angelayi added a commit to angelayi/executorch-1 that referenced this issue Aug 22, 2024
Summary:
Fixing pytorch#4209

Edge Program:
```
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, b_b: "f32[3, 3]", x: "f32[3, 3]"):
             # File: /data/users/angelayi/fbsource/buck-out/v2/gen/fbcode/389acaeb40d57230/executorch/exir/backend/test/__test_partitioner__/test_partitioner#link-tree/executorch/exir/backend/test/test_partitioner.py:631 in forward, code: self.b.add_(x)
            aten_add_tensor: "f32[3, 3]" = executorch_exir_dialects_edge__ops_aten_add_Tensor(b_b, x);  b_b = None

             # File: /data/users/angelayi/fbsource/buck-out/v2/gen/fbcode/389acaeb40d57230/executorch/exir/backend/test/__test_partitioner__/test_partitioner#link-tree/executorch/exir/backend/test/test_partitioner.py:632 in forward, code: return x + self.b
            aten_add_tensor_1: "f32[3, 3]" = executorch_exir_dialects_edge__ops_aten_add_Tensor(x, aten_add_tensor);  x = None
            return (aten_add_tensor, aten_add_tensor_1)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='b_b'), target='b', persistent=True), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.BUFFER_MUTATION: 3>, arg=TensorArgument(name='aten_add_tensor'), target='b'), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='aten_add_tensor_1'), target=None)])
```

Partitioned / lowered Exported Program (buffer mutation gets removed):
```
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "f32[3, 3]"):
            # No stacktrace found for following nodes
            lowered_module_0 = self.lowered_module_0
            executorch_call_delegate = torch.ops.higher_order.executorch_call_delegate(lowered_module_0, x);  lowered_module_0 = x = None

             # File: /data/users/angelayi/fbsource/buck-out/v2/gen/fbcode/389acaeb40d57230/executorch/exir/backend/test/__test_partitioner__/test_partitioner#link-tree/executorch/exir/backend/test/test_partitioner.py:632 in forward, code: return x + self.b
            getitem_1: "f32[3, 3]" = executorch_call_delegate[0];  executorch_call_delegate = None
            return (getitem_1,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='getitem_1'), target=None)])
```

Delegate (consumes the buffer mutation):
```
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, b_b: "f32[3, 3]", x: "f32[3, 3]"):
             # File: /data/users/angelayi/fbsource/buck-out/v2/gen/fbcode/389acaeb40d57230/executorch/exir/backend/test/__test_partitioner__/test_partitioner#link-tree/executorch/exir/backend/test/test_partitioner.py:631 in forward, code: self.b.add_(x)
            aten_add_tensor: "f32[3, 3]" = executorch_exir_dialects_edge__ops_aten_add_Tensor(b_b, x);  b_b = None

             # File: /data/users/angelayi/fbsource/buck-out/v2/gen/fbcode/389acaeb40d57230/executorch/exir/backend/test/__test_partitioner__/test_partitioner#link-tree/executorch/exir/backend/test/test_partitioner.py:632 in forward, code: return x + self.b
            aten_add_tensor_1: "f32[3, 3]" = executorch_exir_dialects_edge__ops_aten_add_Tensor(x, aten_add_tensor);  x = None
            return (aten_add_tensor, aten_add_tensor_1)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='b_b'), target='b', persistent=True), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.BUFFER_MUTATION: 3>, arg=TensorArgument(name='aten_add_tensor'), target='b'), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='aten_add_tensor_1'), target=None)])
```

Differential Revision: D60838243
angelayi added a commit to angelayi/executorch-1 that referenced this issue Aug 28, 2024
Summary:
Pull Request resolved: pytorch#4830

Fixing pytorch#4209

Edge Program:
```
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, b_b: "f32[3, 3]", x: "f32[3, 3]"):
             # File: /data/users/angelayi/fbsource/buck-out/v2/gen/fbcode/389acaeb40d57230/executorch/exir/backend/test/__test_partitioner__/test_partitioner#link-tree/executorch/exir/backend/test/test_partitioner.py:631 in forward, code: self.b.add_(x)
            aten_add_tensor: "f32[3, 3]" = executorch_exir_dialects_edge__ops_aten_add_Tensor(b_b, x);  b_b = None

             # File: /data/users/angelayi/fbsource/buck-out/v2/gen/fbcode/389acaeb40d57230/executorch/exir/backend/test/__test_partitioner__/test_partitioner#link-tree/executorch/exir/backend/test/test_partitioner.py:632 in forward, code: return x + self.b
            aten_add_tensor_1: "f32[3, 3]" = executorch_exir_dialects_edge__ops_aten_add_Tensor(x, aten_add_tensor);  x = None
            return (aten_add_tensor, aten_add_tensor_1)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='b_b'), target='b', persistent=True), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.BUFFER_MUTATION: 3>, arg=TensorArgument(name='aten_add_tensor'), target='b'), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='aten_add_tensor_1'), target=None)])
```

Partitioned / lowered Exported Program (buffer mutation gets removed):
```
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "f32[3, 3]"):
            # No stacktrace found for following nodes
            lowered_module_0 = self.lowered_module_0
            executorch_call_delegate = torch.ops.higher_order.executorch_call_delegate(lowered_module_0, x);  lowered_module_0 = x = None

             # File: /data/users/angelayi/fbsource/buck-out/v2/gen/fbcode/389acaeb40d57230/executorch/exir/backend/test/__test_partitioner__/test_partitioner#link-tree/executorch/exir/backend/test/test_partitioner.py:632 in forward, code: return x + self.b
            getitem_1: "f32[3, 3]" = executorch_call_delegate[0];  executorch_call_delegate = None
            return (getitem_1,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='getitem_1'), target=None)])
```

Delegate (consumes the buffer mutation):
```
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, b_b: "f32[3, 3]", x: "f32[3, 3]"):
             # File: /data/users/angelayi/fbsource/buck-out/v2/gen/fbcode/389acaeb40d57230/executorch/exir/backend/test/__test_partitioner__/test_partitioner#link-tree/executorch/exir/backend/test/test_partitioner.py:631 in forward, code: self.b.add_(x)
            aten_add_tensor: "f32[3, 3]" = executorch_exir_dialects_edge__ops_aten_add_Tensor(b_b, x);  b_b = None

             # File: /data/users/angelayi/fbsource/buck-out/v2/gen/fbcode/389acaeb40d57230/executorch/exir/backend/test/__test_partitioner__/test_partitioner#link-tree/executorch/exir/backend/test/test_partitioner.py:632 in forward, code: return x + self.b
            aten_add_tensor_1: "f32[3, 3]" = executorch_exir_dialects_edge__ops_aten_add_Tensor(x, aten_add_tensor);  x = None
            return (aten_add_tensor, aten_add_tensor_1)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='b_b'), target='b', persistent=True), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.BUFFER_MUTATION: 3>, arg=TensorArgument(name='aten_add_tensor'), target='b'), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='aten_add_tensor_1'), target=None)])
```

Differential Revision: D60838243
angelayi added a commit to angelayi/executorch-1 that referenced this issue Aug 28, 2024
Summary:
Pull Request resolved: pytorch#4830

Fixing pytorch#4209

Edge Program:
```
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, b_b: "f32[3, 3]", x: "f32[3, 3]"):
             # File: /data/users/angelayi/fbsource/buck-out/v2/gen/fbcode/389acaeb40d57230/executorch/exir/backend/test/__test_partitioner__/test_partitioner#link-tree/executorch/exir/backend/test/test_partitioner.py:631 in forward, code: self.b.add_(x)
            aten_add_tensor: "f32[3, 3]" = executorch_exir_dialects_edge__ops_aten_add_Tensor(b_b, x);  b_b = None

             # File: /data/users/angelayi/fbsource/buck-out/v2/gen/fbcode/389acaeb40d57230/executorch/exir/backend/test/__test_partitioner__/test_partitioner#link-tree/executorch/exir/backend/test/test_partitioner.py:632 in forward, code: return x + self.b
            aten_add_tensor_1: "f32[3, 3]" = executorch_exir_dialects_edge__ops_aten_add_Tensor(x, aten_add_tensor);  x = None
            return (aten_add_tensor, aten_add_tensor_1)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='b_b'), target='b', persistent=True), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.BUFFER_MUTATION: 3>, arg=TensorArgument(name='aten_add_tensor'), target='b'), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='aten_add_tensor_1'), target=None)])
```

Partitioned / lowered Exported Program (buffer mutation gets removed):
```
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "f32[3, 3]"):
            # No stacktrace found for following nodes
            lowered_module_0 = self.lowered_module_0
            executorch_call_delegate = torch.ops.higher_order.executorch_call_delegate(lowered_module_0, x);  lowered_module_0 = x = None

             # File: /data/users/angelayi/fbsource/buck-out/v2/gen/fbcode/389acaeb40d57230/executorch/exir/backend/test/__test_partitioner__/test_partitioner#link-tree/executorch/exir/backend/test/test_partitioner.py:632 in forward, code: return x + self.b
            getitem_1: "f32[3, 3]" = executorch_call_delegate[0];  executorch_call_delegate = None
            return (getitem_1,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='getitem_1'), target=None)])
```

Delegate (consumes the buffer mutation):
```
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, b_b: "f32[3, 3]", x: "f32[3, 3]"):
             # File: /data/users/angelayi/fbsource/buck-out/v2/gen/fbcode/389acaeb40d57230/executorch/exir/backend/test/__test_partitioner__/test_partitioner#link-tree/executorch/exir/backend/test/test_partitioner.py:631 in forward, code: self.b.add_(x)
            aten_add_tensor: "f32[3, 3]" = executorch_exir_dialects_edge__ops_aten_add_Tensor(b_b, x);  b_b = None

             # File: /data/users/angelayi/fbsource/buck-out/v2/gen/fbcode/389acaeb40d57230/executorch/exir/backend/test/__test_partitioner__/test_partitioner#link-tree/executorch/exir/backend/test/test_partitioner.py:632 in forward, code: return x + self.b
            aten_add_tensor_1: "f32[3, 3]" = executorch_exir_dialects_edge__ops_aten_add_Tensor(x, aten_add_tensor);  x = None
            return (aten_add_tensor, aten_add_tensor_1)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='b_b'), target='b', persistent=True), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.BUFFER_MUTATION: 3>, arg=TensorArgument(name='aten_add_tensor'), target='b'), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='aten_add_tensor_1'), target=None)])
```

Differential Revision: D60838243
@YifanShenSZ
Copy link
Collaborator Author

Verified. Thanks team!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
good first issue Good for newcomers module: coreml Issues related to Apple's Core ML delegation and code under backends/apple/coreml/ triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

7 participants