Skip to content

Commit f7a4340

Browse files
DeepMind Technologies Ltdlanctot
authored andcommitted
Use smart_holder branch machinery to be able to pass a Bot instance by both unique_ptr and shared_ptr interchangeably.
PiperOrigin-RevId: 517914864 Change-Id: Ia4a10b798d81d622f681ef769082ca01d833238d
1 parent c017cdb commit f7a4340

File tree

2 files changed

+15
-7
lines changed

2 files changed

+15
-7
lines changed

open_spiel/python/pybind11/bots.cc

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,7 @@
1717
#include <stdint.h>
1818

1919
#include <memory>
20-
#include <new>
2120
#include <string>
22-
#include <utility>
2321

2422
#include "open_spiel/algorithms/evaluate_bots.h"
2523
#include "open_spiel/algorithms/is_mcts.h"
@@ -99,8 +97,7 @@ class PyBot : public Bot {
9997
"inform_action", // Name of function in Python
10098
InformAction, // Name of function in C++
10199
state, // Arguments
102-
player_id,
103-
action);
100+
player_id, action);
104101
}
105102
void InformActions(const State& state,
106103
const std::vector<Action>& actions) override {
@@ -153,7 +150,7 @@ class PyBot : public Bot {
153150
} // namespace
154151

155152
void init_pyspiel_bots(py::module& m) {
156-
py::class_<Bot, PyBot> bot(m, "Bot");
153+
py::classh<Bot, PyBot> bot(m, "Bot");
157154
bot.def(py::init<>())
158155
.def("step", &Bot::Step)
159156
.def("restart", &Bot::Restart)
@@ -227,7 +224,7 @@ void init_pyspiel_bots(py::module& m) {
227224
.def("to_string", &SearchNode::ToString)
228225
.def("children_str", &SearchNode::ChildrenStr);
229226

230-
py::class_<algorithms::MCTSBot, Bot>(m, "MCTSBot")
227+
py::classh<algorithms::MCTSBot, Bot>(m, "MCTSBot")
231228
.def(
232229
py::init([](std::shared_ptr<const Game> game,
233230
std::shared_ptr<Evaluator> evaluator, double uct_c,
@@ -253,7 +250,7 @@ void init_pyspiel_bots(py::module& m) {
253250
algorithms::ISMCTSFinalPolicyType::kMaxVisitCount)
254251
.value("MAX_VALUE", algorithms::ISMCTSFinalPolicyType::kMaxValue);
255252

256-
py::class_<algorithms::ISMCTSBot, Bot>(m, "ISMCTSBot")
253+
py::classh<algorithms::ISMCTSBot, Bot>(m, "ISMCTSBot")
257254
.def(py::init<int, std::shared_ptr<Evaluator>, double, int, int,
258255
algorithms::ISMCTSFinalPolicyType, bool, bool>(),
259256
py::arg("seed"), py::arg("evaluator"), py::arg("uct_c"),

open_spiel/python/pybind11/pybind11.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,9 @@
3434
// in one place to help with consistency.
3535

3636
namespace open_spiel {
37+
3738
class NormalFormGame;
39+
class Bot;
3840

3941
namespace matrix_game {
4042
class MatrixGame;
@@ -43,13 +45,22 @@ class MatrixGame;
4345
namespace tensor_game {
4446
class TensorGame;
4547
}
48+
49+
namespace algorithms {
50+
class MCTSBot;
51+
class ISMCTSBot;
52+
} // namespace algorithms
53+
4654
} // namespace open_spiel
4755

4856
PYBIND11_SMART_HOLDER_TYPE_CASTERS(open_spiel::State);
4957
PYBIND11_SMART_HOLDER_TYPE_CASTERS(open_spiel::Game);
5058
PYBIND11_SMART_HOLDER_TYPE_CASTERS(open_spiel::NormalFormGame);
5159
PYBIND11_SMART_HOLDER_TYPE_CASTERS(open_spiel::matrix_game::MatrixGame);
5260
PYBIND11_SMART_HOLDER_TYPE_CASTERS(open_spiel::tensor_game::TensorGame);
61+
PYBIND11_SMART_HOLDER_TYPE_CASTERS(open_spiel::Bot);
62+
PYBIND11_SMART_HOLDER_TYPE_CASTERS(open_spiel::algorithms::MCTSBot);
63+
PYBIND11_SMART_HOLDER_TYPE_CASTERS(open_spiel::algorithms::ISMCTSBot);
5364

5465
// Custom caster for GameParameter (essentially a variant).
5566
namespace pybind11 {

0 commit comments

Comments
 (0)