Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 27 additions & 33 deletions src/plugins/intel_cpu/src/graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,8 @@ void Graph::InitGraph(bool optimize) {

std::tie(m_executableGraphNodes, m_executableSyncNodesInds) = ExtractExecutableNodesAndSyncPoints(syncNodesInds, graphNodes);

status = hasDynNodes ? Status::ReadyDynamic : Status::ReadyStatic;
status = hasDynNodes ? (parallel_get_max_threads() > 1 ? Status::ReadyDynamic : Status::ReadyDynamicSeq)
: Status::ReadyStatic;

CPU_DEBUG_CAP_ENABLE(serialize(*this));
}
Expand Down Expand Up @@ -1137,16 +1138,11 @@ void Graph::InferStatic(SyncInferRequest* request) {

namespace {

class IUpdateNodes {
public:
virtual void run(size_t stopIndx) = 0;
virtual ~IUpdateNodes() = default;
};

class UpdateNodesSeq : public IUpdateNodes {
class UpdateNodesSeq {
public:
explicit UpdateNodesSeq(std::vector<NodePtr>& executableGraphNodes) : m_executableGraphNodes(executableGraphNodes) {}
void run(size_t stopIndx) override {

void operator()(size_t stopIndx) {
for (; prepareCounter < stopIndx; ++prepareCounter) {
const auto& node = m_executableGraphNodes[prepareCounter];
if (node->isDynamicNode()) {
Expand Down Expand Up @@ -1177,7 +1173,7 @@ class UpdateNodesSeq : public IUpdateNodes {
# define ov_memory_order_acquire std::memory_order::memory_order_acquire
# endif

class UpdateNodesBase : public IUpdateNodes {
class UpdateNodesBase {
public:
explicit UpdateNodesBase(std::vector<NodePtr>& executableGraphNodes) : m_executableGraphNodes(executableGraphNodes) {}
void updateShapes(size_t node_indx, size_t stop_indx) {
Expand Down Expand Up @@ -1248,7 +1244,8 @@ class AsyncTask : public tbb::detail::d1::task {
class UpdateNodes : public UpdateNodesBase {
public:
using UpdateNodesBase::UpdateNodesBase;
void run(size_t stopIndx) override {

void operator()(size_t stopIndx) {
m_completion.store(false);
auto startCounter = m_prepareCounter.load();
tbb::detail::d1::wait_context wait_ctx(2);
Expand Down Expand Up @@ -1289,7 +1286,7 @@ class AsyncTask : public tbb::task {
class UpdateNodes : public UpdateNodesBase {
public:
using UpdateNodesBase::UpdateNodesBase;
void run(size_t stopIndx) override {
void operator()(size_t stopIndx) {
m_completion.store(false);
auto startCounter = m_prepareCounter.load();
tbb::task& root = *new(tbb::task::allocate_root()) tbb::empty_task;
Expand Down Expand Up @@ -1317,7 +1314,7 @@ class UpdateNodes : public UpdateNodesBase {
class UpdateNodes : public UpdateNodesBase {
public:
using UpdateNodesBase::UpdateNodesBase;
void run(size_t stopIndx) override {
void operator()(size_t stopIndx) {
m_completion.store(false);
auto startCounter = m_prepareCounter.load();

Expand All @@ -1340,20 +1337,14 @@ class UpdateNodes : public UpdateNodesBase {
#endif
} // namespace


void Graph::InferDynamic(SyncInferRequest* request) {
template<typename UpdateStrategy>
void Graph::InferDynamic(SyncInferRequest* request, UpdateStrategy&& update) {
dnnl::stream stream(getEngine());

std::unique_ptr<IUpdateNodes> updateNodes{};
if (parallel_get_max_threads() > 1) {
updateNodes.reset(new UpdateNodes(m_executableGraphNodes));
} else {
updateNodes.reset(new UpdateNodesSeq(m_executableGraphNodes));
}

size_t inferCounter = 0;
for (auto stopIndx : m_executableSyncNodesInds) {
updateNodes->run(stopIndx);
update(stopIndx);

for (; inferCounter < stopIndx; ++inferCounter) {
auto& node = m_executableGraphNodes[inferCounter];
VERBOSE(node, getConfig().debugCaps.verbose);
Expand Down Expand Up @@ -1455,17 +1446,20 @@ void Graph::ParalleMtNuma(size_t num_nodes,
}

void Graph::Infer(SyncInferRequest* request) {
DEBUG_LOG("Starting inference of the graph: ", GetName(), ". Status: ", static_cast<int>(status));
if (!IsReady()) {
OPENVINO_THROW("Wrong state of the ov::intel_cpu::Graph. Topology is not ready.");
}

if (Status::ReadyDynamic == status) {
InferDynamic(request);
} else if (Status::ReadyStatic == status) {
DEBUG_LOG("Infer graph: ", GetName(), ". Status: ", static_cast<int>(status));

switch (status) {
case Status::ReadyDynamic:
InferDynamic(request, UpdateNodes(m_executableGraphNodes));
break;
case Status::ReadyDynamicSeq:
InferDynamic(request, UpdateNodesSeq(m_executableGraphNodes));
break;
case Status::ReadyStatic:
InferStatic(request);
} else {
OPENVINO_THROW("Unknown ov::intel_cpu::Graph state: " , static_cast<size_t>(status));
break;
default:
OPENVINO_ASSERT(IsReady(), "Wrong state of the ov::intel_cpu::Graph. Topology is not ready: ", static_cast<int>(status));
Comment on lines +1449 to +1462
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can even drop this branch simply assigning a corresponding function pointer at the compilation stage, it will allow us keep the status enumeration simpler.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is the option I have also considered
As well as using an interface with multiple implementations.
But I think switch case is even more clean, explicit, lightweight and straightforward here and do not require strict alignment of infer strategies semantics.
If the list of the infer strategies grows big, we can try to replace it with a function pointer.

}

if (infer_count != -1) infer_count++;
Expand Down
7 changes: 5 additions & 2 deletions src/plugins/intel_cpu/src/graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ class Graph {
enum class Status {
NotReady = 0,
ReadyStatic = 1,
ReadyDynamic = 2
ReadyDynamic = 2,
ReadyDynamicSeq = 3,
};

Graph() = default;
Expand Down Expand Up @@ -228,7 +229,9 @@ class Graph {
void ExecuteNode(const NodePtr& node, const dnnl::stream& stream) const;
void CreatePrimitivesAndExecConstants() const;
void InferStatic(SyncInferRequest* request);
void InferDynamic(SyncInferRequest* request);

template<typename UpdateStrategy>
void InferDynamic(SyncInferRequest* request, UpdateStrategy&& update);
void ParalleMtNuma(size_t num_nodes,
ov::threading::CPUStreamsExecutor::Ptr executor,
const std::function<void(size_t, size_t)>& func) const;
Expand Down