Skip to content

Commit 7fbb749

Browse files
committed
further test additions, added enable_shared_from_this to istatetree
1 parent 3da130d commit 7fbb749

File tree

4 files changed

+267
-27
lines changed

4 files changed

+267
-27
lines changed

open_spiel/algorithms/infostate_tree.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -485,7 +485,8 @@ absl::optional<DecisionId> InfostateTree::DecisionIdForSequence(
485485
}
486486
}
487487
absl::optional<InfostateNode*> InfostateTree::DecisionForSequence(
488-
const SequenceId& sequence_id) {
488+
const SequenceId& sequence_id) const
489+
{
489490
SPIEL_DCHECK_TRUE(sequence_id.BelongsToTree(this));
490491
InfostateNode* node = sequences_.at(sequence_id.id());
491492
SPIEL_DCHECK_TRUE(node);

open_spiel/algorithms/infostate_tree.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,7 @@ std::shared_ptr<InfostateTree> MakeInfostateTree(
288288
const std::vector<InfostateNode*>& start_nodes,
289289
int max_move_ahead_limit = 1000);
290290

291-
class InfostateTree final {
291+
class InfostateTree final : public std::enable_shared_from_this<InfostateTree> {
292292
// Note that only MakeInfostateTree is allowed to call the constructor
293293
// to ensure the trees are always allocated on heap. We do this so that all
294294
// the collected pointers are valid throughout the tree's lifetime even if
@@ -308,6 +308,10 @@ class InfostateTree final {
308308
const std::vector<const InfostateNode*>&, int);
309309

310310
public:
311+
// -- gain shared ownership of the allocated infostate object
312+
std::shared_ptr< InfostateTree > shared_ptr() { return shared_from_this(); }
313+
std::shared_ptr< const InfostateTree > shared_ptr() const { return shared_from_this(); }
314+
311315
// -- Root accessors ---------------------------------------------------------
312316
const InfostateNode& root() const { return *root_; }
313317
InfostateNode* mutable_root() { return root_.get(); }
@@ -347,7 +351,7 @@ class InfostateTree final {
347351
// Returns `None` if the sequence is the empty sequence.
348352
absl::optional<DecisionId> DecisionIdForSequence(const SequenceId&) const;
349353
// Returns `None` if the sequence is the empty sequence.
350-
absl::optional<InfostateNode*> DecisionForSequence(const SequenceId&);
354+
absl::optional<InfostateNode*> DecisionForSequence(const SequenceId& sequence_id) const;
351355
// Returns whether the sequence ends with the last action the player can make.
352356
bool IsLeafSequence(const SequenceId&) const;
353357

open_spiel/python/pybind11/algorithms_infostate_tree.cc

Lines changed: 44 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
namespace py = ::pybind11;
2323

24+
2425
namespace open_spiel {
2526

2627
using namespace algorithms;
@@ -31,7 +32,7 @@ using const_node_uniq_ptr = MockUniquePtr< const InfostateNode >;
3132
void init_pyspiel_infostate_node(::pybind11::module &m)
3233
{
3334
py::class_< InfostateNode, infostatenode_holder_ptr >(m, "InfostateNode", py::is_final())
34-
.def("tree", &InfostateNode::tree, py::return_value_policy::reference_internal)
35+
.def("tree", [](const InfostateNode &node) { return node.tree().shared_ptr(); })
3536
.def(
3637
"parent", [](const InfostateNode &node) { return infostatenode_holder_ptr{node.parent()}; }
3738
)
@@ -77,6 +78,7 @@ void init_pyspiel_infostate_node(::pybind11::module &m)
7778
},
7879
py::arg("index")
7980
)
81+
.def("make_certificate", &InfostateNode::MakeCertificate)
8082
.def(
8183
"__copy__",
8284
[](const InfostateNode &node) {
@@ -89,14 +91,22 @@ void init_pyspiel_infostate_node(::pybind11::module &m)
8991
);
9092
}
9193
)
92-
.def("__deepcopy__", [](const InfostateNode &node) {
93-
throw ForbiddenException(
94-
"InfostateNode cannot be copied, because its "
95-
"lifetime is managed by the owning "
96-
"InfostateTree. Store a variable naming the "
97-
"associated tree to ensure the node's "
98-
"lifetime."
99-
);
94+
.def(
95+
"__deepcopy__",
96+
[](const InfostateNode &node) {
97+
throw ForbiddenException(
98+
"InfostateNode cannot be copied, because its "
99+
"lifetime is managed by the owning "
100+
"InfostateTree. Store a variable naming the "
101+
"associated tree to ensure the node's "
102+
"lifetime."
103+
);
104+
}
105+
)
106+
.def("address_str", [](const InfostateNode &node) {
107+
std::stringstream ss;
108+
ss << &node;
109+
return ss.str();
100110
});
101111

102112
py::enum_< InfostateNodeType >(m, "InfostateNodeType")
@@ -162,7 +172,7 @@ void init_pyspiel_infostate_tree(::pybind11::module &m)
162172
m, "InfostateNodeVector2D"
163173
);
164174

165-
py::class_< InfostateTree, std::shared_ptr< InfostateTree > >(m, "InfostateTree")
175+
py::class_< InfostateTree, std::shared_ptr< InfostateTree > >(m, "InfostateTree", py::is_final())
166176
.def(
167177
py::init([](const Game &game, Player acting_player, int max_move_limit) {
168178
return MakeInfostateTree(game, acting_player, max_move_limit);
@@ -240,12 +250,6 @@ void init_pyspiel_infostate_tree(::pybind11::module &m)
240250
.def("is_leaf_sequence", &InfostateTree::IsLeafSequence)
241251
.def(
242252
"decision_infostate",
243-
[](InfostateTree &tree, const DecisionId &id) {
244-
return infostatenode_holder_ptr{tree.decision_infostate(id)};
245-
}
246-
)
247-
.def(
248-
"decision_infostate_view",
249253
[](const InfostateTree &tree, const DecisionId &id) {
250254
return const_node_uniq_ptr{tree.decision_infostate(id)};
251255
}
@@ -308,10 +312,30 @@ void init_pyspiel_infostate_tree(::pybind11::module &m)
308312
)
309313
.def("best_response", &InfostateTree::BestResponse, py::arg("gradient"))
310314
.def("best_response_value", &InfostateTree::BestResponseValue, py::arg("gradient"))
311-
.def("__repr__", [](const InfostateTree &tree) {
312-
std::ostringstream oss;
313-
oss << tree;
314-
return oss.str();
315+
.def(
316+
"__repr__",
317+
[](const InfostateTree &tree) {
318+
std::ostringstream oss;
319+
oss << tree;
320+
return oss.str();
321+
}
322+
)
323+
.def(
324+
"__copy__",
325+
[](const InfostateTree &) {
326+
throw ForbiddenException(
327+
"InfostateTree cannot be copied, because its "
328+
"internal structure is entangled during construction. "
329+
"Create a new tree instead."
330+
);
331+
}
332+
)
333+
.def("__deepcopy__", [](const InfostateTree &) {
334+
throw ForbiddenException(
335+
"InfostateTree cannot be copied, because its "
336+
"internal structure is entangled during construction. "
337+
"Create a new tree instead."
338+
);
315339
});
316340
}
317341

open_spiel/python/tests/infostate_tree_test.py

Lines changed: 215 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,223 @@
1414

1515
"""Test Python bindings for infostate tree and related classes."""
1616

17-
from absl.testing import absltest
17+
from absl.testing import absltest, parameterized
1818

1919
import pyspiel
20+
import gc
21+
from copy import copy, deepcopy
22+
import weakref
2023

2124

22-
class InfostateTreeTest(absltest.TestCase):
25+
class InfostateTreeTest(parameterized.TestCase):
26+
def test_tree_binding(self):
27+
game = pyspiel.load_game("kuhn_poker")
28+
tree = pyspiel.InfostateTree(game, 0)
29+
self.assertEqual(tree.num_sequences(), 13)
2330

24-
def test_binding(self):
25-
return False
31+
# disallowing copying is enforced
32+
with self.assertRaises(pyspiel.ForbiddenError) as context:
33+
copy(tree)
34+
deepcopy(tree)
35+
36+
def test_node_tree_lifetime_management(self):
37+
root0 = tree.root()
38+
# let's maintain a weak ref to the tree to see when the tree object is deallocated
39+
wptr = weakref.ref(tree)
40+
# ensure we can get a shared_ptr from root that keeps tree alive if we lose the 'tree' name
41+
wptr_node = weakref.ref(root0)
42+
tree_sptr = root0.tree()
43+
# grab the tree id
44+
id_tree0 = id(tree)
45+
# now delete the initial tree ptr
46+
del tree
47+
# ensure that we still hold the object
48+
gc.collect() # force garbage collection
49+
self.assertIsNotNone(wptr())
50+
self.assertEqual(id(tree_sptr), id_tree0)
51+
# now delete the last pointer as well
52+
del tree_sptr
53+
gc.collect() # force garbage collection
54+
self.assertIsNone(wptr())
55+
56+
@parameterized.parameters(
57+
[
58+
# test for matrix mp
59+
dict(
60+
game=pyspiel.load_game("matrix_mp"),
61+
players=[0, 1],
62+
expected_certificate="([" "({}{})" "({}{})" "])",
63+
),
64+
# test for imperfect info goofspiel
65+
dict(
66+
game=pyspiel.load_game(
67+
"goofspiel",
68+
{"num_cards": 2, "imp_info": True, "points_order": "ascending"},
69+
),
70+
players=[0, 1],
71+
expected_certificate="([" "({}{})" "({}{})" "])",
72+
),
73+
# test for kuhn poker (0 player only)
74+
dict(
75+
game=pyspiel.load_game("kuhn_poker"),
76+
players=[0],
77+
expected_certificate=(
78+
"((" # Root node, 1st is getting a card
79+
"(" # 2nd is getting card
80+
"[" # 1st acts
81+
"((" # 1st bet, and 2nd acts
82+
"(({}))"
83+
"(({}))"
84+
"(({}))"
85+
"(({}))"
86+
"))"
87+
"((" # 1st checks, and 2nd acts
88+
# 2nd checked
89+
"(({}))"
90+
"(({}))"
91+
# 2nd betted
92+
"[({}"
93+
"{})"
94+
"({}"
95+
"{})]"
96+
"))"
97+
"]"
98+
")"
99+
# Just 2 more copies.
100+
"([(((({}))(({}))(({}))(({}))))(((({}))(({}))[({}{})({}{})]))])"
101+
"([(((({}))(({}))(({}))(({}))))(((({}))(({}))[({}{})({}{})]))])"
102+
"))"
103+
),
104+
),
105+
]
106+
)
107+
def test_root_certificates(self, game, players, expected_certificate):
108+
for i in players:
109+
tree = pyspiel.InfostateTree(game, i)
110+
self.assertEqual(tree.root().make_certificate(), expected_certificate)
111+
112+
def check_tree_leaves(self, tree, move_limit):
113+
for leaf_node in tree.leaf_nodes():
114+
self.assertTrue(leaf_node.is_leaf_node())
115+
self.assertTrue(leaf_node.has_infostate_string())
116+
self.assertNotEmpty(leaf_node.corresponding_states())
117+
118+
num_states = len(leaf_node.corresponding_states())
119+
terminal_cnt = 0
120+
max_move_number = float("-inf")
121+
min_move_number = float("inf")
122+
for state in leaf_node.corresponding_states():
123+
if state.is_terminal():
124+
terminal_cnt += 1
125+
max_move_number = max(max_move_number, state.move_number())
126+
min_move_number = min(min_move_number, state.move_number())
127+
self.assertTrue(terminal_cnt == 0 or terminal_cnt == num_states)
128+
self.assertTrue(max_move_number == min_move_number)
129+
if terminal_cnt == 0:
130+
self.assertEqual(max_move_number, move_limit)
131+
else:
132+
self.assertLessEqual(max_move_number, move_limit)
133+
134+
def check_continuation(self, tree):
135+
leaves = tree.nodes_at_depth(tree.tree_height())
136+
continuation = pyspiel.InfostateTree(leaves)
137+
self.assertEqual(continuation.root_branching_factor(), len(leaves))
138+
for i in range(len(leaves)):
139+
leaf_node = leaves[i]
140+
root_node = continuation.root().child_at(i)
141+
self.assertTrue(leaf_node.is_leaf_node())
142+
if leaf_node.type() != pyspiel.InfostateNodeType.terminal:
143+
self.assertEqual(leaf_node.type(), root_node.type())
144+
self.assertEqual(
145+
leaf_node.has_infostate_string(), root_node.has_infostate_string()
146+
)
147+
if leaf_node.has_infostate_string():
148+
self.assertEqual(
149+
leaf_node.infostate_string(), root_node.infostate_string()
150+
)
151+
else:
152+
terminal_continuation = continuation.root().child_at(i)
153+
while (
154+
terminal_continuation.type()
155+
== pyspiel.InfostateNodeType.observation
156+
):
157+
self.assertFalse(terminal_continuation.is_leaf_node())
158+
self.assertEqual(terminal_continuation.num_children(), 1)
159+
terminal_continuation = terminal_continuation.child_at(0)
160+
self.assertEqual(
161+
terminal_continuation.type(), pyspiel.InfostateNodeType.terminal
162+
)
163+
self.assertEqual(
164+
leaf_node.has_infostate_string(),
165+
terminal_continuation.has_infostate_string(),
166+
)
167+
if leaf_node.has_infostate_string():
168+
self.assertEqual(
169+
leaf_node.infostate_string(),
170+
terminal_continuation.infostate_string(),
171+
)
172+
self.assertEqual(
173+
leaf_node.terminal_utility(),
174+
terminal_continuation.terminal_utility(),
175+
)
176+
self.assertEqual(
177+
leaf_node.terminal_chance_reach_prob(),
178+
terminal_continuation.terminal_chance_reach_prob(),
179+
)
180+
self.assertEqual(
181+
leaf_node.terminal_history(),
182+
terminal_continuation.terminal_history(),
183+
)
184+
185+
def test_depth_limited_tree_kuhn_poker(self):
186+
# Test MakeTree for Kuhn Poker with depth limit 2
187+
expected_certificate = (
188+
"(" # <dummy>
189+
"(" # 1st is getting a card
190+
"(" # 2nd is getting card
191+
"[" # 1st acts - Node J
192+
# Depth cutoff.
193+
"]"
194+
")"
195+
# Repeat the same for the two other cards.
196+
"([])" # Node Q
197+
"([])" # Node K
198+
")"
199+
")" # </dummy>
200+
)
201+
tree = pyspiel.InfostateTree(pyspiel.load_game("kuhn_poker"), 0, 2)
202+
self.assertEqual(tree.root().make_certificate(), expected_certificate)
203+
204+
# Test leaf nodes in Kuhn Poker tree
205+
for acting in tree.leaf_nodes():
206+
self.assertTrue(acting.is_leaf_node())
207+
self.assertEqual(acting.type(), pyspiel.InfostateNodeType.decision)
208+
self.assertEqual(len(acting.corresponding_states()), 2)
209+
self.assertTrue(acting.has_infostate_string())
210+
211+
@parameterized.parameters(
212+
[
213+
"kuhn_poker",
214+
"kuhn_poker(players=3)",
215+
"leduc_poker",
216+
"goofspiel(players=2,num_cards=3,imp_info=True)",
217+
"goofspiel(players=3,num_cards=3,imp_info=True)",
218+
]
219+
)
220+
def test_depth_limited_trees_all_depths(self, game_name):
221+
game = pyspiel.load_game(game_name)
222+
max_moves = game.max_move_number()
223+
for move_limit in range(max_moves):
224+
for pl in range(game.num_players()):
225+
tree = pyspiel.InfostateTree(game, pl, move_limit)
226+
self.check_tree_leaves(tree, move_limit)
227+
self.check_continuation(tree)
228+
229+
def test_node_binding(self):
230+
with self.assertRaises(TypeError) as context:
231+
pyspiel.InfostateNode()
232+
self.assertTrue("No constructor defined" in context.exception)
233+
234+
235+
if __name__ == "__main__":
236+
absltest.main()

0 commit comments

Comments
 (0)