Skip to content

Fix pre-autograd transforms not getting persisted during xnnpack export #9118

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 11, 2025

Conversation

jackzhxng
Copy link
Contributor

Summary

After moving to to_edge_transform_and_lower for the XNNPack export route in #8624, we were discarding all of the transforms made to the pre-autograd graph module stored in LLMEdgeManager, since the new to_edge_transform_and_lower took in an ExportedProgram instead of a nn.Module as an argument. To solve this, we re-run export for training right before each LLMEdgeManager API that runs the full non-autograd safe torch.export().

Test plan

Tested manually on Llama3.2 1B export:

python -m examples.models.llama.export_llama --checkpoint ~/hf/models--meta-llama--Llama-3.2-1B-Instruct/snapshots/9213176726f574b556790deb65791e0c5aa438b6/original/consolidated.00.pth -p ~/hf/models--meta-llama--Llama-3.2-1B-Instruct/snapshots/9213176726f574b556790deb65791e0c5aa438b6/original/params.json -d fp32 -kv -X --use_sdpa_with_kv_cache -n llama3_1b_regression.pte --verbose

Before ops (contains permute_copy):

Total delegated subgraphs: 113
Number of delegated nodes: 113
Number of non-delegated nodes: 1447
╒════╤══════════════════════════════╤═══════════════════════════════════╤═══════════════════════════════════════╕
│    │ op_type                      │   occurrences_in_delegated_graphs │   occurrences_in_non_delegated_graphs │
╞════╪══════════════════════════════╪═══════════════════════════════════╪═══════════════════════════════════════╡
│  0 │ _assert_scalar               │                                 0 │                                    67 │
├────┼──────────────────────────────┼───────────────────────────────────┼───────────────────────────────────────┤
│  1 │ _local_scalar_dense          │                                 0 │                                    33 │
├────┼──────────────────────────────┼───────────────────────────────────┼───────────────────────────────────────┤
│  2 │ add                          │                                 0 │                                    17 │
├────┼──────────────────────────────┼───────────────────────────────────┼───────────────────────────────────────┤
│  3 │ aten_add_tensor              │                                 0 │                                    97 │
├────┼──────────────────────────────┼───────────────────────────────────┼───────────────────────────────────────┤
│  4 │ aten_cat_default             │                                 0 │                                    32 │
├────┼──────────────────────────────┼───────────────────────────────────┼───────────────────────────────────────┤
│  5 │ aten_embedding_default       │                                 0 │                                     1 │
├────┼──────────────────────────────┼───────────────────────────────────┼───────────────────────────────────────┤
│  6 │ aten_linear_default          │                               113 │                                     0 │
├────┼──────────────────────────────┼───────────────────────────────────┼───────────────────────────────────────┤
│  7 │ aten_mean_dim                │                                 0 │                                    33 │
├────┼──────────────────────────────┼───────────────────────────────────┼───────────────────────────────────────┤
│  8 │ aten_mul_tensor              │                                 0 │                                   259 │
├────┼──────────────────────────────┼───────────────────────────────────┼───────────────────────────────────────┤
│  9 │ aten_permute_copy_default    │                                 0 │                                   160 │
├────┼──────────────────────────────┼───────────────────────────────────┼───────────────────────────────────────┤
│ 10 │ aten_rsqrt_default           │                                 0 │                                    33 │
├────┼──────────────────────────────┼───────────────────────────────────┼───────────────────────────────────────┤
│ 11 │ aten_select_copy_int         │                                 0 │                                    34 │
├────┼──────────────────────────────┼───────────────────────────────────┼───────────────────────────────────────┤
│ 12 │ aten_sigmoid_default         │                                 0 │                                    16 │
├────┼──────────────────────────────┼───────────────────────────────────┼───────────────────────────────────────┤
│ 13 │ aten_slice_copy_tensor       │                                 0 │                                    67 │
├────┼──────────────────────────────┼───────────────────────────────────┼───────────────────────────────────────┤
│ 14 │ aten_squeeze_copy_dims       │                                 0 │                                    64 │
├────┼──────────────────────────────┼───────────────────────────────────┼───────────────────────────────────────┤
│ 15 │ aten_sub_tensor              │                                 0 │                                    32 │
├────┼──────────────────────────────┼───────────────────────────────────┼───────────────────────────────────────┤
│ 16 │ aten_unsqueeze_copy_default  │                                 0 │                                    64 │
├────┼──────────────────────────────┼───────────────────────────────────┼───────────────────────────────────────┤
│ 17 │ aten_view_copy_default       │                                 0 │                                   160 │
├────┼──────────────────────────────┼───────────────────────────────────┼───────────────────────────────────────┤
│ 18 │ auto_functionalized          │                                 0 │                                    32 │
├────┼──────────────────────────────┼───────────────────────────────────┼───────────────────────────────────────┤
│ 19 │ ge                           │                                 0 │                                    17 │
├────┼──────────────────────────────┼───────────────────────────────────┼───────────────────────────────────────┤
│ 20 │ getitem                      │                                 0 │                                   145 │
├────┼──────────────────────────────┼───────────────────────────────────┼───────────────────────────────────────┤
│ 21 │ le                           │                                 0 │                                    17 │
├────┼──────────────────────────────┼───────────────────────────────────┼───────────────────────────────────────┤
│ 22 │ llama_custom_sdpa_default    │                                 0 │                                    16 │
├────┼──────────────────────────────┼───────────────────────────────────┼───────────────────────────────────────┤
│ 23 │ lt                           │                                 0 │                                    33 │
├────┼──────────────────────────────┼───────────────────────────────────┼───────────────────────────────────────┤
│ 24 │ sym_constrain_range_for_size │                                 0 │                                    17 │
├────┼──────────────────────────────┼───────────────────────────────────┼───────────────────────────────────────┤
│ 25 │ sym_size                     │                                 0 │                                     1 │
├────┼──────────────────────────────┼───────────────────────────────────┼───────────────────────────────────────┤
│ 26 │ Total                        │                               113 │                                  1447 │
╘════╧══════════════════════════════╧═══════════════════════════════════╧═══════════════════════════════════════╛

