diff --git a/cpp/c_api.cc b/cpp/c_api.cc index 9a5e961e..6185411b 100644 --- a/cpp/c_api.cc +++ b/cpp/c_api.cc @@ -1,4 +1,3 @@ -#include "./dep_graph.h" #include "./registry.h" #include @@ -582,5 +581,365 @@ MLC_REGISTER_FUNC("mlc.testing.FieldSet").set_body([](ObjectRef root, const char MLC_UNREACHABLE(); }); +struct DepNodeObj { + MLCAny _mlc_header; + Any stmt; + UList input_vars; + UList output_vars; + DepNodeObj *prev; + DepNodeObj *next; + + MLC_DEF_DYN_TYPE(MLC_EXPORTS, DepNodeObj, Object, "mlc.core.DepNode"); + + explicit DepNodeObj(Any stmt, UList input_vars, UList output_vars, DepNodeObj *prev, DepNodeObj *next) + : stmt(stmt), input_vars(input_vars), output_vars(output_vars), prev(prev), next(next) {} + + void Clear(); + Str __str__() const { return this->stmt.str(); } +}; + +struct DepNode : public ObjectRef { + explicit DepNode(Any stmt, UList input_vars, UList output_vars) + : DepNode(DepNode::New(stmt, input_vars, output_vars, nullptr, nullptr)) {} + + MLC_DEF_OBJ_REF(MLC_EXPORTS, DepNode, DepNodeObj, ObjectRef) + .Field("stmt", &DepNodeObj::stmt, /*frozen=*/true) + .Field("input_vars", &DepNodeObj::input_vars, /*frozen=*/true) + .Field("output_vars", &DepNodeObj::output_vars, /*frozen=*/true) + ._Field("_prev", offsetof(DepNodeObj, prev), sizeof(DepNodeObj *), true, ::mlc::core::ParseType()) + ._Field("_next", offsetof(DepNodeObj, next), sizeof(DepNodeObj *), true, ::mlc::core::ParseType()) + .StaticFn("__init__", InitOf) + .MemFn("__str__", &DepNodeObj::__str__); +}; + +struct DepGraphObj { + MLCAny _mlc_header; + Func stmt_to_inputs; + Func stmt_to_outputs; + UDict stmt_to_node; + UDict var_to_producer; + UDict var_to_consumers; + DepNode head; + + MLC_DEF_DYN_TYPE(MLC_EXPORTS, DepGraphObj, Object, "mlc.core.DepGraph"); + + explicit DepGraphObj(Func stmt_to_inputs, Func stmt_to_outputs, UDict stmt_to_node, UDict var_to_producer, + UDict var_to_consumers, DepNode head) + : stmt_to_inputs(stmt_to_inputs), stmt_to_outputs(stmt_to_outputs), stmt_to_node(stmt_to_node), + var_to_producer(var_to_producer), var_to_consumers(var_to_consumers), head(head) {} + + explicit DepGraphObj(UList input_vars, UList stmts, Func stmt_to_inputs, Func stmt_to_outputs) + : stmt_to_inputs(stmt_to_inputs), stmt_to_outputs(stmt_to_outputs), stmt_to_node(), var_to_producer(), + var_to_consumers(), head(DepNode(Any(), UList{}, input_vars)) { + this->stmt_to_node[Any()] = this->head; + for (const Any &var : input_vars) { + this->var_to_producer[var] = this->head; + this->var_to_consumers[var] = UList{}; + } + DepNodeObj *prev = this->head.get(); + for (const Any &stmt : stmts) { + DepNode node = this->CreateNode(stmt); + this->InsertAfter(prev, node.get()); + prev = node.get(); + } + } + ~DepGraphObj() { this->Clear(); } + void Clear(); + DepNode CreateNode(Any stmt) { return DepNode(stmt, this->stmt_to_inputs(stmt), this->stmt_to_outputs(stmt)); } + DepNode GetNodeFromStmt(Any stmt); + void InsertBefore(DepNodeObj *anchor, DepNodeObj *to_insert); + void InsertAfter(DepNodeObj *anchor, DepNodeObj *to_insert); + void EraseNode(DepNodeObj *to_erase); + void Replace(DepNodeObj *old_node, DepNodeObj *new_node); + UList GetNodeProducers(DepNodeObj *node); + UList GetNodeConsumers(DepNodeObj *node); + DepNode GetVarProducer(Any var); + UList GetVarConsumers(Any var); + void _Insert(DepNodeObj *prev, DepNodeObj *next, DepNodeObj *to_insert); +}; + +struct DepGraph : public ObjectRef { + MLC_DEF_OBJ_REF(MLC_EXPORTS, DepGraph, DepGraphObj, ObjectRef) + .Field("_stmt_to_inputs", &DepGraphObj::stmt_to_inputs, /*frozen=*/true) + .Field("_stmt_to_outputs", &DepGraphObj::stmt_to_outputs, /*frozen=*/true) + .Field("_stmt_to_node", &DepGraphObj::stmt_to_node, /*frozen=*/true) + .Field("_var_to_producer", &DepGraphObj::var_to_producer, /*frozen=*/true) + .Field("_var_to_consumers", &DepGraphObj::var_to_consumers, /*frozen=*/true) + .Field("_head", &DepGraphObj::head, /*frozen=*/true) + .StaticFn("__init__", InitOf) + .StaticFn("_init_from_stmts", InitOf) + .MemFn("clear", &DepGraphObj::Clear) + .MemFn("create_node", &DepGraphObj::CreateNode) + .MemFn("get_node_from_stmt", &DepGraphObj::GetNodeFromStmt) + .MemFn("insert_before", &DepGraphObj::InsertBefore) + .MemFn("insert_after", &DepGraphObj::InsertAfter) + .MemFn("erase_node", &DepGraphObj::EraseNode) + .MemFn("replace", &DepGraphObj::Replace) + .MemFn("get_node_producers", &DepGraphObj::GetNodeProducers) + .MemFn("get_node_consumers", &DepGraphObj::GetNodeConsumers) + .MemFn("get_var_producer", &DepGraphObj::GetVarProducer) + .MemFn("get_var_consumers", &DepGraphObj::GetVarConsumers) + .MemFn("__str__", ::mlc::core::StringifyOpaque); + + explicit DepGraph(Func stmt_to_inputs, Func stmt_to_outputs, UDict stmt_to_node, UDict var_to_producer, + UDict var_to_consumers, DepNode head) + : DepGraph( + DepGraph::New(stmt_to_inputs, stmt_to_outputs, stmt_to_node, var_to_producer, var_to_consumers, head)) {} + + explicit DepGraph(UList input_vars, UList stmts, Func stmt_to_inputs, Func stmt_to_outputs) + : DepGraph(DepGraph::New(input_vars, stmts, stmt_to_inputs, stmt_to_outputs)) {} +}; + +inline void DepNodeObj::Clear() { + this->stmt = Any(); + this->prev = nullptr; + this->next = nullptr; +} + +inline void DepGraphObj::Clear() { + for (DepNodeObj *node = this->head.get(); node;) { + DepNodeObj *next = node->next; + node->Clear(); + node = next; + } + this->var_to_producer.clear(); + this->var_to_consumers.clear(); + this->stmt_to_node.clear(); +} + +inline DepNode DepGraphObj::GetNodeFromStmt(Any stmt) { + if (auto it = this->stmt_to_node.find(stmt); it != this->stmt_to_node.end()) { + return it->second; + } + MLC_THROW(RuntimeError) << "Stmt not in graph: " << stmt; + MLC_UNREACHABLE(); +} + +inline void DepGraphObj::InsertBefore(DepNodeObj *anchor, DepNodeObj *to_insert) { + if (anchor->prev == nullptr) { + MLC_THROW(RuntimeError) << "Can't input before the input node: " << anchor->stmt; + } + if (!stmt_to_node.count(anchor->stmt)) { + MLC_THROW(RuntimeError) << "Anchor node not in graph: " << anchor->stmt; + } + DepNodeObj *prev = anchor->prev; + DepNodeObj *next = anchor; + return _Insert(prev, next, to_insert); +} + +inline void DepGraphObj::InsertAfter(DepNodeObj *anchor, DepNodeObj *to_insert) { + if (!stmt_to_node.count(anchor->stmt)) { + MLC_THROW(RuntimeError) << "Anchor node not in graph: " << anchor->stmt; + } + DepNodeObj *prev = anchor; + DepNodeObj *next = anchor->next; + return _Insert(prev, next, to_insert); +} + +inline void DepGraphObj::EraseNode(DepNodeObj *to_erase) { + // Step 1. Unlink the node from the graph + if (to_erase->prev == nullptr) { + MLC_THROW(RuntimeError) << "Can't erase the input node: " << to_erase->stmt; + } + if (!this->stmt_to_node.count(to_erase->stmt)) { + MLC_THROW(RuntimeError) << "Node not in graph: " << to_erase->stmt; + } + this->stmt_to_node.erase(to_erase->stmt); + if (to_erase->prev != nullptr) { + to_erase->prev->next = to_erase->next; + } else { + this->head = to_erase->next; + } + if (to_erase->next != nullptr) { + to_erase->next->prev = to_erase->prev; + } + // Step 2. For each variable produced by the node + // 1) check all its consumers are gone + // 2) remove the producer + for (const Any &var : to_erase->output_vars) { + UListObj *consumers = this->var_to_consumers.at(var); + if (!consumers->empty()) { + MLC_THROW(RuntimeError) << "Removing a node which produces a variable that still has consumers in graph: " << var; + } + this->var_to_producer.erase(var); + this->var_to_consumers.erase(var); + } + // Step 3. For each varibale consumed by the node + // 1) check if the var is in the graph + // 2) remove the node from its consumer list + for (const Any &var : to_erase->input_vars) { + if (!this->var_to_producer.count(var)) { + MLC_THROW(RuntimeError) << "Variable is not produced by any node in the graph: " << var; + } + UListObj *consumers = this->var_to_consumers.at(var); + auto it = std::find_if(consumers->begin(), consumers->end(), + [to_erase](const Any &v) -> bool { return v.operator DepNodeObj *() == to_erase; }); + if (it == consumers->end()) { + MLC_THROW(RuntimeError) << "Node is not a consumer of the variable: " << var; + } + consumers->erase(it); + } + // Step 4. Clear the node + to_erase->Clear(); +} + +inline void DepGraphObj::Replace(DepNodeObj *old_node, DepNodeObj *new_node) { + if (old_node == new_node) { + return; + } + if (old_node->prev == nullptr) { + MLC_THROW(RuntimeError) << "Can't replace the input node: " << old_node->stmt; + } + if (!this->stmt_to_node.count(old_node->stmt)) { + MLC_THROW(RuntimeError) << "Node not in graph: " << old_node->stmt; + } + if (new_node->prev != nullptr || new_node->next != nullptr) { + MLC_THROW(RuntimeError) << "Node is already in the graph: " << new_node->stmt; + } + int64_t num_output_vars = old_node->output_vars.size(); + if (num_output_vars != new_node->output_vars.size()) { + MLC_THROW(RuntimeError) << "Mismatched number of output_vars: " << num_output_vars << " vs " + << new_node->output_vars.size(); + } + // Step 1. Replace each variable produced by the old node + for (int64_t i = 0; i < num_output_vars; ++i) { + Any old_var = old_node->output_vars[i]; + Any new_var = new_node->output_vars[i]; + Ref old_var_consumers = var_to_consumers.at(old_var); + // Search through its consumers + for (DepNodeObj *consumer : *old_var_consumers) { + // Replace the input vars of each consumer + for (Any &v : consumer->input_vars) { + if (v.operator Object *() == old_var.operator Object *()) { + v = new_var; + } + } + } + this->var_to_producer.erase(old_var); + this->var_to_consumers.erase(old_var); + this->var_to_producer[new_var] = new_node; + this->var_to_consumers[new_var] = old_var_consumers; + } + // Step 2. Delete each variable consumed by the old node + for (const Any &var : old_node->input_vars) { + UListObj *consumers = this->var_to_consumers.at(var); + if (auto it = std::find_if(consumers->begin(), consumers->end(), + [old_node](const Any &v) -> bool { return v.operator DepNodeObj *() == old_node; }); + it != consumers->end()) { + consumers->erase(it); + } else { + MLC_THROW(RuntimeError) << "Node is not a consumer of the variable: " << var; + } + } + // Step 3. Add variables consumed by the new node + for (const Any &var : new_node->input_vars) { + if (!this->var_to_producer.count(var)) { + MLC_THROW(RuntimeError) << "Variable is not produced by any node in the graph: " << var; + } + this->var_to_consumers.at(var).operator UListObj *()->push_back(new_node); + } + // Step 4. Link the new node into the graph + new_node->prev = old_node->prev; + new_node->next = old_node->next; + if (old_node->prev != nullptr) { + old_node->prev->next = new_node; + } else { + this->head = new_node; + } + if (old_node->next != nullptr) { + old_node->next->prev = new_node; + } + this->stmt_to_node.erase(old_node->stmt); + if (this->stmt_to_node.count(new_node->stmt)) { + MLC_THROW(RuntimeError) << "Stmt already in the graph: " << new_node->stmt; + } else { + this->stmt_to_node[new_node->stmt] = new_node; + } + // Step 5. Clear the old node + old_node->Clear(); +} + +inline UList DepGraphObj::GetNodeProducers(DepNodeObj *node) { + UList ret; + for (const Any &var : node->input_vars) { + if (auto it = this->var_to_producer.find(var); it != this->var_to_producer.end()) { + ret.push_back(it->second); + } else { + MLC_THROW(RuntimeError) << "Variable is not produced by any node in the graph: " << var; + } + } + return ret; +} + +inline UList DepGraphObj::GetNodeConsumers(DepNodeObj *node) { + UList ret; + for (const Any &var : node->output_vars) { + if (auto it = this->var_to_consumers.find(var); it != this->var_to_consumers.end()) { + UListObj *consumers = it->second; + ret.insert(ret.end(), consumers->begin(), consumers->end()); + } else { + MLC_THROW(RuntimeError) << "Variable is not consumed by any node in the graph: " << var; + } + } + return ret; +} + +inline DepNode DepGraphObj::GetVarProducer(Any var) { + if (auto it = this->var_to_producer.find(var); it != this->var_to_producer.end()) { + return it->second; + } + MLC_THROW(RuntimeError) << "Variable is not produced by any node in the graph: " << var; + MLC_UNREACHABLE(); +} + +inline UList DepGraphObj::GetVarConsumers(Any var) { + if (auto it = this->var_to_consumers.find(var); it != this->var_to_consumers.end()) { + return it->second; + } + MLC_THROW(RuntimeError) << "Variable is not consumed by any node in the graph: " << var; + MLC_UNREACHABLE(); +} + +inline void DepGraphObj::_Insert(DepNodeObj *prev, DepNodeObj *next, DepNodeObj *to_insert) { + if (to_insert->prev != nullptr || to_insert->next != nullptr) { + MLC_THROW(RuntimeError) << "Node is already in the graph: " << to_insert->stmt; + } + // Step 1. Link the node into the graph + if (this->stmt_to_node.count(to_insert->stmt)) { + MLC_THROW(RuntimeError) << "Stmt already in the graph: " << to_insert->stmt; + } + this->stmt_to_node[to_insert->stmt] = to_insert; + to_insert->prev = prev; + to_insert->next = next; + if (prev != nullptr) { + prev->next = to_insert; + } else { + this->head = to_insert; + } + if (next != nullptr) { + next->prev = to_insert; + } + // Step 2. For each variable produced by the node + // 1) check if it doesn't have a producer yet + // 2) record its producer as this node + for (const Any &var : to_insert->output_vars) { + if (auto it = this->var_to_producer.find(var); it != this->var_to_producer.end()) { + MLC_THROW(RuntimeError) << "Variable already has a producer by another node: " + << it->second.operator DepNode()->stmt; + } else { + this->var_to_producer[var] = to_insert; + this->var_to_consumers[var] = UList{}; + } + } + // Step 3. For each variable consumed by the node + // 1) check if the var is in the graph + // 1) add a new consumer of this var + for (const Any &var : to_insert->input_vars) { + if (!this->var_to_producer.count(var)) { + MLC_THROW(RuntimeError) << "Variable is not produced by any node in the graph: " << var; + } + this->var_to_consumers.at(var).operator UListObj *()->push_back(to_insert); + } +} } // namespace } // namespace mlc diff --git a/cpp/dep_graph.h b/cpp/dep_graph.h deleted file mode 100644 index 91c565f2..00000000 --- a/cpp/dep_graph.h +++ /dev/null @@ -1,430 +0,0 @@ -#ifndef MLC_CORE_DEP_GRAPH_H_ -#define MLC_CORE_DEP_GRAPH_H_ - -#include - -namespace mlc { -namespace core { - -/*! - * \brief A dependency node in the dependency graph, which contains - * information about the statement, its input and output vars, and - * pointers to the previous and next nodes in the linked list. - * All the nodes are linked together in a doubly linked list. - */ -struct DepNodeObj { - MLCAny _mlc_header; - /*! \brief The statement that this node represents */ - Any stmt; - /*! \brief The list of input variables for this node */ - UList input_vars; - /*! \brief The list of output variables for this node */ - UList output_vars; - /*! \brief The previous node in the linked list */ - DepNodeObj *prev; - /*! \brief The next node in the linked list */ - DepNodeObj *next; - - MLC_DEF_DYN_TYPE(MLC_EXPORTS, DepNodeObj, Object, "mlc.core.DepNode"); - - explicit DepNodeObj(Any stmt, UList input_vars, UList output_vars, DepNodeObj *prev, DepNodeObj *next) - : stmt(stmt), input_vars(input_vars), output_vars(output_vars), prev(prev), next(next) {} - - void Clear() { - this->stmt = Null; - this->input_vars.clear(); - this->output_vars.clear(); - this->prev = nullptr; - this->next = nullptr; - } -}; - -struct DepNode : public ObjectRef { - explicit DepNode(Any stmt, UList input_vars, UList output_vars) - : DepNode(DepNode::New(stmt, input_vars, output_vars, nullptr, nullptr)) {} - - MLC_DEF_OBJ_REF(MLC_EXPORTS, DepNode, DepNodeObj, ObjectRef) - .Field("stmt", &DepNodeObj::stmt, /*frozen=*/true) - .Field("input_vars", &DepNodeObj::input_vars, /*frozen=*/true) - .Field("output_vars", &DepNodeObj::output_vars, /*frozen=*/true) - ._Field("_prev", offsetof(DepNodeObj, prev), sizeof(DepNodeObj *), true, ::mlc::core::ParseType()) - ._Field("_next", offsetof(DepNodeObj, next), sizeof(DepNodeObj *), true, ::mlc::core::ParseType()) - .StaticFn("__init__", InitOf); -}; - -struct DepGraphObj { - MLCAny _mlc_header; - /*! \brief A function that maps a stmt to a list of variables it consumes */ - Func stmt_to_inputs; - /*! \brief A function that maps a stmt to a list of variables it produces */ - Func stmt_to_outputs; - /*! \brief A map from a stmt to its node in the linked list */ - UDict stmt_to_node; - /*! \brief Map from a variable to its producer nodes */ - UDict var_to_producer; - /*! \brief Map from a variable to a list of consumer nodes */ - UDict var_to_consumers; - /*! \brief The first node in the linked list */ - DepNode head; - - MLC_DEF_DYN_TYPE(MLC_EXPORTS, DepGraphObj, Object, "mlc.core.DepGraph"); - - explicit DepGraphObj(Func stmt_to_inputs, Func stmt_to_outputs, UDict stmt_to_node, UDict var_to_producer, - UDict var_to_consumers, DepNode head) - : stmt_to_inputs(stmt_to_inputs), stmt_to_outputs(stmt_to_outputs), stmt_to_node(stmt_to_node), - var_to_producer(var_to_producer), var_to_consumers(var_to_consumers), head(head) {} - - explicit DepGraphObj(UList input_vars, UList stmts, Func stmt_to_inputs, Func stmt_to_outputs) - : stmt_to_inputs(stmt_to_inputs), stmt_to_outputs(stmt_to_outputs), stmt_to_node(), var_to_producer(), - var_to_consumers(), head(DepNode(Any(), UList{}, input_vars)) { - this->stmt_to_node[Any()] = this->head; - for (const Any &var : input_vars) { - this->var_to_producer[var] = this->head; - this->var_to_consumers[var] = UList{}; - } - DepNodeObj *prev = this->head.get(); - for (const Any &stmt : stmts) { - DepNode node = this->CreateNode(stmt); - this->InsertAfter(prev, node.get()); - prev = node.get(); - } - } - - ~DepGraphObj() { this->Clear(); } - - /*! - * \brief Clear the dependency graph. - * \note This will unlink all nodes from the graph and clear the maps. - */ - void Clear() { - for (DepNodeObj *node = this->head.get(); node;) { - DepNodeObj *next = node->next; - node->Clear(); - node = next; - } - this->var_to_producer.clear(); - this->var_to_consumers.clear(); - this->stmt_to_node.clear(); - } - /*! - * \brief Create a new node which is not linked to the dependency graph. - * \param stmt The statement that this node represents - * \return The new node - * \note This node is not part of the dependency graph until it's explicitly - * inserted using InsertBefore or InsertAfter. - */ - DepNode CreateNode(Any stmt) { return DepNode(stmt, this->stmt_to_inputs(stmt), this->stmt_to_outputs(stmt)); } - /*! - * \brief Get a node containing the given statement - * \param stmt The statement to get the node for - * \return The node containing the statement - * \note This will throw an error if the statement is not inserted into the graph. - */ - DepNode GetNodeFromStmt(Any stmt) { - if (auto it = this->stmt_to_node.find(stmt); it != this->stmt_to_node.end()) { - return it->second; - } - MLC_THROW(RuntimeError) << "Stmt not in graph: " << stmt; - MLC_UNREACHABLE(); - } - /*! - * \brief Insert a node before or after an anchor node. - * \param anchor The anchor node, not nullptr - * \param to_insert The node to insert - * \note This will link the new node to the dependency graph. - */ - void InsertBefore(DepNodeObj *anchor, DepNodeObj *to_insert) { - if (anchor->prev == nullptr) { - MLC_THROW(RuntimeError) << "Can't input before the input node: " << anchor->stmt; - } - if (!stmt_to_node.count(anchor->stmt)) { - MLC_THROW(RuntimeError) << "Anchor node not in graph: " << anchor->stmt; - } - DepNodeObj *prev = anchor->prev; - DepNodeObj *next = anchor; - return _Insert(prev, next, to_insert); - } - /*! - * \brief Insert a node before or after an anchor node. - * \param anchor The anchor node, not nullptr - * \param to_insert The node to insert - * \note This will link the new node to the dependency graph. - */ - void InsertAfter(DepNodeObj *anchor, DepNodeObj *to_insert) { - if (!stmt_to_node.count(anchor->stmt)) { - MLC_THROW(RuntimeError) << "Anchor node not in graph: " << anchor->stmt; - } - DepNodeObj *prev = anchor; - DepNodeObj *next = anchor->next; - return _Insert(prev, next, to_insert); - } - /*! - * \brief Erase a node from the dependency graph. - * \param to_erase The node to erase, not nullptr - * \note This will unlink the node from the dependency graph, remove the variables - * it produces, and remove itself from the consumer lists of the variables it consumes. - */ - void EraseNode(DepNodeObj *to_erase) { - // Step 1. Unlink the node from the graph - if (to_erase->prev == nullptr) { - MLC_THROW(RuntimeError) << "Can't erase the input node: " << to_erase->stmt; - } - if (!this->stmt_to_node.count(to_erase->stmt)) { - MLC_THROW(RuntimeError) << "Node not in graph: " << to_erase->stmt; - } - this->stmt_to_node.erase(to_erase->stmt); - if (to_erase->prev != nullptr) { - to_erase->prev->next = to_erase->next; - } else { - this->head = to_erase->next; - } - if (to_erase->next != nullptr) { - to_erase->next->prev = to_erase->prev; - } - // Step 2. For each variable produced by the node - // 1) check all its consumers are gone - // 2) remove the producer - for (const Any &var : to_erase->output_vars) { - UListObj *consumers = this->var_to_consumers.at(var); - if (!consumers->empty()) { - MLC_THROW(RuntimeError) << "Removing a node which produces a variable that still has consumers in graph: " - << var; - } - this->var_to_producer.erase(var); - this->var_to_consumers.erase(var); - } - // Step 3. For each varibale consumed by the node - // 1) check if the var is in the graph - // 2) remove the node from its consumer list - for (const Any &var : to_erase->input_vars) { - if (!this->var_to_producer.count(var)) { - MLC_THROW(RuntimeError) << "Variable is not produced by any node in the graph: " << var; - } - UListObj *consumers = this->var_to_consumers.at(var); - auto it = std::find_if(consumers->begin(), consumers->end(), - [to_erase](const Any &v) -> bool { return v.operator DepNodeObj *() == to_erase; }); - if (it == consumers->end()) { - MLC_THROW(RuntimeError) << "Node is not a consumer of the variable: " << var; - } - consumers->erase(it); - } - // Step 4. Clear the node - to_erase->Clear(); - } - /*! - * \brief Replace a node in the dependency graph with another node. - * \param old_node The node to replace - * \param new_node The new node to insert - * \note This will unlink the old node from the dependency graph and link the new node. - */ - void Replace(DepNodeObj *old_node, DepNodeObj *new_node) { - if (old_node == new_node) { - return; - } - if (old_node->prev == nullptr) { - MLC_THROW(RuntimeError) << "Can't replace the input node: " << old_node->stmt; - } - if (!this->stmt_to_node.count(old_node->stmt)) { - MLC_THROW(RuntimeError) << "Node not in graph: " << old_node->stmt; - } - if (new_node->prev != nullptr || new_node->next != nullptr) { - MLC_THROW(RuntimeError) << "Node is already in the graph: " << new_node->stmt; - } - int64_t num_output_vars = old_node->output_vars.size(); - if (num_output_vars != new_node->output_vars.size()) { - MLC_THROW(RuntimeError) << "Mismatched number of output_vars: " << num_output_vars << " vs " - << new_node->output_vars.size(); - } - // Step 1. Replace each variable produced by the old node - for (int64_t i = 0; i < num_output_vars; ++i) { - Any old_var = old_node->output_vars[i]; - Any new_var = new_node->output_vars[i]; - Ref old_var_consumers = var_to_consumers.at(old_var); - // Search through its consumers - for (DepNodeObj *consumer : *old_var_consumers) { - // Replace the input vars of each consumer - for (Any &v : consumer->input_vars) { - if (v.operator Object *() == old_var.operator Object *()) { - v = new_var; - } - } - } - this->var_to_producer.erase(old_var); - this->var_to_consumers.erase(old_var); - this->var_to_producer[new_var] = new_node; - this->var_to_consumers[new_var] = old_var_consumers; - } - // Step 2. Delete each variable consumed by the old node - for (const Any &var : old_node->input_vars) { - UListObj *consumers = this->var_to_consumers.at(var); - if (auto it = std::find_if(consumers->begin(), consumers->end(), - [old_node](const Any &v) -> bool { return v.operator DepNodeObj *() == old_node; }); - it != consumers->end()) { - consumers->erase(it); - } else { - MLC_THROW(RuntimeError) << "Node is not a consumer of the variable: " << var; - } - } - // Step 3. Add variables consumed by the new node - for (const Any &var : new_node->input_vars) { - if (!this->var_to_producer.count(var)) { - MLC_THROW(RuntimeError) << "Variable is not produced by any node in the graph: " << var; - } - this->var_to_consumers.at(var).operator UListObj *()->push_back(new_node); - } - // Step 4. Link the new node into the graph - new_node->prev = old_node->prev; - new_node->next = old_node->next; - if (old_node->prev != nullptr) { - old_node->prev->next = new_node; - } else { - this->head = new_node; - } - if (old_node->next != nullptr) { - old_node->next->prev = new_node; - } - this->stmt_to_node.erase(old_node->stmt); - if (this->stmt_to_node.count(new_node->stmt)) { - MLC_THROW(RuntimeError) << "Stmt already in the graph: " << new_node->stmt; - } else { - this->stmt_to_node[new_node->stmt] = new_node; - } - // Step 5. Clear the old node - old_node->Clear(); - } - /*! - * \brief For a given node, returns its producers, i.e. a list of nodes that produce the input variables of the node. - * \param node The node to get the input statements for - * \return The list of input nodes for the node - */ - UList GetNodeProducers(DepNodeObj *node) { - UList ret; - for (const Any &var : node->input_vars) { - if (auto it = this->var_to_producer.find(var); it != this->var_to_producer.end()) { - ret.push_back(it->second); - } else { - MLC_THROW(RuntimeError) << "Variable is not produced by any node in the graph: " << var; - } - } - return ret; - } - /*! - * \brief For a given node, returns its consumers, i.e. a list of nodes that consume the output variables of the node. - * \param node The node to get the output statements for - * \return The list of output statements for the node - */ - UList GetNodeConsumers(DepNodeObj *node) { - UList ret; - for (const Any &var : node->output_vars) { - if (auto it = this->var_to_consumers.find(var); it != this->var_to_consumers.end()) { - UListObj *consumers = it->second; - ret.insert(ret.end(), consumers->begin(), consumers->end()); - } else { - MLC_THROW(RuntimeError) << "Variable is not consumed by any node in the graph: " << var; - } - } - return ret; - } - /*! - * \brief Find the producer of a variable in the dependency graph. - * \param var The variable to find the producer for - * \return The producer node for the variable - */ - DepNode GetVarProducer(Any var) { - if (auto it = this->var_to_producer.find(var); it != this->var_to_producer.end()) { - return it->second; - } - MLC_THROW(RuntimeError) << "Variable is not produced by any node in the graph: " << var; - MLC_UNREACHABLE(); - } - /*! - * \brief Find the consumers of a variable in the dependency graph. - * \param var The variable to find the consumers for - * \return The list of consumer nodes for the variable - */ - UList GetVarConsumers(Any var) { - if (auto it = this->var_to_consumers.find(var); it != this->var_to_consumers.end()) { - return it->second; - } - MLC_THROW(RuntimeError) << "Variable is not consumed by any node in the graph: " << var; - MLC_UNREACHABLE(); - } - - void _Insert(DepNodeObj *prev, DepNodeObj *next, DepNodeObj *to_insert) { - if (to_insert->prev != nullptr || to_insert->next != nullptr) { - MLC_THROW(RuntimeError) << "Node is already in the graph: " << to_insert->stmt; - } - // Step 1. Link the node into the graph - if (this->stmt_to_node.count(to_insert->stmt)) { - MLC_THROW(RuntimeError) << "Stmt already in the graph: " << to_insert->stmt; - } - this->stmt_to_node[to_insert->stmt] = to_insert; - to_insert->prev = prev; - to_insert->next = next; - if (prev != nullptr) { - prev->next = to_insert; - } else { - this->head = to_insert; - } - if (next != nullptr) { - next->prev = to_insert; - } - // Step 2. For each variable produced by the node - // 1) check if it doesn't have a producer yet - // 2) record its producer as this node - for (const Any &var : to_insert->output_vars) { - if (auto it = this->var_to_producer.find(var); it != this->var_to_producer.end()) { - MLC_THROW(RuntimeError) << "Variable already has a producer by another node: " - << it->second.operator DepNode()->stmt; - } else { - this->var_to_producer[var] = to_insert; - this->var_to_consumers[var] = UList{}; - } - } - // Step 3. For each variable consumed by the node - // 1) check if the var is in the graph - // 1) add a new consumer of this var - for (const Any &var : to_insert->input_vars) { - if (!this->var_to_producer.count(var)) { - MLC_THROW(RuntimeError) << "Variable is not produced by any node in the graph: " << var; - } - this->var_to_consumers.at(var).operator UListObj *()->push_back(to_insert); - } - } -}; - -struct DepGraph : public ObjectRef { - MLC_DEF_OBJ_REF(MLC_EXPORTS, DepGraph, DepGraphObj, ObjectRef) - .Field("_stmt_to_inputs", &DepGraphObj::stmt_to_inputs, /*frozen=*/true) - .Field("_stmt_to_outputs", &DepGraphObj::stmt_to_outputs, /*frozen=*/true) - .Field("_stmt_to_node", &DepGraphObj::stmt_to_node, /*frozen=*/true) - .Field("_var_to_producer", &DepGraphObj::var_to_producer, /*frozen=*/true) - .Field("_var_to_consumers", &DepGraphObj::var_to_consumers, /*frozen=*/true) - .Field("_head", &DepGraphObj::head, /*frozen=*/true) - .StaticFn("__init__", InitOf) - .StaticFn("_init_from_stmts", InitOf) - .MemFn("clear", &DepGraphObj::Clear) - .MemFn("create_node", &DepGraphObj::CreateNode) - .MemFn("get_node_from_stmt", &DepGraphObj::GetNodeFromStmt) - .MemFn("insert_before", &DepGraphObj::InsertBefore) - .MemFn("insert_after", &DepGraphObj::InsertAfter) - .MemFn("erase_node", &DepGraphObj::EraseNode) - .MemFn("replace", &DepGraphObj::Replace) - .MemFn("get_node_producers", &DepGraphObj::GetNodeProducers) - .MemFn("get_node_consumers", &DepGraphObj::GetNodeConsumers) - .MemFn("get_var_producer", &DepGraphObj::GetVarProducer) - .MemFn("get_var_consumers", &DepGraphObj::GetVarConsumers); - - explicit DepGraph(Func stmt_to_inputs, Func stmt_to_outputs, UDict stmt_to_node, UDict var_to_producer, - UDict var_to_consumers, DepNode head) - : DepGraph( - DepGraph::New(stmt_to_inputs, stmt_to_outputs, stmt_to_node, var_to_producer, var_to_consumers, head)) {} - - explicit DepGraph(UList input_vars, UList stmts, Func stmt_to_inputs, Func stmt_to_outputs) - : DepGraph(DepGraph::New(input_vars, stmts, stmt_to_inputs, stmt_to_outputs)) {} -}; - -} // namespace core -} // namespace mlc - -#endif // MLC_CORE_DEP_GRAPH_H_ diff --git a/cpp/registry.h b/cpp/registry.h index 03384f1d..fb3f0ef6 100644 --- a/cpp/registry.h +++ b/cpp/registry.h @@ -645,6 +645,7 @@ inline TypeTable *TypeTable::New() { self->SetFunc("mlc.base.DataTypeFromStr", Func([self](const char *str) { return self->DataTypeFromStr(str); }).get()); self->SetFunc("mlc.base.DeviceTypeRegister", Func([self](const char *name) { return self->DeviceTypeRegister(name); }).get()); + self->SetFunc("mlc.core.Stringify", Func(::mlc::core::StringifyWithFields).get()); self->SetFunc("mlc.core.JSONLoads", Func(::mlc::registry::JSONLoads).get()); self->SetFunc("mlc.core.JSONSerialize", Func(::mlc::registry::JSONSerialize).get()); self->SetFunc("mlc.core.JSONDeserialize", Func(::mlc::registry::JSONDeserialize).get()); diff --git a/include/mlc/base/traits_scalar.h b/include/mlc/base/traits_scalar.h index c03e289e..17e76c89 100644 --- a/include/mlc/base/traits_scalar.h +++ b/include/mlc/base/traits_scalar.h @@ -2,6 +2,7 @@ #define MLC_BASE_TRAITS_SCALAR_H_ #include "./utils.h" +#include #include namespace mlc { @@ -88,7 +89,7 @@ template <> struct TypeTraits { return "None"; } else { std::ostringstream oss; - oss << src; + oss << "0x" << std::setfill('0') << std::setw(12) << std::hex << (uintptr_t)(src); return oss.str(); } } diff --git a/include/mlc/core/all.h b/include/mlc/core/all.h index a98b4d04..08c02165 100644 --- a/include/mlc/core/all.h +++ b/include/mlc/core/all.h @@ -14,6 +14,7 @@ #include "./typing.h" // IWYU pragma: export #include "./utils.h" // IWYU pragma: export #include "./visitor.h" // IWYU pragma: export +#include namespace mlc { namespace core { @@ -110,6 +111,54 @@ MLC_INLINE void DeleteExternObject(Object *objptr) { } } +inline std::string StringifyOpaque(const Object *self) { + std::ostringstream os; + os << self->GetTypeKey(); + os << "@0x" << std::setfill('0') << std::setw(12) << std::hex << (uintptr_t)(self->_mlc_header.v.v_ptr); + return os.str(); +} + +inline std::string StringifyWithFields(const Object *self) { + std::ostringstream os; + os << self->GetTypeKey(); + os << "@0x" << std::setfill('0') << std::setw(12) << std::hex << (uintptr_t)(self->_mlc_header.v.v_ptr); + os.copyfmt(std::ostringstream{}); + struct Printer { + void operator()(MLCTypeField *f, const Any *any) { Print(f, AnyView(*any)); } + void operator()(MLCTypeField *f, ObjectRef *obj) { Print(f, AnyView(*obj)); } + void operator()(MLCTypeField *f, Optional *opt) { Print(f, AnyView(*opt)); } + void operator()(MLCTypeField *f, Optional *opt) { Print(f, AnyView(*opt)); } + void operator()(MLCTypeField *f, Optional *opt) { Print(f, AnyView(*opt)); } + void operator()(MLCTypeField *f, Optional *opt) { Print(f, AnyView(*opt)); } + void operator()(MLCTypeField *f, Optional *opt) { Print(f, AnyView(*opt)); } + void operator()(MLCTypeField *f, Optional *opt) { Print(f, AnyView(*opt)); } + void operator()(MLCTypeField *f, bool *v) { Print(f, AnyView(*v)); } + void operator()(MLCTypeField *f, int8_t *v) { Print(f, AnyView(*v)); } + void operator()(MLCTypeField *f, int16_t *v) { Print(f, AnyView(*v)); } + void operator()(MLCTypeField *f, int32_t *v) { Print(f, AnyView(*v)); } + void operator()(MLCTypeField *f, int64_t *v) { Print(f, AnyView(*v)); } + void operator()(MLCTypeField *f, float *v) { Print(f, AnyView(*v)); } + void operator()(MLCTypeField *f, double *v) { Print(f, AnyView(*v)); } + void operator()(MLCTypeField *f, DLDataType *v) { Print(f, AnyView(*v)); } + void operator()(MLCTypeField *f, DLDevice *v) { Print(f, AnyView(*v)); } + void operator()(MLCTypeField *f, Optional *v) { Print(f, AnyView(*v)); } + void operator()(MLCTypeField *f, void **v) { Print(f, AnyView(*v)); } + void operator()(MLCTypeField *f, const char **v) { Print(f, AnyView(*v)); } + void Print(MLCTypeField *f, const AnyView &any) { + if (f->index > 0) { + (*os) << ", "; + } + (*os) << f->name << "="; + (*os) << any; + } + std::ostringstream *os; + }; + os << "("; + VisitFields(const_cast(self), Lib::GetTypeInfo(self->GetTypeIndex()), Printer{&os}); + os << ")"; + return os.str(); +} + } // namespace core } // namespace mlc @@ -229,7 +278,6 @@ template inline VTable &VTable::Set(Func func) { MLC_CHECK_ERR(::MLCVTableSetFunc(this->self, Obj::_type_index, func.get(), override_mode)); return *this; } - } // namespace mlc #endif // MLC_CORE_ALL_H_ diff --git a/include/mlc/core/func.h b/include/mlc/core/func.h index 8e009f70..a474e1f5 100644 --- a/include/mlc/core/func.h +++ b/include/mlc/core/func.h @@ -55,8 +55,7 @@ struct Func : public ObjectRef { static FuncObj *GetGlobal(const char *name, bool allow_missing = false) { return Lib::FuncGetGlobal(name, allow_missing); } - // A dummy function to trigger reflection - otherwise reflection registration will be skipped - MLC_DEF_OBJ_REF(MLC_EXPORTS, Func, FuncObj, ObjectRef).StaticFn("__nothing__", []() {}); + MLC_DEF_OBJ_REF(MLC_EXPORTS, Func, FuncObj, ObjectRef).MemFn("__str__", ::mlc::core::StringifyOpaque); }; } // namespace mlc diff --git a/include/mlc/core/object.h b/include/mlc/core/object.h index e483437f..4d9e30f2 100644 --- a/include/mlc/core/object.h +++ b/include/mlc/core/object.h @@ -2,7 +2,6 @@ #define MLC_CORE_OBJECT_H_ #include "./reflection.h" -#include /******************* Section 0. Dummy root *******************/ @@ -17,6 +16,8 @@ struct ObjectRefDummyRoot : protected ::mlc::base::PtrBase { ObjectRefDummyRoot() : PtrBase() {} ObjectRefDummyRoot(NullType) : PtrBase() {} }; +std::string StringifyWithFields(const Object *self); +std::string StringifyOpaque(const Object *self); } // namespace core } // namespace mlc @@ -61,12 +62,7 @@ struct Object { MLC_INLINE Object &operator=(const Object &) { return *this; } MLC_INLINE Object &operator=(Object &&) { return *this; } Str str() const; - std::string __str__() const { - std::ostringstream os; - os << this->GetTypeKey() << "@0x" << std::setfill('0') << std::setw(12) << std::hex - << (uintptr_t)(this->_mlc_header.v.v_ptr); - return os.str(); - } + std::string __str__() const { return ::mlc::core::StringifyWithFields(this); } friend std::ostream &operator<<(std::ostream &os, const Object &src); MLC_DEF_STATIC_TYPE(MLC_EXPORTS, Object, ::mlc::core::ObjectDummyRoot, MLCTypeIndex::kMLCObject, "object.Object"); diff --git a/include/mlc/core/str.h b/include/mlc/core/str.h index 1ac64e85..a4f14314 100644 --- a/include/mlc/core/str.h +++ b/include/mlc/core/str.h @@ -3,6 +3,7 @@ #include "./object.h" #include #include +#include #include #include #include diff --git a/python/mlc/dataclasses/__init__.py b/python/mlc/dataclasses/__init__.py index 85f8a69e..0bd81e32 100644 --- a/python/mlc/dataclasses/__init__.py +++ b/python/mlc/dataclasses/__init__.py @@ -6,5 +6,6 @@ field, prototype, replace, + stringify, vtable_method, ) diff --git a/python/mlc/dataclasses/utils.py b/python/mlc/dataclasses/utils.py index 179ef294..10c21fb5 100644 --- a/python/mlc/dataclasses/utils.py +++ b/python/mlc/dataclasses/utils.py @@ -456,3 +456,9 @@ def prototype( def replace(obj: Any, /, **changes: Any) -> Any: return obj.__replace__(**changes) + + +def stringify(obj: Any) -> str: + from mlc.core.func import Func + + return Func.get("mlc.core.Stringify")(obj) diff --git a/tests/cpp/test_base_any.cc b/tests/cpp/test_base_any.cc index e9f9cbb1..a726e7ef 100644 --- a/tests/cpp/test_base_any.cc +++ b/tests/cpp/test_base_any.cc @@ -229,11 +229,7 @@ TEST(Any, Constructor_Bool) { template struct Checker_Constructor_Ptr_NotNull { static void Check(const AnyType &v) { CheckAnyPOD(v, MLCTypeIndex::kMLCPtr, reinterpret_cast(0x1234)); -#ifndef _MSC_VER - EXPECT_STREQ(v.str()->c_str(), "0x1234"); -#else - EXPECT_STREQ(v.str()->c_str(), "0000000000001234"); -#endif + EXPECT_STREQ(v.str()->c_str(), "0x000000001234"); EXPECT_EQ(v.operator void *(), reinterpret_cast(0x1234)); CheckConvertFail([&]() { return v.operator int(); }, v.type_index, "int"); CheckConvertFail([&]() { return v.operator double(); }, v.type_index, "float"); diff --git a/tests/python/test_dataclasses_fields.py b/tests/python/test_dataclasses_fields.py index d312863d..17b816b8 100644 --- a/tests/python/test_dataclasses_fields.py +++ b/tests/python/test_dataclasses_fields.py @@ -1,3 +1,4 @@ +import re from typing import Union import mlc @@ -641,3 +642,17 @@ def test_mlc_class_mem_fn(mlc_class_for_test: MLCClassForTest) -> None: obj = mlc_class_for_test assert obj.i64 == 64 assert obj.i64_plus_one() == 65 + + +def test_stringify(mlc_class_for_test: MLCClassForTest) -> None: + obj = mlc_class_for_test + type_key = type(mlc_class_for_test)._mlc_type_info.type_key + expected = ( + type_key + + """@0x(bool_=False, i8=8, i16=16, i32=32, i64=64, f32=1.500000, f64=2.500000, raw_ptr=0x0000deadbeef, dtype=float8, device=cuda:0, any="hello", func=object.Func@0x, ulist=[1, 2.000000, "three", object.Func@0x], udict={"2": 2.000000, "4": object.Func@0x, "1": 1, "3": "three"}, str_="world", str_readonly="world", list_any=[1, 2.000000, "three", object.Func@0x], list_list_int=[[1, 2, 3], [4, 5, 6]], dict_any_any={2.000000: 2, 4: object.Func@0x, 1: 1.000000, "three": "four"}, dict_str_any={"4": object.Func@0x, "1": 1.000000, "2.0": 2, "three": "four"}, dict_any_str={2.000000: "2", 4: "5", 1: "1.0", "three": "four"}, dict_str_list_int={"2": [4, 5, 6], "1": [1, 2, 3]}, opt_bool=True, opt_i64=-64, opt_f64=None, opt_raw_ptr=None, opt_dtype=None, opt_device=cuda:0, opt_func=None, opt_ulist=None, opt_udict=None, opt_str=None, opt_list_any=[1, 2.000000, "three", object.Func@0x], opt_list_list_int=[[1, 2, 3], [4, 5, 6]], opt_dict_any_any=None, opt_dict_str_any={"4": object.Func@0x, "1": 1.000000, "2.0": 2, "three": "four"}, opt_dict_any_str={2.000000: "2", 4: "5", 1: "1.0", "three": "four"}, opt_dict_str_list_int={"2": [4, 5, 6], "1": [1, 2, 3]})""" + ) + actual = re.compile(r"@0x[0-9A-Fa-f]{12}\b").sub( + "@0x", + mlc.dataclasses.stringify(obj), + ) + assert actual == expected