Skip to content

Commit 146b896

Browse files
committed
Complete IfElse Op
1 parent 4f7f3c9 commit 146b896

File tree

1 file changed

+98
-10
lines changed

1 file changed

+98
-10
lines changed

python/paddle/v2/framework/layers.py

Lines changed: 98 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1516,11 +1516,43 @@ def complete(self):
15161516
attrs={'block': inside_block})
15171517

15181518

1519+
class IfElseBlockGuard(object):
1520+
def __init__(self, is_true, ifelse):
1521+
if not isinstance(ifelse, IfElse):
1522+
raise TypeError("ifelse must be an instance of IfElse class")
1523+
1524+
if ifelse.status != IfElse.OUT_IF_ELSE_BLOCKS:
1525+
raise ValueError("You cannot invoke IfElse.block() inside a block")
1526+
1527+
self.is_true = is_true
1528+
self.ie = ifelse
1529+
if is_true:
1530+
self.cond_block = ifelse.conditional_true_block
1531+
else:
1532+
self.cond_block = ifelse.conditional_false_block
1533+
1534+
if not isinstance(self.cond_block, ConditionalBlock):
1535+
raise TypeError("Unexpected situation")
1536+
1537+
self.cond_block = self.cond_block.block()
1538+
1539+
def __enter__(self):
1540+
self.ie.status = IfElse.IN_IF_ELSE_TRUE_BLOCKS if self.is_true else IfElse.IN_IF_ELSE_FALSE_BLOCKS
1541+
self.cond_block.__enter__()
1542+
1543+
def __exit__(self, exc_type, exc_val, exc_tb):
1544+
if not self.cond_block.__exit__(exc_type, exc_val, exc_tb):
1545+
# re-raise inside exception
1546+
return False
1547+
if len(self.ie.outupt_table[1 if self.is_true else 0]) == 0:
1548+
raise ValueError("Must set output inside block")
1549+
self.ie.status = IfElse.OUT_IF_ELSE_BLOCKS
1550+
1551+
15191552
class IfElse(object):
1520-
BEFORE_IF_ELSE_BLOCKS = 0
1553+
OUT_IF_ELSE_BLOCKS = 0
15211554
IN_IF_ELSE_TRUE_BLOCKS = 1
1522-
AFTER_IF_ELSE_BLOCKS = 2
1523-
IN_IF_ELSE_FALSE_BLOCKS = 3
1555+
IN_IF_ELSE_FALSE_BLOCKS = 2
15241556

15251557
def __init__(self, cond, name=None, main_program=None,
15261558
startup_program=None):
@@ -1533,12 +1565,14 @@ def __init__(self, cond, name=None, main_program=None,
15331565
startup_program=startup_program)
15341566
self.cond = cond
15351567
self.input_table = {}
1536-
self.status = IfElse.BEFORE_IF_ELSE_BLOCKS
1537-
1538-
def input(self, x, level=0):
1539-
if self.status not in (IfElse.IN_IF_ELSE_TRUE_BLOCKS,
1540-
IfElse.IN_IF_ELSE_FALSE_BLOCKS):
1541-
raise Exception("input must in true/false blocks")
1568+
self.status = IfElse.OUT_IF_ELSE_BLOCKS
1569+
self.conditional_true_block = ConditionalBlock(inputs=[self.cond])
1570+
self.conditional_false_block = ConditionalBlock(inputs=[self.cond])
1571+
self.output_table = ([], []) # (true_outs, false_outs)
1572+
1573+
def input(self, x):
1574+
if self.status == IfElse.OUT_IF_ELSE_BLOCKS:
1575+
raise ValueError("input must in true/false blocks")
15421576
if id(x) not in self.input_table:
15431577
parent_block = self.parent_block()
15441578
out_true = parent_block.create_var(
@@ -1556,7 +1590,7 @@ def input(self, x, level=0):
15561590
},
15571591
outputs={'OutTrue': out_true,
15581592
'OutFalse': out_false},
1559-
attrs={'level': level})
1593+
attrs={'level': 0})
15601594
self.input_table[id(x)] = (out_true, out_false)
15611595
else:
15621596
out_true, out_false = self.input_table[id(x)]
@@ -1569,3 +1603,57 @@ def input(self, x, level=0):
15691603
def parent_block(self):
15701604
current_block = self.helper.main_program.current_block()
15711605
return self.helper.main_program.block(current_block.parent_idx)
1606+
1607+
def true_block(self):
1608+
return IfElseBlockGuard(True, self)
1609+
1610+
def false_block(self):
1611+
return IfElseBlockGuard(False, self)
1612+
1613+
def output(self, *outs):
1614+
if self.status == self.OUT_IF_ELSE_BLOCKS:
1615+
raise ValueError("output can only be invoked in the sub-block")
1616+
1617+
out_table = self.output_table[1 if self.status ==
1618+
self.IN_IF_ELSE_TRUE_BLOCKS else 0]
1619+
parent_block = self.parent_block()
1620+
for each_out in outs:
1621+
if not isinstance(each_out, Variable):
1622+
raise TypeError("Each output should be a variable")
1623+
# create outside tensor
1624+
outside_out = parent_block.create_var(
1625+
name=unique_name("_".join([self.helper.name, 'output'])),
1626+
dtype=each_out.data_type)
1627+
out_table.append(outside_out)
1628+
1629+
# assign local var to outside
1630+
assign(
1631+
input=each_out,
1632+
output=outside_out,
1633+
main_program=self.helper.main_program,
1634+
startup_program=self.helper.startup_program)
1635+
1636+
def __call__(self):
1637+
if self.status != self.OUT_IF_ELSE_BLOCKS:
1638+
raise ValueError("IfElse::__call__ must be out of sub-block")
1639+
false_len, true_len = map(len, self.output_table)
1640+
if false_len == 0 and true_len == 0:
1641+
raise ValueError("Must invoke true_block/false_block before "
1642+
"__call__")
1643+
elif false_len != true_len and false_len != 0 and true_len != 0:
1644+
raise ValueError("The output side must be same")
1645+
elif false_len == 0 or true_len == 0:
1646+
return self.output_table[0 if false_len != 0 else 1]
1647+
1648+
# else none of false_len/true_len is zero
1649+
# merge together
1650+
rlist = []
1651+
for false_var, true_var in zip(*self.output_table):
1652+
rlist.append(
1653+
merge_lod_tensor(
1654+
in_true=true_var,
1655+
in_false=false_var,
1656+
mask=self.cond,
1657+
x=self.cond,
1658+
level=0))
1659+
return rlist

0 commit comments

Comments
 (0)