Skip to content

Expand dict API to allow dtype strings #8323

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion codegen/tools/gen_oplist.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,23 @@ def _dump_yaml(
)


def create_kernel_key(maybe_kernel_key: str) -> str:
# It is a kernel key.
if maybe_kernel_key.lstrip().startswith("v1"):
return maybe_kernel_key
# It is a dtype.
else:
# Generate a kernel key based on the dtype provided.
# Note: no dim order is included in this kernel key.
# For a description of the kernel key format, see
# executorch/blob/main/runtime/kernel/operator_registry.h#L97-L123
try:
dtype = ScalarType[maybe_kernel_key]
return "v1/" + str(dtype.value) + ";"
except KeyError:
raise Exception(f"Unknown dtype: {maybe_kernel_key}")


def gen_oplist(
output_path: str,
model_file_path: Optional[str] = None,
Expand Down Expand Up @@ -223,7 +240,11 @@ def gen_oplist(
ops_and_metadata = json.loads(ops_dict)
for op, metadata in ops_and_metadata.items():
op_set.update({op})
op_metadata = metadata if len(metadata) > 0 else ["default"]
op_metadata = (
[create_kernel_key(x) for x in metadata]
if len(metadata) > 0
else ["default"]
)
et_kernel_metadata = merge_et_kernel_metadata(
et_kernel_metadata, {op: op_metadata}
)
Expand Down
5 changes: 3 additions & 2 deletions codegen/tools/test/test_gen_oplist.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import executorch.codegen.tools.gen_oplist as gen_oplist
import yaml
from executorch.codegen.tools.gen_oplist import ScalarType


class TestGenOpList(unittest.TestCase):
Expand Down Expand Up @@ -89,7 +90,7 @@ def test_gen_op_list_with_root_ops_and_dtypes(
) -> None:
output_path = os.path.join(self.temp_dir.name, "output.yaml")
ops_dict = {
"aten::add": ["v1/3;0,1|3;0,1|3;0,1|3;0,1", "v1/6;0,1|6;0,1|6;0,1|6;0,1"],
"aten::add": ["v1/3;0,1|3;0,1|3;0,1|3;0,1", ScalarType.Float.name],
"aten::mul": [],
}
args = [
Expand All @@ -104,7 +105,7 @@ def test_gen_op_list_with_root_ops_and_dtypes(
{
"aten::add": [
"v1/3;0,1|3;0,1|3;0,1|3;0,1",
"v1/6;0,1|6;0,1|6;0,1|6;0,1",
"v1/6;",
],
"aten::mul": ["default"],
},
Expand Down
76 changes: 74 additions & 2 deletions codegen/tools/test/test_gen_selected_op_variants.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import expecttest


class TestGenSelectedMobileOpsHeader(expecttest.TestCase):
class TestGenSelectedOpVariants(expecttest.TestCase):
def setUp(self):
self.temp_dir = tempfile.TemporaryDirectory()
self.addCleanup(self.temp_dir.cleanup)
Expand Down Expand Up @@ -84,7 +84,79 @@ def test_generates_correct_header(self) -> None:
)


class TestGenSelectedMobileOpsHeader_Empty(expecttest.TestCase):
class TestGenSelectedOpVariants_UsingDtypeString(expecttest.TestCase):
def setUp(self):
self.temp_dir = tempfile.TemporaryDirectory()
self.addCleanup(self.temp_dir.cleanup)
self.selected_ops_yaml = os.path.join(
self.temp_dir.name, "selected_operators.yaml"
)
with open(self.selected_ops_yaml, "w") as f:
f.write(
"""
include_all_non_op_selectives: False
include_all_operators: False
debug_info:
- model1@v100
- model2@v50
operators:
aten::add:
is_root_operator: Yes
is_used_for_training: Yes
include_all_overloads: No
aten::add.int:
is_root_operator: No
is_used_for_training: No
include_all_overloads: Yes
kernel_metadata: {}
et_kernel_metadata:
aten::add.out:
# A list of different kernel keys (tensors with dtype-enum/dim-order) combinations used in model
- v1/6; # Float
- v1/3; # Int
aten::mul.out:
- v1/6; # Float
aten::sub.out:
- default
build_features: []
custom_classes: []
"""
)

def tearDown(self):
self.temp_dir.cleanup()

def test_generates_correct_header(self) -> None:
gen_selected_op_variants.write_selected_op_variants(
os.path.join(self.temp_dir.name, "selected_operators.yaml"),
self.temp_dir.name,
)
with open(
os.path.join(self.temp_dir.name, "selected_op_variants.h"), "r"
) as result:
self.assertExpectedInline(
result.read(),
"""#pragma once
/**
* Generated by executorch/codegen/tools/gen_selected_op_variants.py
*/

inline constexpr bool should_include_kernel_dtype(
const char *operator_name,
executorch::aten::ScalarType scalar_type
) {
return ((executorch::aten::string_view(operator_name).compare("add.out") == 0)
&& (scalar_type == executorch::aten::ScalarType::Float || scalar_type == executorch::aten::ScalarType::Int))
|| ((executorch::aten::string_view(operator_name).compare("mul.out") == 0)
&& (scalar_type == executorch::aten::ScalarType::Float))
|| ((executorch::aten::string_view(operator_name).compare("sub.out") == 0)
&& (true));
}
""",
)


class TestGenSelectedOpVariants_Empty(expecttest.TestCase):
def setUp(self):
self.temp_dir = tempfile.TemporaryDirectory()
self.addCleanup(self.temp_dir.cleanup)
Expand Down
6 changes: 4 additions & 2 deletions examples/selective_build/targets.bzl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "get_oss_build_kwargs", "is_xplat", "runtime")
load("@fbsource//xplat/executorch/codegen:codegen.bzl", "et_operator_library", "executorch_generated_lib")
load("@fbsource//xplat/executorch/codegen:codegen.bzl", "et_operator_library", "executorch_generated_lib", "ScalarType")

def define_common_targets():
"""Defines targets that should be shared between fbcode and xplat.
Expand Down Expand Up @@ -49,7 +49,9 @@ def define_common_targets():
et_operator_library(
name = "select_ops_in_dict",
ops_dict = {
"aten::add.out": ["v1/3;0,1", "v1/6;0,1"], # int, float
# 1. Use kernel key, generated with a model, or
# 2. Specify the dtype, from executorch/codegen/codegen.bzl
"aten::add.out": ["v1/3;0,1", ScalarType("Float")], # int, float
"aten::mm.out": [], # all dtypes
},
)
Expand Down
1 change: 0 additions & 1 deletion examples/selective_build/test_selective_build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ test_buck2_select_ops_in_dict() {
# select ops and their dtypes using the dictionary API.
$BUCK run //examples/selective_build:selective_build_test \
--config=executorch.select_ops=dict \
--config=executorch.dtype_selective_build_lib=//examples/selective_build:select_ops_in_dict_lib \
-- --model_path=./add_mul.pte

echo "Removing add_mul.pte"
Expand Down
33 changes: 33 additions & 0 deletions shim/xplat/executorch/codegen/codegen.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,39 @@ CUSTOM_OPS_SCHEMA_REGISTRATION_SOURCES = [
"RegisterSchema.cpp",
]

ScalarType = enum(
"Byte",
"Char",
"Short",
"Int",
"Long",
"Half",
"Float",
"Double",
"ComplexHalf",
"ComplexFloat",
"ComplexDouble",
"Bool",
"QInt8",
"QUInt8",
"QInt32",
"BFloat16",
"QUInt4x2",
"QUInt2x4",
"Bits1x8",
"Bits2x4",
"Bits4x2",
"Bits8",
"Bits16",
"Float8_e5m2",
"Float8_e4m3fn",
"Float8_e5m2fnuz",
"Float8_e4m3fnuz",
"UInt16",
"UInt32",
"Uint64",
)

# Hide the dependency to caffe2 internally.
def et_operator_library(
name,
Expand Down
Loading