Skip to content

Commit 14ec0ae

Browse files
authored
Merge pull request #1005 from borglab/feature/better_decision_tree
2 parents 53b4053 + 7f87a4c commit 14ec0ae

File tree

4 files changed

+244
-68
lines changed

4 files changed

+244
-68
lines changed

gtsam/discrete/DecisionTree-inl.h

Lines changed: 117 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
#include <boost/tuple/tuple.hpp>
2929
#include <boost/type_traits/has_dereference.hpp>
3030
#include <boost/unordered_set.hpp>
31+
#include <boost/make_shared.hpp>
3132
#include <cmath>
3233
#include <fstream>
3334
#include <list>
@@ -82,13 +83,7 @@ namespace gtsam {
8283
return compare(this->constant_, other->constant_);
8384
}
8485

85-
/**
86-
* @brief Print method.
87-
*
88-
* @param s Prefix string.
89-
* @param labelFormatter Functor to format the labels of type L.
90-
* @param valueFormatter Functor to format the values of type Y.
91-
*/
86+
/** print */
9287
void print(const std::string& s, const LabelFormatter& labelFormatter,
9388
const ValueFormatter& valueFormatter) const override {
9489
std::cout << s << " Leaf " << valueFormatter(constant_) << std::endl;
@@ -332,7 +327,7 @@ namespace gtsam {
332327

333328
/** apply unary operator */
334329
NodePtr apply(const Unary& op) const override {
335-
boost::shared_ptr<Choice> r(new Choice(label_, *this, op));
330+
auto r = boost::make_shared<Choice>(label_, *this, op);
336331
return Unique(r);
337332
}
338333

@@ -347,24 +342,24 @@ namespace gtsam {
347342

348343
// If second argument of binary op is Leaf node, recurse on branches
349344
NodePtr apply_g_op_fL(const Leaf& fL, const Binary& op) const override {
350-
boost::shared_ptr<Choice> h(new Choice(label(), nrChoices()));
351-
for(NodePtr branch: branches_)
352-
h->push_back(fL.apply_f_op_g(*branch, op));
345+
auto h = boost::make_shared<Choice>(label(), nrChoices());
346+
for (auto&& branch : branches_)
347+
h->push_back(fL.apply_f_op_g(*branch, op));
353348
return Unique(h);
354349
}
355350

356351
// If second argument of binary op is Choice, call constructor
357352
NodePtr apply_g_op_fC(const Choice& fC, const Binary& op) const override {
358-
boost::shared_ptr<Choice> h(new Choice(fC, *this, op));
353+
auto h = boost::make_shared<Choice>(fC, *this, op);
359354
return Unique(h);
360355
}
361356

362357
// If second argument of binary op is Leaf
363358
template<typename OP>
364359
NodePtr apply_fC_op_gL(const Leaf& gL, OP op) const {
365-
boost::shared_ptr<Choice> h(new Choice(label(), nrChoices()));
366-
for(const NodePtr& branch: branches_)
367-
h->push_back(branch->apply_f_op_g(gL, op));
360+
auto h = boost::make_shared<Choice>(label(), nrChoices());
361+
for (auto&& branch : branches_)
362+
h->push_back(branch->apply_f_op_g(gL, op));
368363
return Unique(h);
369364
}
370365

@@ -374,9 +369,9 @@ namespace gtsam {
374369
return branches_[index]; // choose branch
375370

376371
// second case, not label of interest, just recurse
377-
boost::shared_ptr<Choice> r(new Choice(label_, branches_.size()));
378-
for(const NodePtr& branch: branches_)
379-
r->push_back(branch->choose(label, index));
372+
auto r = boost::make_shared<Choice>(label_, branches_.size());
373+
for (auto&& branch : branches_)
374+
r->push_back(branch->choose(label, index));
380375
return Unique(r);
381376
}
382377

@@ -401,23 +396,22 @@ namespace gtsam {
401396
}
402397

403398
/*********************************************************************************/
404-
template<typename L, typename Y>
405-
DecisionTree<L, Y>::DecisionTree(//
406-
const L& label, const Y& y1, const Y& y2) {
407-
boost::shared_ptr<Choice> a(new Choice(label, 2));
399+
template <typename L, typename Y>
400+
DecisionTree<L, Y>::DecisionTree(const L& label, const Y& y1, const Y& y2) {
401+
auto a = boost::make_shared<Choice>(label, 2);
408402
NodePtr l1(new Leaf(y1)), l2(new Leaf(y2));
409403
a->push_back(l1);
410404
a->push_back(l2);
411405
root_ = Choice::Unique(a);
412406
}
413407

414408
/*********************************************************************************/
415-
template<typename L, typename Y>
416-
DecisionTree<L, Y>::DecisionTree(//
417-
const LabelC& labelC, const Y& y1, const Y& y2) {
409+
template <typename L, typename Y>
410+
DecisionTree<L, Y>::DecisionTree(const LabelC& labelC, const Y& y1,
411+
const Y& y2) {
418412
if (labelC.second != 2) throw std::invalid_argument(
419413
"DecisionTree: binary constructor called with non-binary label");
420-
boost::shared_ptr<Choice> a(new Choice(labelC.first, 2));
414+
auto a = boost::make_shared<Choice>(labelC.first, 2);
421415
NodePtr l1(new Leaf(y1)), l2(new Leaf(y2));
422416
a->push_back(l1);
423417
a->push_back(l2);
@@ -465,23 +459,20 @@ namespace gtsam {
465459

466460
/*********************************************************************************/
467461
template <typename L, typename Y>
468-
template <typename X>
462+
template <typename X, typename Func>
469463
DecisionTree<L, Y>::DecisionTree(const DecisionTree<L, X>& other,
470-
std::function<Y(const X&)> Y_of_X) {
464+
Func Y_of_X) {
471465
// Define functor for identity mapping of node label.
472-
auto L_of_L = [](const L& label) { return label; };
466+
auto L_of_L = [](const L& label) { return label; };
473467
root_ = convertFrom<L, X>(other.root_, L_of_L, Y_of_X);
474468
}
475469

476470
/*********************************************************************************/
477471
template <typename L, typename Y>
478-
template <typename M, typename X>
472+
template <typename M, typename X, typename Func>
479473
DecisionTree<L, Y>::DecisionTree(const DecisionTree<M, X>& other,
480-
const std::map<M, L>& map,
481-
std::function<Y(const X&)> Y_of_X) {
482-
std::function<L(const M&)> L_of_M = [&map](const M& label) -> L {
483-
return map.at(label);
484-
};
474+
const std::map<M, L>& map, Func Y_of_X) {
475+
auto L_of_M = [&map](const M& label) -> L { return map.at(label); };
485476
root_ = convertFrom<M, X>(other.root_, L_of_M, Y_of_X);
486477
}
487478

@@ -511,13 +502,14 @@ namespace gtsam {
511502

512503
// if label is already in correct order, just put together a choice on label
513504
if (!nrChoices || !highestLabel || label > *highestLabel) {
514-
boost::shared_ptr<Choice> choiceOnLabel(new Choice(label, end - begin));
505+
auto choiceOnLabel = boost::make_shared<Choice>(label, end - begin);
515506
for (Iterator it = begin; it != end; it++)
516507
choiceOnLabel->push_back(it->root_);
517508
return Choice::Unique(choiceOnLabel);
518509
} else {
519510
// Set up a new choice on the highest label
520-
boost::shared_ptr<Choice> choiceOnHighestLabel(new Choice(*highestLabel, nrChoices));
511+
auto choiceOnHighestLabel =
512+
boost::make_shared<Choice>(*highestLabel, nrChoices);
521513
// now, for all possible values of highestLabel
522514
for (size_t index = 0; index < nrChoices; index++) {
523515
// make a new set of functions for composing by iterating over the given
@@ -576,7 +568,7 @@ namespace gtsam {
576568
std::cout << boost::format("DecisionTree::create: expected %d values but got %d instead") % nrChoices % size << std::endl;
577569
throw std::invalid_argument("DecisionTree::create invalid argument");
578570
}
579-
boost::shared_ptr<Choice> choice(new Choice(begin->first, endY - beginY));
571+
auto choice = boost::make_shared<Choice>(begin->first, endY - beginY);
580572
for (ValueIt y = beginY; y != endY; y++)
581573
choice->push_back(NodePtr(new Leaf(*y)));
582574
return Choice::Unique(choice);
@@ -589,7 +581,7 @@ namespace gtsam {
589581
size_t split = size / nrChoices;
590582
for (size_t i = 0; i < nrChoices; i++, beginY += split) {
591583
NodePtr f = create<It, ValueIt>(labelC, end, beginY, beginY + split);
592-
functions += DecisionTree(f);
584+
functions.emplace_back(f);
593585
}
594586
return compose(functions.begin(), functions.end(), begin->first);
595587
}
@@ -601,18 +593,16 @@ namespace gtsam {
601593
const typename DecisionTree<M, X>::NodePtr& f,
602594
std::function<L(const M&)> L_of_M,
603595
std::function<Y(const X&)> Y_of_X) const {
604-
using MX = DecisionTree<M, X>;
605-
using MXLeaf = typename MX::Leaf;
606-
using MXChoice = typename MX::Choice;
607-
using MXNodePtr = typename MX::NodePtr;
608596
using LY = DecisionTree<L, Y>;
609597

610598
// ugliness below because apparently we can't have templated virtual functions
611599
// If leaf, apply unary conversion "op" and create a unique leaf
612-
auto leaf = boost::dynamic_pointer_cast<const MXLeaf>(f);
613-
if (leaf) return NodePtr(new Leaf(Y_of_X(leaf->constant())));
600+
using MXLeaf = typename DecisionTree<M, X>::Leaf;
601+
if (auto leaf = boost::dynamic_pointer_cast<const MXLeaf>(f))
602+
return NodePtr(new Leaf(Y_of_X(leaf->constant())));
614603

615604
// Check if Choice
605+
using MXChoice = typename DecisionTree<M, X>::Choice;
616606
auto choice = boost::dynamic_pointer_cast<const MXChoice>(f);
617607
if (!choice) throw std::invalid_argument(
618608
"DecisionTree::Convert: Invalid NodePtr");
@@ -623,14 +613,93 @@ namespace gtsam {
623613

624614
// put together via Shannon expansion otherwise not sorted.
625615
std::vector<LY> functions;
626-
for(const MXNodePtr& branch: choice->branches()) {
627-
LY converted(convertFrom<M, X>(branch, L_of_M, Y_of_X));
628-
functions += converted;
616+
for(auto && branch: choice->branches()) {
617+
functions.emplace_back(convertFrom<M, X>(branch, L_of_M, Y_of_X));
629618
}
630619
return LY::compose(functions.begin(), functions.end(), newLabel);
631620
}
632621

633622
/*********************************************************************************/
623+
// Functor performing depth-first visit without Assignment<L> argument.
624+
template <typename L, typename Y>
625+
struct Visit {
626+
using F = std::function<void(const Y&)>;
627+
Visit(F f) : f(f) {} ///< Construct from folding function.
628+
F f; ///< folding function object.
629+
630+
/// Do a depth-first visit on the tree rooted at node.
631+
void operator()(const typename DecisionTree<L, Y>::NodePtr& node) const {
632+
using Leaf = typename DecisionTree<L, Y>::Leaf;
633+
if (auto leaf = boost::dynamic_pointer_cast<const Leaf>(node))
634+
return f(leaf->constant());
635+
636+
using Choice = typename DecisionTree<L, Y>::Choice;
637+
auto choice = boost::dynamic_pointer_cast<const Choice>(node);
638+
for (auto&& branch : choice->branches()) (*this)(branch); // recurse!
639+
}
640+
};
641+
642+
template <typename L, typename Y>
643+
template <typename Func>
644+
void DecisionTree<L, Y>::visit(Func f) const {
645+
Visit<L, Y> visit(f);
646+
visit(root_);
647+
}
648+
649+
/*********************************************************************************/
650+
// Functor performing depth-first visit with Assignment<L> argument.
651+
template <typename L, typename Y>
652+
struct VisitWith {
653+
using Choices = Assignment<L>;
654+
using F = std::function<void(const Choices&, const Y&)>;
655+
VisitWith(F f) : f(f) {} ///< Construct from folding function.
656+
Choices choices; ///< Assignment, mutating through recursion.
657+
F f; ///< folding function object.
658+
659+
/// Do a depth-first visit on the tree rooted at node.
660+
void operator()(const typename DecisionTree<L, Y>::NodePtr& node) {
661+
using Leaf = typename DecisionTree<L, Y>::Leaf;
662+
if (auto leaf = boost::dynamic_pointer_cast<const Leaf>(node))
663+
return f(choices, leaf->constant());
664+
665+
using Choice = typename DecisionTree<L, Y>::Choice;
666+
auto choice = boost::dynamic_pointer_cast<const Choice>(node);
667+
for (size_t i = 0; i < choice->nrChoices(); i++) {
668+
choices[choice->label()] = i; // Set assignment for label to i
669+
(*this)(choice->branches()[i]); // recurse!
670+
}
671+
}
672+
};
673+
674+
template <typename L, typename Y>
675+
template <typename Func>
676+
void DecisionTree<L, Y>::visitWith(Func f) const {
677+
VisitWith<L, Y> visit(f);
678+
visit(root_);
679+
}
680+
681+
/*********************************************************************************/
682+
// fold is just done with a visit
683+
template <typename L, typename Y>
684+
template <typename Func, typename X>
685+
X DecisionTree<L, Y>::fold(Func f, X x0) const {
686+
visit([&](const Y& y) { x0 = f(y, x0); });
687+
return x0;
688+
}
689+
690+
/*********************************************************************************/
691+
// labels is just done with a visit
692+
template <typename L, typename Y>
693+
std::set<L> DecisionTree<L, Y>::labels() const {
694+
std::set<L> unique;
695+
auto f = [&](const Assignment<L>& choices, const Y&) {
696+
for (auto&& kv : choices) unique.insert(kv.first);
697+
};
698+
visitWith(f);
699+
return unique;
700+
}
701+
702+
/*********************************************************************************/
634703
template <typename L, typename Y>
635704
bool DecisionTree<L, Y>::equals(const DecisionTree& other,
636705
const CompareFunc& compare) const {

gtsam/discrete/DecisionTree.h

Lines changed: 52 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
#include <map>
2929
#include <sstream>
3030
#include <vector>
31+
#include <set>
3132

3233
namespace gtsam {
3334

@@ -176,9 +177,8 @@ namespace gtsam {
176177
* @param other The DecisionTree to convert from.
177178
* @param Y_of_X Functor to convert from value type X to type Y.
178179
*/
179-
template <typename X>
180-
DecisionTree(const DecisionTree<L, X>& other,
181-
std::function<Y(const X&)> Y_of_X);
180+
template <typename X, typename Func>
181+
DecisionTree(const DecisionTree<L, X>& other, Func Y_of_X);
182182

183183
/**
184184
* @brief Convert from a different value type X to value type Y, also transate
@@ -190,9 +190,9 @@ namespace gtsam {
190190
* @param L_of_M Map from label type M to type L.
191191
* @param Y_of_X Functor to convert from type X to type Y.
192192
*/
193-
template <typename M, typename X>
194-
DecisionTree(const DecisionTree<M, X>& other, const std::map<M, L>& L_of_M,
195-
std::function<Y(const X&)> Y_of_X);
193+
template <typename M, typename X, typename Func>
194+
DecisionTree(const DecisionTree<M, X>& other, const std::map<M, L>& map,
195+
Func Y_of_X);
196196

197197
/// @}
198198
/// @name Testable
@@ -229,6 +229,52 @@ namespace gtsam {
229229
/** evaluate */
230230
const Y& operator()(const Assignment<L>& x) const;
231231

232+
/**
233+
* @brief Visit all leaves in depth-first fashion.
234+
*
235+
* @param f side-effect taking a value.
236+
*
237+
* Example:
238+
* int sum = 0;
239+
* auto visitor = [&](int y) { sum += y; };
240+
* tree.visitWith(visitor);
241+
*/
242+
template <typename Func>
243+
void visit(Func f) const;
244+
245+
/**
246+
* @brief Visit all leaves in depth-first fashion.
247+
*
248+
* @param f side-effect taking an assignment and a value.
249+
*
250+
* Example:
251+
* int sum = 0;
252+
* auto visitor = [&](const Assignment<L>& choices, int y) { sum += y; };
253+
* tree.visitWith(visitor);
254+
*/
255+
template <typename Func>
256+
void visitWith(Func f) const;
257+
258+
/**
259+
* @brief Fold a binary function over the tree, returning accumulator.
260+
*
261+
* @tparam X type for accumulator.
262+
* @param f binary function: Y * X -> X returning an updated accumulator.
263+
* @param x0 initial value for accumulator.
264+
* @return X final value for accumulator.
265+
*
266+
* @note X is always passed by value.
267+
*
268+
* Example:
269+
* auto add = [](const double& y, double x) { return y + x; };
270+
* double sum = tree.fold(add, 0.0);
271+
*/
272+
template <typename Func, typename X>
273+
X fold(Func f, X x0) const;
274+
275+
/** Retrieve all unique labels as a set. */
276+
std::set<L> labels() const;
277+
232278
/** apply Unary operation "op" to f */
233279
DecisionTree apply(const Unary& op) const;
234280

0 commit comments

Comments
 (0)