Skip to content

Commit 4bb3112

Browse files
authored
Merge pull request #44 from ynimmaga/fix_linter_issues
Fixed MyPy TypeChecker Issues
2 parents a9010ac + f74e3e3 commit 4bb3112

File tree

10 files changed

+75
-34
lines changed

10 files changed

+75
-34
lines changed

.lintrunner.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,12 +299,14 @@ include_patterns = [
299299
# TODO(https://github.com/pytorch/executorch/issues/7441): Gradually start enabling all folders.
300300
# 'backends/**/*.py',
301301
'backends/arm/**/*.py',
302+
'backends/openvino/**/*.py',
302303
'build/**/*.py',
303304
'codegen/**/*.py',
304305
# 'devtools/**/*.py',
305306
'devtools/visualization/**/*.py',
306307
'docs/**/*.py',
307308
# 'examples/**/*.py',
309+
'examples/openvino/**/*.py',
308310
# 'exir/**/*.py',
309311
# 'extension/**/*.py',
310312
'kernels/**/*.py',

backends/openvino/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@
22
from .preprocess import OpenvinoBackend
33
from .quantizer.quantizer import OpenVINOQuantizer
44

5-
__all__ = [OpenvinoBackend, OpenvinoPartitioner, OpenVINOQuantizer]
5+
__all__ = ["OpenvinoBackend", "OpenvinoPartitioner", "OpenVINOQuantizer"]

backends/openvino/partitioner.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@
1515
PartitionResult,
1616
)
1717
from executorch.exir.backend.utils import tag_constant_data
18-
from openvino.frontend.pytorch.torchdynamo.op_support import OperatorSupport
18+
from openvino.frontend.pytorch.torchdynamo.op_support import ( # type: ignore[import-untyped]
19+
OperatorSupport,
20+
)
1921

2022
from torch.export.exported_program import ExportedProgram
2123
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
@@ -53,8 +55,11 @@ def is_node_supported(self, _, node: torch.fx.Node) -> bool:
5355
if node.op != "call_function":
5456
return False
5557

56-
options = []
57-
op_type = node.target.__name__
58+
options: list[str] = []
59+
if not isinstance(node.target, str):
60+
op_type = node.target.__name__
61+
else:
62+
op_type = str(node.target)
5863
supported_ops = OperatorSupport(options)._support_dict
5964
if op_type == "getitem":
6065
return True

backends/openvino/preprocess.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@
1212
PreprocessResult,
1313
)
1414
from executorch.exir.backend.compile_spec_schema import CompileSpec
15-
from openvino.frontend.pytorch.torchdynamo.compile import openvino_compile
15+
from openvino.frontend.pytorch.torchdynamo.compile import ( # type: ignore[import-untyped]
16+
openvino_compile,
17+
)
1618

1719

1820
@final
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
from .quantizer import OpenVINOQuantizer, quantize_model
22

3-
__all__ = [OpenVINOQuantizer, quantize_model]
3+
__all__ = ["OpenVINOQuantizer", "quantize_model"]

backends/openvino/quantizer/quantizer.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,20 @@
66

77
from collections import defaultdict
88
from enum import Enum
9-
from typing import Any, Callable, Dict, List, Optional, Tuple
9+
from typing import Any, Callable, DefaultDict, Dict, List, Optional, Tuple, Type
1010

11-
import nncf
12-
import nncf.common.quantization as quantization
13-
import nncf.experimental.torch.fx as nncf_fx
11+
import nncf # type: ignore[import-untyped]
12+
import nncf.common.quantization as quantization # type: ignore[import-untyped]
13+
import nncf.experimental.torch.fx as nncf_fx # type: ignore[import-untyped]
1414

1515
import torch.fx
1616

