2828#include < boost/tuple/tuple.hpp>
2929#include < boost/type_traits/has_dereference.hpp>
3030#include < boost/unordered_set.hpp>
31+ #include < boost/make_shared.hpp>
3132#include < cmath>
3233#include < fstream>
3334#include < list>
@@ -82,13 +83,7 @@ namespace gtsam {
8283 return compare (this ->constant_ , other->constant_ );
8384 }
8485
85- /* *
86- * @brief Print method.
87- *
88- * @param s Prefix string.
89- * @param labelFormatter Functor to format the labels of type L.
90- * @param valueFormatter Functor to format the values of type Y.
91- */
86+ /* * print */
9287 void print (const std::string& s, const LabelFormatter& labelFormatter,
9388 const ValueFormatter& valueFormatter) const override {
9489 std::cout << s << " Leaf " << valueFormatter (constant_) << std::endl;
@@ -332,7 +327,7 @@ namespace gtsam {
332327
333328 /* * apply unary operator */
334329 NodePtr apply (const Unary& op) const override {
335- boost::shared_ptr <Choice> r ( new Choice ( label_, *this , op) );
330+ auto r = boost::make_shared <Choice>( label_, *this , op);
336331 return Unique (r);
337332 }
338333
@@ -347,24 +342,24 @@ namespace gtsam {
347342
348343 // If second argument of binary op is Leaf node, recurse on branches
349344 NodePtr apply_g_op_fL (const Leaf& fL , const Binary& op) const override {
350- boost::shared_ptr <Choice> h ( new Choice ( label (), nrChoices () ));
351- for (NodePtr branch: branches_)
352- h->push_back (fL .apply_f_op_g (*branch, op));
345+ auto h = boost::make_shared <Choice>( label (), nrChoices ());
346+ for ( auto && branch : branches_)
347+ h->push_back (fL .apply_f_op_g (*branch, op));
353348 return Unique (h);
354349 }
355350
356351 // If second argument of binary op is Choice, call constructor
357352 NodePtr apply_g_op_fC (const Choice& fC , const Binary& op) const override {
358- boost::shared_ptr <Choice> h ( new Choice ( fC , *this , op) );
353+ auto h = boost::make_shared <Choice>( fC , *this , op);
359354 return Unique (h);
360355 }
361356
362357 // If second argument of binary op is Leaf
363358 template <typename OP>
364359 NodePtr apply_fC_op_gL (const Leaf& gL , OP op) const {
365- boost::shared_ptr <Choice> h ( new Choice ( label (), nrChoices () ));
366- for ( const NodePtr& branch: branches_)
367- h->push_back (branch->apply_f_op_g (gL , op));
360+ auto h = boost::make_shared <Choice>( label (), nrChoices ());
361+ for ( auto && branch : branches_)
362+ h->push_back (branch->apply_f_op_g (gL , op));
368363 return Unique (h);
369364 }
370365
@@ -374,9 +369,9 @@ namespace gtsam {
374369 return branches_[index]; // choose branch
375370
376371 // second case, not label of interest, just recurse
377- boost::shared_ptr <Choice> r ( new Choice ( label_, branches_.size () ));
378- for ( const NodePtr& branch: branches_)
379- r->push_back (branch->choose (label, index));
372+ auto r = boost::make_shared <Choice>( label_, branches_.size ());
373+ for ( auto && branch : branches_)
374+ r->push_back (branch->choose (label, index));
380375 return Unique (r);
381376 }
382377
@@ -401,23 +396,22 @@ namespace gtsam {
401396 }
402397
403398 /* ********************************************************************************/
404- template <typename L, typename Y>
405- DecisionTree<L, Y>::DecisionTree(//
406- const L& label, const Y& y1, const Y& y2) {
407- boost::shared_ptr<Choice> a (new Choice (label, 2 ));
399+ template <typename L, typename Y>
400+ DecisionTree<L, Y>::DecisionTree(const L& label, const Y& y1, const Y& y2) {
401+ auto a = boost::make_shared<Choice>(label, 2 );
408402 NodePtr l1 (new Leaf (y1)), l2 (new Leaf (y2));
409403 a->push_back (l1);
410404 a->push_back (l2);
411405 root_ = Choice::Unique (a);
412406 }
413407
414408 /* ********************************************************************************/
415- template <typename L, typename Y>
416- DecisionTree<L, Y>::DecisionTree(//
417- const LabelC& labelC, const Y& y1, const Y& y2) {
409+ template <typename L, typename Y>
410+ DecisionTree<L, Y>::DecisionTree(const LabelC& labelC, const Y& y1,
411+ const Y& y2) {
418412 if (labelC.second != 2 ) throw std::invalid_argument (
419413 " DecisionTree: binary constructor called with non-binary label" );
420- boost::shared_ptr <Choice> a ( new Choice ( labelC.first , 2 ) );
414+ auto a = boost::make_shared <Choice>( labelC.first , 2 );
421415 NodePtr l1 (new Leaf (y1)), l2 (new Leaf (y2));
422416 a->push_back (l1);
423417 a->push_back (l2);
@@ -465,23 +459,20 @@ namespace gtsam {
465459
466460 /* ********************************************************************************/
467461 template <typename L, typename Y>
468- template <typename X>
462+ template <typename X, typename Func >
469463 DecisionTree<L, Y>::DecisionTree(const DecisionTree<L, X>& other,
470- std::function<Y( const X&)> Y_of_X) {
464+ Func Y_of_X) {
471465 // Define functor for identity mapping of node label.
472- auto L_of_L = [](const L& label) { return label; };
466+ auto L_of_L = [](const L& label) { return label; };
473467 root_ = convertFrom<L, X>(other.root_ , L_of_L, Y_of_X);
474468 }
475469
476470 /* ********************************************************************************/
477471 template <typename L, typename Y>
478- template <typename M, typename X>
472+ template <typename M, typename X, typename Func >
479473 DecisionTree<L, Y>::DecisionTree(const DecisionTree<M, X>& other,
480- const std::map<M, L>& map,
481- std::function<Y(const X&)> Y_of_X) {
482- std::function<L (const M&)> L_of_M = [&map](const M& label) -> L {
483- return map.at (label);
484- };
474+ const std::map<M, L>& map, Func Y_of_X) {
475+ auto L_of_M = [&map](const M& label) -> L { return map.at (label); };
485476 root_ = convertFrom<M, X>(other.root_ , L_of_M, Y_of_X);
486477 }
487478
@@ -511,13 +502,14 @@ namespace gtsam {
511502
512503 // if label is already in correct order, just put together a choice on label
513504 if (!nrChoices || !highestLabel || label > *highestLabel) {
514- boost::shared_ptr <Choice> choiceOnLabel ( new Choice ( label, end - begin) );
505+ auto choiceOnLabel = boost::make_shared <Choice>( label, end - begin);
515506 for (Iterator it = begin; it != end; it++)
516507 choiceOnLabel->push_back (it->root_ );
517508 return Choice::Unique (choiceOnLabel);
518509 } else {
519510 // Set up a new choice on the highest label
520- boost::shared_ptr<Choice> choiceOnHighestLabel (new Choice (*highestLabel, nrChoices));
511+ auto choiceOnHighestLabel =
512+ boost::make_shared<Choice>(*highestLabel, nrChoices);
521513 // now, for all possible values of highestLabel
522514 for (size_t index = 0 ; index < nrChoices; index++) {
523515 // make a new set of functions for composing by iterating over the given
@@ -576,7 +568,7 @@ namespace gtsam {
576568 std::cout << boost::format (" DecisionTree::create: expected %d values but got %d instead" ) % nrChoices % size << std::endl;
577569 throw std::invalid_argument (" DecisionTree::create invalid argument" );
578570 }
579- boost::shared_ptr <Choice> choice ( new Choice ( begin->first , endY - beginY) );
571+ auto choice = boost::make_shared <Choice>( begin->first , endY - beginY);
580572 for (ValueIt y = beginY; y != endY; y++)
581573 choice->push_back (NodePtr (new Leaf (*y)));
582574 return Choice::Unique (choice);
@@ -589,7 +581,7 @@ namespace gtsam {
589581 size_t split = size / nrChoices;
590582 for (size_t i = 0 ; i < nrChoices; i++, beginY += split) {
591583 NodePtr f = create<It, ValueIt>(labelC, end, beginY, beginY + split);
592- functions += DecisionTree (f);
584+ functions. emplace_back (f);
593585 }
594586 return compose (functions.begin (), functions.end (), begin->first );
595587 }
@@ -601,18 +593,16 @@ namespace gtsam {
601593 const typename DecisionTree<M, X>::NodePtr& f,
602594 std::function<L(const M&)> L_of_M,
603595 std::function<Y(const X&)> Y_of_X) const {
604- using MX = DecisionTree<M, X>;
605- using MXLeaf = typename MX::Leaf;
606- using MXChoice = typename MX::Choice;
607- using MXNodePtr = typename MX::NodePtr;
608596 using LY = DecisionTree<L, Y>;
609597
610598 // ugliness below because apparently we can't have templated virtual functions
611599 // If leaf, apply unary conversion "op" and create a unique leaf
612- auto leaf = boost::dynamic_pointer_cast<const MXLeaf>(f);
613- if (leaf) return NodePtr (new Leaf (Y_of_X (leaf->constant ())));
600+ using MXLeaf = typename DecisionTree<M, X>::Leaf;
601+ if (auto leaf = boost::dynamic_pointer_cast<const MXLeaf>(f))
602+ return NodePtr (new Leaf (Y_of_X (leaf->constant ())));
614603
615604 // Check if Choice
605+ using MXChoice = typename DecisionTree<M, X>::Choice;
616606 auto choice = boost::dynamic_pointer_cast<const MXChoice>(f);
617607 if (!choice) throw std::invalid_argument (
618608 " DecisionTree::Convert: Invalid NodePtr" );
@@ -623,14 +613,93 @@ namespace gtsam {
623613
624614 // put together via Shannon expansion otherwise not sorted.
625615 std::vector<LY> functions;
626- for (const MXNodePtr& branch: choice->branches ()) {
627- LY converted (convertFrom<M, X>(branch, L_of_M, Y_of_X));
628- functions += converted;
616+ for (auto && branch: choice->branches ()) {
617+ functions.emplace_back (convertFrom<M, X>(branch, L_of_M, Y_of_X));
629618 }
630619 return LY::compose (functions.begin (), functions.end (), newLabel);
631620 }
632621
633622 /* ********************************************************************************/
623+ // Functor performing depth-first visit without Assignment<L> argument.
624+ template <typename L, typename Y>
625+ struct Visit {
626+ using F = std::function<void (const Y&)>;
627+ Visit (F f) : f(f) {} // /< Construct from folding function.
628+ F f; // /< folding function object.
629+
630+ // / Do a depth-first visit on the tree rooted at node.
631+ void operator ()(const typename DecisionTree<L, Y>::NodePtr& node) const {
632+ using Leaf = typename DecisionTree<L, Y>::Leaf;
633+ if (auto leaf = boost::dynamic_pointer_cast<const Leaf>(node))
634+ return f (leaf->constant ());
635+
636+ using Choice = typename DecisionTree<L, Y>::Choice;
637+ auto choice = boost::dynamic_pointer_cast<const Choice>(node);
638+ for (auto && branch : choice->branches ()) (*this )(branch); // recurse!
639+ }
640+ };
641+
642+ template <typename L, typename Y>
643+ template <typename Func>
644+ void DecisionTree<L, Y>::visit(Func f) const {
645+ Visit<L, Y> visit (f);
646+ visit (root_);
647+ }
648+
649+ /* ********************************************************************************/
650+ // Functor performing depth-first visit with Assignment<L> argument.
651+ template <typename L, typename Y>
652+ struct VisitWith {
653+ using Choices = Assignment<L>;
654+ using F = std::function<void (const Choices&, const Y&)>;
655+ VisitWith (F f) : f(f) {} // /< Construct from folding function.
656+ Choices choices; // /< Assignment, mutating through recursion.
657+ F f; // /< folding function object.
658+
659+ // / Do a depth-first visit on the tree rooted at node.
660+ void operator ()(const typename DecisionTree<L, Y>::NodePtr& node) {
661+ using Leaf = typename DecisionTree<L, Y>::Leaf;
662+ if (auto leaf = boost::dynamic_pointer_cast<const Leaf>(node))
663+ return f (choices, leaf->constant ());
664+
665+ using Choice = typename DecisionTree<L, Y>::Choice;
666+ auto choice = boost::dynamic_pointer_cast<const Choice>(node);
667+ for (size_t i = 0 ; i < choice->nrChoices (); i++) {
668+ choices[choice->label ()] = i; // Set assignment for label to i
669+ (*this )(choice->branches ()[i]); // recurse!
670+ }
671+ }
672+ };
673+
674+ template <typename L, typename Y>
675+ template <typename Func>
676+ void DecisionTree<L, Y>::visitWith(Func f) const {
677+ VisitWith<L, Y> visit (f);
678+ visit (root_);
679+ }
680+
681+ /* ********************************************************************************/
682+ // fold is just done with a visit
683+ template <typename L, typename Y>
684+ template <typename Func, typename X>
685+ X DecisionTree<L, Y>::fold(Func f, X x0) const {
686+ visit ([&](const Y& y) { x0 = f (y, x0); });
687+ return x0;
688+ }
689+
690+ /* ********************************************************************************/
691+ // labels is just done with a visit
692+ template <typename L, typename Y>
693+ std::set<L> DecisionTree<L, Y>::labels() const {
694+ std::set<L> unique;
695+ auto f = [&](const Assignment<L>& choices, const Y&) {
696+ for (auto && kv : choices) unique.insert (kv.first );
697+ };
698+ visitWith (f);
699+ return unique;
700+ }
701+
702+ /* ********************************************************************************/
634703 template <typename L, typename Y>
635704 bool DecisionTree<L, Y>::equals(const DecisionTree& other,
636705 const CompareFunc& compare) const {
0 commit comments