After ops (no permute_copy):

Total delegated subgraphs: 113
Number of delegated nodes: 113
Number of non-delegated nodes: 1253

╒════╤══════════════════════════════╤═══════════════════════════════════╤═══════════════════════════════════════╕
│    │ op_type                      │   occurrences_in_delegated_graphs │   occurrences_in_non_delegated_graphs │
╞════╪══════════════════════════════╪═══════════════════════════════════╪═══════════════════════════════════════╡
│  0 │ _assert_scalar               │                                 0 │                                    50 │
├────┼──────────────────────────────┼───────────────────────────────────┼───────────────────────────────────────┤
│  1 │ _local_scalar_dense          │                                 0 │                                    33 │
├────┼──────────────────────────────┼───────────────────────────────────┼───────────────────────────────────────┤
│  2 │ add                          │                                 0 │                                    17 │
├────┼──────────────────────────────┼───────────────────────────────────┼───────────────────────────────────────┤
│  3 │ aten_add_tensor              │                                 0 │                                    97 │
├────┼──────────────────────────────┼───────────────────────────────────┼───────────────────────────────────────┤
│  4 │ aten_cat_default             │                                 0 │                                    32 │
├────┼──────────────────────────────┼───────────────────────────────────┼───────────────────────────────────────┤
│  5 │ aten_embedding_default       │                                 0 │                                     1 │
├────┼──────────────────────────────┼───────────────────────────────────┼───────────────────────────────────────┤
│  6 │ aten_linear_default          │                               113 │                                     0 │
├────┼──────────────────────────────┼───────────────────────────────────┼───────────────────────────────────────┤
│  7 │ aten_mean_dim                │                                 0 │                                    33 │
├────┼──────────────────────────────┼───────────────────────────────────┼───────────────────────────────────────┤
│  8 │ aten_mul_tensor              │                                 0 │                                   259 │
├────┼──────────────────────────────┼───────────────────────────────────┼───────────────────────────────────────┤
│  9 │ aten_rsqrt_default           │                                 0 │                                    33 │
├────┼──────────────────────────────┼───────────────────────────────────┼───────────────────────────────────────┤
│ 10 │ aten_select_copy_int         │                                 0 │                                    34 │
├────┼──────────────────────────────┼───────────────────────────────────┼───────────────────────────────────────┤
│ 11 │ aten_sigmoid_default         │                                 0 │                                    16 │
├────┼──────────────────────────────┼───────────────────────────────────┼───────────────────────────────────────┤
│ 12 │ aten_slice_copy_tensor       │                                 0 │                                    67 │
├────┼──────────────────────────────┼───────────────────────────────────┼───────────────────────────────────────┤
│ 13 │ aten_squeeze_copy_dims       │                                 0 │                                    64 │
├────┼──────────────────────────────┼───────────────────────────────────┼───────────────────────────────────────┤
│ 14 │ aten_sub_tensor              │                                 0 │                                    32 │
├────┼──────────────────────────────┼───────────────────────────────────┼───────────────────────────────────────┤
│ 15 │ aten_unsqueeze_copy_default  │                                 0 │                                    64 │
├────┼──────────────────────────────┼───────────────────────────────────┼───────────────────────────────────────┤
│ 16 │ aten_view_copy_default       │                                 0 │                                   160 │
├────┼──────────────────────────────┼───────────────────────────────────┼───────────────────────────────────────┤
│ 17 │ auto_functionalized          │                                 0 │                                    32 │
├────┼──────────────────────────────┼───────────────────────────────────┼───────────────────────────────────────┤
│ 18 │ ge                           │                                 0 │                                    17 │
├────┼──────────────────────────────┼───────────────────────────────────┼───────────────────────────────────────┤
│ 19 │ getitem                      │                                 0 │                                   145 │
├────┼──────────────────────────────┼───────────────────────────────────┼───────────────────────────────────────┤
│ 20 │ le                           │                                 0 │                                    17 │
├────┼──────────────────────────────┼───────────────────────────────────┼───────────────────────────────────────┤
│ 21 │ llama_custom_sdpa_default    │                                 0 │                                    16 │
├────┼──────────────────────────────┼───────────────────────────────────┼───────────────────────────────────────┤
│ 22 │ lt                           │                                 0 │                                    16 │
├────┼──────────────────────────────┼───────────────────────────────────┼───────────────────────────────────────┤
│ 23 │ sym_constrain_range_for_size │                                 0 │                                    17 │
├────┼──────────────────────────────┼───────────────────────────────────┼───────────────────────────────────────┤
│ 24 │ sym_size                     │                                 0 │                                     1 │
├────┼──────────────────────────────┼───────────────────────────────────┼───────────────────────────────────────┤
│ 25 │ Total                        │                               113 │                                  1253 │
╘════╧══════════════════════════════╧═══════════════════════════════════╧═══════════════════════════════════════╛

