Skip to content

Commit 81ebcca

Browse files
committed
Assert expected results in test_multiple_outs_taps
1 parent 4eded29 commit 81ebcca

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

tests/scan/test_basic.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3753,8 +3753,6 @@ def f_rnn_cmpl(u1_t, u2_tm1, u2_t, u2_tp1, x_tm1, y_tm1, y_tm3, W_in1):
37533753
[u1, u2, x0, y0, W_in1], outputs, updates=updates, allow_input_downcast=True
37543754
)
37553755

3756-
f(v_u1, v_u2, v_x0, v_y0, vW_in1)
3757-
37583756
ny0 = np.zeros((5, 2))
37593757
ny1 = np.zeros((5,))
37603758
ny2 = np.zeros((5, 2))
@@ -3802,7 +3800,10 @@ def f_rnn_cmpl(u1_t, u2_tm1, u2_t, u2_tp1, x_tm1, y_tm1, y_tm3, W_in1):
38023800
ny1[4] = (ny1[3] + ny1[1]) * np.dot(ny0[3], vWout)
38033801
ny2[4] = np.dot(v_u1[4], vW_in1)
38043802

3805-
# TODO FIXME: What is this testing? At least assert something.
3803+
res = f(v_u1, v_u2, v_x0, v_y0, vW_in1)
3804+
np.testing.assert_almost_equal(res[0], ny0)
3805+
np.testing.assert_almost_equal(res[1], ny1)
3806+
np.testing.assert_almost_equal(res[2], ny2)
38063807

38073808
def _grad_mout_helper(self, n_iters, mode):
38083809
rng = np.random.default_rng(utt.fetch_seed())

0 commit comments

Comments
 (0)