Skip to content

Commit 09cc743

Browse files
authored
Expand dict API to allow dtype strings
Differential Revision: D69335674 Pull Request resolved: #8323
1 parent d99970b commit 09cc743

File tree

6 files changed

+136
-8
lines changed

6 files changed

+136
-8
lines changed

codegen/tools/gen_oplist.py

+22-1
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,23 @@ def _dump_yaml(
189189
)
190190

191191

192+
def create_kernel_key(maybe_kernel_key: str) -> str:
193+
# It is a kernel key.
194+
if maybe_kernel_key.lstrip().startswith("v1"):
195+
return maybe_kernel_key
196+
# It is a dtype.
197+
else:
198+
# Generate a kernel key based on the dtype provided.
199+
# Note: no dim order is included in this kernel key.
200+
# For a description of the kernel key format, see
201+
# executorch/blob/main/runtime/kernel/operator_registry.h#L97-L123
202+
try:
203+
dtype = ScalarType[maybe_kernel_key]
204+
return "v1/" + str(dtype.value) + ";"
205+
except KeyError:
206+
raise Exception(f"Unknown dtype: {maybe_kernel_key}")
207+
208+
192209
def gen_oplist(
193210
output_path: str,
194211
model_file_path: Optional[str] = None,
@@ -223,7 +240,11 @@ def gen_oplist(
223240
ops_and_metadata = json.loads(ops_dict)
224241
for op, metadata in ops_and_metadata.items():
225242
op_set.update({op})
226-
op_metadata = metadata if len(metadata) > 0 else ["default"]
243+
op_metadata = (
244+
[create_kernel_key(x) for x in metadata]
245+
if len(metadata) > 0
246+
else ["default"]
247+
)
227248
et_kernel_metadata = merge_et_kernel_metadata(
228249
et_kernel_metadata, {op: op_metadata}
229250
)

codegen/tools/test/test_gen_oplist.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
import executorch.codegen.tools.gen_oplist as gen_oplist
1515
import yaml
16+
from executorch.codegen.tools.gen_oplist import ScalarType
1617

1718

1819
class TestGenOpList(unittest.TestCase):
@@ -89,7 +90,7 @@ def test_gen_op_list_with_root_ops_and_dtypes(
8990
) -> None:
9091
output_path = os.path.join(self.temp_dir.name, "output.yaml")
9192
ops_dict = {
92-
"aten::add": ["v1/3;0,1|3;0,1|3;0,1|3;0,1", "v1/6;0,1|6;0,1|6;0,1|6;0,1"],
93+
"aten::add": ["v1/3;0,1|3;0,1|3;0,1|3;0,1", ScalarType.Float.name],
9394
"aten::mul": [],
9495
}
9596
args = [
@@ -104,7 +105,7 @@ def test_gen_op_list_with_root_ops_and_dtypes(
104105
{
105106
"aten::add": [
106107
"v1/3;0,1|3;0,1|3;0,1|3;0,1",
107-
"v1/6;0,1|6;0,1|6;0,1|6;0,1",
108+
"v1/6;",
108109
],
109110
"aten::mul": ["default"],
110111
},

codegen/tools/test/test_gen_selected_op_variants.py

+74-2
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import expecttest
1313

1414

15-
class TestGenSelectedMobileOpsHeader(expecttest.TestCase):
15+
class TestGenSelectedOpVariants(expecttest.TestCase):
1616
def setUp(self):
1717
self.temp_dir = tempfile.TemporaryDirectory()
1818
self.addCleanup(self.temp_dir.cleanup)
@@ -84,7 +84,79 @@ def test_generates_correct_header(self) -> None:
8484
)
8585

8686

87-
class TestGenSelectedMobileOpsHeader_Empty(expecttest.TestCase):
87+
class TestGenSelectedOpVariants_UsingDtypeString(expecttest.TestCase):
88+
def setUp(self):
89+
self.temp_dir = tempfile.TemporaryDirectory()
90+
self.addCleanup(self.temp_dir.cleanup)
91+
self.selected_ops_yaml = os.path.join(
92+
self.temp_dir.name, "selected_operators.yaml"
93+
)
94+
with open(self.selected_ops_yaml, "w") as f:
95+
f.write(
96+
"""
97+
include_all_non_op_selectives: False
98+
include_all_operators: False
99+
debug_info:
100+
- model1@v100
101+
- model2@v50
102+
operators:
103+
aten::add:
104+
is_root_operator: Yes
105+
is_used_for_training: Yes
106+
include_all_overloads: No
107+
aten::add.int:
108+
is_root_operator: No
109+
is_used_for_training: No
110+
include_all_overloads: Yes
111+
kernel_metadata: {}
112+
et_kernel_metadata:
113+
aten::add.out:
114+
# A list of different kernel keys (tensors with dtype-enum/dim-order) combinations used in model
115+
- v1/6; # Float
116+
- v1/3; # Int
117+
aten::mul.out:
118+
- v1/6; # Float
119+
aten::sub.out:
120+
- default
121+
build_features: []
122+
custom_classes: []
123+
"""
124+
)
125+
126+
def tearDown(self):
127+
self.temp_dir.cleanup()
128+
129+
def test_generates_correct_header(self) -> None:
130+
gen_selected_op_variants.write_selected_op_variants(
131+
os.path.join(self.temp_dir.name, "selected_operators.yaml"),
132+
self.temp_dir.name,
133+
)
134+
with open(
135+
os.path.join(self.temp_dir.name, "selected_op_variants.h"), "r"
136+
) as result:
137+
self.assertExpectedInline(
138+
result.read(),
139+
"""#pragma once
140+
/**
141+
* Generated by executorch/codegen/tools/gen_selected_op_variants.py
142+
*/
143+
144+
inline constexpr bool should_include_kernel_dtype(
145+
const char *operator_name,
146+
executorch::aten::ScalarType scalar_type
147+
) {
148+
return ((executorch::aten::string_view(operator_name).compare("add.out") == 0)
149+
&& (scalar_type == executorch::aten::ScalarType::Float || scalar_type == executorch::aten::ScalarType::Int))
150+
|| ((executorch::aten::string_view(operator_name).compare("mul.out") == 0)
151+
&& (scalar_type == executorch::aten::ScalarType::Float))
152+
|| ((executorch::aten::string_view(operator_name).compare("sub.out") == 0)
153+
&& (true));
154+
}
155+
""",
156+
)
157+
158+
159+
class TestGenSelectedOpVariants_Empty(expecttest.TestCase):
88160
def setUp(self):
89161
self.temp_dir = tempfile.TemporaryDirectory()
90162
self.addCleanup(self.temp_dir.cleanup)

examples/selective_build/targets.bzl

+4-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "get_oss_build_kwargs", "is_xplat", "runtime")
2-
load("@fbsource//xplat/executorch/codegen:codegen.bzl", "et_operator_library", "executorch_generated_lib")
2+
load("@fbsource//xplat/executorch/codegen:codegen.bzl", "et_operator_library", "executorch_generated_lib", "ScalarType")
33

44
def define_common_targets():
55
"""Defines targets that should be shared between fbcode and xplat.
@@ -49,7 +49,9 @@ def define_common_targets():
4949
et_operator_library(
5050
name = "select_ops_in_dict",
5151
ops_dict = {
52-
"aten::add.out": ["v1/3;0,1", "v1/6;0,1"], # int, float
52+
# 1. Use kernel key, generated with a model, or
53+
# 2. Specify the dtype, from executorch/codegen/codegen.bzl
54+
"aten::add.out": ["v1/3;0,1", ScalarType("Float")], # int, float
5355
"aten::mm.out": [], # all dtypes
5456
},
5557
)

examples/selective_build/test_selective_build.sh

-1
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,6 @@ test_buck2_select_ops_in_dict() {
6666
# select ops and their dtypes using the dictionary API.
6767
$BUCK run //examples/selective_build:selective_build_test \
6868
--config=executorch.select_ops=dict \
69-
--config=executorch.dtype_selective_build_lib=//examples/selective_build:select_ops_in_dict_lib \
7069
-- --model_path=./add_mul.pte
7170

7271
echo "Removing add_mul.pte"

shim/xplat/executorch/codegen/codegen.bzl

+33
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,39 @@ CUSTOM_OPS_SCHEMA_REGISTRATION_SOURCES = [
4242
"RegisterSchema.cpp",
4343
]
4444

45+
ScalarType = enum(
46+
"Byte",
47+
"Char",
48+
"Short",
49+
"Int",
50+
"Long",
51+
"Half",
52+
"Float",
53+
"Double",
54+
"ComplexHalf",
55+
"ComplexFloat",
56+
"ComplexDouble",
57+
"Bool",
58+
"QInt8",
59+
"QUInt8",
60+
"QInt32",
61+
"BFloat16",
62+
"QUInt4x2",
63+
"QUInt2x4",
64+
"Bits1x8",
65+
"Bits2x4",
66+
"Bits4x2",
67+
"Bits8",
68+
"Bits16",
69+
"Float8_e5m2",
70+
"Float8_e4m3fn",
71+
"Float8_e5m2fnuz",
72+
"Float8_e4m3fnuz",
73+
"UInt16",
74+
"UInt32",
75+
"Uint64",
76+
)
77+
4578
# Hide the dependency to caffe2 internally.
4679
def et_operator_library(
4780
name,

0 commit comments

Comments
 (0)