Skip to content

Commit b15c675

Browse files
authored
Merge pull request #7421 from emailweixu/fetch_var
helper functions fetch_var and get_var
2 parents 1ead6c2 + 37a251e commit b15c675

File tree

4 files changed

+104
-14
lines changed

4 files changed

+104
-14
lines changed

python/paddle/v2/fluid/executor.py

Lines changed: 42 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@
1717
from framework import Program, default_main_program
1818
from . import core
1919

20-
__all__ = ['Executor', 'global_scope', 'scope_guard', 'switch_scope']
20+
__all__ = [
21+
'Executor', 'global_scope', 'scope_guard', 'switch_scope', 'fetch_var'
22+
]
2123

2224
g_scope = core.Scope()
2325

@@ -80,12 +82,12 @@ def has_feed_operators(block, feed_targets, feed_holder_name):
8082
Args:
8183
block: a block instance (typically global block of a program)
8284
feed_targets: a dictionary of {feed_target_name: feed_target_data}
83-
feed_holder_name: the name of the variable that holds the data of
84-
all feed targets. The type of this feed_holder variable is
85+
feed_holder_name: the name of the variable that holds the data of
86+
all feed targets. The type of this feed_holder variable is
8587
FEED_MINIBATCH, which is essentially vector<LoDTensor>.
8688
8789
Returns:
88-
A boolean value that indicates whether a block has feed operators
90+
A boolean value that indicates whether a block has feed operators
8991
that match the info contained in feed_targets and feed_holder_name.
9092
"""
9193

@@ -108,7 +110,7 @@ def has_feed_operators(block, feed_targets, feed_holder_name):
108110

109111
def has_fetch_operators(block, fetch_targets, fetch_holder_name):
110112
""" Check whether the block already has fetch operators.
111-
113+
112114
Return false if the block does not have any fetch operators.
113115
If some fetch operators have been appended to the block, check that
114116
the info contained in these fetch operators matches the fetch_targets
@@ -118,13 +120,13 @@ def has_fetch_operators(block, fetch_targets, fetch_holder_name):
118120
Args:
119121
block: a block instance (typically global block of a program)
120122
fetch_targets: a dictionary of {fetch_target_name: fetch_target_data}
121-
fetch_holder_name: the name of the variable that holds the data of
122-
all fetch targets. The type of this fetch_holder variable is
123-
FETCH_LIST, which is essentially vector<LoDTensor>.
123+
fetch_holder_name: the name of the variable that holds the data of
124+
all fetch targets. The type of this fetch_holder variable is
125+
FETCH_LIST, which is essentially vector<LoDTensor>.
124126
125-
Return:
126-
A boolean value that indicates whether a block has fetch operators
127-
that match the info contained in fetch_targets and fetch_holder_name.
127+
Return:
128+
A boolean value that indicates whether a block has fetch operators
129+
that match the info contained in fetch_targets and fetch_holder_name.
128130
"""
129131

130132
fetch_count = 0
@@ -146,6 +148,35 @@ def has_fetch_operators(block, fetch_targets, fetch_holder_name):
146148
return fetch_count > 0
147149

148150

151+
def fetch_var(name, scope=None, return_numpy=True):
152+
"""
153+
Fetch the value of the variable with the given name from the given scope
154+
Args:
155+
name(str): name of the variable. Typically, only persistable variables
156+
can be found in the scope used for running the program.
157+
scope(core.Scope|None): scope object. It should be the scope where
158+
you pass to Executor.run() when running your program.
159+
If None, global_scope() will be used.
160+
return_numpy(bool): whether convert the tensor to numpy.ndarray
161+
Returns:
162+
LodTensor|numpy.ndarray
163+
"""
164+
assert isinstance(name, str)
165+
if scope is None:
166+
scope = global_scope()
167+
assert isinstance(scope, core.Scope)
168+
169+
var = global_scope().find_var(name)
170+
assert var is not None, (
171+
"Cannot find " + name + " in scope. Perhaps you need to make the"
172+
" variable persistable by using var.persistable = True in your"
173+
" program.")
174+
tensor = var.get_tensor()
175+
if return_numpy:
176+
tensor = as_numpy(tensor)
177+
return tensor
178+
179+
149180
class Executor(object):
150181
def __init__(self, places):
151182
if not isinstance(places, list) and not isinstance(places, tuple):

python/paddle/v2/fluid/framework.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
'program_guard',
3232
'switch_startup_program',
3333
'switch_main_program',
34+
'get_var',
3435
]
3536

3637
EMPTY_VAR_NAME = core.kEmptyVarName()
@@ -1123,3 +1124,22 @@ def program_guard(main_program, startup_program=None):
11231124
switch_main_program(main_program)
11241125
if startup_program is not None:
11251126
switch_startup_program(startup_program)
1127+
1128+
1129+
def get_var(name, program=None):
1130+
"""
1131+
Get a variable by name from the global block of a program
1132+
Args:
1133+
name(str): name of the variable
1134+
program(Program|None): program object.
1135+
If None, default_global_program() will be used.
1136+
1137+
Returns:
1138+
Variable
1139+
"""
1140+
if program is None:
1141+
program = default_main_program()
1142+
assert isinstance(name, str)
1143+
assert isinstance(name, Program)
1144+
1145+
return program.global_block().var(name)

python/paddle/v2/fluid/layers/tensor.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,15 @@
3535
]
3636

3737

38-
def create_tensor(dtype, name=None):
38+
def create_tensor(dtype, name=None, persistable=False):
3939
helper = LayerHelper("create_tensor", **locals())
40-
return helper.create_variable(name=helper.name, dtype=dtype)
40+
return helper.create_variable(
41+
name=helper.name, dtype=dtype, persistable=persistable)
4142

4243

4344
def create_parameter(shape,
4445
dtype,
46+
name=None,
4547
attr=None,
4648
is_bias=False,
4749
default_initializer=None):
@@ -62,7 +64,7 @@ def create_parameter(shape,
6264
"""
6365
helper = LayerHelper("create_parameter", **locals())
6466
if attr is None:
65-
attr = ParamAttr()
67+
attr = ParamAttr(name=name)
6668
return helper.create_parameter(attr, shape, dtype, is_bias,
6769
default_initializer)
6870

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import paddle.v2.fluid as fluid
16+
import paddle.v2.fluid.layers as layers
17+
import op_test
18+
import numpy
19+
import unittest
20+
21+
22+
class TestFetchVar(op_test.OpTest):
23+
def test_fetch_var(self):
24+
val = numpy.array([1, 3, 5]).astype(numpy.int32)
25+
x = layers.create_tensor(dtype="int32", persistable=True, name="x")
26+
layers.assign(input=val, output=x)
27+
exe = fluid.Executor(fluid.CPUPlace())
28+
exe.run(fluid.default_main_program(), feed={}, fetch_list=[])
29+
fetched_x = fluid.fetch_var("x")
30+
self.assertTrue(
31+
numpy.array_equal(fetched_x, val),
32+
"fetch_x=%s val=%s" % (fetched_x, val))
33+
self.assertEqual(fetched_x.dtype, val.dtype)
34+
35+
36+
if __name__ == '__main__':
37+
unittest.main()

0 commit comments

Comments
 (0)