@@ -1026,6 +1026,52 @@ def local_Shape_of_SpecifyShape(fgraph, node):
1026
1026
return [stack (shape ).astype (np .int64 )]
1027
1027
1028
1028
1029
+ @register_canonicalize
1030
+ @register_specialize
1031
+ @node_rewriter ([SpecifyShape ])
1032
+ def local_specify_shape_lift (fgraph , node ):
1033
+ """Lift SpecifyShape of Elemwise towards the inputs."""
1034
+ inp , * shape = node .inputs
1035
+ if inp .owner and isinstance (inp .owner .op , Elemwise ):
1036
+ if len (inp .owner .outputs ) != 1 :
1037
+ return None
1038
+
1039
+ elem_inps = inp .owner .inputs
1040
+ if len (elem_inps ) == 1 :
1041
+ new_elem_inps = [specify_shape (elem_inps [0 ], shape )]
1042
+ else :
1043
+ # Rewrite does not support case where specify_shape provides new broadcastable information,
1044
+ # As that may require a specify_shape for each input
1045
+ out_broadcastable = node .outputs [0 ].type .broadcastable
1046
+ if out_broadcastable != inp .type .broadcastable :
1047
+ return None
1048
+
1049
+ # All non-broadcastable dimensions of inputs must match the non-broadcastbale specify_shape dims
1050
+ # We look for a sufficient input to assign all the specify_shape dims
1051
+ # We could consider distributing the SpecifyShape across multiple inputs, when none is sufficient
1052
+
1053
+ nonbcast_dims = {
1054
+ i
1055
+ for i , (dim , bcast ) in enumerate (zip (shape , out_broadcastable ))
1056
+ if (not bcast and not NoneConst .equals (dim ))
1057
+ }
1058
+ new_elem_inps = elem_inps .copy ()
1059
+ for i , elem_inp in enumerate (elem_inps ):
1060
+ if all (
1061
+ bcast_dim is False
1062
+ for dim , bcast_dim in enumerate (elem_inp .type .broadcastable )
1063
+ if dim in nonbcast_dims
1064
+ ):
1065
+ new_elem_inps [i ] = specify_shape (elem_inp , shape )
1066
+ break
1067
+ else : # no-break, no sufficient candidate found
1068
+ return None
1069
+
1070
+ new_out = inp .owner .op .make_node (* new_elem_inps ).outputs
1071
+ copy_stack_trace (node .outputs , new_out )
1072
+ return new_out
1073
+
1074
+
1029
1075
@register_useless
1030
1076
@register_canonicalize
1031
1077
@node_rewriter ([Shape_i ])
0 commit comments