Skip to content

Commit 7828080

Browse files
author
Jorge Pineda
committed
Update on "[ET-VK] Move files using the Vulkan API to vk_api/"
and move from namespace `api` to `vkapi`. This gave me a major headache in the number of places to update. This stack orgnaizes ET-VK neatly into three abstraction levels, both by folder and by namespace: 1. `namespace vkcompute` and `graph/`: for operator computation implementation and scheduling. 2. `namespace vkcompute::api` and `graph/api/`: for tensor objects (and other objects wrapping our VulkanBuffer/VulkanImage). 3. `namespace vkcompute::vkapi` and `graph/api/vk_api/`: for direct users of Vulkan API Additionally, we have - `namespace vkcompute::utils` and `graph/api/utils/`: for utils used by both namespace `api` and `vkapi` Differential Revision: [D59281539](https://our.internmc.facebook.com/intern/diff/D59281539/) [ghstack-poisoned]
2 parents 1f0fe59 + e5e34be commit 7828080

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+617
-361
lines changed
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
name: 🐛 Bug Report
2+
description: Create a report to help us reproduce and fix the bug
3+
4+
body:
5+
- type: markdown
6+
attributes:
7+
value: >
8+
#### Before submitting a bug, please make sure the issue hasn't been already addressed by searching through [the existing and past issues](https://github.com/pytorch/executorch/issues?q=is%3Aissue+sort%3Acreated-desc+).
9+
- type: textarea
10+
attributes:
11+
label: 🐛 Describe the bug
12+
description: |
13+
Please provide a clear and concise description of what the bug is.
14+
15+
If relevant, add a minimal example so that we can reproduce the error by running the code.
16+
17+
If the code is too long (hopefully, it isn't), feel free to put it in a public gist and link it in the issue: https://gist.github.com.
18+
19+
Please also paste or describe the results you observe instead of the expected results. If you observe an error, please paste the error message including the **full** traceback of the exception. It may be relevant to wrap error messages in ```` ```triple quotes blocks``` ````.
20+
placeholder: |
21+
A clear and concise description of what the bug is.
22+
23+
```python
24+
# Sample code to reproduce the problem
25+
```
26+
27+
```
28+
The error message you got, with the full traceback.
29+
```
30+
validations:
31+
required: true
32+
- type: textarea
33+
attributes:
34+
label: Versions
35+
description: |
36+
Please run the following and paste the output below.
37+
```sh
38+
wget https://raw.githubusercontent.com/pytorch/pytorch/main/torch/utils/collect_env.py
39+
# For security purposes, please check the contents of collect_env.py before running it.
40+
python collect_env.py
41+
```
42+
validations:
43+
required: true
44+
- type: markdown
45+
attributes:
46+
value: >
47+
Thanks for contributing 🎉!

.github/ISSUE_TEMPLATE/config.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
blank_issues_enabled: true
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
name: 📚 Documentation
2+
description: Report an issue related to https://pytorch.org/executorch/stable/index.html
3+
4+
body:
5+
- type: textarea
6+
attributes:
7+
label: 📚 The doc issue
8+
description: >
9+
A clear and concise description of what content in https://pytorch.org/executorch/stable/index.html is an issue.
10+
validations:
11+
required: true
12+
- type: textarea
13+
attributes:
14+
label: Suggest a potential alternative/fix
15+
description: >
16+
Tell us how we could improve the documentation in this regard.
17+
- type: markdown
18+
attributes:
19+
value: >
20+
Thanks for contributing 🎉!
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
name: 🚀 Feature request
2+
description: Submit a proposal/request for a new ExecuTorch feature
3+
4+
body:
5+
- type: textarea
6+
attributes:
7+
label: 🚀 The feature, motivation and pitch
8+
description: >
9+
A clear and concise description of the feature proposal. Please outline the motivation for the proposal. Is your feature request related to a specific problem? e.g., *"I'm working on X and would like Y to be possible"*. If this is related to another GitHub issue, please link here too.
10+
validations:
11+
required: true
12+
- type: textarea
13+
attributes:
14+
label: Alternatives
15+
description: >
16+
A description of any alternative solutions or features you've considered, if any.
17+
- type: textarea
18+
attributes:
19+
label: Additional context
20+
description: >
21+
Add any other context or screenshots about the feature request.
22+
- type: textarea
23+
attributes:
24+
label: RFC (Optional)
25+
description: >
26+
Explain the design in enough detail.
27+
- type: markdown
28+
attributes:
29+
value: >
30+
Thanks for contributing 🎉!

backends/arm/quantizer/arm_quantizer.py

Lines changed: 65 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
)
4242
from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor
4343
from torch.ao.quantization.quantizer import QuantizationSpec, Quantizer
44-
from torch.fx import Node
44+
from torch.fx import GraphModule, Node
4545

4646
__all__ = [
4747
"ArmQuantizer",
@@ -172,33 +172,40 @@ def _get_supported_config_and_operators() -> List[OperatorConfig]:
172172
return _get_supported_symmetric_config_and_operators()
173173

174174

175-
def _get_module_name_filter(module_name: str):
175+
NodeFilterType = Callable[[Node], bool]
176+
"""Type for a Node Filter used by annotators. A Node filter is a function that takes
177+
a Node and returns whether the node should be annotated or not.
178+
"""
179+
180+
181+
def _get_module_name_filter(module_name: str) -> NodeFilterType:
176182
"""Get the module_name_filter function for a given module name, the filter accepts
177183
a node and checks if the node comes from a module that has certain module name
178184
179185
For example:
180186
node: linear_op = call_function[...](...) # comes from a module with name blocks.sub.linear1
181187
182-
183188
>> module_name_filter = _get_module_name_filter("blocks.sub")
184189
>> print(module_name_filter(node))
185190
True # the node is from "blocks.sub" based on the fully qualified name "blocks.sub.linear1"
186191
"""
187192

193+
name_start = len("L['self'].")
194+
188195
def module_name_filter(n: Node) -> bool:
189-
# example: {
196+
# node_stack example: {
190197
# 'L__self___sub': ("L['self'].sub", <class '....Sub'>),
191198
# 'L__self___sub_linear': ("L['self'].sub.linear", <class 'torch.nn.modules.linear.Linear'>)
192199
# }
193200
# get_attr nodes doesn't have nn_module_stack?
194201
nn_module_stack = n.meta.get("nn_module_stack", {})
195-
names = [n[len("L['self'].") :] for n, klass in nn_module_stack.values()]
202+
names = [name[name_start:] for name, _ in nn_module_stack.values()]
196203
return module_name in names
197204

198205
return module_name_filter
199206

200207

201-
def _get_module_type_filter(tp: Callable):
208+
def _get_module_type_filter(tp: Callable) -> NodeFilterType:
202209
"""Get the module_type_filter function for a given module type, the filter accepts
203210
a node and checks if the node comes from a module that has certain module type
204211
@@ -212,7 +219,7 @@ def _get_module_type_filter(tp: Callable):
212219
"""
213220

214221
def module_type_filter(n: Node) -> bool:
215-
# example: {
222+
# node_stack example: {
216223
# 'L__self___sub': ("L['self'].sub", <class '....Sub'>),
217224
# 'L__self___sub_linear': ("L['self'].sub.linear", <class 'torch.nn.modules.linear.Linear'>)
218225
# }
@@ -225,7 +232,7 @@ def module_type_filter(n: Node) -> bool:
225232

226233
def _get_not_module_type_or_name_filter(
227234
tp_list: List[Callable], module_name_list: List[str]
228-
) -> Callable[[Node], bool]:
235+
) -> NodeFilterType:
229236
module_type_filters = [_get_module_type_filter(tp) for tp in tp_list]
230237
module_name_list_filters = [_get_module_name_filter(m) for m in module_name_list]
231238

@@ -238,62 +245,32 @@ def not_module_type_or_name_filter(n: Node) -> bool:
238245
class ArmQuantizer(Quantizer):
239246
supported_config_and_operators = _get_supported_config_and_operators()
240247

241-
# A list of supported static quantization ops (both PTQ and QAT)
248+
# A list of supported static quantization annotators, in order of application.
249+
# For example, fusions come before singular ops.
242250
# The name must match the name used when registering the annotator.
243-
# Preserve the order that fusions come before singular ops
244-
STATIC_OPS = ["linear", "conv", "adaptive_avg_pool2d", "max_pool2d", "add", "mul"]
245-
246-
def __init__(self):
251+
STATIC_ANNOTATION_ORDER = [
252+
"linear",
253+
"conv",
254+
"adaptive_avg_pool2d",
255+
"max_pool2d",
256+
"add",
257+
"mul",
258+
]
259+
260+
def __init__(self) -> None:
247261
super().__init__()
248262
self.global_config: Optional[QuantizationConfig] = None
249-
self.operator_type_config: Dict[
250-
torch._ops.OpOverloadPacket, Optional[QuantizationConfig]
251-
] = {}
252263
self.module_type_config: Dict[Callable, Optional[QuantizationConfig]] = {}
253264
self.module_name_config: Dict[str, Optional[QuantizationConfig]] = {}
254265

255-
@classmethod
256-
def get_supported_quantization_configs(cls) -> List[QuantizationConfig]:
257-
op_configs: Set[QuantizationConfig] = set({})
258-
for spec, _ in cls.supported_config_and_operators:
259-
op_configs.add(spec)
260-
return list(op_configs)
261-
262-
@classmethod
263-
def get_supported_operator_for_quantization_config(
264-
cls, quantization_config: Optional[QuantizationConfig]
265-
) -> List[OperatorPatternType]:
266-
if quantization_config is None:
267-
all_ops = []
268-
for _, ops in cls.supported_config_and_operators:
269-
all_ops.extend(ops)
270-
return all_ops
271-
272-
for config, ops in cls.supported_config_and_operators:
273-
# note: this assumes each entry in cls.supported_spec_and_operators
274-
# corresponds to one spec, e.g. we don't have
275-
# [(spec1, op_list1), (spec1, op_list2), (spec2, op_list3)]
276-
# where the first and second entry have the same spec but did not
277-
# merge the op list
278-
if config == quantization_config:
279-
return ops
280-
return []
281-
282266
def set_global(self, quantization_config: QuantizationConfig) -> ArmQuantizer:
267+
"""Set quantization_config for submodules that are not already annotated by name or type filters."""
283268
self.global_config = quantization_config
284269
return self
285270

286-
def set_operator_type(
287-
self,
288-
operator_type: torch._ops.OpOverloadPacket,
289-
quantization_config: QuantizationConfig,
290-
) -> ArmQuantizer:
291-
self.operator_type_config[operator_type] = quantization_config
292-
return self
293-
294271
def set_module_type(
295272
self, module_type: Callable, quantization_config: QuantizationConfig
296-
):
273+
) -> ArmQuantizer:
297274
"""Set quantization_config for a submodule with type: `module_type`, for example:
298275
quantizer.set_module_name(Sub) or quantizer.set_module_name(nn.Linear), it will quantize all supported operator/operator
299276
patterns in the submodule with this module type with the given `quantization_config`
@@ -303,7 +280,7 @@ def set_module_type(
303280

304281
def set_module_name(
305282
self, module_name: str, quantization_config: Optional[QuantizationConfig]
306-
):
283+
) -> ArmQuantizer:
307284
"""Set quantization_config for a submodule with name: `module_name`, for example:
308285
quantizer.set_module_name("blocks.sub"), it will quantize all supported operator/operator
309286
patterns in the submodule with this module name with the given `quantization_config`
@@ -314,15 +291,13 @@ def set_module_name(
314291
self.module_name_config[module_name] = quantization_config
315292
return self
316293

317-
def transform_for_annotation(
318-
self, model: torch.fx.GraphModule
319-
) -> torch.fx.GraphModule:
294+
def transform_for_annotation(self, model: GraphModule) -> GraphModule:
320295
"""An initial pass for transforming the graph to prepare it for annotation.
321296
Currently transforms scalar values to tensor attributes.
322297
"""
323298
return convert_scalars_to_attrs(model)
324299

325-
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
300+
def annotate(self, model: GraphModule) -> GraphModule:
326301
"""Performs the quantization annotation on the graph.
327302
Currently only does static quantization annotation.
328303
Args:
@@ -336,10 +311,10 @@ def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
336311

337312
def _annotate_all_static_patterns(
338313
self,
339-
model: torch.fx.GraphModule,
314+
model: GraphModule,
340315
quantization_config: Optional[QuantizationConfig],
341316
filter_fn: Optional[Callable[[Node], bool]] = None,
342-
) -> torch.fx.GraphModule:
317+
) -> GraphModule:
343318
"""Loops over all STATIC_OPS and runs the corresponding registred annotator.
344319
Args:
345320
model: The model to annotate statically.
@@ -353,13 +328,13 @@ def _annotate_all_static_patterns(
353328
if quantization_config is None:
354329
return model
355330

356-
for op in self.STATIC_OPS:
331+
for op in self.STATIC_ANNOTATION_ORDER:
357332
OP_TO_ANNOTATOR[op](model, quantization_config, filter_fn)
358333
return model
359334

360335
def _annotate_for_static_quantization_config(
361-
self, model: torch.fx.GraphModule
362-
) -> torch.fx.GraphModule:
336+
self, model: GraphModule
337+
) -> GraphModule:
363338
"""Matches the correct QuantizationConfig with the correct module using a filter
364339
when running _annotate_all_static_patterns.
365340
"""
@@ -382,9 +357,36 @@ def _annotate_for_static_quantization_config(
382357
)
383358
return model
384359

385-
def validate(self, model: torch.fx.GraphModule) -> None:
360+
def validate(self, model: GraphModule) -> None:
386361
pass
387362

388363
@classmethod
389364
def get_supported_operators(cls) -> List[OperatorConfig]:
390365
return cls.supported_config_and_operators
366+
367+
@classmethod
368+
def get_supported_quantization_configs(cls) -> List[QuantizationConfig]:
369+
op_configs: Set[QuantizationConfig] = set({})
370+
for spec, _ in cls.supported_config_and_operators:
371+
op_configs.add(spec)
372+
return list(op_configs)
373+
374+
@classmethod
375+
def get_supported_operator_for_quantization_config(
376+
cls, quantization_config: Optional[QuantizationConfig]
377+
) -> List[OperatorPatternType]:
378+
if quantization_config is None:
379+
all_ops = []
380+
for _, ops in cls.supported_config_and_operators:
381+
all_ops.extend(ops)
382+
return all_ops
383+
384+
for config, ops in cls.supported_config_and_operators:
385+
# note: this assumes each entry in cls.supported_spec_and_operators
386+
# corresponds to one spec, e.g. we don't have
387+
# [(spec1, op_list1), (spec1, op_list2), (spec2, op_list3)]
388+
# where the first and second entry have the same spec but did not
389+
# merge the op list
390+
if config == quantization_config:
391+
return ops
392+
return []

0 commit comments

Comments
 (0)