Skip to content

Commit 00af75e

Browse files
committed
move Model enum to examples
1 parent 053292a commit 00af75e

File tree

4 files changed

+133
-89
lines changed

4 files changed

+133
-89
lines changed

.ci/scripts/wheel/test_base.py

Lines changed: 41 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -7,66 +7,60 @@
77
import os
88
import subprocess
99
import sys
10+
from dataclasses import dataclass
1011
from functools import lru_cache
1112
from typing import List
12-
from dataclasses import dataclass
13-
from enum import Enum
1413

1514

16-
class Model(str, Enum):
17-
Mv3 = "mv3"
15+
@lru_cache()
16+
def _unsafe_get_env(key: str) -> str:
17+
value = os.getenv(key)
18+
if value is None:
19+
raise RuntimeError(f"environment variable '{key}' is not set")
20+
return value
21+
22+
23+
@lru_cache()
24+
def _repository_root_dir() -> str:
25+
return os.path.join(
26+
_unsafe_get_env("GITHUB_WORKSPACE"),
27+
_unsafe_get_env("REPOSITORY"),
28+
)
1829

19-
def __str__(self) -> str:
20-
return self.value
2130

22-
class Backend(str, Enum):
23-
XnnpackQuantizationDelegation = "xnnpack-quantization-delegation"
31+
# For some reason, we are unable to see the entire repo in the python path.
32+
# So manually add it.
33+
sys.path.append(_repository_root_dir())
34+
from examples.models import Backend, Model
2435

25-
def __str__(self) -> str:
26-
return self.value
2736

2837
@dataclass
2938
class ModelTest:
3039
model: Model
3140
backend: Backend
3241

3342

34-
@lru_cache()
35-
def _repository_root_dir() -> str:
36-
workspace_dir = os.getenv("GITHUB_WORKSPACE")
37-
if workspace_dir is None:
38-
print("GITHUB_WORKSPACE is not set")
39-
sys.exit(1)
40-
41-
repository_dir = os.getenv("REPOSITORY")
42-
if repository_dir is None:
43-
print("REPOSITORY is not set")
44-
sys.exit(1)
45-
46-
return os.path.join(workspace_dir, repository_dir)
47-
48-
4943
def run_tests(model_tests: List[ModelTest]) -> None:
50-
# Why are we doing this envvar shenanigans? Since we build the testers, which
51-
# uses buck, we cannot run as root. This is a sneaky of getting around that
52-
# test.
53-
#
54-
# This can be reverted if either:
55-
# - We remove usage of buck in our builds
56-
# - We stop running the Docker image as root: https://github.com/pytorch/test-infra/issues/5091
57-
envvars = os.environ.copy()
58-
envvars.pop("HOME")
44+
# Why are we doing this envvar shenanigans? Since we build the testers, which
45+
# uses buck, we cannot run as root. This is a sneaky of getting around that
46+
# test.
47+
#
48+
# This can be reverted if either:
49+
# - We remove usage of buck in our builds
50+
# - We stop running the Docker image as root: https://github.com/pytorch/test-infra/issues/5091
51+
envvars = os.environ.copy()
52+
envvars.pop("HOME")
5953

60-
for model_test in model_tests:
61-
subprocess.run(
62-
[
63-
os.path.join(_repository_root_dir(), ".ci/scripts/test_model.sh"),
64-
str(model_test.model),
65-
# What to build `executor_runner` with for testing.
66-
"cmake",
67-
str(model_test.backend),
68-
],
69-
env=envvars,
70-
check=True,
71-
cwd=_repository_root_dir(),
72-
)
54+
for model_test in model_tests:
55+
subprocess.run(
56+
[
57+
os.path.join(_repository_root_dir(), ".ci/scripts/test_model.sh"),
58+
str(model_test.model),
59+
# What to build `executor_runner` with for testing.
60+
"cmake",
61+
str(model_test.backend),
62+
],
63+
env=envvars,
64+
check=True,
65+
cwd=_repository_root_dir(),
66+
)

