2222#include < boost/make_shared.hpp>
2323
2424#ifdef GTSAM_USE_TBB
25- #include < tbb/task_group .h> // tbb::task_group
25+ #include < tbb/task .h> // tbb::task, tbb::task_list
2626#include < tbb/scalable_allocator.h> // tbb::scalable_allocator
2727
2828namespace gtsam {
@@ -34,80 +34,89 @@ namespace gtsam {
3434
3535 /* ************************************************************************* */
3636 template <typename NODE, typename DATA, typename VISITOR_PRE, typename VISITOR_POST>
37- class PreOrderTask
37+ class PreOrderTask : public tbb ::task
3838 {
3939 public:
4040 const boost::shared_ptr<NODE>& treeNode;
4141 boost::shared_ptr<DATA> myData;
4242 VISITOR_PRE& visitorPre;
4343 VISITOR_POST& visitorPost;
4444 int problemSizeThreshold;
45- tbb::task_group& tg;
4645 bool makeNewTasks;
4746
48- // Keep track of order phase across multiple calls to the same functor
49- mutable bool isPostOrderPhase;
47+ bool isPostOrderPhase;
5048
5149 PreOrderTask (const boost::shared_ptr<NODE>& treeNode, const boost::shared_ptr<DATA>& myData,
5250 VISITOR_PRE& visitorPre, VISITOR_POST& visitorPost, int problemSizeThreshold,
53- tbb::task_group& tg, bool makeNewTasks = true )
51+ bool makeNewTasks = true )
5452 : treeNode(treeNode),
5553 myData (myData),
5654 visitorPre(visitorPre),
5755 visitorPost(visitorPost),
5856 problemSizeThreshold(problemSizeThreshold),
59- tg(tg),
6057 makeNewTasks(makeNewTasks),
6158 isPostOrderPhase(false ) {}
6259
63- void operator ()() const
60+ tbb::task* execute () override
6461 {
6562 if (isPostOrderPhase)
6663 {
6764 // Run the post-order visitor since this task was recycled to run the post-order visitor
6865 (void ) visitorPost (treeNode, *myData);
66+ return nullptr ;
6967 }
7068 else
7169 {
7270 if (makeNewTasks)
7371 {
7472 if (!treeNode->children .empty ())
7573 {
74+ // Allocate post-order task as a continuation
75+ isPostOrderPhase = true ;
76+ recycle_as_continuation ();
77+
7678 bool overThreshold = (treeNode->problemSize () >= problemSizeThreshold);
7779
78- // If we have child tasks, start subtasks and wait for them to complete
79- tbb::task_group ctg ;
80+ tbb::task* firstChild = 0 ;
81+ tbb::task_list childTasks ;
8082 for (const boost::shared_ptr<NODE>& child: treeNode->children )
8183 {
8284 // Process child in a subtask. Important: Run visitorPre before calling
8385 // allocate_child so that if visitorPre throws an exception, we will not have
8486 // allocated an extra child, this causes a TBB error.
8587 boost::shared_ptr<DATA> childData = boost::allocate_shared<DATA>(
8688 tbb::scalable_allocator<DATA>(), visitorPre (child, *myData));
87- ctg.run (PreOrderTask (child, childData, visitorPre, visitorPost,
88- problemSizeThreshold, ctg, overThreshold));
89+ tbb::task* childTask =
90+ new (allocate_child ()) PreOrderTask (child, childData, visitorPre, visitorPost,
91+ problemSizeThreshold, overThreshold);
92+ if (firstChild)
93+ childTasks.push_back (*childTask);
94+ else
95+ firstChild = childTask;
8996 }
90- ctg.wait ();
9197
92- // Allocate post-order task as a continuation
93- isPostOrderPhase = true ;
94- tg.run (*this );
98+ // If we have child tasks, start subtasks and wait for them to complete
99+ set_ref_count ((int )treeNode->children .size ());
100+ spawn (childTasks);
101+ return firstChild;
95102 }
96103 else
97104 {
98105 // Run the post-order visitor in this task if we have no children
99106 (void ) visitorPost (treeNode, *myData);
107+ return nullptr ;
100108 }
101109 }
102110 else
103111 {
104112 // Process this node and its children in this task
105113 processNodeRecursively (treeNode, *myData);
114+ return nullptr ;
106115 }
107116 }
108117 }
109118
110- void processNodeRecursively (const boost::shared_ptr<NODE>& node, DATA& myData) const
119+ void processNodeRecursively (const boost::shared_ptr<NODE>& node, DATA& myData)
111120 {
112121 for (const boost::shared_ptr<NODE>& child: node->children )
113122 {
@@ -122,39 +131,46 @@ namespace gtsam {
122131
123132 /* ************************************************************************* */
124133 template <typename ROOTS, typename NODE, typename DATA, typename VISITOR_PRE, typename VISITOR_POST>
125- class RootTask
134+ class RootTask : public tbb ::task
126135 {
127136 public:
128137 const ROOTS& roots;
129138 DATA& myData;
130139 VISITOR_PRE& visitorPre;
131140 VISITOR_POST& visitorPost;
132141 int problemSizeThreshold;
133- tbb::task_group& tg;
134142 RootTask (const ROOTS& roots, DATA& myData, VISITOR_PRE& visitorPre, VISITOR_POST& visitorPost,
135- int problemSizeThreshold, tbb::task_group& tg ) :
143+ int problemSizeThreshold) :
136144 roots (roots), myData(myData), visitorPre(visitorPre), visitorPost(visitorPost),
137- problemSizeThreshold (problemSizeThreshold), tg(tg) {}
145+ problemSizeThreshold (problemSizeThreshold) {}
138146
139- void operator ()() const
147+ tbb::task* execute () override
140148 {
141149 typedef PreOrderTask<NODE, DATA, VISITOR_PRE, VISITOR_POST> PreOrderTask;
142150 // Create data and tasks for our children
151+ tbb::task_list tasks;
143152 for (const boost::shared_ptr<NODE>& root: roots)
144153 {
145154 boost::shared_ptr<DATA> rootData = boost::allocate_shared<DATA>(tbb::scalable_allocator<DATA>(), visitorPre (root, myData));
146- tg.run (PreOrderTask (root, rootData, visitorPre, visitorPost, problemSizeThreshold, tg));
155+ tasks.push_back (*new (allocate_child ())
156+ PreOrderTask (root, rootData, visitorPre, visitorPost, problemSizeThreshold));
147157 }
158+ // Set TBB ref count
159+ set_ref_count (1 + (int ) roots.size ());
160+ // Spawn tasks
161+ spawn_and_wait_for_all (tasks);
162+ // Return nullptr
163+ return nullptr ;
148164 }
149165 };
150166
151167 template <typename NODE, typename ROOTS, typename DATA, typename VISITOR_PRE, typename VISITOR_POST>
152- void CreateRootTask (const ROOTS& roots, DATA& rootData, VISITOR_PRE& visitorPre, VISITOR_POST& visitorPost, int problemSizeThreshold)
168+ RootTask<ROOTS, NODE, DATA, VISITOR_PRE, VISITOR_POST>&
169+ CreateRootTask (const ROOTS& roots, DATA& rootData, VISITOR_PRE& visitorPre, VISITOR_POST& visitorPost, int problemSizeThreshold)
153170 {
154171 typedef RootTask<ROOTS, NODE, DATA, VISITOR_PRE, VISITOR_POST> RootTask;
155- tbb::task_group tg;
156- tg.run_and_wait (RootTask (roots, rootData, visitorPre, visitorPost, problemSizeThreshold, tg));
157- }
172+ return *new (tbb::task::allocate_root ()) RootTask (roots, rootData, visitorPre, visitorPost, problemSizeThreshold);
173+ }
158174
159175 }
160176
0 commit comments