68
68
)
69
69
from aesara .tensor .elemwise import DimShuffle , Elemwise
70
70
from aesara .tensor .exceptions import NotScalarConstantError , ShapeError
71
- from aesara .tensor .extra_ops import BroadcastTo , Repeat , Unique , broadcast_shape
71
+ from aesara .tensor .extra_ops import (
72
+ BroadcastTo ,
73
+ Repeat ,
74
+ Unique ,
75
+ broadcast_shape ,
76
+ broadcast_to ,
77
+ )
72
78
from aesara .tensor .math import all as at_all
73
79
from aesara .tensor .math import eq
74
80
from aesara .tensor .shape import (
@@ -1491,26 +1497,11 @@ def local_elemwise_alloc(fgraph, node):
1491
1497
introduces them as a canonicalization of `Alloc`'s with leading
1492
1498
broadcastable dimensions.
1493
1499
"""
1494
- if not isinstance (node .op , Elemwise ):
1495
- return False
1496
-
1497
1500
# Rewrite is only applicable when there are at least two inputs
1498
1501
if len (node .inputs ) == 1 :
1499
- return None
1502
+ return False
1500
1503
1501
1504
if len (node .outputs ) > 1 :
1502
- # Ensure all outputs have the same broadcast pattern
1503
- # This is a supposition that I'm not sure is always true.
1504
- assert all (
1505
- o .type .broadcastable == node .outputs [0 ].type .broadcastable
1506
- for o in node .outputs [1 :]
1507
- )
1508
-
1509
- # The broadcast pattern of the output must match the broadcast
1510
- # pattern of at least one of the inputs.
1511
- if not any (
1512
- i .type .broadcastable == node .outputs [0 ].type .broadcastable for i in node .inputs
1513
- ):
1514
1505
return False
1515
1506
1516
1507
def dimshuffled_alloc (i ):
@@ -1523,103 +1514,74 @@ def dimshuffled_alloc(i):
1523
1514
# At least one input must have an owner that is either a `Alloc` or a
1524
1515
# `DimShuffle` with an owner that is a `Alloc` -- otherwise there is
1525
1516
# nothing to optimize.
1526
- if not any (
1527
- i .owner and (isinstance (i .owner .op , Alloc ) or dimshuffled_alloc (i ))
1528
- for i in node .inputs
1529
- ):
1517
+ alloc_idxs = [
1518
+ idx
1519
+ for idx , i in enumerate (node .inputs )
1520
+ if i .owner and (isinstance (i .owner .op , Alloc ) or dimshuffled_alloc (i ))
1521
+ ]
1522
+ if len (alloc_idxs ) == 0 :
1530
1523
return False
1531
1524
1532
1525
# Search for a non `Alloc` or `DimShuffle` of `Alloc` input that we can use as a
1533
1526
# baseline for the dimensions.
1534
- assert_op_idx = None
1527
+ ref_var_idx = None
1535
1528
for idx , i in enumerate (node .inputs ):
1536
1529
if i .type .broadcastable == node .outputs [0 ].type .broadcastable :
1537
1530
# Prefer an input that is not a `Alloc` nor a `DimShuffle` of a
1538
1531
# `Alloc` so that all `Alloc`s can be optimized.
1539
- if not (
1540
- i .owner and (isinstance (i .owner .op , Alloc ) or dimshuffled_alloc (i ))
1541
- ):
1542
- assert_op_idx = idx
1532
+ if idx not in alloc_idxs :
1533
+ ref_var_idx = idx
1543
1534
break
1544
1535
1545
1536
# If only `Alloc` and `DimShuffle` of `Alloc` exist, we pick the first suitable one
1546
- if assert_op_idx is None :
1537
+ if ref_var_idx is None :
1547
1538
for idx , i in enumerate (node .inputs ):
1548
- if (i .type .broadcastable == node .outputs [0 ].type .broadcastable ) and (
1549
- i .owner and (isinstance (i .owner .op , Alloc ) or dimshuffled_alloc (i ))
1550
- ):
1551
- assert_op_idx = idx
1539
+ # XXX: This broadcastable comparison doesn't work
1540
+ if (
1541
+ i .type .broadcastable == node .outputs [0 ].type .broadcastable
1542
+ ) and idx in alloc_idxs :
1543
+ ref_var_idx = idx
1552
1544
break
1553
1545
1554
- assert_op_in = node .inputs [assert_op_idx ]
1555
- cmp_op = assert_op_in
1556
- new_i = []
1557
- same_shape = fgraph .shape_feature .same_shape
1558
- for i in node .inputs :
1546
+ if not hasattr (fgraph , "shape_feature" ):
1547
+ return False
1548
+
1549
+ input_shapes = [
1550
+ tuple (fgraph .shape_feature .get_shape (i , j ) for j in range (i .type .ndim ))
1551
+ for i in node .inputs
1552
+ ]
1553
+ bcasted_shape = broadcast_shape (
1554
+ * input_shapes ,
1555
+ arrays_are_shapes = True ,
1556
+ )
1557
+
1558
+ new_inputs = list (node .inputs )
1559
+ for idx in alloc_idxs :
1560
+ i = node .inputs [idx ]
1561
+
1559
1562
# Remove `Alloc`
1560
- if i .owner and isinstance (i .owner .op , Alloc ):
1561
- assert i .type .ndim == cmp_op .ndim
1562
- if config .experimental__local_alloc_elemwise_assert :
1563
- get_shape = fgraph .shape_feature .get_shape
1564
- cond = []
1565
- for idx in range (i .type .ndim ):
1566
- if not i .type .broadcastable [idx ] and not same_shape (
1567
- i , cmp_op , idx , idx
1568
- ):
1569
- i_shp = get_shape (i , idx )
1570
- cmp_shp = get_shape (cmp_op , idx )
1571
- cond .append (eq (i_shp , cmp_shp ))
1572
- if cond :
1573
- assert_op_in = assert_op (assert_op_in , * cond )
1574
- alloc_input = i .owner .inputs [0 ]
1575
- if alloc_input .ndim != i .ndim :
1576
- # The `Alloc` can add dimensions to the value.
1577
- # We replace those cases with a `DimShuffle` here.
1578
- nb_dim_to_add = i .ndim - alloc_input .ndim
1579
- alloc_input = alloc_input .dimshuffle (
1580
- ["x" ] * nb_dim_to_add + list (range (alloc_input .ndim ))
1581
- )
1582
- copy_stack_trace (i , alloc_input )
1583
- new_i .append (alloc_input )
1563
+ if isinstance (i .owner .op , Alloc ):
1564
+ new_alloc = broadcast_to (i .owner .inputs [0 ], bcasted_shape )
1584
1565
1566
+ # TODO FIXME: This shouldn't be handled here.
1567
+ # `DimShuffle`s should be lifted through `Alloc`s
1568
+ # by other, more general rewrites.
1585
1569
# Remove `Alloc` in `DimShuffle`
1586
- elif i .owner and dimshuffled_alloc (i ):
1587
- assert i .type .ndim == cmp_op .type .ndim
1588
- if config .experimental__local_alloc_elemwise_assert :
1589
- assert_cond = [
1590
- eq (i .shape [idx ], cmp_op .shape [idx ])
1591
- for idx in range (i .type .ndim )
1592
- if not i .type .broadcastable [idx ]
1593
- and not same_shape (i , cmp_op , idx , idx )
1594
- ]
1595
- if assert_cond :
1596
- assert_op_in = assert_op (assert_op_in , * assert_cond )
1597
- alloc_input = i .owner .inputs [0 ].owner .inputs [0 ]
1598
- if alloc_input .ndim != i .owner .inputs [0 ].ndim :
1599
- # The `Alloc` can add dimensions to the value.
1600
- # We replace those cases with a `DimShuffle` here.
1601
- # We let later optimizations merge the nested `DimShuffle`s
1602
- nb_dim_to_add = i .owner .inputs [0 ].ndim - alloc_input .ndim
1603
- alloc_input = alloc_input .dimshuffle (
1604
- ["x" ] * nb_dim_to_add + list (range (alloc_input .ndim ))
1605
- )
1606
-
1570
+ elif isinstance (i .owner .op , DimShuffle ):
1571
+ new_alloc = i .owner .inputs [0 ].owner .inputs [0 ]
1607
1572
# We need to keep the old `DimShuffle`. It could swap axes or
1608
1573
# add dimensions anywhere.
1609
- r_i = i .owner .op (alloc_input )
1610
- copy_stack_trace (i , r_i )
1611
- new_i .append (r_i )
1574
+ new_alloc = broadcast_to (i .owner .op (new_alloc ), bcasted_shape )
1612
1575
1613
- else :
1614
- new_i .append (i )
1615
- new_i [assert_op_idx ] = assert_op_in
1576
+ copy_stack_trace (i , new_alloc )
1577
+ new_inputs [idx ] = new_alloc
1616
1578
1617
1579
# If this assert is triggered, it means we are recreating an equivalent graph
1618
1580
# which would result in a cyclical merge optimization.
1619
- if all (new is old for new , old in zip (new_i , node .inputs )):
1581
+ if all (new is old for new , old in zip (new_inputs , node .inputs )):
1620
1582
return
1621
1583
1622
- ret = node .op (* new_i , return_list = True )
1584
+ ret = node .op (* new_inputs , return_list = True )
1623
1585
copy_stack_trace (node .outputs , ret )
1624
1586
return ret
1625
1587
0 commit comments