Skip to content

Commit c15fab1

Browse files
[Executorch][BE] Rename sdpa_with_kv_cache.py to custom_ops.py (#7210)
Pull Request resolved: #6996 Because now we have more than sdpa_with_kv_cache in it ghstack-source-id: 256711931 @exported-using-ghexport //oss complaining of internal lint failure, unit-test-arm broken in trunk @bypass-github-export-checks @exported-using-ghexport Differential Revision: [D66269486](https://our.internmc.facebook.com/intern/diff/D66269486/) Co-authored-by: Kimish Patel <[email protected]>
1 parent 98e4dd5 commit c15fab1

File tree

9 files changed

+8
-9
lines changed

9 files changed

+8
-9
lines changed

examples/models/llama/eval_llama_lib.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def __init__(
106106

107107
# Note: import this after portable_lib
108108
from executorch.extension.llm.custom_ops import ( # noqa
109-
sdpa_with_kv_cache, # usort: skip
109+
custom_ops, # usort: skip
110110
)
111111
from executorch.kernels import quantized # noqa
112112

examples/models/llama/runner/native.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from executorch.examples.models.llama.runner.generation import LlamaRunner
2424

2525
# Note: import this after portable_lib
26-
from executorch.extension.llm.custom_ops import sdpa_with_kv_cache # noqa # usort: skip
26+
from executorch.extension.llm.custom_ops import custom_ops # noqa # usort: skip
2727
from executorch.kernels import quantized # noqa
2828

2929

examples/models/llama/source_transformation/sdpa.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def _replace_sdpa_with_custom_op(module: torch.nn.Module):
9999

100100

101101
def replace_sdpa_with_custom_op(module: torch.nn.Module) -> torch.nn.Module:
102-
from executorch.extension.llm.custom_ops import sdpa_with_kv_cache # noqa
102+
from executorch.extension.llm.custom_ops import custom_ops # noqa
103103

104104
_replace_sdpa_with_custom_op(module)
105105
return module

examples/models/llava/test/test_llava.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from executorch.extension.pybindings.portable_lib import (
1919
_load_for_executorch_from_buffer,
2020
)
21-
from executorch.extension.llm.custom_ops import sdpa_with_kv_cache # noqa # usort: skip
21+
from executorch.extension.llm.custom_ops import custom_ops # noqa # usort: skip
2222
from executorch.kernels import quantized # noqa # usort: skip
2323

2424
logging.basicConfig(level=logging.INFO)

examples/models/llava/test/test_pte.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from PIL import Image
1515

1616
# Custom ops has to be loaded after portable_lib.
17-
from executorch.extension.llm.custom_ops import sdpa_with_kv_cache # noqa # usort: skip
17+
from executorch.extension.llm.custom_ops import custom_ops # noqa # usort: skip
1818
from executorch.kernels import quantized # noqa # usort: skip
1919

2020
FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"

extension/llm/README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ A sampler class in C++ to sample the logistics given some hyperparameters.
3838
## custom_ops
3939
Contains custom op, such as:
4040
- custom sdpa: implements CPU flash attention and avoids copies by taking the kv cache as one of its arguments.
41-
- _sdpa_with_kv_cache.py_, _op_sdpa_aot.cpp_: custom op definition in PyTorch with C++ registration.
41+
- _custom_ops.py_, _op_sdpa_aot.cpp_: custom op definition in PyTorch with C++ registration.
4242
- _op_sdpa.cpp_: the optimized operator implementation and registration of _sdpa_with_kv_cache.out_.
4343

4444
## runner

extension/llm/custom_ops/sdpa_with_kv_cache.py renamed to extension/llm/custom_ops/custom_ops.py

-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717

1818
from torch.library import impl
1919

20-
# TODO rename this file to custom_ops_meta_registration.py
2120
try:
2221
op = torch.ops.llama.sdpa_with_kv_cache.default
2322
assert op is not None

extension/llm/custom_ops/targets.bzl

+1-1
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def define_common_targets():
8181
runtime.python_library(
8282
name = "custom_ops_aot_py",
8383
srcs = [
84-
"sdpa_with_kv_cache.py",
84+
"custom_ops.py",
8585
],
8686
visibility = [
8787
"//executorch/...",

extension/llm/custom_ops/test_sdpa_with_kv_cache.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import torch
1212
import torch.nn.functional as F
1313

14-
from .sdpa_with_kv_cache import custom_ops_lib # noqa
14+
from .custom_ops import custom_ops_lib # noqa
1515

1616

1717
def _sdpa_with_kv_cache_ref(q, k, v, k_cache, v_cache, attn_mask, start_pos, seq_len):

0 commit comments

Comments
 (0)