.ci/scripts/wheel/test_linux.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,15 @@
55
# This source code is licensed under the BSD-style license found in the
66
# LICENSE file in the root directory of this source tree.
77

8-
from test_base import run_tests, ModelTest, Model, Backend
8+
import test_base
9+
from examples.models import Backend, Model
910

1011
if __name__ == "__main__":
11-
run_tests(model_tests=[
12-
ModelTest(
13-
model=Model.Mv3,
14-
backend=Backend.XnnpackQuantizationDelegation,
15-
)
16-
])
12+
test_base.run_tests(
13+
model_tests=[
14+
test_base.ModelTest(
15+
model=Model.Mv3,
16+
backend=Backend.XnnpackQuantizationDelegation,
17+
)
18+
]
19+
)

.ci/scripts/wheel/test_macos.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,15 @@
55
# This source code is licensed under the BSD-style license found in the
66
# LICENSE file in the root directory of this source tree.
77

8-
from test_base import run_tests, ModelTest, Model, Backend
8+
import test_base
9+
from examples.models import Backend, Model
910

1011
if __name__ == "__main__":
11-
run_tests(model_tests=[
12-
ModelTest(
13-
model=Model.Mv3,
14-
backend=Backend.XnnpackQuantizationDelegation,
15-
)
16-
])
12+
test_base.run_tests(
13+
model_tests=[
14+
test_base.ModelTest(
15+
model=Model.Mv3,
16+
backend=Backend.XnnpackQuantizationDelegation,
17+
)
18+
]
19+
)

examples/models/__init__.py

Lines changed: 72 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -5,37 +5,81 @@
55
# This source code is licensed under the BSD-style license found in the
66
# LICENSE file in the root directory of this source tree.
77

