|
14 | 14 |
|
15 | 15 | """Test Python bindings for infostate tree and related classes."""
|
16 | 16 |
|
17 |
| -from absl.testing import absltest |
| 17 | +from absl.testing import absltest, parameterized |
18 | 18 |
|
19 | 19 | import pyspiel
|
| 20 | +import gc |
| 21 | +from copy import copy, deepcopy |
| 22 | +import weakref |
20 | 23 |
|
21 | 24 |
|
22 |
| -class InfostateTreeTest(absltest.TestCase): |
| 25 | +class InfostateTreeTest(parameterized.TestCase): |
| 26 | + def test_tree_binding(self): |
| 27 | + game = pyspiel.load_game("kuhn_poker") |
| 28 | + tree = pyspiel.InfostateTree(game, 0) |
| 29 | + self.assertEqual(tree.num_sequences(), 13) |
23 | 30 |
|
24 |
| - def test_binding(self): |
25 |
| - return False |
| 31 | + # disallowing copying is enforced |
| 32 | + with self.assertRaises(pyspiel.ForbiddenError) as context: |
| 33 | + copy(tree) |
| 34 | + deepcopy(tree) |
| 35 | + |
| 36 | + def test_node_tree_lifetime_management(self): |
| 37 | + root0 = tree.root() |
| 38 | + # let's maintain a weak ref to the tree to see when the tree object is deallocated |
| 39 | + wptr = weakref.ref(tree) |
| 40 | + # ensure we can get a shared_ptr from root that keeps tree alive if we lose the 'tree' name |
| 41 | + wptr_node = weakref.ref(root0) |
| 42 | + tree_sptr = root0.tree() |
| 43 | + # grab the tree id |
| 44 | + id_tree0 = id(tree) |
| 45 | + # now delete the initial tree ptr |
| 46 | + del tree |
| 47 | + # ensure that we still hold the object |
| 48 | + gc.collect() # force garbage collection |
| 49 | + self.assertIsNotNone(wptr()) |
| 50 | + self.assertEqual(id(tree_sptr), id_tree0) |
| 51 | + # now delete the last pointer as well |
| 52 | + del tree_sptr |
| 53 | + gc.collect() # force garbage collection |
| 54 | + self.assertIsNone(wptr()) |
| 55 | + |
| 56 | + @parameterized.parameters( |
| 57 | + [ |
| 58 | + # test for matrix mp |
| 59 | + dict( |
| 60 | + game=pyspiel.load_game("matrix_mp"), |
| 61 | + players=[0, 1], |
| 62 | + expected_certificate="([" "({}{})" "({}{})" "])", |
| 63 | + ), |
| 64 | + # test for imperfect info goofspiel |
| 65 | + dict( |
| 66 | + game=pyspiel.load_game( |
| 67 | + "goofspiel", |
| 68 | + {"num_cards": 2, "imp_info": True, "points_order": "ascending"}, |
| 69 | + ), |
| 70 | + players=[0, 1], |
| 71 | + expected_certificate="([" "({}{})" "({}{})" "])", |
| 72 | + ), |
| 73 | + # test for kuhn poker (0 player only) |
| 74 | + dict( |
| 75 | + game=pyspiel.load_game("kuhn_poker"), |
| 76 | + players=[0], |
| 77 | + expected_certificate=( |
| 78 | + "((" # Root node, 1st is getting a card |
| 79 | + "(" # 2nd is getting card |
| 80 | + "[" # 1st acts |
| 81 | + "((" # 1st bet, and 2nd acts |
| 82 | + "(({}))" |
| 83 | + "(({}))" |
| 84 | + "(({}))" |
| 85 | + "(({}))" |
| 86 | + "))" |
| 87 | + "((" # 1st checks, and 2nd acts |
| 88 | + # 2nd checked |
| 89 | + "(({}))" |
| 90 | + "(({}))" |
| 91 | + # 2nd betted |
| 92 | + "[({}" |
| 93 | + "{})" |
| 94 | + "({}" |
| 95 | + "{})]" |
| 96 | + "))" |
| 97 | + "]" |
| 98 | + ")" |
| 99 | + # Just 2 more copies. |
| 100 | + "([(((({}))(({}))(({}))(({}))))(((({}))(({}))[({}{})({}{})]))])" |
| 101 | + "([(((({}))(({}))(({}))(({}))))(((({}))(({}))[({}{})({}{})]))])" |
| 102 | + "))" |
| 103 | + ), |
| 104 | + ), |
| 105 | + ] |
| 106 | + ) |
| 107 | + def test_root_certificates(self, game, players, expected_certificate): |
| 108 | + for i in players: |
| 109 | + tree = pyspiel.InfostateTree(game, i) |
| 110 | + self.assertEqual(tree.root().make_certificate(), expected_certificate) |
| 111 | + |
| 112 | + def check_tree_leaves(self, tree, move_limit): |
| 113 | + for leaf_node in tree.leaf_nodes(): |
| 114 | + self.assertTrue(leaf_node.is_leaf_node()) |
| 115 | + self.assertTrue(leaf_node.has_infostate_string()) |
| 116 | + self.assertNotEmpty(leaf_node.corresponding_states()) |
| 117 | + |
| 118 | + num_states = len(leaf_node.corresponding_states()) |
| 119 | + terminal_cnt = 0 |
| 120 | + max_move_number = float("-inf") |
| 121 | + min_move_number = float("inf") |
| 122 | + for state in leaf_node.corresponding_states(): |
| 123 | + if state.is_terminal(): |
| 124 | + terminal_cnt += 1 |
| 125 | + max_move_number = max(max_move_number, state.move_number()) |
| 126 | + min_move_number = min(min_move_number, state.move_number()) |
| 127 | + self.assertTrue(terminal_cnt == 0 or terminal_cnt == num_states) |
| 128 | + self.assertTrue(max_move_number == min_move_number) |
| 129 | + if terminal_cnt == 0: |
| 130 | + self.assertEqual(max_move_number, move_limit) |
| 131 | + else: |
| 132 | + self.assertLessEqual(max_move_number, move_limit) |
| 133 | + |
| 134 | + def check_continuation(self, tree): |
| 135 | + leaves = tree.nodes_at_depth(tree.tree_height()) |
| 136 | + continuation = pyspiel.InfostateTree(leaves) |
| 137 | + self.assertEqual(continuation.root_branching_factor(), len(leaves)) |
| 138 | + for i in range(len(leaves)): |
| 139 | + leaf_node = leaves[i] |
| 140 | + root_node = continuation.root().child_at(i) |
| 141 | + self.assertTrue(leaf_node.is_leaf_node()) |
| 142 | + if leaf_node.type() != pyspiel.InfostateNodeType.terminal: |
| 143 | + self.assertEqual(leaf_node.type(), root_node.type()) |
| 144 | + self.assertEqual( |
| 145 | + leaf_node.has_infostate_string(), root_node.has_infostate_string() |
| 146 | + ) |
| 147 | + if leaf_node.has_infostate_string(): |
| 148 | + self.assertEqual( |
| 149 | + leaf_node.infostate_string(), root_node.infostate_string() |
| 150 | + ) |
| 151 | + else: |
| 152 | + terminal_continuation = continuation.root().child_at(i) |
| 153 | + while ( |
| 154 | + terminal_continuation.type() |
| 155 | + == pyspiel.InfostateNodeType.observation |
| 156 | + ): |
| 157 | + self.assertFalse(terminal_continuation.is_leaf_node()) |
| 158 | + self.assertEqual(terminal_continuation.num_children(), 1) |
| 159 | + terminal_continuation = terminal_continuation.child_at(0) |
| 160 | + self.assertEqual( |
| 161 | + terminal_continuation.type(), pyspiel.InfostateNodeType.terminal |
| 162 | + ) |
| 163 | + self.assertEqual( |
| 164 | + leaf_node.has_infostate_string(), |
| 165 | + terminal_continuation.has_infostate_string(), |
| 166 | + ) |
| 167 | + if leaf_node.has_infostate_string(): |
| 168 | + self.assertEqual( |
| 169 | + leaf_node.infostate_string(), |
| 170 | + terminal_continuation.infostate_string(), |
| 171 | + ) |
| 172 | + self.assertEqual( |
| 173 | + leaf_node.terminal_utility(), |
| 174 | + terminal_continuation.terminal_utility(), |
| 175 | + ) |
| 176 | + self.assertEqual( |
| 177 | + leaf_node.terminal_chance_reach_prob(), |
| 178 | + terminal_continuation.terminal_chance_reach_prob(), |
| 179 | + ) |
| 180 | + self.assertEqual( |
| 181 | + leaf_node.terminal_history(), |
| 182 | + terminal_continuation.terminal_history(), |
| 183 | + ) |
| 184 | + |
| 185 | + def test_depth_limited_tree_kuhn_poker(self): |
| 186 | + # Test MakeTree for Kuhn Poker with depth limit 2 |
| 187 | + expected_certificate = ( |
| 188 | + "(" # <dummy> |
| 189 | + "(" # 1st is getting a card |
| 190 | + "(" # 2nd is getting card |
| 191 | + "[" # 1st acts - Node J |
| 192 | + # Depth cutoff. |
| 193 | + "]" |
| 194 | + ")" |
| 195 | + # Repeat the same for the two other cards. |
| 196 | + "([])" # Node Q |
| 197 | + "([])" # Node K |
| 198 | + ")" |
| 199 | + ")" # </dummy> |
| 200 | + ) |
| 201 | + tree = pyspiel.InfostateTree(pyspiel.load_game("kuhn_poker"), 0, 2) |
| 202 | + self.assertEqual(tree.root().make_certificate(), expected_certificate) |
| 203 | + |
| 204 | + # Test leaf nodes in Kuhn Poker tree |
| 205 | + for acting in tree.leaf_nodes(): |
| 206 | + self.assertTrue(acting.is_leaf_node()) |
| 207 | + self.assertEqual(acting.type(), pyspiel.InfostateNodeType.decision) |
| 208 | + self.assertEqual(len(acting.corresponding_states()), 2) |
| 209 | + self.assertTrue(acting.has_infostate_string()) |
| 210 | + |
| 211 | + @parameterized.parameters( |
| 212 | + [ |
| 213 | + "kuhn_poker", |
| 214 | + "kuhn_poker(players=3)", |
| 215 | + "leduc_poker", |
| 216 | + "goofspiel(players=2,num_cards=3,imp_info=True)", |
| 217 | + "goofspiel(players=3,num_cards=3,imp_info=True)", |
| 218 | + ] |
| 219 | + ) |
| 220 | + def test_depth_limited_trees_all_depths(self, game_name): |
| 221 | + game = pyspiel.load_game(game_name) |
| 222 | + max_moves = game.max_move_number() |
| 223 | + for move_limit in range(max_moves): |
| 224 | + for pl in range(game.num_players()): |
| 225 | + tree = pyspiel.InfostateTree(game, pl, move_limit) |
| 226 | + self.check_tree_leaves(tree, move_limit) |
| 227 | + self.check_continuation(tree) |
| 228 | + |
| 229 | + def test_node_binding(self): |
| 230 | + with self.assertRaises(TypeError) as context: |
| 231 | + pyspiel.InfostateNode() |
| 232 | + self.assertTrue("No constructor defined" in context.exception) |
| 233 | + |
| 234 | + |
| 235 | +if __name__ == "__main__": |
| 236 | + absltest.main() |
0 commit comments