Skip to content

Commit e66a7d8

Browse files
Guang Yangfacebook-github-bot
Guang Yang
authored andcommitted
Add deeplab_v3 model to examples (#60)
Summary: Add image segmentation model **deeplabv3_resnet50** to `executorch/examples`. Info about the model: https://pytorch.org/hub/pytorch_vision_deeplabv3_resnet101/ Reviewed By: kirklandsign Differential Revision: D48136966
1 parent 0cd9c57 commit e66a7d8

File tree

5 files changed

+52
-0
lines changed

5 files changed

+52
-0
lines changed

examples/export/test/test_export.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,3 +126,13 @@ def test_resnet50_export_to_executorch(self):
126126
self._assert_eager_lowered_same_result(
127127
eager_model, example_inputs, self.validate_tensor_allclose
128128
)
129+
130+
def test_dl3_export_to_executorch(self):
131+
eager_model, example_inputs = EagerModelFactory.create_model(
132+
*MODEL_NAME_TO_MODEL["dl3"]
133+
)
134+
eager_model = eager_model.eval()
135+
136+
self._assert_eager_lowered_same_result(
137+
eager_model, example_inputs, self.validate_tensor_allclose
138+
)

examples/models/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ python_library(
99
deps = [
1010
"//caffe2:torch",
1111
"//executorch/examples/models:model_base", # @manual
12+
"//executorch/examples/models/deeplab_v3:dl3_model", # @manual
1213
"//executorch/examples/models/inception_v3:ic3_model", # @manual
1314
"//executorch/examples/models/inception_v4:ic4_model", # @manual
1415
"//executorch/examples/models/mobilebert:mobilebert_model", # @manual

examples/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
"linear": ("toy_model", "LinearModule"),
1212
"add": ("toy_model", "AddModule"),
1313
"add_mul": ("toy_model", "AddMulModule"),
14+
"dl3": ("deeplab_v3", "DeepLabV3ResNet50Model"),
1415
"mobilebert": ("mobilebert", "MobileBertModelExample"),
1516
"mv2": ("mobilenet_v2", "MV2Model"),
1617
"mv3": ("mobilenet_v3", "MV3Model"),
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from .model import DeepLabV3ResNet50Model
8+
9+
__all__ = [
10+
DeepLabV3ResNet50Model,
11+
]

examples/models/deeplab_v3/model.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import logging
8+
9+
import torch
10+
from torchvision.models.segmentation import deeplabv3, deeplabv3_resnet50 # @manual
11+
12+
from ..model_base import EagerModelBase
13+
14+
15+
class DeepLabV3ResNet50Model(EagerModelBase):
16+
def __init__(self):
17+
pass
18+
19+
def get_eager_model(self) -> torch.nn.Module:
20+
logging.info("loading deeplabv3_resnet50 model")
21+
deeplabv3_model = deeplabv3_resnet50(
22+
weights=deeplabv3.DeepLabV3_ResNet50_Weights.DEFAULT
23+
)
24+
logging.info("loaded deeplabv3_resnet50 model")
25+
return deeplabv3_model
26+
27+
def get_example_inputs(self):
28+
input_shape = (1, 3, 224, 224)
29+
return (torch.randn(input_shape),)

0 commit comments

Comments
 (0)