1717from framework import Program , default_main_program
1818from . import core
1919
20- __all__ = ['Executor' , 'global_scope' , 'scope_guard' , 'switch_scope' ]
20+ __all__ = [
21+ 'Executor' , 'global_scope' , 'scope_guard' , 'switch_scope' , 'fetch_var'
22+ ]
2123
2224g_scope = core .Scope ()
2325
@@ -80,12 +82,12 @@ def has_feed_operators(block, feed_targets, feed_holder_name):
8082 Args:
8183 block: a block instance (typically global block of a program)
8284 feed_targets: a dictionary of {feed_target_name: feed_target_data}
83- feed_holder_name: the name of the variable that holds the data of
84- all feed targets. The type of this feed_holder variable is
85+ feed_holder_name: the name of the variable that holds the data of
86+ all feed targets. The type of this feed_holder variable is
8587 FEED_MINIBATCH, which is essentially vector<LoDTensor>.
8688
8789 Returns:
88- A boolean value that indicates whether a block has feed operators
90+ A boolean value that indicates whether a block has feed operators
8991 that match the info contained in feed_targets and feed_holder_name.
9092 """
9193
@@ -108,7 +110,7 @@ def has_feed_operators(block, feed_targets, feed_holder_name):
108110
109111def has_fetch_operators (block , fetch_targets , fetch_holder_name ):
110112 """ Check whether the block already has fetch operators.
111-
113+
112114 Return false if the block does not have any fetch operators.
113115 If some fetch operators have been appended to the block, check that
114116 the info contained in these fetch operators matches the fetch_targets
@@ -118,13 +120,13 @@ def has_fetch_operators(block, fetch_targets, fetch_holder_name):
118120 Args:
119121 block: a block instance (typically global block of a program)
120122 fetch_targets: a dictionary of {fetch_target_name: fetch_target_data}
121- fetch_holder_name: the name of the variable that holds the data of
122- all fetch targets. The type of this fetch_holder variable is
123- FETCH_LIST, which is essentially vector<LoDTensor>.
123+ fetch_holder_name: the name of the variable that holds the data of
124+ all fetch targets. The type of this fetch_holder variable is
125+ FETCH_LIST, which is essentially vector<LoDTensor>.
124126
125- Return:
126- A boolean value that indicates whether a block has fetch operators
127- that match the info contained in fetch_targets and fetch_holder_name.
127+ Return:
128+ A boolean value that indicates whether a block has fetch operators
129+ that match the info contained in fetch_targets and fetch_holder_name.
128130 """
129131
130132 fetch_count = 0
@@ -146,6 +148,35 @@ def has_fetch_operators(block, fetch_targets, fetch_holder_name):
146148 return fetch_count > 0
147149
148150
151+ def fetch_var (name , scope = None , return_numpy = True ):
152+ """
153+ Fetch the value of the variable with the given name from the given scope
154+ Args:
155+ name(str): name of the variable. Typically, only persistable variables
156+ can be found in the scope used for running the program.
157+ scope(core.Scope|None): scope object. It should be the scope where
158+ you pass to Executor.run() when running your program.
159+ If None, global_scope() will be used.
160+ return_numpy(bool): whether convert the tensor to numpy.ndarray
161+ Returns:
162+ LodTensor|numpy.ndarray
163+ """
164+ assert isinstance (name , str )
165+ if scope is None :
166+ scope = global_scope ()
167+ assert isinstance (scope , core .Scope )
168+
169+ var = global_scope ().find_var (name )
170+ assert var is not None , (
171+ "Cannot find " + name + " in scope. Perhaps you need to make the"
172+ " variable persistable by using var.persistable = True in your"
173+ " program." )
174+ tensor = var .get_tensor ()
175+ if return_numpy :
176+ tensor = as_numpy (tensor )
177+ return tensor
178+
179+
149180class Executor (object ):
150181 def __init__ (self , places ):
151182 if not isinstance (places , list ) and not isinstance (places , tuple ):
0 commit comments