Skip to content

Commit d5f1e30

Browse files
authored
Add 0d Tensor Test Cases for cond, case, switch_case (#49544)
Add 0d Tensor Test Cases for cond, case, switch_case. Since the 3 APIs are control flow APIs, their support for 0d tensor relies on the underneath APIs. This PR just added test cases to prove that the 3 APIs have already handled 0d tensor well.
1 parent 5feadc0 commit d5f1e30

File tree

3 files changed

+516
-0
lines changed

3 files changed

+516
-0
lines changed

python/paddle/fluid/tests/unittests/test_case.py

Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,67 @@ def fn_3():
8989
np.testing.assert_allclose(res[3], 2, rtol=1e-05)
9090
np.testing.assert_allclose(res[4], 2, rtol=1e-05)
9191

92+
def test_0d_tensor(self):
93+
def fn_1():
94+
return paddle.full(shape=[], dtype='int32', fill_value=1)
95+
96+
def fn_2():
97+
return paddle.full(shape=[], dtype='int32', fill_value=2)
98+
99+
def fn_3():
100+
return paddle.full(shape=[], dtype='int32', fill_value=3)
101+
102+
main_program = Program()
103+
startup_program = Program()
104+
with program_guard(main_program, startup_program):
105+
x = paddle.full(shape=[], dtype='float32', fill_value=0.3)
106+
y = paddle.full(shape=[], dtype='float32', fill_value=0.1)
107+
z = paddle.full(shape=[], dtype='float32', fill_value=0.2)
108+
pred_2 = paddle.less_than(x, y) # false: 0.3 < 0.1
109+
pred_1 = paddle.less_than(z, x) # true: 0.2 < 0.3
110+
111+
# call fn_1
112+
out_0 = paddle.static.nn.control_flow.case(
113+
pred_fn_pairs=[(pred_1, fn_1), (pred_1, fn_2)], default=fn_3
114+
)
115+
116+
# call fn_2
117+
out_1 = paddle.static.nn.control_flow.case(
118+
pred_fn_pairs=[(pred_2, fn_1), (pred_1, fn_2)], default=fn_3
119+
)
120+
121+
# call default fn_3
122+
out_2 = paddle.static.nn.control_flow.case(
123+
pred_fn_pairs=((pred_2, fn_1), (pred_2, fn_2)), default=fn_3
124+
)
125+
126+
# no default, call fn_2
127+
out_3 = paddle.static.nn.control_flow.case(
128+
pred_fn_pairs=[(pred_1, fn_2)]
129+
)
130+
131+
# no default, call fn_2. but pred_2 is false
132+
out_4 = paddle.static.nn.control_flow.case(
133+
pred_fn_pairs=[(pred_2, fn_2)]
134+
)
135+
136+
place = (
137+
fluid.CUDAPlace(0)
138+
if core.is_compiled_with_cuda()
139+
else fluid.CPUPlace()
140+
)
141+
exe = fluid.Executor(place)
142+
143+
res = exe.run(
144+
main_program, fetch_list=[out_0, out_1, out_2, out_3, out_4]
145+
)
146+
147+
np.testing.assert_allclose(res[0], 1, rtol=1e-05)
148+
np.testing.assert_allclose(res[1], 2, rtol=1e-05)
149+
np.testing.assert_allclose(res[2], 3, rtol=1e-05)
150+
np.testing.assert_allclose(res[3], 2, rtol=1e-05)
151+
np.testing.assert_allclose(res[4], 2, rtol=1e-05)
152+
92153
def test_return_var_tuple(self):
93154
def fn_1():
94155
return layers.fill_constant(
@@ -236,6 +297,106 @@ def fn_3():
236297
np.testing.assert_allclose(res[1], 2, rtol=1e-05)
237298
np.testing.assert_allclose(res[2], 3, rtol=1e-05)
238299

300+
def test_nested_0d_tensor(self):
301+
def fn_1(x=1):
302+
var_5 = paddle.full(shape=[], dtype='int32', fill_value=5)
303+
var_6 = paddle.full(shape=[], dtype='int32', fill_value=6)
304+
out = paddle.static.nn.control_flow.case(
305+
pred_fn_pairs=[
306+
(
307+
var_5 < var_6,
308+
partial(
309+
paddle.full,
310+
shape=[],
311+
dtype='int32',
312+
fill_value=x,
313+
),
314+
),
315+
(
316+
var_5 == var_6,
317+
partial(
318+
paddle.full,
319+
shape=[],
320+
dtype='int32',
321+
fill_value=x,
322+
),
323+
),
324+
]
325+
)
326+
return out
327+
328+
def fn_2(x=2):
329+
var_5 = paddle.full(shape=[], dtype='int32', fill_value=5)
330+
var_6 = paddle.full(shape=[], dtype='int32', fill_value=6)
331+
out = paddle.static.nn.control_flow.case(
332+
pred_fn_pairs=[
333+
(var_5 < var_6, partial(fn_1, x=x)),
334+
(
335+
var_5 == var_6,
336+
partial(
337+
paddle.full,
338+
shape=[],
339+
dtype='int32',
340+
fill_value=x,
341+
),
342+
),
343+
]
344+
)
345+
return out
346+
347+
def fn_3():
348+
var_5 = paddle.full(shape=[], dtype='int32', fill_value=5)
349+
var_6 = paddle.full(shape=[], dtype='int32', fill_value=6)
350+
out = paddle.static.nn.control_flow.case(
351+
pred_fn_pairs=[
352+
(var_5 < var_6, partial(fn_2, x=3)),
353+
(
354+
var_5 == var_6,
355+
partial(
356+
paddle.full,
357+
shape=[],
358+
dtype='int32',
359+
fill_value=7,
360+
),
361+
),
362+
]
363+
)
364+
return out
365+
366+
main_program = Program()
367+
startup_program = Program()
368+
with program_guard(main_program, startup_program):
369+
x = paddle.full(shape=[], dtype='float32', fill_value=0.3)
370+
y = paddle.full(shape=[], dtype='float32', fill_value=0.1)
371+
z = paddle.full(shape=[], dtype='float32', fill_value=0.2)
372+
pred_2 = paddle.less_than(x, y) # false: 0.3 < 0.1
373+
pred_1 = paddle.less_than(z, x) # true: 0.2 < 0.3
374+
375+
out_1 = paddle.static.nn.control_flow.case(
376+
pred_fn_pairs=[(pred_1, fn_1), (pred_2, fn_2)], default=fn_3
377+
)
378+
379+
out_2 = paddle.static.nn.control_flow.case(
380+
pred_fn_pairs=[(pred_2, fn_1), (pred_1, fn_2)], default=fn_3
381+
)
382+
383+
out_3 = paddle.static.nn.control_flow.case(
384+
pred_fn_pairs=[(x == y, fn_1), (x == z, fn_2)], default=fn_3
385+
)
386+
387+
place = (
388+
fluid.CUDAPlace(0)
389+
if core.is_compiled_with_cuda()
390+
else fluid.CPUPlace()
391+
)
392+
exe = fluid.Executor(place)
393+
394+
res = exe.run(main_program, fetch_list=[out_1, out_2, out_3])
395+
396+
np.testing.assert_allclose(res[0], 1, rtol=1e-05)
397+
np.testing.assert_allclose(res[1], 2, rtol=1e-05)
398+
np.testing.assert_allclose(res[2], 3, rtol=1e-05)
399+
239400

240401
class TestAPICase_Error(unittest.TestCase):
241402
def test_error(self):

python/paddle/fluid/tests/unittests/test_cond.py

Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,115 @@ def false_func():
6868
np.asarray(ret), np.full((3, 2), -1, np.int32), rtol=1e-05
6969
)
7070

71+
def test_return_0d_tensor(self):
72+
"""
73+
pseudocode:
74+
75+
if 0.23 >= 0.1:
76+
return 2
77+
else:
78+
return -1
79+
"""
80+
81+
paddle.enable_static()
82+
83+
def true_func():
84+
return paddle.full(shape=[], dtype='int32', fill_value=2)
85+
86+
def false_func():
87+
return paddle.full(shape=[], dtype='int32', fill_value=-1)
88+
89+
main_program = Program()
90+
startup_program = Program()
91+
with program_guard(main_program, startup_program):
92+
x = paddle.full(shape=[1], dtype='float32', fill_value=0.1)
93+
y = paddle.full(shape=[1], dtype='float32', fill_value=0.23)
94+
pred = paddle.greater_equal(y, x)
95+
out = paddle.static.nn.cond(pred, true_func, false_func)
96+
# out is one tensor
97+
98+
place = (
99+
fluid.CUDAPlace(0)
100+
if core.is_compiled_with_cuda()
101+
else fluid.CPUPlace()
102+
)
103+
exe = fluid.Executor(place)
104+
(ret,) = exe.run(main_program, fetch_list=[out.name])
105+
np.testing.assert_allclose(np.asarray(ret), np.array(2), rtol=1e-05)
106+
107+
def test_0d_tensor_as_cond(self):
108+
"""
109+
pseudocode:
110+
111+
if 0.23 >= 0.1:
112+
return 2
113+
else:
114+
return -1
115+
"""
116+
117+
paddle.enable_static()
118+
119+
def true_func():
120+
return paddle.full(shape=[3, 3], dtype='int32', fill_value=2)
121+
122+
def false_func():
123+
return paddle.full(shape=[3, 3], dtype='int32', fill_value=-1)
124+
125+
main_program = Program()
126+
startup_program = Program()
127+
with program_guard(main_program, startup_program):
128+
x = paddle.full(shape=[], dtype='float32', fill_value=0.1)
129+
y = paddle.full(shape=[], dtype='float32', fill_value=0.23)
130+
pred = paddle.greater_equal(y, x)
131+
out = paddle.static.nn.cond(pred, true_func, false_func)
132+
# out is one tensor
133+
134+
place = (
135+
fluid.CUDAPlace(0)
136+
if core.is_compiled_with_cuda()
137+
else fluid.CPUPlace()
138+
)
139+
exe = fluid.Executor(place)
140+
(ret,) = exe.run(main_program, fetch_list=[out.name])
141+
np.testing.assert_allclose(
142+
np.asarray(ret), np.full((3, 3), 2, np.int32), rtol=1e-05
143+
)
144+
145+
def test_0d_tensor_backward(self):
146+
"""
147+
pseudocode:
148+
149+
a = -2.0
150+
if a >= 0:
151+
return a
152+
else:
153+
return -a
154+
"""
155+
156+
paddle.enable_static()
157+
158+
main_program = Program()
159+
startup_program = Program()
160+
with program_guard(main_program, startup_program):
161+
a = paddle.full(shape=[], dtype='float32', fill_value=-2.0)
162+
a.stop_gradient = False
163+
out = paddle.static.nn.cond(a >= 0, lambda: a, lambda: -a)
164+
append_backward(out)
165+
166+
place = (
167+
fluid.CUDAPlace(0)
168+
if core.is_compiled_with_cuda()
169+
else fluid.CPUPlace()
170+
)
171+
exe = fluid.Executor(place)
172+
ret = exe.run(main_program, fetch_list=[out.name, a.grad_name])
173+
np.testing.assert_allclose(
174+
np.asarray(ret[0]), np.array(2.0), rtol=1e-05
175+
)
176+
np.testing.assert_allclose(
177+
np.asarray(ret[1]), np.array(-1.0), rtol=1e-05
178+
)
179+
71180
def test_return_var_tuple(self):
72181
"""
73182
pseudocode:
@@ -358,6 +467,70 @@ def greater_equal_branch(i, a):
358467
self.assertEqual(ret[0][0], expected_ret)
359468
self.assertEqual(ret[1][0], expected_a_grad)
360469

470+
def test_cond_inside_cond_0d_tensor(self):
471+
"""
472+
pseudocode:
473+
i = 3.0
474+
a = 2 * i
475+
if i < 5:
476+
if i >= 3:
477+
return a + 1
478+
else:
479+
return 1 - a
480+
else:
481+
if i < 8:
482+
return a * 2
483+
else:
484+
return a / 2
485+
"""
486+
487+
paddle.enable_static()
488+
489+
def less_than_branch(i, a):
490+
return paddle.static.nn.cond(
491+
i >= 3.0,
492+
lambda: a + 1,
493+
lambda: 1 - a,
494+
)
495+
496+
def greater_equal_branch(i, a):
497+
return paddle.static.nn.cond(
498+
i < 8.0,
499+
lambda: a * 2,
500+
lambda: a / 2,
501+
)
502+
503+
main_program = Program()
504+
startup_program = Program()
505+
with program_guard(main_program, startup_program):
506+
i = paddle.full(fill_value=3.0, shape=[], dtype='float32')
507+
i.stop_gradient = False
508+
a = 2.0 * i
509+
out = paddle.static.nn.cond(
510+
i < 5.0,
511+
lambda: less_than_branch(i, a),
512+
lambda: greater_equal_branch(i, a),
513+
)
514+
mean = paddle.mean(out)
515+
append_backward(out)
516+
517+
place = (
518+
fluid.CUDAPlace(0)
519+
if core.is_compiled_with_cuda()
520+
else fluid.CPUPlace()
521+
)
522+
exe = fluid.Executor(place)
523+
ret = exe.run(
524+
main_program,
525+
fetch_list=[out.name, i.grad_name],
526+
)
527+
np.testing.assert_allclose(
528+
np.asarray(ret[0]), np.array(7.0), rtol=1e-05
529+
)
530+
np.testing.assert_allclose(
531+
np.asarray(ret[1]), np.array(2.0), rtol=1e-05
532+
)
533+
361534
def test_cond_op_in_condition(self):
362535
paddle.enable_static()
363536
main_program = fluid.Program()

0 commit comments

Comments
 (0)