Skip to content

Generating ETDump fails when using XNNPACK delegation #8177

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
spalatinate opened this issue Feb 4, 2025 · 6 comments
Open

Generating ETDump fails when using XNNPACK delegation #8177

spalatinate opened this issue Feb 4, 2025 · 6 comments
Assignees
Labels
module: user experience Issues related to reducing friction for users module: xnnpack Issues related to xnnpack delegation and the code under backends/xnnpack/ triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@spalatinate
Copy link

spalatinate commented Feb 4, 2025

🐛 Describe the bug

When reproducing the ETDump generation example the executor to run Bundled Program file outputs aborted when I try to execute the bp file. The bp files was generated as follows:

In ETDump generatione example I simply replaced the to_edge() function with the API for the XNNPACK backend to_edge_transform_and_lower(). See below:

to_edge(aten_model, compile_config=EdgeCompileConfig(_check_ir_validity=True))

with

to_edge_transform_and_lower(aten_model, partitioner=[XnnpackPartitioner()],)

The output is then:

cmake-out/examples/devtools/example_runner --bundled_program_path="bundled_program_xnn.bp"
Abgebrochen

With 'to_edge()' everything works just fine. Can anyone point me in the right direction? Thanks!

Versions

PyTorch version: 2.5.0
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: Debian GNU/Linux 12 (bookworm) (aarch64)
GCC version: (Debian 12.2.0-14) 12.2.0
Clang version: 14.0.6
CMake version: version 3.31.2
Libc version: glibc-2.36

Python version: 3.10.0 (default, Mar 3 2022, 09:51:40) [GCC 10.2.0] (64-bit runtime)
Python platform: Linux-6.6.62+rpt-rpi-v8-aarch64-with-glibc2.36
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture: aarch64
CPU op-mode(s): 32-bit, 64-bit
Byte Order: Little Endian
CPU(s): 4
On-line CPU(s) list: 0-3
Vendor ID: ARM
Model name: Cortex-A76
Model: 1
Thread(s) per core: 1
Core(s) per cluster: 4
Socket(s): -
Cluster(s): 1
Stepping: r4p1
CPU(s) scaling MHz: 100%
CPU max MHz: 2400,0000
CPU min MHz: 1500,0000
BogoMIPS: 108,00
Flags: fp asimd evtstrm aes pmull sha1 sha2 crc32 atomics fphp asimdhp cpuid asimdrdm lrcpc dcpop asimddp
L1d cache: 256 KiB (4 instances)
L1i cache: 256 KiB (4 instances)
L2 cache: 2 MiB (4 instances)
L3 cache: 2 MiB (1 instance)
NUMA node(s): 1
NUMA node0 CPU(s): 0-3
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed: Not affected
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1: Mitigation; __user pointer sanitization
Vulnerability Spectre v2: Mitigation; CSV2, BHB
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected

Versions of relevant libraries:
[pip3] executorch==0.4.0a0+6a085ff
[pip3] numpy==1.26.4
[pip3] torch==2.5.0
[pip3] torchao==0.5.0+git0916b5b2
[pip3] torchaudio==2.5.0
[pip3] torchsr==1.0.4
[pip3] torchvision==0.20.0
[conda] executorch 0.4.0a0+6a085ff pypi_0 pypi
[conda] numpy 1.26.4 pypi_0 pypi
[conda] torch 2.5.0 pypi_0 pypi
[conda] torchao 0.5.0+git0916b5b2 pypi_0 pypi
[conda] torchaudio 2.5.0 pypi_0 pypi
[conda] torchsr 1.0.4 pypi_0 pypi
[conda] torchvision 0.20.0 pypi_0 pypi

cc @digantdesai @mcr229 @mergennachin @byjlw

@digantdesai digantdesai added bug module: xnnpack Issues related to xnnpack delegation and the code under backends/xnnpack/ triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Feb 4, 2025
@digantdesai
Copy link
Contributor

cc @mcr229 can you help with this?

@mcr229
Copy link
Contributor

mcr229 commented Feb 4, 2025

hmm this seems like an issue with generating ETDump? I'm actually not sure how that changes with to_edge vs to_edge_transform_and_lower. @Olivia-liu @tarun292 do you know how ET Dump interacts with these different API surfaces?