17-
from nncf.common.graph.graph import NNCFGraph
18-
from torch.ao.quantization.observer import HistogramObserver, PerChannelMinMaxObserver
17+
from nncf.common.graph.graph import NNCFGraph # type: ignore[import-untyped]
18+
from torch.ao.quantization.observer import (
19+
HistogramObserver,
20+
PerChannelMinMaxObserver,
21+
UniformQuantizationObserverBase,
22+
)
1923
from torch.ao.quantization.quantizer.quantizer import (
2024
EdgeOrNode,
2125
QuantizationAnnotation,
@@ -117,13 +121,15 @@ def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
117121
quantization_setup = self.get_nncf_quantization_setup(model, nncf_graph)
118122

119123
graph = model.graph
120-
node_vs_torch_annotation = defaultdict(QuantizationAnnotation)
124+
node_vs_torch_annotation: DefaultDict[torch.fx.Node, QuantizationAnnotation] = (
125+
defaultdict(QuantizationAnnotation)
126+
)
121127

122128
for qp in quantization_setup.quantization_points.values():
123129
edge_or_node, annotation = self._get_edge_or_node_and_annotation(
124130
graph, nncf_graph, qp, node_vs_torch_annotation
125131
)
126-
qspec = self._get_torch_ao_qspec_from_qp(qp)
132+
qspec: QuantizationSpecBase = self._get_torch_ao_qspec_from_qp(qp)
127133
self._fill_torch_ao_annotation(edge_or_node, qspec, annotation)
128134

129135
for quantizer_ids in quantization_setup.unified_scale_groups.values():
@@ -199,6 +205,9 @@ def _get_unified_scales_root_quantizer_id(
199205
):
200206
root_quantizer_id = quantizer_id
201207
nncf_node_quantizer_id = nncf_node.node_id
208+
if root_quantizer_id is None:
209+
msg = "Root quantizer ids can't be None"
210+
raise nncf.InternalError(msg)
202211
return root_quantizer_id
203212

204213
@staticmethod
@@ -299,6 +308,8 @@ def _get_torch_ao_qspec_from_qp(
299308
qconfig = qp.qconfig
300309
is_weight = qp.is_weight_quantization_point()
301310

311+
observer: Type[UniformQuantizationObserverBase]
312+
302313
if qconfig.per_channel:
303314
torch_qscheme = (
304315
torch.per_channel_symmetric

backends/openvino/tests/models/test_classification.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1-
import timm
1+
import timm # type: ignore[import-untyped]
22
import torch
3-
import torchvision.models as torchvision_models
3+
import torchvision.models as torchvision_models # type: ignore[import-untyped]
44
from executorch.backends.openvino.tests.ops.base_openvino_op_test import (
55
BaseOpenvinoOpTest,
66
)
7-
from transformers import AutoModel
7+
from transformers import AutoModel # type: ignore[import-untyped]
88

99
classifier_params = [
1010
{"model": ["torchvision", "resnet50", (1, 3, 224, 224)]},

backends/openvino/tests/ops/base_openvino_op_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ def execute_layer_test(
6868
runtime = Runtime.get()
6969
program = runtime.load_program(exec_prog.buffer)
7070
method = program.load_method("forward")
71+
assert method is not None
7172
outputs = method.execute(sample_inputs)
7273

7374
# Compare the outputs with the reference outputs

backends/openvino/tests/test_runner.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
import argparse
22
import unittest
33

4-
import nncf.torch
4+
import nncf.torch # type: ignore[import-untyped]
55

66

77
class OpenvinoTestSuite(unittest.TestSuite):
88

9-
test_params = {}
9+
test_params: dict[str, str] = {}
1010

1111
def __init__(self, *args, **kwargs):
1212
super().__init__(*args, **kwargs)
@@ -51,7 +51,7 @@ def parse_arguments():
5151
)
5252

5353
args, ns_args = parser.parse_known_args(namespace=unittest)
54-
test_params = {}
54+
test_params: dict[str, str] = {}
5555
test_params["device"] = args.device
5656
test_params["pattern"] = args.pattern
5757
test_params["test_type"] = args.test_type

examples/openvino/aot_optimize_and_infer.py

Lines changed: 34 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,11 @@
44
# except in compliance with the License. See the license file found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
# mypy: disable-error-code=import-untyped
8+
79
import argparse
810
import time
11+
from typing import cast, List, Optional
912

1013
import executorch
1114

@@ -15,7 +18,11 @@
1518
import torchvision.models as torchvision_models
1619
from executorch.backends.openvino.partitioner import OpenvinoPartitioner
1720
from executorch.backends.openvino.quantizer import quantize_model
18-
from executorch.exir import EdgeProgramManager, to_edge_transform_and_lower
21+
from executorch.exir import (
22+
EdgeProgramManager,
23+
ExecutorchProgramManager,
24+
to_edge_transform_and_lower,
25+
)
1926
from executorch.exir.backend.backend_details import CompileSpec
2027
from executorch.runtime import Runtime
2128
from sklearn.metrics import accuracy_score
@@ -102,7 +109,7 @@ def load_calibration_dataset(
102109

103110

104111
def infer_model(
105-
exec_prog: EdgeProgramManager,
112+
exec_prog: ExecutorchProgramManager,
106113
inputs,
107114
num_iter: int,
108115
warmup_iter: int,
@@ -111,7 +118,7 @@ def infer_model(
111118
"""
112119
Executes inference and reports the average timing.
113120
114-
:param exec_prog: EdgeProgramManager of the lowered model
121+
:param exec_prog: ExecutorchProgramManager of the lowered model
115122
:param inputs: The inputs for the model.
116123
:param num_iter: The number of iterations to execute inference for timing.
117124
:param warmup_iter: The number of iterations to execute inference for warmup before timing.
@@ -122,8 +129,11 @@ def infer_model(
122129
runtime = Runtime.get()
123130
program = runtime.load_program(exec_prog.buffer)
124131
method = program.load_method("forward")
132+
if method is None:
133+
raise ValueError("Load method failed")
125134

126135
# Execute warmup
136+
out = None
127137
for _i in range(warmup_iter):
128138
out = method.execute(inputs)
129139

@@ -137,34 +147,38 @@ def infer_model(
137147

138148
# Save output tensor as raw tensor file
139149
if output_path:
150+
assert out is not None
140151
torch.save(out, output_path)
141152

142153
# Return average inference timing
143154
return time_total / float(num_iter)
144155

145156

146157
def validate_model(
147-
exec_prog: EdgeProgramManager, calibration_dataset: torch.utils.data.DataLoader
158+
exec_prog: ExecutorchProgramManager,
159+
calibration_dataset: torch.utils.data.DataLoader,
148160
) -> float:
149161
"""
150162
Validates the model using the calibration dataset.
151163
152-
:param exec_prog: EdgeProgramManager of the lowered model
164+
:param exec_prog: ExecutorchProgramManager of the lowered model
153165
:param calibration_dataset: A DataLoader containing calibration data.
154166
:return: The accuracy score of the model.
155167
"""
156168
# Load model from buffer
157169
runtime = Runtime.get()
158170
program = runtime.load_program(exec_prog.buffer)
159171
method = program.load_method("forward")
172+
if method is None:
173+
raise ValueError("Load method failed")
160174

161175
# Iterate over the dataset and run the executor
162-
predictions = []
176+
predictions: List[int] = []
163177
targets = []
164178
for _idx, data in enumerate(calibration_dataset):
165179
feature, target = data
166180
targets.extend(target)
167-
out = method.execute((feature,))
181+
out = list(method.execute((feature,)))
168182
predictions.extend(torch.stack(out).reshape(-1, 1000).argmax(-1))
169183

170184
# Check accuracy
@@ -213,12 +227,18 @@ def main( # noqa: C901
213227
model = load_model(suite, model_name)
214228
model = model.eval()
215229

230+
calibration_dataset: Optional[torch.utils.data.DataLoader] = None
231+
216232
if dataset_path:
217233
calibration_dataset = load_calibration_dataset(
218234
dataset_path, batch_size, suite, model, model_name
219235
)
220-
input_shape = tuple(next(iter(calibration_dataset))[0].shape)
221-
print(f"Input shape retrieved from the model config: {input_shape}")
236+
if calibration_dataset is not None:
237+
input_shape = tuple(next(iter(calibration_dataset))[0].shape)
238+
print(f"Input shape retrieved from the model config: {input_shape}")
239+
else:
240+
msg = "Quantization requires a valid calibration dataset"
241+
raise ValueError(msg)
222242
# Ensure input_shape is a tuple
223243
elif isinstance(input_shape, (list, tuple)):
224244
input_shape = tuple(input_shape)
@@ -240,7 +260,7 @@ def main( # noqa: C901
240260
# Export the model to the aten dialect
241261
aten_dialect: ExportedProgram = export(model, example_args)
242262

243-
if quantize:
263+
if quantize and calibration_dataset:
244264
if suite == "huggingface":
245265
msg = f"Quantization of {suite} models did not support yet."
246266
raise ValueError(msg)
@@ -251,20 +271,20 @@ def main( # noqa: C901
251271
raise ValueError(msg)
252272

253273
subset_size = 300
254-
batch_size = calibration_dataset.batch_size
274+
batch_size = calibration_dataset.batch_size or 1
255275
subset_size = (subset_size // batch_size) + int(subset_size % batch_size > 0)
256276

257277
def transform_fn(x):
258278
return x[0]
259279

260280
quantized_model = quantize_model(
261-
aten_dialect.module(),
281+
cast(torch.fx.GraphModule, aten_dialect.module()),
262282
calibration_dataset,
263283
subset_size=subset_size,
264284
transform_fn=transform_fn,
265285
)
266286

267-
aten_dialect: ExportedProgram = export(quantized_model, example_args)
287+
aten_dialect = export(quantized_model, example_args)
268288

269289
# Convert to edge dialect and lower the module to the backend with a custom partitioner
270290
compile_spec = [CompileSpec("device", device.encode())]
@@ -288,7 +308,7 @@ def transform_fn(x):
288308
exec_prog.write_to_file(file)
289309
print(f"Model exported and saved as {model_file_name} on {device}.")
290310

291-
if validate:
311+
if validate and calibration_dataset:
292312
if suite == "huggingface":
293313
msg = f"Validation of {suite} models did not support yet."
294314
raise ValueError(msg)

0 commit comments

Comments
 (0)