Skip to content

Commit 67d0dd7

Browse files
pytorchbotdbort
andauthored
Add a pure python wrapper to pybindings.portable_lib (#3137) (#3218)
Summary: Pull Request resolved: #3137 When installed as a pip wheel, we must import `torch` before trying to import the pybindings shared library extension. This will load libtorch.so and related libs, ensuring that the pybindings lib can resolve those runtime dependencies. So, add a pure python wrapper that lets us do this when users say `import executorch.extension.pybindings.portable_lib` We only need this for OSS, so don't bother doing this for other pybindings targets. Reviewed By: orionr, mikekgfb Differential Revision: D56317150 fbshipit-source-id: 920382636732aa276c25a76163afb7d28b1846d0 (cherry picked from commit 969aa96) Co-authored-by: Dave Bort <[email protected]>
1 parent 773da4d commit 67d0dd7

File tree

5 files changed

+59
-10
lines changed

5 files changed

+59
-10
lines changed

CMakeLists.txt

+4-1
Original file line numberDiff line numberDiff line change
@@ -548,8 +548,11 @@ if(EXECUTORCH_BUILD_PYBIND)
548548

549549
# pybind portable_lib
550550
pybind11_add_module(portable_lib extension/pybindings/pybindings.cpp)
551+
# The actual output file needs a leading underscore so it can coexist with
552+
# portable_lib.py in the same python package.
553+
set_target_properties(portable_lib PROPERTIES OUTPUT_NAME "_portable_lib")
551554
target_compile_definitions(portable_lib
552-
PUBLIC EXECUTORCH_PYTHON_MODULE_NAME=portable_lib)
555+
PUBLIC EXECUTORCH_PYTHON_MODULE_NAME=_portable_lib)
553556
target_include_directories(portable_lib PRIVATE ${TORCH_INCLUDE_DIRS})
554557
target_compile_options(portable_lib PUBLIC ${_pybind_compile_options})
555558
target_link_libraries(

extension/pybindings/TARGETS

+12-4
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,9 @@ runtime.genrule(
3030
srcs = [":pybinding_types"],
3131
outs = {
3232
"aten_lib.pyi": ["aten_lib.pyi"],
33-
"portable_lib.pyi": ["portable_lib.pyi"],
33+
"_portable_lib.pyi": ["_portable_lib.pyi"],
3434
},
35-
cmd = "cp $(location :pybinding_types)/* $OUT/portable_lib.pyi && cp $(location :pybinding_types)/* $OUT/aten_lib.pyi",
35+
cmd = "cp $(location :pybinding_types)/* $OUT/_portable_lib.pyi && cp $(location :pybinding_types)/* $OUT/aten_lib.pyi",
3636
visibility = ["//executorch/extension/pybindings/..."],
3737
)
3838

@@ -46,8 +46,9 @@ executorch_pybindings(
4646
executorch_pybindings(
4747
compiler_flags = ["-std=c++17"],
4848
cppdeps = PORTABLE_MODULE_DEPS + MODELS_ATEN_OPS_LEAN_MODE_GENERATED_LIB,
49-
python_module_name = "portable_lib",
50-
types = ["//executorch/extension/pybindings:pybindings_types_gen[portable_lib.pyi]"],
49+
# Give this an underscore prefix because it has a pure python wrapper.
50+
python_module_name = "_portable_lib",
51+
types = ["//executorch/extension/pybindings:pybindings_types_gen[_portable_lib.pyi]"],
5152
visibility = ["PUBLIC"],
5253
)
5354

@@ -58,3 +59,10 @@ executorch_pybindings(
5859
types = ["//executorch/extension/pybindings:pybindings_types_gen[aten_lib.pyi]"],
5960
visibility = ["PUBLIC"],
6061
)
62+
63+
runtime.python_library(
64+
name = "portable_lib",
65+
srcs = ["portable_lib.py"],
66+
visibility = ["@EXECUTORCH_CLIENTS"],
67+
deps = [":_portable_lib"],
68+
)

extension/pybindings/portable_lib.py

+34
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
# When installed as a pip wheel, we must import `torch` before trying to import
10+
# the pybindings shared library extension. This will load libtorch.so and
11+
# related libs, ensuring that the pybindings lib can resolve those runtime
12+
# dependencies.
13+
import torch as _torch
14+
15+
# Let users import everything from the C++ _portable_lib extension as if this
16+
# python file defined them. Although we could import these dynamically, it
17+
# wouldn't preserve the static type annotations.
18+
from executorch.extension.pybindings._portable_lib import ( # noqa: F401
19+
# Disable "imported but unused" (F401) checks.
20+
_create_profile_block, # noqa: F401
21+
_dump_profile_results, # noqa: F401
22+
_get_operator_names, # noqa: F401
23+
_load_bundled_program_from_buffer, # noqa: F401
24+
_load_for_executorch, # noqa: F401
25+
_load_for_executorch_from_buffer, # noqa: F401
26+
_load_for_executorch_from_bundled_program, # noqa: F401
27+
_reset_profile_results, # noqa: F401
28+
BundledModule, # noqa: F401
29+
ExecuTorchModule, # noqa: F401
30+
)
31+
32+
# Clean up so that `dir(portable_lib)` is the same as `dir(_portable_lib)`
33+
# (apart from some __dunder__ names).
34+
del _torch

extension/pybindings/pybindings.pyi

+8-4
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,12 @@
77
# pyre-strict
88
from typing import Any, Dict, List, Sequence, Tuple
99

10-
class ExecutorchModule:
10+
class ExecuTorchModule:
11+
# pyre-ignore[2, 3]: "Any" in parameter and return type annotations.
1112
def __call__(self, inputs: Any) -> List[Any]: ...
13+
# pyre-ignore[2, 3]: "Any" in parameter and return type annotations.
1214
def run_method(self, method_name: str, inputs: Sequence[Any]) -> List[Any]: ...
15+
# pyre-ignore[2, 3]: "Any" in parameter and return type annotations.
1316
def forward(self, inputs: Sequence[Any]) -> List[Any]: ...
1417
# Bundled program methods.
1518
def load_bundled_input(
@@ -30,16 +33,17 @@ class BundledModule: ...
3033

3134
def _load_for_executorch(
3235
path: str, enable_etdump: bool = False
33-
) -> ExecutorchModule: ...
36+
) -> ExecuTorchModule: ...
3437
def _load_for_executorch_from_buffer(
3538
buffer: bytes, enable_etdump: bool = False
36-
) -> ExecutorchModule: ...
39+
) -> ExecuTorchModule: ...
3740
def _load_for_executorch_from_bundled_program(
3841
module: BundledModule, enable_etdump: bool = False
39-
) -> ExecutorchModule: ...
42+
) -> ExecuTorchModule: ...
4043
def _load_bundled_program_from_buffer(
4144
buffer: bytes, non_const_pool_size: int = ...
4245
) -> BundledModule: ...
46+
def _get_operator_names() -> List[str]: ...
4347
def _create_profile_block(name: str) -> None: ...
4448
def _dump_profile_results() -> bytes: ...
4549
def _reset_profile_results() -> None: ...

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -435,7 +435,7 @@ def get_ext_modules() -> list[Extension]:
435435
# portable kernels, and a selection of backends. This lets users
436436
# load and execute .pte files from python.
437437
BuiltExtension(
438-
"portable_lib.*", "executorch.extension.pybindings.portable_lib"
438+
"_portable_lib.*", "executorch.extension.pybindings._portable_lib"
439439
)
440440
)
441441

0 commit comments

Comments
 (0)