Skip to content

Commit 5ae8a95

Browse files
committed
feat(dataclasses): Add dataclass style __str__ to c_class
1 parent 03b655c commit 5ae8a95

File tree

11 files changed

+439
-442
lines changed

11 files changed

+439
-442
lines changed

cpp/c_api.cc

Lines changed: 360 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
#include "./dep_graph.h"
21
#include "./registry.h"
32
#include <mlc/core/all.h>
43

@@ -582,5 +581,365 @@ MLC_REGISTER_FUNC("mlc.testing.FieldSet").set_body([](ObjectRef root, const char
582581
MLC_UNREACHABLE();
583582
});
584583

584+
struct DepNodeObj {
585+
MLCAny _mlc_header;
586+
Any stmt;
587+
UList input_vars;
588+
UList output_vars;
589+
DepNodeObj *prev;
590+
DepNodeObj *next;
591+
592+
MLC_DEF_DYN_TYPE(MLC_EXPORTS, DepNodeObj, Object, "mlc.core.DepNode");
593+
594+
explicit DepNodeObj(Any stmt, UList input_vars, UList output_vars, DepNodeObj *prev, DepNodeObj *next)
595+
: stmt(stmt), input_vars(input_vars), output_vars(output_vars), prev(prev), next(next) {}
596+
597+
void Clear();
598+
Str __str__() const { return this->stmt.str(); }
599+
};
600+
601+
struct DepNode : public ObjectRef {
602+
explicit DepNode(Any stmt, UList input_vars, UList output_vars)
603+
: DepNode(DepNode::New(stmt, input_vars, output_vars, nullptr, nullptr)) {}
604+
605+
MLC_DEF_OBJ_REF(MLC_EXPORTS, DepNode, DepNodeObj, ObjectRef)
606+
.Field("stmt", &DepNodeObj::stmt, /*frozen=*/true)
607+
.Field("input_vars", &DepNodeObj::input_vars, /*frozen=*/true)
608+
.Field("output_vars", &DepNodeObj::output_vars, /*frozen=*/true)
609+
._Field("_prev", offsetof(DepNodeObj, prev), sizeof(DepNodeObj *), true, ::mlc::core::ParseType<ObjectRef>())
610+
._Field("_next", offsetof(DepNodeObj, next), sizeof(DepNodeObj *), true, ::mlc::core::ParseType<ObjectRef>())
611+
.StaticFn("__init__", InitOf<DepNodeObj, Any, UList, UList, DepNodeObj *, DepNodeObj *>)
612+
.MemFn("__str__", &DepNodeObj::__str__);
613+
};
614+
615+
struct DepGraphObj {
616+
MLCAny _mlc_header;
617+
Func stmt_to_inputs;
618+
Func stmt_to_outputs;
619+
UDict stmt_to_node;
620+
UDict var_to_producer;
621+
UDict var_to_consumers;
622+
DepNode head;
623+
624+
MLC_DEF_DYN_TYPE(MLC_EXPORTS, DepGraphObj, Object, "mlc.core.DepGraph");
625+
626+
explicit DepGraphObj(Func stmt_to_inputs, Func stmt_to_outputs, UDict stmt_to_node, UDict var_to_producer,
627+
UDict var_to_consumers, DepNode head)
628+
: stmt_to_inputs(stmt_to_inputs), stmt_to_outputs(stmt_to_outputs), stmt_to_node(stmt_to_node),
629+
var_to_producer(var_to_producer), var_to_consumers(var_to_consumers), head(head) {}
630+
631+
explicit DepGraphObj(UList input_vars, UList stmts, Func stmt_to_inputs, Func stmt_to_outputs)
632+
: stmt_to_inputs(stmt_to_inputs), stmt_to_outputs(stmt_to_outputs), stmt_to_node(), var_to_producer(),
633+
var_to_consumers(), head(DepNode(Any(), UList{}, input_vars)) {
634+
this->stmt_to_node[Any()] = this->head;
635+
for (const Any &var : input_vars) {
636+
this->var_to_producer[var] = this->head;
637+
this->var_to_consumers[var] = UList{};
638+
}
639+
DepNodeObj *prev = this->head.get();
640+
for (const Any &stmt : stmts) {
641+
DepNode node = this->CreateNode(stmt);
642+
this->InsertAfter(prev, node.get());
643+
prev = node.get();
644+
}
645+
}
646+
~DepGraphObj() { this->Clear(); }
647+
void Clear();
648+
DepNode CreateNode(Any stmt) { return DepNode(stmt, this->stmt_to_inputs(stmt), this->stmt_to_outputs(stmt)); }
649+
DepNode GetNodeFromStmt(Any stmt);
650+
void InsertBefore(DepNodeObj *anchor, DepNodeObj *to_insert);
651+
void InsertAfter(DepNodeObj *anchor, DepNodeObj *to_insert);
652+
void EraseNode(DepNodeObj *to_erase);
653+
void Replace(DepNodeObj *old_node, DepNodeObj *new_node);
654+
UList GetNodeProducers(DepNodeObj *node);
655+
UList GetNodeConsumers(DepNodeObj *node);
656+
DepNode GetVarProducer(Any var);
657+
UList GetVarConsumers(Any var);
658+
void _Insert(DepNodeObj *prev, DepNodeObj *next, DepNodeObj *to_insert);
659+
};
660+
661+
struct DepGraph : public ObjectRef {
662+
MLC_DEF_OBJ_REF(MLC_EXPORTS, DepGraph, DepGraphObj, ObjectRef)
663+
.Field("_stmt_to_inputs", &DepGraphObj::stmt_to_inputs, /*frozen=*/true)
664+
.Field("_stmt_to_outputs", &DepGraphObj::stmt_to_outputs, /*frozen=*/true)
665+
.Field("_stmt_to_node", &DepGraphObj::stmt_to_node, /*frozen=*/true)
666+
.Field("_var_to_producer", &DepGraphObj::var_to_producer, /*frozen=*/true)
667+
.Field("_var_to_consumers", &DepGraphObj::var_to_consumers, /*frozen=*/true)
668+
.Field("_head", &DepGraphObj::head, /*frozen=*/true)
669+
.StaticFn("__init__", InitOf<DepGraphObj, Func, Func, UDict, UDict, UDict, DepNode>)
670+
.StaticFn("_init_from_stmts", InitOf<DepGraphObj, UList, UList, Func, Func>)
671+
.MemFn("clear", &DepGraphObj::Clear)
672+
.MemFn("create_node", &DepGraphObj::CreateNode)
673+
.MemFn("get_node_from_stmt", &DepGraphObj::GetNodeFromStmt)
674+
.MemFn("insert_before", &DepGraphObj::InsertBefore)
675+
.MemFn("insert_after", &DepGraphObj::InsertAfter)
676+
.MemFn("erase_node", &DepGraphObj::EraseNode)
677+
.MemFn("replace", &DepGraphObj::Replace)
678+
.MemFn("get_node_producers", &DepGraphObj::GetNodeProducers)
679+
.MemFn("get_node_consumers", &DepGraphObj::GetNodeConsumers)
680+
.MemFn("get_var_producer", &DepGraphObj::GetVarProducer)
681+
.MemFn("get_var_consumers", &DepGraphObj::GetVarConsumers)
682+
.MemFn("__str__", ::mlc::core::StringifyOpaque);
683+
684+
explicit DepGraph(Func stmt_to_inputs, Func stmt_to_outputs, UDict stmt_to_node, UDict var_to_producer,
685+
UDict var_to_consumers, DepNode head)
686+
: DepGraph(
687+
DepGraph::New(stmt_to_inputs, stmt_to_outputs, stmt_to_node, var_to_producer, var_to_consumers, head)) {}
688+
689+
explicit DepGraph(UList input_vars, UList stmts, Func stmt_to_inputs, Func stmt_to_outputs)
690+
: DepGraph(DepGraph::New(input_vars, stmts, stmt_to_inputs, stmt_to_outputs)) {}
691+
};
692+
693+
inline void DepNodeObj::Clear() {
694+
this->stmt = Any();
695+
this->prev = nullptr;
696+
this->next = nullptr;
697+
}
698+
699+
inline void DepGraphObj::Clear() {
700+
for (DepNodeObj *node = this->head.get(); node;) {
701+
DepNodeObj *next = node->next;
702+
node->Clear();
703+
node = next;
704+
}
705+
this->var_to_producer.clear();
706+
this->var_to_consumers.clear();
707+
this->stmt_to_node.clear();
708+
}
709+
710+
inline DepNode DepGraphObj::GetNodeFromStmt(Any stmt) {
711+
if (auto it = this->stmt_to_node.find(stmt); it != this->stmt_to_node.end()) {
712+
return it->second;
713+
}
714+
MLC_THROW(RuntimeError) << "Stmt not in graph: " << stmt;
715+
MLC_UNREACHABLE();
716+
}
717+
718+
inline void DepGraphObj::InsertBefore(DepNodeObj *anchor, DepNodeObj *to_insert) {
719+
if (anchor->prev == nullptr) {
720+
MLC_THROW(RuntimeError) << "Can't input before the input node: " << anchor->stmt;
721+
}
722+
if (!stmt_to_node.count(anchor->stmt)) {
723+
MLC_THROW(RuntimeError) << "Anchor node not in graph: " << anchor->stmt;
724+
}
725+
DepNodeObj *prev = anchor->prev;
726+
DepNodeObj *next = anchor;
727+
return _Insert(prev, next, to_insert);
728+
}
729+
730+
inline void DepGraphObj::InsertAfter(DepNodeObj *anchor, DepNodeObj *to_insert) {
731+
if (!stmt_to_node.count(anchor->stmt)) {
732+
MLC_THROW(RuntimeError) << "Anchor node not in graph: " << anchor->stmt;
733+
}
734+
DepNodeObj *prev = anchor;
735+
DepNodeObj *next = anchor->next;
736+
return _Insert(prev, next, to_insert);
737+
}
738+
739+
inline void DepGraphObj::EraseNode(DepNodeObj *to_erase) {
740+
// Step 1. Unlink the node from the graph
741+
if (to_erase->prev == nullptr) {
742+
MLC_THROW(RuntimeError) << "Can't erase the input node: " << to_erase->stmt;
743+
}
744+
if (!this->stmt_to_node.count(to_erase->stmt)) {
745+
MLC_THROW(RuntimeError) << "Node not in graph: " << to_erase->stmt;
746+
}
747+
this->stmt_to_node.erase(to_erase->stmt);
748+
if (to_erase->prev != nullptr) {
749+
to_erase->prev->next = to_erase->next;
750+
} else {
751+
this->head = to_erase->next;
752+
}
753+
if (to_erase->next != nullptr) {
754+
to_erase->next->prev = to_erase->prev;
755+
}
756+
// Step 2. For each variable produced by the node
757+
// 1) check all its consumers are gone
758+
// 2) remove the producer
759+
for (const Any &var : to_erase->output_vars) {
760+
UListObj *consumers = this->var_to_consumers.at(var);
761+
if (!consumers->empty()) {
762+
MLC_THROW(RuntimeError) << "Removing a node which produces a variable that still has consumers in graph: " << var;
763+
}
764+
this->var_to_producer.erase(var);
765+
this->var_to_consumers.erase(var);
766+
}
767+
// Step 3. For each varibale consumed by the node
768+
// 1) check if the var is in the graph
769+
// 2) remove the node from its consumer list
770+
for (const Any &var : to_erase->input_vars) {
771+
if (!this->var_to_producer.count(var)) {
772+
MLC_THROW(RuntimeError) << "Variable is not produced by any node in the graph: " << var;
773+
}
774+
UListObj *consumers = this->var_to_consumers.at(var);
775+
auto it = std::find_if(consumers->begin(), consumers->end(),
776+
[to_erase](const Any &v) -> bool { return v.operator DepNodeObj *() == to_erase; });
777+
if (it == consumers->end()) {
778+
MLC_THROW(RuntimeError) << "Node is not a consumer of the variable: " << var;
779+
}
780+
consumers->erase(it);
781+
}
782+
// Step 4. Clear the node
783+
to_erase->Clear();
784+
}
785+
786+
inline void DepGraphObj::Replace(DepNodeObj *old_node, DepNodeObj *new_node) {
787+
if (old_node == new_node) {
788+
return;
789+
}
790+
if (old_node->prev == nullptr) {
791+
MLC_THROW(RuntimeError) << "Can't replace the input node: " << old_node->stmt;
792+
}
793+
if (!this->stmt_to_node.count(old_node->stmt)) {
794+
MLC_THROW(RuntimeError) << "Node not in graph: " << old_node->stmt;
795+
}
796+
if (new_node->prev != nullptr || new_node->next != nullptr) {
797+
MLC_THROW(RuntimeError) << "Node is already in the graph: " << new_node->stmt;
798+
}
799+
int64_t num_output_vars = old_node->output_vars.size();
800+
if (num_output_vars != new_node->output_vars.size()) {
801+
MLC_THROW(RuntimeError) << "Mismatched number of output_vars: " << num_output_vars << " vs "
802+
<< new_node->output_vars.size();
803+
}
804+
// Step 1. Replace each variable produced by the old node
805+
for (int64_t i = 0; i < num_output_vars; ++i) {
806+
Any old_var = old_node->output_vars[i];
807+
Any new_var = new_node->output_vars[i];
808+
Ref<UListObj> old_var_consumers = var_to_consumers.at(old_var);
809+
// Search through its consumers
810+
for (DepNodeObj *consumer : *old_var_consumers) {
811+
// Replace the input vars of each consumer
812+
for (Any &v : consumer->input_vars) {
813+
if (v.operator Object *() == old_var.operator Object *()) {
814+
v = new_var;
815+
}
816+
}
817+
}
818+
this->var_to_producer.erase(old_var);
819+
this->var_to_consumers.erase(old_var);
820+
this->var_to_producer[new_var] = new_node;
821+
this->var_to_consumers[new_var] = old_var_consumers;
822+
}
823+
// Step 2. Delete each variable consumed by the old node
824+
for (const Any &var : old_node->input_vars) {
825+
UListObj *consumers = this->var_to_consumers.at(var);
826+
if (auto it = std::find_if(consumers->begin(), consumers->end(),
827+
[old_node](const Any &v) -> bool { return v.operator DepNodeObj *() == old_node; });
828+
it != consumers->end()) {
829+
consumers->erase(it);
830+
} else {
831+
MLC_THROW(RuntimeError) << "Node is not a consumer of the variable: " << var;
832+
}
833+
}
834+
// Step 3. Add variables consumed by the new node
835+
for (const Any &var : new_node->input_vars) {
836+
if (!this->var_to_producer.count(var)) {
837+
MLC_THROW(RuntimeError) << "Variable is not produced by any node in the graph: " << var;
838+
}
839+
this->var_to_consumers.at(var).operator UListObj *()->push_back(new_node);
840+
}
841+
// Step 4. Link the new node into the graph
842+
new_node->prev = old_node->prev;
843+
new_node->next = old_node->next;
844+
if (old_node->prev != nullptr) {
845+
old_node->prev->next = new_node;
846+
} else {
847+
this->head = new_node;
848+
}
849+
if (old_node->next != nullptr) {
850+
old_node->next->prev = new_node;
851+
}
852+
this->stmt_to_node.erase(old_node->stmt);
853+
if (this->stmt_to_node.count(new_node->stmt)) {
854+
MLC_THROW(RuntimeError) << "Stmt already in the graph: " << new_node->stmt;
855+
} else {
856+
this->stmt_to_node[new_node->stmt] = new_node;
857+
}
858+
// Step 5. Clear the old node
859+
old_node->Clear();
860+
}
861+
862+
inline UList DepGraphObj::GetNodeProducers(DepNodeObj *node) {
863+
UList ret;
864+
for (const Any &var : node->input_vars) {
865+
if (auto it = this->var_to_producer.find(var); it != this->var_to_producer.end()) {
866+
ret.push_back(it->second);
867+
} else {
868+
MLC_THROW(RuntimeError) << "Variable is not produced by any node in the graph: " << var;
869+
}
870+
}
871+
return ret;
872+
}
873+
874+
inline UList DepGraphObj::GetNodeConsumers(DepNodeObj *node) {
875+
UList ret;
876+
for (const Any &var : node->output_vars) {
877+
if (auto it = this->var_to_consumers.find(var); it != this->var_to_consumers.end()) {
878+
UListObj *consumers = it->second;
879+
ret.insert(ret.end(), consumers->begin(), consumers->end());
880+
} else {
881+
MLC_THROW(RuntimeError) << "Variable is not consumed by any node in the graph: " << var;
882+
}
883+
}
884+
return ret;
885+
}
886+
887+
inline DepNode DepGraphObj::GetVarProducer(Any var) {
888+
if (auto it = this->var_to_producer.find(var); it != this->var_to_producer.end()) {
889+
return it->second;
890+
}
891+
MLC_THROW(RuntimeError) << "Variable is not produced by any node in the graph: " << var;
892+
MLC_UNREACHABLE();
893+
}
894+
895+
inline UList DepGraphObj::GetVarConsumers(Any var) {
896+
if (auto it = this->var_to_consumers.find(var); it != this->var_to_consumers.end()) {
897+
return it->second;
898+
}
899+
MLC_THROW(RuntimeError) << "Variable is not consumed by any node in the graph: " << var;
900+
MLC_UNREACHABLE();
901+
}
902+
903+
inline void DepGraphObj::_Insert(DepNodeObj *prev, DepNodeObj *next, DepNodeObj *to_insert) {
904+
if (to_insert->prev != nullptr || to_insert->next != nullptr) {
905+
MLC_THROW(RuntimeError) << "Node is already in the graph: " << to_insert->stmt;
906+
}
907+
// Step 1. Link the node into the graph
908+
if (this->stmt_to_node.count(to_insert->stmt)) {
909+
MLC_THROW(RuntimeError) << "Stmt already in the graph: " << to_insert->stmt;
910+
}
911+
this->stmt_to_node[to_insert->stmt] = to_insert;
912+
to_insert->prev = prev;
913+
to_insert->next = next;
914+
if (prev != nullptr) {
915+
prev->next = to_insert;
916+
} else {
917+
this->head = to_insert;
918+
}
919+
if (next != nullptr) {
920+
next->prev = to_insert;
921+
}
922+
// Step 2. For each variable produced by the node
923+
// 1) check if it doesn't have a producer yet
924+
// 2) record its producer as this node
925+
for (const Any &var : to_insert->output_vars) {
926+
if (auto it = this->var_to_producer.find(var); it != this->var_to_producer.end()) {
927+
MLC_THROW(RuntimeError) << "Variable already has a producer by another node: "
928+
<< it->second.operator DepNode()->stmt;
929+
} else {
930+
this->var_to_producer[var] = to_insert;
931+
this->var_to_consumers[var] = UList{};
932+
}
933+
}
934+
// Step 3. For each variable consumed by the node
935+
// 1) check if the var is in the graph
936+
// 1) add a new consumer of this var
937+
for (const Any &var : to_insert->input_vars) {
938+
if (!this->var_to_producer.count(var)) {
939+
MLC_THROW(RuntimeError) << "Variable is not produced by any node in the graph: " << var;
940+
}
941+
this->var_to_consumers.at(var).operator UListObj *()->push_back(to_insert);
942+
}
943+
}
585944
} // namespace
586945
} // namespace mlc

0 commit comments

Comments
 (0)