@tarun292
Copy link
Contributor

tarun292 commented Feb 6, 2025

@sfsouthpalatinate can you share what's the failure stacktrace you're seeing. I don't see the stack trace in the issue.

@tarun292 tarun292 self-assigned this Feb 6, 2025
@spalatinate
Copy link
Author

spalatinate commented Feb 6, 2025

@tarun292 The problem is that the python script runs without any error. Subsequently, when I want to run the bp file, I get the output "Terminated"; no stacktrace. That made it challenging for me to debug.

Further, I was a bit confused since for the XNNPack delegation API. There are two ways of calling it according to the examples:

  1. via to_edge(): in Section Delegation https://pytorch.org/executorch/stable/llm/getting-started.html
  2. via to to_edge_transform_and_lower(): https://pytorch.org/executorch/stable/_modules/executorch/exir/program/_program.html#to_edge_transform_and_lower

For generating the ETDump, both didn't work.

@tarun292
Copy link
Contributor

tarun292 commented Feb 6, 2025

@spalatinate you should use to_edge_transform_and_lower(). Can you build this in debug mode and run it again so that we can get a look at the crash logs and see where it's failing in the runner? If you can feel free to share .bp file too.

@mergennachin mergennachin added the module: user experience Issues related to reducing friction for users label Feb 10, 2025
@github-project-automation github-project-automation bot moved this to To triage in ExecuTorch DevX Feb 10, 2025
@spalatinate
Copy link
Author

spalatinate commented Feb 11, 2025

@tarun292 Sorry, for my delayed answer. I have built the runner in debug mode (set DCMAKE_BUILD_TYPE=Debug in build_example_runner.sh. After running the bp file again, I got "Terminated" and no stacktrace.

I attached the Python script to generate the bp. I thought it is easier for inspection to post the code used to generate the bp file. Without XNNPACK delegation, the ETDump generation works just fine.

import copy
import torch

import torchvision.models as models
from torchvision.models.mobilenetv2 import MobileNet_V2_Weights

from executorch.devtools import generate_etrecord
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
from executorch.exir import (
    EdgeCompileConfig,
    EdgeProgramManager,
    ExecutorchProgramManager,
    to_edge,
    to_edge_transform_and_lower
)
from torch.export import export, ExportedProgram
from executorch.devtools import BundledProgram

from executorch.backends.xnnpack.utils.configs import get_xnnpack_edge_compile_config

from executorch.devtools.bundled_program.config import MethodTestCase, MethodTestSuite
from executorch.devtools.bundled_program.serialize import (serialize_from_bundled_program_to_flatbuffer)

mobilenet_v2 = models.mobilenetv2.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).eval()
sample_inputs = (torch.randn(1,3,224,224),)

aten_model = export(mobilenet_v2, sample_inputs)

edge_manager: EdgeProgramManager = to_edge_transform_and_lower(aten_model, partitioner=[XnnpackPartitioner()],)

edge_manager_copy = copy.deepcopy(edge_manager)
et_program = edge_manager.to_executorch() 

etrecord_path = "etrecord.bin"
generate_etrecord(etrecord_path, edge_manager_copy, et_program)

m_name = "forward"
method_graphs = {m_name: export(mobilenet_v2, sample_inputs)}

inputs = [[torch.randn(1,3,224,224)] for _ in range(2)]

method_test_suites = [
    MethodTestSuite(
        method_name=m_name,
        test_cases=[
            MethodTestCase(inputs=inp, expected_outputs=getattr(mobilenet_v2, m_name)(*inp))
            for inp in inputs
        ],
    )
]

executorch_program = to_edge_transform_and_lower(aten_model, partitioner=[XnnpackPartitioner()],).to_executorch()

bundled_program = BundledProgram(executorch_program, method_test_suites)

serialized_bundled_program = serialize_from_bundled_program_to_flatbuffer(bundled_program)

save_path = "bundled_program_xnn.bp"
with open(save_path, "wb") as f:
    f.write(serialized_bundled_program)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: user experience Issues related to reducing friction for users module: xnnpack Issues related to xnnpack delegation and the code under backends/xnnpack/ triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
Status: Incoming
Development

No branches or pull requests

6 participants