@@ -879,135 +879,142 @@ def local_set_to_inc_subtensor(fgraph, node):
879
879
880
880
@register_canonicalize
881
881
@register_specialize
882
- @local_optimizer ([Subtensor , AdvancedSubtensor1 ])
882
+ @local_optimizer ([Subtensor ])
883
883
def local_useless_subtensor (fgraph , node ):
884
- """
885
- Remove Subtensor/AdvancedSubtensor1 if it takes the full input. In the
886
- AdvancedSubtensor1 case, the full input is taken when the indices are
887
- equivalent to `arange(0, input.shape[0], 1)` using either an explicit
888
- list/vector or the ARange op.
889
-
890
- """
884
+ """Remove `Subtensor` if it takes the full input."""
891
885
# This optimization needs ShapeOpt and fgraph.shape_feature
892
886
if not hasattr (fgraph , "shape_feature" ):
893
887
return
894
888
895
889
shape_of = fgraph .shape_feature .shape_of
896
890
897
- if isinstance (node .op , Subtensor ):
898
- cdata = get_constant_idx (
899
- node .op .idx_list ,
900
- node .inputs ,
901
- allow_partial = True ,
902
- only_process_constants = True ,
903
- )
904
- for pos , idx in enumerate (cdata ):
905
- if not isinstance (idx , slice ):
906
- # If idx is not a slice, this means we remove this dimension
907
- # from the output, so the subtensor is not useless
908
- return False
909
- if idx .start is not None and idx .start != 0 :
910
- # If the start of the slice is different from 0, or is a
911
- # variable, then we assume the subtensor is not useless
912
- return False
913
- if idx .step is not None and idx .step != 1 :
914
- # If we are going backwards, or skipping elements, then this
915
- # is not a useless subtensor
916
- return False
917
-
918
- for pos , idx in enumerate (cdata ):
919
-
920
- length_pos = shape_of [node .inputs [0 ]][pos ]
921
-
922
- if isinstance (idx .stop , (int , np .integer )):
923
- length_pos_data = sys .maxsize
924
- try :
925
- length_pos_data = get_scalar_constant_value (
926
- length_pos , only_process_constants = True
927
- )
928
- except NotScalarConstantError :
929
- pass
930
-
931
- if idx .stop < length_pos_data :
932
- return False
933
- elif isinstance (idx .stop , Variable ):
934
- length_pos_shape_i = idx .stop
935
- # length_pos is a tensor variable, but length_pos_shape_i
936
- # is a scalar variable. We try to see if they represent
937
- # the same underlying variable.
938
- if length_pos_shape_i .owner and isinstance (
939
- length_pos_shape_i .owner .op , ScalarFromTensor
940
- ):
941
- length_pos_shape_i = length_pos_shape_i .owner .inputs [0 ]
942
- elif length_pos .owner and isinstance (
943
- length_pos .owner .op , TensorFromScalar
944
- ):
945
- length_pos = length_pos .owner .inputs [0 ]
946
- else :
947
- # We did not find underlying variables of the same type
948
- return False
949
-
950
- # The type can be different: int32 vs int64. length_pos
951
- # should always be int64 as that is what the shape
952
- # tracker keep. Subtensor accept any scalar int{8,16,32,64}
953
- # as index type.
954
- assert str (length_pos .type .dtype ) == "int64"
955
- assert str (length_pos_shape_i .type .dtype ) in [
956
- "int8" ,
957
- "int16" ,
958
- "int32" ,
959
- "int64" ,
960
- ]
961
-
962
- # length_pos_shape_i cannot be None
963
- if length_pos_shape_i != length_pos :
964
- return False
965
- elif idx .stop is None :
966
- pass
967
- else :
968
- return False
969
- elif isinstance (node .op , AdvancedSubtensor1 ):
970
- # get length of the indexed tensor along the first axis
971
- try :
972
- length = get_scalar_constant_value (
973
- shape_of [node .inputs [0 ]][0 ], only_process_constants = True
974
- )
975
- except NotScalarConstantError :
891
+ cdata = get_constant_idx (
892
+ node .op .idx_list ,
893
+ node .inputs ,
894
+ allow_partial = True ,
895
+ only_process_constants = True ,
896
+ )
897
+ for pos , idx in enumerate (cdata ):
898
+ if not isinstance (idx , slice ):
899
+ # If idx is not a slice, this means we remove this dimension
900
+ # from the output, so the subtensor is not useless
901
+ return False
902
+ if idx .start is not None and idx .start != 0 :
903
+ # If the start of the slice is different from 0, or is a
904
+ # variable, then we assume the subtensor is not useless
905
+ return False
906
+ if idx .step is not None and idx .step != 1 :
907
+ # If we are going backwards, or skipping elements, then this
908
+ # is not a useless subtensor
976
909
return False
977
910
978
- # get index (which must be a vector by definition)
979
- idx = node .inputs [1 ]
911
+ for pos , idx in enumerate (cdata ):
980
912
981
- # `idx` must be equivalent to [0,1,...,shape[0] - 1] to qualify for
982
- # this optimization
983
- if isinstance (idx , Constant ):
984
- idx = idx .value
985
- if len (idx ) != length :
986
- return False
987
- if np .any (idx != np .arange (length )):
988
- return False
989
- elif idx .owner is not None and isinstance (idx .owner .op , ARange ):
913
+ length_pos = shape_of [node .inputs [0 ]][pos ]
914
+
915
+ if isinstance (idx .stop , (int , np .integer )):
916
+ length_pos_data = sys .maxsize
990
917
try :
991
- start , stop , step = map (
992
- lambda x : get_scalar_constant_value (x , only_process_constants = True ),
993
- idx .owner .inputs ,
918
+ length_pos_data = get_scalar_constant_value (
919
+ length_pos , only_process_constants = True
994
920
)
995
921
except NotScalarConstantError :
996
- return False
922
+ pass
997
923
998
- if start != 0 :
924
+ if idx . stop < length_pos_data :
999
925
return False
1000
- if stop != length :
926
+ elif isinstance (idx .stop , Variable ):
927
+ length_pos_shape_i = idx .stop
928
+ # length_pos is a tensor variable, but length_pos_shape_i
929
+ # is a scalar variable. We try to see if they represent
930
+ # the same underlying variable.
931
+ if length_pos_shape_i .owner and isinstance (
932
+ length_pos_shape_i .owner .op , ScalarFromTensor
933
+ ):
934
+ length_pos_shape_i = length_pos_shape_i .owner .inputs [0 ]
935
+ elif length_pos .owner and isinstance (length_pos .owner .op , TensorFromScalar ):
936
+ length_pos = length_pos .owner .inputs [0 ]
937
+ else :
938
+ # We did not find underlying variables of the same type
1001
939
return False
1002
- if step != 1 :
940
+
941
+ # The type can be different: int32 vs int64. length_pos
942
+ # should always be int64 as that is what the shape
943
+ # tracker keep. Subtensor accept any scalar int{8,16,32,64}
944
+ # as index type.
945
+ assert str (length_pos .type .dtype ) == "int64"
946
+ assert str (length_pos_shape_i .type .dtype ) in [
947
+ "int8" ,
948
+ "int16" ,
949
+ "int32" ,
950
+ "int64" ,
951
+ ]
952
+
953
+ # length_pos_shape_i cannot be None
954
+ if length_pos_shape_i != length_pos :
1003
955
return False
956
+ elif idx .stop is None :
957
+ continue
1004
958
else :
1005
959
return False
960
+
961
+ return [node .inputs [0 ]]
962
+
963
+
964
+ @register_canonicalize
965
+ @register_specialize
966
+ @local_optimizer ([AdvancedSubtensor1 ])
967
+ def local_useless_AdvancedSubtensor1 (fgraph , node ):
968
+ """Remove `AdvancedSubtensor1` if it takes the full input.
969
+
970
+ In the `AdvancedSubtensor1` case, the full input is taken when the indices
971
+ are equivalent to ``arange(0, input.shape[0], 1)`` using either an explicit
972
+ list/vector or the `ARange` `Op`.
973
+
974
+ """
975
+ # This optimization needs ShapeOpt and fgraph.shape_feature
976
+ if not hasattr (fgraph , "shape_feature" ):
977
+ return
978
+
979
+ shape_of = fgraph .shape_feature .shape_of
980
+
981
+ # get length of the indexed tensor along the first axis
982
+ try :
983
+ length = get_scalar_constant_value (
984
+ shape_of [node .inputs [0 ]][0 ], only_process_constants = True
985
+ )
986
+ except NotScalarConstantError :
987
+ return False
988
+
989
+ # get index (which must be a vector by definition)
990
+ idx = node .inputs [1 ]
991
+
992
+ # `idx` must be equivalent to [0,1,...,shape[0] - 1] to qualify for
993
+ # this optimization
994
+ if isinstance (idx , Constant ):
995
+ idx = idx .value
996
+ if len (idx ) != length :
997
+ return False
998
+ if np .any (idx != np .arange (length )):
999
+ return False
1000
+ elif idx .owner is not None and isinstance (idx .owner .op , ARange ):
1001
+ try :
1002
+ start , stop , step = map (
1003
+ lambda x : get_scalar_constant_value (x , only_process_constants = True ),
1004
+ idx .owner .inputs ,
1005
+ )
1006
+ except NotScalarConstantError :
1007
+ return False
1008
+
1009
+ if start != 0 :
1010
+ return False
1011
+ if stop != length :
1012
+ return False
1013
+ if step != 1 :
1014
+ return False
1006
1015
else :
1007
1016
return False
1008
1017
1009
- # We don't need to copy over any stacktrace here,
1010
- # because previous stacktrace should suffice.
1011
1018
return [node .inputs [0 ]]
1012
1019
1013
1020
0 commit comments