Skip to content

Commit 44a3e7d

Browse files
author
Yibing Liu
authored
Merge pull request #45 from kuke/fix_rename_error
Fix the fetch var bug when the arg is renamed
2 parents 51cad24 + de7e6df commit 44a3e7d

File tree

2 files changed

+18
-7
lines changed

2 files changed

+18
-7
lines changed

fluid/utils.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,3 +75,14 @@ def __call__(self, op):
7575

7676
# Instantiate the class to a callable object
7777
op_io_info = OpIOsInfo()
78+
79+
80+
def get_old_name(arg):
81+
"""Get the old rame for a possible renamed argument
82+
"""
83+
84+
idx = arg.find('@')
85+
if idx == -1:
86+
return arg
87+
else:
88+
return arg[:idx]

fluid_onnx/ops.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from onnx import TensorProto
1818
from onnx.helper import make_node, make_tensor
1919
from paddle.fluid.executor import fetch_var
20-
from fluid.utils import op_io_info
20+
from fluid.utils import op_io_info, get_old_name
2121
from fluid_onnx.variables import PADDLE_TO_ONNX_DTYPE
2222
"""
2323
Priority of ops (uniques) to figure out support for.
@@ -86,7 +86,7 @@ def argmin_op():
8686
def batch_norm_op(operator, block):
8787
inputs, attrs, outputs = op_io_info(operator)
8888

89-
x_shape = block.vars[inputs['X'][0]].shape
89+
x_shape = block.vars[get_old_name(inputs['X'][0])].shape
9090
reshape_node = None
9191
if len(x_shape) == 2:
9292
reshaped_x = [inputs['X'][0] + '@reshape_0']
@@ -164,7 +164,7 @@ def constant_op(var, scope):
164164
def conv2d_op(operator, block):
165165
inputs, attrs, outputs = op_io_info(operator)
166166

167-
kernel_shape = block.vars[inputs['Filter'][0]].shape
167+
kernel_shape = block.vars[get_old_name(inputs['Filter'][0])].shape
168168
conv2d = make_node(
169169
'Conv',
170170
inputs=inputs['Input'] + inputs['Filter'],
@@ -180,7 +180,7 @@ def conv2d_op(operator, block):
180180
def conv2d_transpose_op(operator, block):
181181
inputs, attrs, outputs = op_io_info(operator)
182182

183-
kernel_shape = block.vars[inputs['Filter'][0]].shape
183+
kernel_shape = block.vars[get_old_name(inputs['Filter'][0])].shape
184184
conv2d_transpose = make_node(
185185
'ConvTranspose',
186186
inputs=inputs['Input'] + inputs['Filter'],
@@ -222,8 +222,8 @@ def elementwise_ops(op_type, operator, block):
222222
"""
223223

224224
inputs, attrs, outputs = op_io_info(operator)
225-
rank_x = len(block.vars[inputs['X'][0]].shape)
226-
rank_y = len(block.vars[inputs['Y'][0]].shape)
225+
rank_x = len(block.vars[get_old_name(inputs['X'][0])].shape)
226+
rank_y = len(block.vars[get_old_name(inputs['Y'][0])].shape)
227227
axis = rank_x - rank_y if attrs['axis'] == -1 else attrs['axis']
228228
return make_node(
229229
op_type,
@@ -481,7 +481,7 @@ def reduce_ops(op_type, operator, block):
481481
"""
482482

483483
inputs, attrs, outputs = op_io_info(operator)
484-
rank = len(block.vars[inputs['X'][0]].shape)
484+
rank = len(block.vars[get_old_name(inputs['X'][0])].shape)
485485
dim = attrs['dim']
486486
axes = [dim if dim >= 0 else rank + dim]
487487
reduce_out = [outputs['Out'][0] + '@reduce_0'] if attrs[

0 commit comments

Comments
 (0)