Skip to content

Cleanup fusion rewrite database #369

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jul 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions pytensor/compile/mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,11 +248,6 @@ def apply(self, fgraph):
# misc special cases for speed that break canonicalization
optdb.register("uncanonicalize", EquilibriumDB(), "fast_run", position=3)

# misc special cases for speed that are dependent on the device.
optdb.register(
"specialize_device", EquilibriumDB(), "fast_compile", "fast_run", position=48.6
) # must be after gpu stuff at 48.5

# especially constant merge
optdb.register("merge2", MergeOptimizer(), "fast_run", "merge", position=49)

Expand Down
10 changes: 0 additions & 10 deletions pytensor/configdefaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -640,16 +640,6 @@ def add_tensor_configvars():
in_c_key=False,
)

config.add(
"tensor__local_elemwise_fusion",
(
"Enable or not in fast_run mode(fast_run optimization) the elemwise "
"fusion optimization"
),
BoolParam(True),
in_c_key=False,
)

# http://developer.amd.com/CPU/LIBRARIES/LIBM/Pages/default.aspx
config.add(
"lib__amblibm",
Expand Down
19 changes: 0 additions & 19 deletions pytensor/tensor/rewriting/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,25 +205,6 @@ def register(inner_rewriter: Union[RewriteDatabase, Rewriter]):
return node_rewriter


def register_specialize_device(
node_rewriter: Union[RewriteDatabase, Rewriter, str], *tags: str, **kwargs
):
if isinstance(node_rewriter, str):

def register(inner_rewriter: Union[RewriteDatabase, Rewriter]):
return register_specialize_device(
inner_rewriter, node_rewriter, *tags, **kwargs
)

return register
else:
name = (kwargs and kwargs.pop("name", None)) or node_rewriter.__name__
compile.optdb["specialize_device"].register(
name, node_rewriter, "fast_run", *tags, **kwargs
)
return node_rewriter


@register_canonicalize
@register_specialize
@node_rewriter([TensorFromScalar])
Expand Down
113 changes: 67 additions & 46 deletions pytensor/tensor/rewriting/elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -1085,38 +1085,10 @@ def print_profile(stream, prof, level=0):
print(blanc, " time_toposort", prof[7], file=stream)


if config.tensor__local_elemwise_fusion:
# Must be after gpu(48.5) and before AddDestroyHandler(49.5)
fuse_seqopt = SequenceDB()
fuse_seqopt.register(
"local_add_mul_fusion",
EquilibriumGraphRewriter(rewriters=[local_add_mul_fusion], max_use_ratio=1000),
"fast_run",
"fusion",
position=0,
)
fuse_seqopt.register(
"composite_elemwise_fusion",
FusionOptimizer(),
"fast_run",
"fusion",
position=1,
)
compile.optdb.register(
"elemwise_fusion",
fuse_seqopt,
"fast_run",
"fusion",
"local_elemwise_fusion",
"FusionOptimizer",
position=49,
)


@register_canonicalize
@register_specialize
@node_rewriter([Elemwise])
def local_useless_composite(fgraph, node):
def local_useless_composite_outputs(fgraph, node):
"""Remove inputs and outputs of Composite Ops that are not used anywhere."""
if not isinstance(node.op, Elemwise) or not isinstance(
node.op.scalar_op, aes.Composite
Expand Down Expand Up @@ -1150,11 +1122,20 @@ def local_careduce_fusion(fgraph, node):
"""Fuse a `CAReduce` applied to an `Elemwise`."""

(car_input,) = node.inputs
car_scalar_op = node.op.scalar_op

# FIXME: This check is needed because of the faulty logic in the FIXME below!
# Right now, rewrite only works for `Sum`/`Prod`
if not isinstance(car_scalar_op, (aes.Add, aes.Mul)):
return None

elm_node = car_input.owner

if elm_node is None or not isinstance(elm_node.op, Elemwise):
return False

elm_scalar_op = elm_node.op.scalar_op

elm_inputs = elm_node.inputs
elm_outputs = elm_node.outputs

Expand All @@ -1166,21 +1147,15 @@ def local_careduce_fusion(fgraph, node):
return False

# Don't form the fusion when the target language is Python
elm_scalar_op = elm_node.op.scalar_op
car_scalar_op = node.op.scalar_op

if get_target_language() == ("py",):
return False

try:
elm_scalar_op.c_code(
elm_node,
"test_presence_of_c_code",
["x" for x in elm_inputs],
["z" for z in elm_outputs],
{"fail": "%(fail)s"},
)
if not elm_scalar_op.supports_c_code(elm_inputs, elm_outputs):
return None

# FIXME: This fails with Ops like `Max` whose `c_code` always expects two inputs!
# Should implement a `CAReduce.supports_c_code`?
try:
car_scalar_op.c_code(
node,
"test_presence_of_c_code",
Expand All @@ -1191,18 +1166,24 @@ def local_careduce_fusion(fgraph, node):
except (NotImplementedError, MethodNotDefined):
return False

car_axis = node.op.axis
car_op = node.op
car_acc_dtype = node.op.acc_dtype

scalar_elm_inputs = [
aes.get_scalar_type(inp.type.dtype).make_variable() for inp in elm_inputs
]

elm_output = elm_scalar_op(*scalar_elm_inputs)

# This input represents the previous value in the `CAReduce` binary reduction
carried_car_input = elm_output.type()
scalar_fused_outputs = [car_scalar_op(carried_car_input, elm_output)]
carried_car_input = aes.get_scalar_type(car_acc_dtype).make_variable()

scalar_fused_output = car_scalar_op(carried_car_input, elm_output)
if scalar_fused_output.type.dtype != car_acc_dtype:
scalar_fused_output = aes.cast(scalar_fused_output, car_acc_dtype)

fused_scalar_op = aes.Composite(
inputs=[carried_car_input] + scalar_elm_inputs, outputs=scalar_fused_outputs
inputs=[carried_car_input] + scalar_elm_inputs, outputs=[scalar_fused_output]
)

# The fused `Op` needs to look and behave like a `BinaryScalarOp`
Expand All @@ -1211,16 +1192,56 @@ def local_careduce_fusion(fgraph, node):
fused_scalar_op.nin = 2
fused_scalar_op.nout = 1

new_car_op = CAReduce(fused_scalar_op, car_axis)
new_car_op = CAReduce(
scalar_op=fused_scalar_op,
axis=car_op.axis,
acc_dtype=car_acc_dtype,
dtype=car_op.dtype,
upcast_discrete_output=car_op.upcast_discrete_output,
)

return [new_car_op(*elm_inputs)]


# Register fusion database just before AddDestroyHandler(49.5) (inplace rewrites)
fuse_seqopt = SequenceDB()
compile.optdb.register(
"elemwise_fusion",
fuse_seqopt,
"fast_run",
"fusion",
"local_elemwise_fusion",
"FusionOptimizer",
position=49,
)

fuse_seqopt.register(
"local_add_mul_fusion",
EquilibriumGraphRewriter(rewriters=[local_add_mul_fusion], max_use_ratio=1000),
"fast_run",
"fusion",
position=0,
)
fuse_seqopt.register(
"composite_elemwise_fusion",
FusionOptimizer(),
"fast_run",
"fusion",
position=1,
)
fuse_seqopt.register(
"local_useless_composite_outputs",
in2out(local_useless_composite_outputs),
"fast_run",
"fusion",
position=2,
)
fuse_seqopt.register(
"local_careduce_fusion",
in2out(local_careduce_fusion),
"fast_run",
"fusion",
position=49,
position=10,
)


Expand Down
11 changes: 6 additions & 5 deletions pytensor/tensor/rewriting/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,6 @@
local_fill_sink,
register_canonicalize,
register_specialize,
register_specialize_device,
register_stabilize,
register_uncanonicalize,
register_useless,
Expand Down Expand Up @@ -2078,12 +2077,14 @@ def local_pow_specialize(fgraph, node):
return False


@register_specialize_device
@register_specialize
@node_rewriter([at_pow])
def local_pow_specialize_device(fgraph, node):
"""
This rewrite is not the same on all device. We do it only on cpu here.
def local_pow_to_nested_squaring(fgraph, node):
"""Convert a large power exponent to multiple squaring operations.

Note: This sounds like the kind of thing any half-decent compiler can do by itself?
"""

if node.op == at_pow:
# the idea here is that we have pow(x, y)
odtype = node.outputs[0].dtype
Expand Down
93 changes: 55 additions & 38 deletions tests/tensor/rewriting/test_elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -1177,8 +1177,24 @@ def test_test_values(self, test_value):
)

@pytest.mark.parametrize("linker", ["cvm", "py"])
@pytest.mark.parametrize("inp_dtype", ("floatX", "int32"))
@pytest.mark.parametrize("axis", [None, 0, 1, (0, 1), (0, 1, 2)])
def test_CAReduce_single_input(self, linker, axis):
@pytest.mark.parametrize(
"careduce_op, numpy_op",
[
(at_sum, np.sum),
pytest.param(
at_all,
np.all,
marks=pytest.mark.xfail(
reason="Rewrite logic does not support all CAReduce"
),
),
Comment on lines +1186 to +1192
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TIL !

],
)
def test_CAReduce_single_input(
self, linker, inp_dtype, axis, careduce_op, numpy_op
):
"""Make sure that `CAReduce` and `Elemwise` fusions work with a single input."""

mode = Mode(linker=linker)
Expand All @@ -1188,8 +1204,8 @@ def test_CAReduce_single_input(self, linker, axis):
"inplace",
)

x = tensor(dtype="floatX", shape=(None, None, None), name="x")
out = exp(x).sum(axis=axis)
x = tensor(dtype=inp_dtype, shape=(None, None, None), name="x")
out = careduce_op(exp(x), axis=axis)

out_fn = function([x], out, mode=mode)

Expand All @@ -1198,9 +1214,9 @@ def test_CAReduce_single_input(self, linker, axis):
assert isinstance(getattr(out_node.op, "scalar_op"), aes.basic.Composite)

rng = np.random.default_rng(2320)
x_val = rng.random((4, 3, 2), dtype=config.floatX)
x_val = rng.random((4, 3, 2)).astype(x.type.dtype)

exp_res = np.exp(x_val).sum(axis=axis)
exp_res = numpy_op(np.exp(x_val), axis=axis)

out_val = out_fn(x_val)
assert out_val.shape == exp_res.shape
Expand All @@ -1216,7 +1232,7 @@ def test_CAReduce_single_input(self, linker, axis):
# `Elemwise`s with more than one client shouldn't be rewritten
x = tensor(dtype="floatX", shape=(None, None, None), name="x")
exp_x = exp(x)
out = exp_x.sum(axis=axis) + exp(x)
out = careduce_op(exp_x, axis=axis) + exp(x)

out_fn = function([x], out, mode=mode)
out_nodes = out_fn.maker.fgraph.toposort()
Expand Down Expand Up @@ -1409,39 +1425,40 @@ def test_nested_composite(self):
fval = f([1, 2, 3])
assert np.all(fval == [6, 12, 18])

def test_local_useless_composite(self):
x = aes.float32()
y = aes.float32()
z = aes.float32()
c = aes.Composite([x, y, z], [x + 1, y - 1])
X = matrix("X")
Y = matrix("Y")
Z = matrix("Z")
o1, o2 = Elemwise(scalar_op=c)(X, Y, Z)
mode = get_default_mode().including("local_useless_composite")

f = function([X, Y, Z], [o1, o2], mode=mode)
topo = f.maker.fgraph.toposort()
assert len(topo) == 1
assert len(topo[0].inputs) == 2
assert len(topo[0].outputs) == 2
res1, res2 = f([[1.0]], [[1.0]], [[np.nan]])
utt.assert_allclose(res1, [[2.0]])
utt.assert_allclose(res2, [[0.0]])

f = function([X, Y, Z], o1, mode=mode)
topo = f.maker.fgraph.toposort()
assert len(topo) == 1
assert len(topo[0].inputs) == 1
assert len(topo[0].outputs) == 1
utt.assert_allclose(f([[1.0]], [[np.nan]], [[np.nan]]), [[2.0]])

f = function([X, Y, Z], o2, mode=mode)
topo = f.maker.fgraph.toposort()
assert len(topo) == 1
assert len(topo[0].inputs) == 1
assert len(topo[0].outputs) == 1
utt.assert_allclose(f([[np.nan]], [[1.0]], [[np.nan]]), [[0.0]])
def test_local_useless_composite_outputs():
x = aes.float32()
y = aes.float32()
z = aes.float32()
c = aes.Composite([x, y, z], [x + 1, y - 1])
X = matrix("X")
Y = matrix("Y")
Z = matrix("Z")
o1, o2 = Elemwise(scalar_op=c)(X, Y, Z)
mode = get_default_mode().including("local_useless_composite")

f = function([X, Y, Z], [o1, o2], mode=mode)
topo = f.maker.fgraph.toposort()
assert len(topo) == 1
assert len(topo[0].inputs) == 2
assert len(topo[0].outputs) == 2
res1, res2 = f([[1.0]], [[1.0]], [[np.nan]])
utt.assert_allclose(res1, [[2.0]])
utt.assert_allclose(res2, [[0.0]])

f = function([X, Y, Z], o1, mode=mode)
topo = f.maker.fgraph.toposort()
assert len(topo) == 1
assert len(topo[0].inputs) == 1
assert len(topo[0].outputs) == 1
utt.assert_allclose(f([[1.0]], [[np.nan]], [[np.nan]]), [[2.0]])

f = function([X, Y, Z], o2, mode=mode)
topo = f.maker.fgraph.toposort()
assert len(topo) == 1
assert len(topo[0].inputs) == 1
assert len(topo[0].outputs) == 1
utt.assert_allclose(f([[np.nan]], [[1.0]], [[np.nan]]), [[0.0]])


def test_local_useless_dimshuffle_makevector():
Expand Down
Loading