File tree 5 files changed +52
-0
lines changed 5 files changed +52
-0
lines changed Original file line number Diff line number Diff line change @@ -126,3 +126,13 @@ def test_resnet50_export_to_executorch(self):
126
126
self ._assert_eager_lowered_same_result (
127
127
eager_model , example_inputs , self .validate_tensor_allclose
128
128
)
129
+
130
+ def test_dl3_export_to_executorch (self ):
131
+ eager_model , example_inputs = EagerModelFactory .create_model (
132
+ * MODEL_NAME_TO_MODEL ["dl3" ]
133
+ )
134
+ eager_model = eager_model .eval ()
135
+
136
+ self ._assert_eager_lowered_same_result (
137
+ eager_model , example_inputs , self .validate_tensor_allclose
138
+ )
Original file line number Diff line number Diff line change @@ -9,6 +9,7 @@ python_library(
9
9
deps = [
10
10
"//caffe2:torch",
11
11
"//executorch/examples/models:model_base", # @manual
12
+ "//executorch/examples/models/deeplab_v3:dl3_model", # @manual
12
13
"//executorch/examples/models/inception_v3:ic3_model", # @manual
13
14
"//executorch/examples/models/inception_v4:ic4_model", # @manual
14
15
"//executorch/examples/models/mobilebert:mobilebert_model", # @manual
Original file line number Diff line number Diff line change 11
11
"linear" : ("toy_model" , "LinearModule" ),
12
12
"add" : ("toy_model" , "AddModule" ),
13
13
"add_mul" : ("toy_model" , "AddMulModule" ),
14
+ "dl3" : ("deeplab_v3" , "DeepLabV3ResNet50Model" ),
14
15
"mobilebert" : ("mobilebert" , "MobileBertModelExample" ),
15
16
"mv2" : ("mobilenet_v2" , "MV2Model" ),
16
17
"mv3" : ("mobilenet_v3" , "MV3Model" ),
Original file line number Diff line number Diff line change
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from .model import DeepLabV3ResNet50Model
8
+
9
+ __all__ = [
10
+ DeepLabV3ResNet50Model ,
11
+ ]
Original file line number Diff line number Diff line change
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import logging
8
+
9
+ import torch
10
+ from torchvision .models .segmentation import deeplabv3 , deeplabv3_resnet50 # @manual
11
+
12
+ from ..model_base import EagerModelBase
13
+
14
+
15
+ class DeepLabV3ResNet50Model (EagerModelBase ):
16
+ def __init__ (self ):
17
+ pass
18
+
19
+ def get_eager_model (self ) -> torch .nn .Module :
20
+ logging .info ("loading deeplabv3_resnet50 model" )
21
+ deeplabv3_model = deeplabv3_resnet50 (
22
+ weights = deeplabv3 .DeepLabV3_ResNet50_Weights .DEFAULT
23
+ )
24
+ logging .info ("loaded deeplabv3_resnet50 model" )
25
+ return deeplabv3_model
26
+
27
+ def get_example_inputs (self ):
28
+ input_shape = (1 , 3 , 224 , 224 )
29
+ return (torch .randn (input_shape ),)
You can’t perform that action at this time.
0 commit comments