8+
from enum import Enum
9+
10+
11+
class Model(str, Enum):
12+
Mul = "mul"
13+
Linear = "linear"
14+
Add = "add"
15+
AddMul = "add_mul"
16+
Softmax = "softmax"
17+
Dl3 = "dl3"
18+
Edsr = "edsr"
19+
EmformerTranscribe = "emformer_transcribe"
20+
EmformerPredict = "emformer_predict"
21+
EmformerJoin = "emformer_join"
22+
Llama2 = "llama2"
23+
Llama = "llama"
24+
Llama32VisionEncoder = "llama3_2_vision_encoder"
25+
Lstm = "lstm"
26+
MobileBert = "mobilebert"
27+
Mv2 = "mv2"
28+
Mv2Untrained = "mv2_untrained"
29+
Mv3 = "mv3"
30+
Vit = "vit"
31+
W2l = "w2l"
32+
Ic3 = "ic3"
33+
Ic4 = "ic4"
34+
ResNet18 = "resnet18"
35+
ResNet50 = "resnet50"
36+
Llava = "llava"
37+
EfficientSam = "efficient_sam"
38+
Qwen25 = "qwen2_5"
39+
Phi4Mini = "phi-4-mini"
40+
41+
def __str__(self) -> str:
42+
return self.value
43+
44+
45+
class Backend(str, Enum):
46+
XnnpackQuantizationDelegation = "xnnpack-quantization-delegation"
47+
48+
def __str__(self) -> str:
49+
return self.value
50+
51+
852
MODEL_NAME_TO_MODEL = {
9-
"mul": ("toy_model", "MulModule"),
10-
"linear": ("toy_model", "LinearModule"),
11-
"add": ("toy_model", "AddModule"),
12-
"add_mul": ("toy_model", "AddMulModule"),
13-
"softmax": ("toy_model", "SoftmaxModule"),
14-
"dl3": ("deeplab_v3", "DeepLabV3ResNet50Model"),
15-
"edsr": ("edsr", "EdsrModel"),
16-
"emformer_transcribe": ("emformer_rnnt", "EmformerRnntTranscriberModel"),
17-
"emformer_predict": ("emformer_rnnt", "EmformerRnntPredictorModel"),
18-
"emformer_join": ("emformer_rnnt", "EmformerRnntJoinerModel"),
19-
"llama2": ("llama", "Llama2Model"),
20-
"llama": ("llama", "Llama2Model"),
21-
"llama3_2_vision_encoder": ("llama3_2_vision", "FlamingoVisionEncoderModel"),
53+
str(Model.Mul): ("toy_model", "MulModule"),
54+
str(Model.Linear): ("toy_model", "LinearModule"),
55+
str(Model.Add): ("toy_model", "AddModule"),
56+
str(Model.AddMul): ("toy_model", "AddMulModule"),
57+
str(Model.Softmax): ("toy_model", "SoftmaxModule"),
58+
str(Model.Dl3): ("deeplab_v3", "DeepLabV3ResNet50Model"),
59+
str(Model.Edsr): ("edsr", "EdsrModel"),
60+
str(Model.EmformerTranscribe): ("emformer_rnnt", "EmformerRnntTranscriberModel"),
61+
str(Model.EmformerPredict): ("emformer_rnnt", "EmformerRnntPredictorModel"),
62+
str(Model.EmformerJoin): ("emformer_rnnt", "EmformerRnntJoinerModel"),
63+
str(Model.Llama2): ("llama", "Llama2Model"),
64+
str(Model.Llama): ("llama", "Llama2Model"),
65+
str(Model.Llama32VisionEncoder): ("llama3_2_vision", "FlamingoVisionEncoderModel"),
2266
# TODO: This take too long to export on both Linux and MacOS (> 6 hours)
2367
# "llama3_2_text_decoder": ("llama3_2_vision", "Llama3_2Decoder"),
24-
"lstm": ("lstm", "LSTMModel"),
25-
"mobilebert": ("mobilebert", "MobileBertModelExample"),
26-
"mv2": ("mobilenet_v2", "MV2Model"),
27-
"mv2_untrained": ("mobilenet_v2", "MV2UntrainedModel"),
28-
"mv3": ("mobilenet_v3", "MV3Model"),
29-
"vit": ("torchvision_vit", "TorchVisionViTModel"),
30-
"w2l": ("wav2letter", "Wav2LetterModel"),
31-
"ic3": ("inception_v3", "InceptionV3Model"),
32-
"ic4": ("inception_v4", "InceptionV4Model"),
33-
"resnet18": ("resnet", "ResNet18Model"),
34-
"resnet50": ("resnet", "ResNet50Model"),
35-
"llava": ("llava", "LlavaModel"),
36-
"efficient_sam": ("efficient_sam", "EfficientSAM"),
37-
"qwen2_5": ("qwen2_5", "Qwen2_5Model"),
38-
"phi-4-mini": ("phi-4-mini", "Phi4MiniModel"),
68+
str(Model.Lstm): ("lstm", "LSTMModel"),
69+
str(Model.MobileBert): ("mobilebert", "MobileBertModelExample"),
70+
str(Model.Mv2): ("mobilenet_v2", "MV2Model"),
71+
str(Model.Mv2Untrained): ("mobilenet_v2", "MV2UntrainedModel"),
72+
str(Model.Mv3): ("mobilenet_v3", "MV3Model"),
73+
str(Model.Vit): ("torchvision_vit", "TorchVisionViTModel"),
74+
str(Model.W2l): ("wav2letter", "Wav2LetterModel"),
75+
str(Model.Ic3): ("inception_v3", "InceptionV3Model"),
76+
str(Model.Ic4): ("inception_v4", "InceptionV4Model"),
77+
str(Model.ResNet18): ("resnet", "ResNet18Model"),
78+
str(Model.ResNet50): ("resnet", "ResNet50Model"),
79+
str(Model.Llava): ("llava", "LlavaModel"),
80+
str(Model.EfficientSam): ("efficient_sam", "EfficientSAM"),
81+
str(Model.Qwen25): ("qwen2_5", "Qwen2_5Model"),
82+
str(Model.Phi4Mini): ("phi-4-mini", "Phi4MiniModel"),
3983
}
4084

4185
__all__ = [

0 commit comments

Comments
 (0)