@@ -1516,11 +1516,43 @@ def complete(self):
15161516 attrs = {'block' : inside_block })
15171517
15181518
1519+ class IfElseBlockGuard (object ):
1520+ def __init__ (self , is_true , ifelse ):
1521+ if not isinstance (ifelse , IfElse ):
1522+ raise TypeError ("ifelse must be an instance of IfElse class" )
1523+
1524+ if ifelse .status != IfElse .OUT_IF_ELSE_BLOCKS :
1525+ raise ValueError ("You cannot invoke IfElse.block() inside a block" )
1526+
1527+ self .is_true = is_true
1528+ self .ie = ifelse
1529+ if is_true :
1530+ self .cond_block = ifelse .conditional_true_block
1531+ else :
1532+ self .cond_block = ifelse .conditional_false_block
1533+
1534+ if not isinstance (self .cond_block , ConditionalBlock ):
1535+ raise TypeError ("Unexpected situation" )
1536+
1537+ self .cond_block = self .cond_block .block ()
1538+
1539+ def __enter__ (self ):
1540+ self .ie .status = IfElse .IN_IF_ELSE_TRUE_BLOCKS if self .is_true else IfElse .IN_IF_ELSE_FALSE_BLOCKS
1541+ self .cond_block .__enter__ ()
1542+
1543+ def __exit__ (self , exc_type , exc_val , exc_tb ):
1544+ if not self .cond_block .__exit__ (exc_type , exc_val , exc_tb ):
1545+ # re-raise inside exception
1546+ return False
1547+ if len (self .ie .outupt_table [1 if self .is_true else 0 ]) == 0 :
1548+ raise ValueError ("Must set output inside block" )
1549+ self .ie .status = IfElse .OUT_IF_ELSE_BLOCKS
1550+
1551+
15191552class IfElse (object ):
1520- BEFORE_IF_ELSE_BLOCKS = 0
1553+ OUT_IF_ELSE_BLOCKS = 0
15211554 IN_IF_ELSE_TRUE_BLOCKS = 1
1522- AFTER_IF_ELSE_BLOCKS = 2
1523- IN_IF_ELSE_FALSE_BLOCKS = 3
1555+ IN_IF_ELSE_FALSE_BLOCKS = 2
15241556
15251557 def __init__ (self , cond , name = None , main_program = None ,
15261558 startup_program = None ):
@@ -1533,12 +1565,14 @@ def __init__(self, cond, name=None, main_program=None,
15331565 startup_program = startup_program )
15341566 self .cond = cond
15351567 self .input_table = {}
1536- self .status = IfElse .BEFORE_IF_ELSE_BLOCKS
1537-
1538- def input (self , x , level = 0 ):
1539- if self .status not in (IfElse .IN_IF_ELSE_TRUE_BLOCKS ,
1540- IfElse .IN_IF_ELSE_FALSE_BLOCKS ):
1541- raise Exception ("input must in true/false blocks" )
1568+ self .status = IfElse .OUT_IF_ELSE_BLOCKS
1569+ self .conditional_true_block = ConditionalBlock (inputs = [self .cond ])
1570+ self .conditional_false_block = ConditionalBlock (inputs = [self .cond ])
1571+ self .output_table = ([], []) # (true_outs, false_outs)
1572+
1573+ def input (self , x ):
1574+ if self .status == IfElse .OUT_IF_ELSE_BLOCKS :
1575+ raise ValueError ("input must in true/false blocks" )
15421576 if id (x ) not in self .input_table :
15431577 parent_block = self .parent_block ()
15441578 out_true = parent_block .create_var (
@@ -1556,7 +1590,7 @@ def input(self, x, level=0):
15561590 },
15571591 outputs = {'OutTrue' : out_true ,
15581592 'OutFalse' : out_false },
1559- attrs = {'level' : level })
1593+ attrs = {'level' : 0 })
15601594 self .input_table [id (x )] = (out_true , out_false )
15611595 else :
15621596 out_true , out_false = self .input_table [id (x )]
@@ -1569,3 +1603,57 @@ def input(self, x, level=0):
15691603 def parent_block (self ):
15701604 current_block = self .helper .main_program .current_block ()
15711605 return self .helper .main_program .block (current_block .parent_idx )
1606+
1607+ def true_block (self ):
1608+ return IfElseBlockGuard (True , self )
1609+
1610+ def false_block (self ):
1611+ return IfElseBlockGuard (False , self )
1612+
1613+ def output (self , * outs ):
1614+ if self .status == self .OUT_IF_ELSE_BLOCKS :
1615+ raise ValueError ("output can only be invoked in the sub-block" )
1616+
1617+ out_table = self .output_table [1 if self .status ==
1618+ self .IN_IF_ELSE_TRUE_BLOCKS else 0 ]
1619+ parent_block = self .parent_block ()
1620+ for each_out in outs :
1621+ if not isinstance (each_out , Variable ):
1622+ raise TypeError ("Each output should be a variable" )
1623+ # create outside tensor
1624+ outside_out = parent_block .create_var (
1625+ name = unique_name ("_" .join ([self .helper .name , 'output' ])),
1626+ dtype = each_out .data_type )
1627+ out_table .append (outside_out )
1628+
1629+ # assign local var to outside
1630+ assign (
1631+ input = each_out ,
1632+ output = outside_out ,
1633+ main_program = self .helper .main_program ,
1634+ startup_program = self .helper .startup_program )
1635+
1636+ def __call__ (self ):
1637+ if self .status != self .OUT_IF_ELSE_BLOCKS :
1638+ raise ValueError ("IfElse::__call__ must be out of sub-block" )
1639+ false_len , true_len = map (len , self .output_table )
1640+ if false_len == 0 and true_len == 0 :
1641+ raise ValueError ("Must invoke true_block/false_block before "
1642+ "__call__" )
1643+ elif false_len != true_len and false_len != 0 and true_len != 0 :
1644+ raise ValueError ("The output side must be same" )
1645+ elif false_len == 0 or true_len == 0 :
1646+ return self .output_table [0 if false_len != 0 else 1 ]
1647+
1648+ # else none of false_len/true_len is zero
1649+ # merge together
1650+ rlist = []
1651+ for false_var , true_var in zip (* self .output_table ):
1652+ rlist .append (
1653+ merge_lod_tensor (
1654+ in_true = true_var ,
1655+ in_false = false_var ,
1656+ mask = self .cond ,
1657+ x = self .cond ,
1658+ level = 0 ))
1659+ return rlist
0 commit comments