1717from onnx import TensorProto
1818from onnx .helper import make_node , make_tensor
1919from paddle .fluid .executor import fetch_var
20- from fluid .utils import op_io_info
20+ from fluid .utils import op_io_info , get_old_name
2121from fluid_onnx .variables import PADDLE_TO_ONNX_DTYPE
2222"""
2323Priority of ops (uniques) to figure out support for.
@@ -86,7 +86,7 @@ def argmin_op():
8686def batch_norm_op (operator , block ):
8787 inputs , attrs , outputs = op_io_info (operator )
8888
89- x_shape = block .vars [inputs ['X' ][0 ]].shape
89+ x_shape = block .vars [get_old_name ( inputs ['X' ][0 ]) ].shape
9090 reshape_node = None
9191 if len (x_shape ) == 2 :
9292 reshaped_x = [inputs ['X' ][0 ] + '@reshape_0' ]
@@ -164,7 +164,7 @@ def constant_op(var, scope):
164164def conv2d_op (operator , block ):
165165 inputs , attrs , outputs = op_io_info (operator )
166166
167- kernel_shape = block .vars [inputs ['Filter' ][0 ]].shape
167+ kernel_shape = block .vars [get_old_name ( inputs ['Filter' ][0 ]) ].shape
168168 conv2d = make_node (
169169 'Conv' ,
170170 inputs = inputs ['Input' ] + inputs ['Filter' ],
@@ -180,7 +180,7 @@ def conv2d_op(operator, block):
180180def conv2d_transpose_op (operator , block ):
181181 inputs , attrs , outputs = op_io_info (operator )
182182
183- kernel_shape = block .vars [inputs ['Filter' ][0 ]].shape
183+ kernel_shape = block .vars [get_old_name ( inputs ['Filter' ][0 ]) ].shape
184184 conv2d_transpose = make_node (
185185 'ConvTranspose' ,
186186 inputs = inputs ['Input' ] + inputs ['Filter' ],
@@ -222,8 +222,8 @@ def elementwise_ops(op_type, operator, block):
222222 """
223223
224224 inputs , attrs , outputs = op_io_info (operator )
225- rank_x = len (block .vars [inputs ['X' ][0 ]].shape )
226- rank_y = len (block .vars [inputs ['Y' ][0 ]].shape )
225+ rank_x = len (block .vars [get_old_name ( inputs ['X' ][0 ]) ].shape )
226+ rank_y = len (block .vars [get_old_name ( inputs ['Y' ][0 ]) ].shape )
227227 axis = rank_x - rank_y if attrs ['axis' ] == - 1 else attrs ['axis' ]
228228 return make_node (
229229 op_type ,
@@ -481,7 +481,7 @@ def reduce_ops(op_type, operator, block):
481481 """
482482
483483 inputs , attrs , outputs = op_io_info (operator )
484- rank = len (block .vars [inputs ['X' ][0 ]].shape )
484+ rank = len (block .vars [get_old_name ( inputs ['X' ][0 ]) ].shape )
485485 dim = attrs ['dim' ]
486486 axes = [dim if dim >= 0 else rank + dim ]
487487 reduce_out = [outputs ['Out' ][0 ] + '@reduce_0' ] if attrs [
0 commit comments