Skip to content

Commit 62c3aa6

Browse files
Split local_useless_AdvancedSubtensor1 from local_useless_subtensor
1 parent 9417b49 commit 62c3aa6

File tree

1 file changed

+114
-107
lines changed

1 file changed

+114
-107
lines changed

aesara/tensor/subtensor_opt.py

Lines changed: 114 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -879,135 +879,142 @@ def local_set_to_inc_subtensor(fgraph, node):
879879

880880
@register_canonicalize
881881
@register_specialize
882-
@local_optimizer([Subtensor, AdvancedSubtensor1])
882+
@local_optimizer([Subtensor])
883883
def local_useless_subtensor(fgraph, node):
884-
"""
885-
Remove Subtensor/AdvancedSubtensor1 if it takes the full input. In the
886-
AdvancedSubtensor1 case, the full input is taken when the indices are
887-
equivalent to `arange(0, input.shape[0], 1)` using either an explicit
888-
list/vector or the ARange op.
889-
890-
"""
884+
"""Remove `Subtensor` if it takes the full input."""
891885
# This optimization needs ShapeOpt and fgraph.shape_feature
892886
if not hasattr(fgraph, "shape_feature"):
893887
return
894888

895889
shape_of = fgraph.shape_feature.shape_of
896890

897-
if isinstance(node.op, Subtensor):
898-
cdata = get_constant_idx(
899-
node.op.idx_list,
900-
node.inputs,
901-
allow_partial=True,
902-
only_process_constants=True,
903-
)
904-
for pos, idx in enumerate(cdata):
905-
if not isinstance(idx, slice):
906-
# If idx is not a slice, this means we remove this dimension
907-
# from the output, so the subtensor is not useless
908-
return False
909-
if idx.start is not None and idx.start != 0:
910-
# If the start of the slice is different from 0, or is a
911-
# variable, then we assume the subtensor is not useless
912-
return False
913-
if idx.step is not None and idx.step != 1:
914-
# If we are going backwards, or skipping elements, then this
915-
# is not a useless subtensor
916-
return False
917-
918-
for pos, idx in enumerate(cdata):
919-
920-
length_pos = shape_of[node.inputs[0]][pos]
921-
922-
if isinstance(idx.stop, (int, np.integer)):
923-
length_pos_data = sys.maxsize
924-
try:
925-
length_pos_data = get_scalar_constant_value(
926-
length_pos, only_process_constants=True
927-
)
928-
except NotScalarConstantError:
929-
pass
930-
931-
if idx.stop < length_pos_data:
932-
return False
933-
elif isinstance(idx.stop, Variable):
934-
length_pos_shape_i = idx.stop
935-
# length_pos is a tensor variable, but length_pos_shape_i
936-
# is a scalar variable. We try to see if they represent
937-
# the same underlying variable.
938-
if length_pos_shape_i.owner and isinstance(
939-
length_pos_shape_i.owner.op, ScalarFromTensor
940-
):
941-
length_pos_shape_i = length_pos_shape_i.owner.inputs[0]
942-
elif length_pos.owner and isinstance(
943-
length_pos.owner.op, TensorFromScalar
944-
):
945-
length_pos = length_pos.owner.inputs[0]
946-
else:
947-
# We did not find underlying variables of the same type
948-
return False
949-
950-
# The type can be different: int32 vs int64. length_pos
951-
# should always be int64 as that is what the shape
952-
# tracker keep. Subtensor accept any scalar int{8,16,32,64}
953-
# as index type.
954-
assert str(length_pos.type.dtype) == "int64"
955-
assert str(length_pos_shape_i.type.dtype) in [
956-
"int8",
957-
"int16",
958-
"int32",
959-
"int64",
960-
]
961-
962-
# length_pos_shape_i cannot be None
963-
if length_pos_shape_i != length_pos:
964-
return False
965-
elif idx.stop is None:
966-
pass
967-
else:
968-
return False
969-
elif isinstance(node.op, AdvancedSubtensor1):
970-
# get length of the indexed tensor along the first axis
971-
try:
972-
length = get_scalar_constant_value(
973-
shape_of[node.inputs[0]][0], only_process_constants=True
974-
)
975-
except NotScalarConstantError:
891+
cdata = get_constant_idx(
892+
node.op.idx_list,
893+
node.inputs,
894+
allow_partial=True,
895+
only_process_constants=True,
896+
)
897+
for pos, idx in enumerate(cdata):
898+
if not isinstance(idx, slice):
899+
# If idx is not a slice, this means we remove this dimension
900+
# from the output, so the subtensor is not useless
901+
return False
902+
if idx.start is not None and idx.start != 0:
903+
# If the start of the slice is different from 0, or is a
904+
# variable, then we assume the subtensor is not useless
905+
return False
906+
if idx.step is not None and idx.step != 1:
907+
# If we are going backwards, or skipping elements, then this
908+
# is not a useless subtensor
976909
return False
977910

