|
10 | 10 | from typing import Any, Dict, List, Optional, Tuple
|
11 | 11 |
|
12 | 12 | import torch
|
| 13 | +import torch._export as export |
13 | 14 | from executorch import exir
|
14 | 15 | from executorch.backends.xnnpack.partition.xnnpack_partitioner import (
|
15 | 16 | XnnpackFloatingPointPartitioner,
|
@@ -145,23 +146,23 @@ def __init__(
|
145 | 146 |
|
146 | 147 | self.quantizer.set_global(self.quantization_config)
|
147 | 148 |
|
148 |
| - self.converted_program = None |
| 149 | + self.converted_graph = None |
149 | 150 |
|
150 | 151 | def run(
|
151 |
| - self, artifact: ExirExportedProgram, inputs: Optional[Tuple[torch.Tensor]] |
| 152 | + self, artifact: torch.nn.Module, inputs: Optional[Tuple[torch.Tensor]] |
152 | 153 | ) -> None:
|
153 |
| - prepared = prepare_pt2e(artifact.exported_program.graph_module, self.quantizer) |
| 154 | + captured_graph = export.capture_pre_autograd_graph(artifact, inputs) |
| 155 | + prepared = prepare_pt2e(captured_graph, self.quantizer) |
154 | 156 | converted = convert_pt2e(prepared)
|
155 |
| - artifact.exported_program._graph_module = converted |
156 |
| - self.converted_program = artifact |
| 157 | + self.converted_graph = converted |
157 | 158 |
|
158 | 159 | @property
|
159 |
| - def artifact(self) -> ExirExportedProgram: |
160 |
| - return self.converted_program |
| 160 | + def artifact(self) -> torch.fx.GraphModule: |
| 161 | + return self.converted_graph |
161 | 162 |
|
162 | 163 | @property
|
163 | 164 | def graph_module(self) -> str:
|
164 |
| - return self.converted_program.exported_program.graph_module |
| 165 | + return self.converted_graph |
165 | 166 |
|
166 | 167 |
|
167 | 168 | @register_stage
|
@@ -274,12 +275,11 @@ def __init__(
|
274 | 275 | self.inputs = inputs
|
275 | 276 | self.stages: Dict[str, Stage] = OrderedDict.fromkeys(list(_stages_.keys()))
|
276 | 277 | self.pipeline = {
|
| 278 | + self._stage_name(Quantize2): [self._stage_name(Export)], |
277 | 279 | self._stage_name(Quantize): [self._stage_name(Export)],
|
278 | 280 | self._stage_name(Export): [
|
279 |
| - self._stage_name(Quantize2), |
280 | 281 | self._stage_name(ToEdge),
|
281 | 282 | ],
|
282 |
| - self._stage_name(Quantize2): [self._stage_name(ToEdge)], |
283 | 283 | self._stage_name(ToEdge): [self._stage_name(Partition)],
|
284 | 284 | # TODO Make this Stage optional
|
285 | 285 | self._stage_name(Partition): [self._stage_name(ToExecutorch)],
|
|
0 commit comments