Copy link

pytorch-bot bot commented Mar 11, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/9118

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 9978148 with merge base cf8ce89 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Mar 11, 2025
@jackzhxng jackzhxng added the release notes: examples Changes to any of our example LLMs integrations, such as Llama3 and Llava label Mar 11, 2025
Copy link
Contributor

@kimishpatel kimishpatel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lets keep exported_program as the only source of truth

@@ -394,9 +415,12 @@ def export_to_edge(self) -> "LLMEdgeManager":
return_value=False,
)

# Prior to export, persist the changes to the pre autograd
# graph module back to the source-of-truth ExportedProgram.
self.export(self.pre_autograd_graph_module)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should keep exported_program up-to-date. Thus shouldnt do this here but rather wherever we extract graph_module and apply any transformations. Thus we should not keep self.pre_autograd_graph_module at all. Only source of truth would be exported_program

@github-project-automation github-project-automation bot moved this to To triage in ExecuTorch Core Mar 11, 2025
@jackzhxng jackzhxng moved this from To triage to In progress in ExecuTorch Core Mar 11, 2025
@jackzhxng jackzhxng requested a review from swolchok as a code owner March 11, 2025 18:16
Copy link
Contributor

@kimishpatel kimishpatel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good. We discussed to follow up to answer " what should be the source of truth, graph_module or EP"

@jackzhxng jackzhxng merged commit 1b2c60c into main Mar 11, 2025
50 checks passed
@jackzhxng jackzhxng deleted the jz/fix-regression branch March 11, 2025 21:06
@github-project-automation github-project-automation bot moved this from In progress to Done in ExecuTorch Core Mar 11, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. release notes: examples Changes to any of our example LLMs integrations, such as Llama3 and Llava
Projects
Status: Done
Development

Successfully merging this pull request may close these issues.

3 participants