978-
# get index (which must be a vector by definition)
979-
idx = node.inputs[1]
911+
for pos, idx in enumerate(cdata):
980912

981-
# `idx` must be equivalent to [0,1,...,shape[0] - 1] to qualify for
982-
# this optimization
983-
if isinstance(idx, Constant):
984-
idx = idx.value
985-
if len(idx) != length:
986-
return False
987-
if np.any(idx != np.arange(length)):
988-
return False
989-
elif idx.owner is not None and isinstance(idx.owner.op, ARange):
913+
length_pos = shape_of[node.inputs[0]][pos]
914+
915+
if isinstance(idx.stop, (int, np.integer)):
916+
length_pos_data = sys.maxsize
990917
try:
991-
start, stop, step = map(
992-
lambda x: get_scalar_constant_value(x, only_process_constants=True),
993-
idx.owner.inputs,
918+
length_pos_data = get_scalar_constant_value(
919+
length_pos, only_process_constants=True
994920
)
995921
except NotScalarConstantError:
996-
return False
922+
pass
997923

998-
if start != 0:
924+
if idx.stop < length_pos_data:
999925
return False
1000-
if stop != length:
926+
elif isinstance(idx.stop, Variable):
927+
length_pos_shape_i = idx.stop
928+
# length_pos is a tensor variable, but length_pos_shape_i
929+
# is a scalar variable. We try to see if they represent
930+
# the same underlying variable.
931+
if length_pos_shape_i.owner and isinstance(
932+
length_pos_shape_i.owner.op, ScalarFromTensor
933+
):
934+
length_pos_shape_i = length_pos_shape_i.owner.inputs[0]
935+
elif length_pos.owner and isinstance(length_pos.owner.op, TensorFromScalar):
936+
length_pos = length_pos.owner.inputs[0]
937+
else:
938+
# We did not find underlying variables of the same type
1001939
return False
1002-
if step != 1:
940+
941+
# The type can be different: int32 vs int64. length_pos
942+
# should always be int64 as that is what the shape
943+
# tracker keep. Subtensor accept any scalar int{8,16,32,64}
944+
# as index type.
945+
assert str(length_pos.type.dtype) == "int64"
946+
assert str(length_pos_shape_i.type.dtype) in [
947+
"int8",
948+
"int16",
949+
"int32",
950+
"int64",
951+
]
952+
953+
# length_pos_shape_i cannot be None
954+
if length_pos_shape_i != length_pos:
1003955
return False
956+
elif idx.stop is None:
957+
continue
1004958
else:
1005959
return False
960+
961+
return [node.inputs[0]]
962+
963+
964+
@register_canonicalize
965+
@register_specialize
966+
@local_optimizer([AdvancedSubtensor1])
967+
def local_useless_AdvancedSubtensor1(fgraph, node):
968+
"""Remove `AdvancedSubtensor1` if it takes the full input.
969+
970+
In the `AdvancedSubtensor1` case, the full input is taken when the indices
971+
are equivalent to ``arange(0, input.shape[0], 1)`` using either an explicit
972+
list/vector or the `ARange` `Op`.
973+
974+
"""
975+
# This optimization needs ShapeOpt and fgraph.shape_feature
976+
if not hasattr(fgraph, "shape_feature"):
977+
return
978+
979+
shape_of = fgraph.shape_feature.shape_of
980+
981+
# get length of the indexed tensor along the first axis
982+
try:
983+
length = get_scalar_constant_value(
984+
shape_of[node.inputs[0]][0], only_process_constants=True
985+
)
986+
except NotScalarConstantError:
987+
return False
988+
989+
# get index (which must be a vector by definition)
990+
idx = node.inputs[1]
991+
992+
# `idx` must be equivalent to [0,1,...,shape[0] - 1] to qualify for
993+
# this optimization
994+
if isinstance(idx, Constant):
995+
idx = idx.value
996+
if len(idx) != length:
997+
return False
998+
if np.any(idx != np.arange(length)):
999+
return False
1000+
elif idx.owner is not None and isinstance(idx.owner.op, ARange):
1001+
try:
1002+
start, stop, step = map(
1003+
lambda x: get_scalar_constant_value(x, only_process_constants=True),
1004+
idx.owner.inputs,
1005+
)
1006+
except NotScalarConstantError:
1007+
return False
1008+
1009+
if start != 0:
1010+
return False
1011+
if stop != length:
1012+
return False
1013+
if step != 1:
1014+
return False
10061015
else:
10071016
return False
10081017

1009-
# We don't need to copy over any stacktrace here,
1010-
# because previous stacktrace should suffice.
10111018
return [node.inputs[0]]
10121019

10131020

0 commit comments

Comments
 (0)