|
10 | 10 | from typing import Any, Dict
|
11 | 11 |
|
12 | 12 | import torch
|
13 |
| -from executorch.exir import ExecutorchBackendConfig |
| 13 | +from executorch.exir import EdgeCompileConfig, ExecutorchBackendConfig |
14 | 14 | from executorch.exir.backend.test.op_partitioner_demo import AddMulPartitionerDemo
|
15 | 15 | from executorch.exir.dialects._ops import ops as exir_ops
|
16 | 16 | from executorch.exir.error import ExportError
|
|
26 | 26 | from executorch.extension.pybindings.portable_lib import (
|
27 | 27 | _load_for_executorch_from_buffer,
|
28 | 28 | )
|
29 |
| -from torch.export import export, ExportedProgram |
| 29 | +from torch.export import Dim, export, ExportedProgram |
30 | 30 |
|
31 | 31 | from torch.library import impl, Library
|
32 | 32 |
|
@@ -225,6 +225,38 @@ def test_edge_manager_transform(self):
|
225 | 225 | original_res, # x * y + x
|
226 | 226 | )
|
227 | 227 |
|
| 228 | + def test_issue_3659(self): |
| 229 | + |
| 230 | + class Mul(torch.nn.Module): |
| 231 | + def __init__(self): |
| 232 | + super(Mul, self).__init__() |
| 233 | + |
| 234 | + def forward(self, x: torch.Tensor, y: torch.Tensor): |
| 235 | + return torch.matmul(x, y) |
| 236 | + |
| 237 | + def get_eager_model(self) -> torch.nn.Module: |
| 238 | + return self |
| 239 | + |
| 240 | + def get_example_inputs(self): |
| 241 | + return (torch.randn(1, 3, 10), torch.randn(1, 10, 3)) |
| 242 | + |
| 243 | + def get_dynamic_shapes(self): |
| 244 | + dim1_x = Dim("Dot_dim1_x", min=2, max=100) |
| 245 | + dim2_x = Dim("Dot_dim2_x", min=2, max=100) |
| 246 | + return {"x": {1: dim1_x, 2: dim2_x}, "y": {1: dim2_x, 2: dim1_x}} |
| 247 | + |
| 248 | + model = Mul() |
| 249 | + ep = torch.export.export( |
| 250 | + model, model.get_example_inputs(), dynamic_shapes=model.get_dynamic_shapes() |
| 251 | + ) |
| 252 | + |
| 253 | + to_edge( |
| 254 | + ep, |
| 255 | + compile_config=EdgeCompileConfig( |
| 256 | + _check_ir_validity=True, |
| 257 | + ), |
| 258 | + ) |
| 259 | + |
228 | 260 | def test_transform_dict_api(self):
|
229 | 261 | edge_manager = to_edge(get_exported_programs(), get_config_methods())
|
230 | 262 |
|
|
0 commit comments