Skip to content

Commit b5abb02

Browse files
kimishpatelfacebook-github-bot
authored andcommitted
Fix delegate export scripts
Summary: Two changes: - Delegate module example: rename get_random_inputs to get_example_inputs to be consistent with other model examples - Fix model.py Reviewed By: mergennachin Differential Revision: D47918752 fbshipit-source-id: a73e02d29bf44e2d36f767c92dae384ddaaadf4e
1 parent c8f9bd5 commit b5abb02

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

examples/export/export_and_delegate.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -116,11 +116,11 @@ def forward(self, a, x, b):
116116
z = y + b
117117
return z
118118

119-
def get_random_inputs(self):
119+
def get_example_inputs(self):
120120
return (torch.randn(2, 2), torch.randn(2, 2), torch.randn(2, 2))
121121

122122
m = Model()
123-
edge = exir.capture(m, m.get_random_inputs(), _CAPTURE_CONFIG).to_edge(
123+
edge = exir.capture(m, m.get_example_inputs(), _CAPTURE_CONFIG).to_edge(
124124
_EDGE_COMPILE_CONFIG
125125
)
126126
print("Exported graph:\n", edge.exported_program.graph)

examples/models/models.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def get_example_inputs():
6868
return (torch.ones(2, 2), 2 * torch.ones(2, 2), 3 * torch.ones(2, 2))
6969

7070
def get_compile_spec(self):
71-
max_value = self.get_random_inputs()[0].shape[0]
71+
max_value = self.get_example_inputs()[0].shape[0]
7272
return [CompileSpec("max_value", bytes([max_value]))]
7373

7474

0 commit comments

Comments
 (0)