|
32 | 32 | from executorch.extension.pybindings.portable_lib import (
|
33 | 33 | _load_for_executorch_from_buffer,
|
34 | 34 | )
|
35 |
| -from torch.export import export, ExportedProgram |
| 35 | +from torch.export import Dim, export, ExportedProgram |
36 | 36 | from torch.export._trace import _export
|
37 | 37 |
|
38 | 38 | from torch.library import impl, Library
|
@@ -272,6 +272,38 @@ def test_edge_manager_transform(self):
|
272 | 272 | original_res, # x * y + x
|
273 | 273 | )
|
274 | 274 |
|
| 275 | + def test_issue_3659(self): |
| 276 | + |
| 277 | + class Mul(torch.nn.Module): |
| 278 | + def __init__(self): |
| 279 | + super(Mul, self).__init__() |
| 280 | + |
| 281 | + def forward(self, x: torch.Tensor, y: torch.Tensor): |
| 282 | + return torch.matmul(x, y) |
| 283 | + |
| 284 | + def get_eager_model(self) -> torch.nn.Module: |
| 285 | + return self |
| 286 | + |
| 287 | + def get_example_inputs(self): |
| 288 | + return (torch.randn(1, 3, 10), torch.randn(1, 10, 3)) |
| 289 | + |
| 290 | + def get_dynamic_shapes(self): |
| 291 | + dim1_x = Dim("Dot_dim1_x", min=2, max=100) |
| 292 | + dim2_x = Dim("Dot_dim2_x", min=2, max=100) |
| 293 | + return {"x": {1: dim1_x, 2: dim2_x}, "y": {1: dim2_x, 2: dim1_x}} |
| 294 | + |
| 295 | + model = Mul() |
| 296 | + ep = torch.export.export( |
| 297 | + model, model.get_example_inputs(), dynamic_shapes=model.get_dynamic_shapes() |
| 298 | + ) |
| 299 | + |
| 300 | + to_edge( |
| 301 | + ep, |
| 302 | + compile_config=EdgeCompileConfig( |
| 303 | + _check_ir_validity=True, |
| 304 | + ), |
| 305 | + ) |
| 306 | + |
275 | 307 | def test_transform_dict_api(self):
|
276 | 308 | edge_manager = to_edge(get_exported_programs(), get_config_methods())
|
277 | 309 |
|
|
0 commit comments