Skip to content

Commit a5ce151

Browse files
committed
[CPU] Refactor Infer update strategy
by using a template function with a functor template argument to reduce Infer overhead
1 parent 2a9af43 commit a5ce151

File tree

2 files changed

+20
-24
lines changed

2 files changed

+20
-24
lines changed

src/plugins/intel_cpu/src/graph.cpp

Lines changed: 17 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1137,16 +1137,11 @@ void Graph::InferStatic(SyncInferRequest* request) {
11371137

11381138
namespace {
11391139

1140-
class IUpdateNodes {
1141-
public:
1142-
virtual void run(size_t stopIndx) = 0;
1143-
virtual ~IUpdateNodes() = default;
1144-
};
1145-
1146-
class UpdateNodesSeq : public IUpdateNodes {
1140+
class UpdateNodesSeq {
11471141
public:
11481142
explicit UpdateNodesSeq(std::vector<NodePtr>& executableGraphNodes) : m_executableGraphNodes(executableGraphNodes) {}
1149-
void run(size_t stopIndx) override {
1143+
1144+
void operator()(size_t stopIndx) {
11501145
for (; prepareCounter < stopIndx; ++prepareCounter) {
11511146
const auto& node = m_executableGraphNodes[prepareCounter];
11521147
if (node->isDynamicNode()) {
@@ -1177,7 +1172,7 @@ class UpdateNodesSeq : public IUpdateNodes {
11771172
# define ov_memory_order_acquire std::memory_order::memory_order_acquire
11781173
# endif
11791174

1180-
class UpdateNodesBase : public IUpdateNodes {
1175+
class UpdateNodesBase {
11811176
public:
11821177
explicit UpdateNodesBase(std::vector<NodePtr>& executableGraphNodes) : m_executableGraphNodes(executableGraphNodes) {}
11831178
void updateShapes(size_t node_indx, size_t stop_indx) {
@@ -1248,7 +1243,8 @@ class AsyncTask : public tbb::detail::d1::task {
12481243
class UpdateNodes : public UpdateNodesBase {
12491244
public:
12501245
using UpdateNodesBase::UpdateNodesBase;
1251-
void run(size_t stopIndx) override {
1246+
1247+
void operator()(size_t stopIndx) {
12521248
m_completion.store(false);
12531249
auto startCounter = m_prepareCounter.load();
12541250
tbb::detail::d1::wait_context wait_ctx(2);
@@ -1289,7 +1285,7 @@ class AsyncTask : public tbb::task {
12891285
class UpdateNodes : public UpdateNodesBase {
12901286
public:
12911287
using UpdateNodesBase::UpdateNodesBase;
1292-
void run(size_t stopIndx) override {
1288+
void operator()(size_t stopIndx) {
12931289
m_completion.store(false);
12941290
auto startCounter = m_prepareCounter.load();
12951291
tbb::task& root = *new(tbb::task::allocate_root()) tbb::empty_task;
@@ -1317,7 +1313,7 @@ class UpdateNodes : public UpdateNodesBase {
13171313
class UpdateNodes : public UpdateNodesBase {
13181314
public:
13191315
using UpdateNodesBase::UpdateNodesBase;
1320-
void run(size_t stopIndx) override {
1316+
void operator()(size_t stopIndx) {
13211317
m_completion.store(false);
13221318
auto startCounter = m_prepareCounter.load();
13231319

@@ -1340,20 +1336,14 @@ class UpdateNodes : public UpdateNodesBase {
13401336
#endif
13411337
} // namespace
13421338

1343-
1344-
void Graph::InferDynamic(SyncInferRequest* request) {
1339+
template<typename UpdateStrategy>
1340+
void Graph::InferDynamic(SyncInferRequest* request, UpdateStrategy&& update) {
13451341
dnnl::stream stream(getEngine());
13461342

1347-
std::unique_ptr<IUpdateNodes> updateNodes{};
1348-
if (parallel_get_max_threads() > 1) {
1349-
updateNodes.reset(new UpdateNodes(m_executableGraphNodes));
1350-
} else {
1351-
updateNodes.reset(new UpdateNodesSeq(m_executableGraphNodes));
1352-
}
1353-
13541343
size_t inferCounter = 0;
13551344
for (auto stopIndx : m_executableSyncNodesInds) {
1356-
updateNodes->run(stopIndx);
1345+
update(stopIndx);
1346+
13571347
for (; inferCounter < stopIndx; ++inferCounter) {
13581348
auto& node = m_executableGraphNodes[inferCounter];
13591349
VERBOSE(node, getConfig().debugCaps.verbose);
@@ -1461,7 +1451,11 @@ void Graph::Infer(SyncInferRequest* request) {
14611451
}
14621452

14631453
if (Status::ReadyDynamic == status) {
1464-
InferDynamic(request);
1454+
if (parallel_get_max_threads() > 1) {
1455+
InferDynamic(request, UpdateNodes(m_executableGraphNodes));
1456+
} else {
1457+
InferDynamic(request, UpdateNodesSeq(m_executableGraphNodes));
1458+
}
14651459
} else if (Status::ReadyStatic == status) {
14661460
InferStatic(request);
14671461
} else {

src/plugins/intel_cpu/src/graph.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,9 @@ class Graph {
228228
void ExecuteNode(const NodePtr& node, const dnnl::stream& stream) const;
229229
void CreatePrimitivesAndExecConstants() const;
230230
void InferStatic(SyncInferRequest* request);
231-
void InferDynamic(SyncInferRequest* request);
231+
232+
template<typename UpdateStrategy>
233+
void InferDynamic(SyncInferRequest* request, UpdateStrategy&& update);
232234
void ParalleMtNuma(size_t num_nodes,
233235
ov::threading::CPUStreamsExecutor::Ptr executor,
234236
const std::function<void(size_t, size_t)>& func) const;

0 commit comments

Comments
 (0)