Skip to content

feat(dataclasses): Introduce mlc.dataclasses.stringify #69

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
361 changes: 360 additions & 1 deletion cpp/c_api.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
#include "./dep_graph.h"
#include "./registry.h"
#include <mlc/core/all.h>

Expand Down Expand Up @@ -582,5 +581,365 @@ MLC_REGISTER_FUNC("mlc.testing.FieldSet").set_body([](ObjectRef root, const char
MLC_UNREACHABLE();
});

struct DepNodeObj {
MLCAny _mlc_header;
Any stmt;
UList input_vars;
UList output_vars;
DepNodeObj *prev;
DepNodeObj *next;

MLC_DEF_DYN_TYPE(MLC_EXPORTS, DepNodeObj, Object, "mlc.core.DepNode");

explicit DepNodeObj(Any stmt, UList input_vars, UList output_vars, DepNodeObj *prev, DepNodeObj *next)
: stmt(stmt), input_vars(input_vars), output_vars(output_vars), prev(prev), next(next) {}

void Clear();
Str __str__() const { return this->stmt.str(); }
};

struct DepNode : public ObjectRef {
explicit DepNode(Any stmt, UList input_vars, UList output_vars)
: DepNode(DepNode::New(stmt, input_vars, output_vars, nullptr, nullptr)) {}

MLC_DEF_OBJ_REF(MLC_EXPORTS, DepNode, DepNodeObj, ObjectRef)
.Field("stmt", &DepNodeObj::stmt, /*frozen=*/true)
.Field("input_vars", &DepNodeObj::input_vars, /*frozen=*/true)
.Field("output_vars", &DepNodeObj::output_vars, /*frozen=*/true)
._Field("_prev", offsetof(DepNodeObj, prev), sizeof(DepNodeObj *), true, ::mlc::core::ParseType<ObjectRef>())
._Field("_next", offsetof(DepNodeObj, next), sizeof(DepNodeObj *), true, ::mlc::core::ParseType<ObjectRef>())
.StaticFn("__init__", InitOf<DepNodeObj, Any, UList, UList, DepNodeObj *, DepNodeObj *>)
.MemFn("__str__", &DepNodeObj::__str__);
};

struct DepGraphObj {
MLCAny _mlc_header;
Func stmt_to_inputs;
Func stmt_to_outputs;
UDict stmt_to_node;
UDict var_to_producer;
UDict var_to_consumers;
DepNode head;

MLC_DEF_DYN_TYPE(MLC_EXPORTS, DepGraphObj, Object, "mlc.core.DepGraph");

explicit DepGraphObj(Func stmt_to_inputs, Func stmt_to_outputs, UDict stmt_to_node, UDict var_to_producer,
UDict var_to_consumers, DepNode head)
: stmt_to_inputs(stmt_to_inputs), stmt_to_outputs(stmt_to_outputs), stmt_to_node(stmt_to_node),
var_to_producer(var_to_producer), var_to_consumers(var_to_consumers), head(head) {}

explicit DepGraphObj(UList input_vars, UList stmts, Func stmt_to_inputs, Func stmt_to_outputs)
: stmt_to_inputs(stmt_to_inputs), stmt_to_outputs(stmt_to_outputs), stmt_to_node(), var_to_producer(),
var_to_consumers(), head(DepNode(Any(), UList{}, input_vars)) {
this->stmt_to_node[Any()] = this->head;
for (const Any &var : input_vars) {
this->var_to_producer[var] = this->head;
this->var_to_consumers[var] = UList{};
}
DepNodeObj *prev = this->head.get();
for (const Any &stmt : stmts) {
DepNode node = this->CreateNode(stmt);
this->InsertAfter(prev, node.get());
prev = node.get();
}
}
~DepGraphObj() { this->Clear(); }
void Clear();
DepNode CreateNode(Any stmt) { return DepNode(stmt, this->stmt_to_inputs(stmt), this->stmt_to_outputs(stmt)); }
DepNode GetNodeFromStmt(Any stmt);
void InsertBefore(DepNodeObj *anchor, DepNodeObj *to_insert);
void InsertAfter(DepNodeObj *anchor, DepNodeObj *to_insert);
void EraseNode(DepNodeObj *to_erase);
void Replace(DepNodeObj *old_node, DepNodeObj *new_node);
UList GetNodeProducers(DepNodeObj *node);
UList GetNodeConsumers(DepNodeObj *node);
DepNode GetVarProducer(Any var);
UList GetVarConsumers(Any var);
void _Insert(DepNodeObj *prev, DepNodeObj *next, DepNodeObj *to_insert);
};

struct DepGraph : public ObjectRef {
MLC_DEF_OBJ_REF(MLC_EXPORTS, DepGraph, DepGraphObj, ObjectRef)
.Field("_stmt_to_inputs", &DepGraphObj::stmt_to_inputs, /*frozen=*/true)
.Field("_stmt_to_outputs", &DepGraphObj::stmt_to_outputs, /*frozen=*/true)
.Field("_stmt_to_node", &DepGraphObj::stmt_to_node, /*frozen=*/true)
.Field("_var_to_producer", &DepGraphObj::var_to_producer, /*frozen=*/true)
.Field("_var_to_consumers", &DepGraphObj::var_to_consumers, /*frozen=*/true)
.Field("_head", &DepGraphObj::head, /*frozen=*/true)
.StaticFn("__init__", InitOf<DepGraphObj, Func, Func, UDict, UDict, UDict, DepNode>)
.StaticFn("_init_from_stmts", InitOf<DepGraphObj, UList, UList, Func, Func>)
.MemFn("clear", &DepGraphObj::Clear)
.MemFn("create_node", &DepGraphObj::CreateNode)
.MemFn("get_node_from_stmt", &DepGraphObj::GetNodeFromStmt)
.MemFn("insert_before", &DepGraphObj::InsertBefore)
.MemFn("insert_after", &DepGraphObj::InsertAfter)
.MemFn("erase_node", &DepGraphObj::EraseNode)
.MemFn("replace", &DepGraphObj::Replace)
.MemFn("get_node_producers", &DepGraphObj::GetNodeProducers)
.MemFn("get_node_consumers", &DepGraphObj::GetNodeConsumers)
.MemFn("get_var_producer", &DepGraphObj::GetVarProducer)
.MemFn("get_var_consumers", &DepGraphObj::GetVarConsumers)
.MemFn("__str__", ::mlc::core::StringifyOpaque);

explicit DepGraph(Func stmt_to_inputs, Func stmt_to_outputs, UDict stmt_to_node, UDict var_to_producer,
UDict var_to_consumers, DepNode head)
: DepGraph(
DepGraph::New(stmt_to_inputs, stmt_to_outputs, stmt_to_node, var_to_producer, var_to_consumers, head)) {}

explicit DepGraph(UList input_vars, UList stmts, Func stmt_to_inputs, Func stmt_to_outputs)
: DepGraph(DepGraph::New(input_vars, stmts, stmt_to_inputs, stmt_to_outputs)) {}
};

inline void DepNodeObj::Clear() {
this->stmt = Any();
this->prev = nullptr;
this->next = nullptr;
}

inline void DepGraphObj::Clear() {
for (DepNodeObj *node = this->head.get(); node;) {
DepNodeObj *next = node->next;
node->Clear();
node = next;
}
this->var_to_producer.clear();
this->var_to_consumers.clear();
this->stmt_to_node.clear();
}

inline DepNode DepGraphObj::GetNodeFromStmt(Any stmt) {
if (auto it = this->stmt_to_node.find(stmt); it != this->stmt_to_node.end()) {
return it->second;
}
MLC_THROW(RuntimeError) << "Stmt not in graph: " << stmt;
MLC_UNREACHABLE();
}

inline void DepGraphObj::InsertBefore(DepNodeObj *anchor, DepNodeObj *to_insert) {
if (anchor->prev == nullptr) {
MLC_THROW(RuntimeError) << "Can't input before the input node: " << anchor->stmt;
}
if (!stmt_to_node.count(anchor->stmt)) {
MLC_THROW(RuntimeError) << "Anchor node not in graph: " << anchor->stmt;
}
DepNodeObj *prev = anchor->prev;
DepNodeObj *next = anchor;
return _Insert(prev, next, to_insert);
}

inline void DepGraphObj::InsertAfter(DepNodeObj *anchor, DepNodeObj *to_insert) {
if (!stmt_to_node.count(anchor->stmt)) {
MLC_THROW(RuntimeError) << "Anchor node not in graph: " << anchor->stmt;
}
DepNodeObj *prev = anchor;
DepNodeObj *next = anchor->next;
return _Insert(prev, next, to_insert);
}

inline void DepGraphObj::EraseNode(DepNodeObj *to_erase) {
// Step 1. Unlink the node from the graph
if (to_erase->prev == nullptr) {
MLC_THROW(RuntimeError) << "Can't erase the input node: " << to_erase->stmt;
}
if (!this->stmt_to_node.count(to_erase->stmt)) {
MLC_THROW(RuntimeError) << "Node not in graph: " << to_erase->stmt;
}
this->stmt_to_node.erase(to_erase->stmt);
if (to_erase->prev != nullptr) {
to_erase->prev->next = to_erase->next;
} else {
this->head = to_erase->next;
}
if (to_erase->next != nullptr) {
to_erase->next->prev = to_erase->prev;
}
// Step 2. For each variable produced by the node
// 1) check all its consumers are gone
// 2) remove the producer
for (const Any &var : to_erase->output_vars) {
UListObj *consumers = this->var_to_consumers.at(var);
if (!consumers->empty()) {
MLC_THROW(RuntimeError) << "Removing a node which produces a variable that still has consumers in graph: " << var;
}
this->var_to_producer.erase(var);
this->var_to_consumers.erase(var);
}
// Step 3. For each varibale consumed by the node
// 1) check if the var is in the graph
// 2) remove the node from its consumer list
for (const Any &var : to_erase->input_vars) {
if (!this->var_to_producer.count(var)) {
MLC_THROW(RuntimeError) << "Variable is not produced by any node in the graph: " << var;
}
UListObj *consumers = this->var_to_consumers.at(var);
auto it = std::find_if(consumers->begin(), consumers->end(),
[to_erase](const Any &v) -> bool { return v.operator DepNodeObj *() == to_erase; });
if (it == consumers->end()) {
MLC_THROW(RuntimeError) << "Node is not a consumer of the variable: " << var;
}
consumers->erase(it);
}
// Step 4. Clear the node
to_erase->Clear();
}

inline void DepGraphObj::Replace(DepNodeObj *old_node, DepNodeObj *new_node) {
if (old_node == new_node) {
return;
}
if (old_node->prev == nullptr) {
MLC_THROW(RuntimeError) << "Can't replace the input node: " << old_node->stmt;
}
if (!this->stmt_to_node.count(old_node->stmt)) {
MLC_THROW(RuntimeError) << "Node not in graph: " << old_node->stmt;
}
if (new_node->prev != nullptr || new_node->next != nullptr) {
MLC_THROW(RuntimeError) << "Node is already in the graph: " << new_node->stmt;
}
int64_t num_output_vars = old_node->output_vars.size();
if (num_output_vars != new_node->output_vars.size()) {
MLC_THROW(RuntimeError) << "Mismatched number of output_vars: " << num_output_vars << " vs "
<< new_node->output_vars.size();
}
// Step 1. Replace each variable produced by the old node
for (int64_t i = 0; i < num_output_vars; ++i) {
Any old_var = old_node->output_vars[i];
Any new_var = new_node->output_vars[i];
Ref<UListObj> old_var_consumers = var_to_consumers.at(old_var);
// Search through its consumers
for (DepNodeObj *consumer : *old_var_consumers) {
// Replace the input vars of each consumer
for (Any &v : consumer->input_vars) {
if (v.operator Object *() == old_var.operator Object *()) {
v = new_var;
}
}
}
this->var_to_producer.erase(old_var);
this->var_to_consumers.erase(old_var);
this->var_to_producer[new_var] = new_node;
this->var_to_consumers[new_var] = old_var_consumers;
}
// Step 2. Delete each variable consumed by the old node
for (const Any &var : old_node->input_vars) {
UListObj *consumers = this->var_to_consumers.at(var);
if (auto it = std::find_if(consumers->begin(), consumers->end(),
[old_node](const Any &v) -> bool { return v.operator DepNodeObj *() == old_node; });
it != consumers->end()) {
consumers->erase(it);
} else {
MLC_THROW(RuntimeError) << "Node is not a consumer of the variable: " << var;
}
}
// Step 3. Add variables consumed by the new node
for (const Any &var : new_node->input_vars) {
if (!this->var_to_producer.count(var)) {
MLC_THROW(RuntimeError) << "Variable is not produced by any node in the graph: " << var;
}
this->var_to_consumers.at(var).operator UListObj *()->push_back(new_node);
}
// Step 4. Link the new node into the graph
new_node->prev = old_node->prev;
new_node->next = old_node->next;
if (old_node->prev != nullptr) {
old_node->prev->next = new_node;
} else {
this->head = new_node;
}
if (old_node->next != nullptr) {
old_node->next->prev = new_node;
}
this->stmt_to_node.erase(old_node->stmt);
if (this->stmt_to_node.count(new_node->stmt)) {
MLC_THROW(RuntimeError) << "Stmt already in the graph: " << new_node->stmt;
} else {
this->stmt_to_node[new_node->stmt] = new_node;
}
// Step 5. Clear the old node
old_node->Clear();
}

inline UList DepGraphObj::GetNodeProducers(DepNodeObj *node) {
UList ret;
for (const Any &var : node->input_vars) {
if (auto it = this->var_to_producer.find(var); it != this->var_to_producer.end()) {
ret.push_back(it->second);
} else {
MLC_THROW(RuntimeError) << "Variable is not produced by any node in the graph: " << var;
}
}
return ret;
}

inline UList DepGraphObj::GetNodeConsumers(DepNodeObj *node) {
UList ret;
for (const Any &var : node->output_vars) {
if (auto it = this->var_to_consumers.find(var); it != this->var_to_consumers.end()) {
UListObj *consumers = it->second;
ret.insert(ret.end(), consumers->begin(), consumers->end());
} else {
MLC_THROW(RuntimeError) << "Variable is not consumed by any node in the graph: " << var;
}
}
return ret;
}

inline DepNode DepGraphObj::GetVarProducer(Any var) {
if (auto it = this->var_to_producer.find(var); it != this->var_to_producer.end()) {
return it->second;
}
MLC_THROW(RuntimeError) << "Variable is not produced by any node in the graph: " << var;
MLC_UNREACHABLE();
}

inline UList DepGraphObj::GetVarConsumers(Any var) {
if (auto it = this->var_to_consumers.find(var); it != this->var_to_consumers.end()) {
return it->second;
}
MLC_THROW(RuntimeError) << "Variable is not consumed by any node in the graph: " << var;
MLC_UNREACHABLE();
}

inline void DepGraphObj::_Insert(DepNodeObj *prev, DepNodeObj *next, DepNodeObj *to_insert) {
if (to_insert->prev != nullptr || to_insert->next != nullptr) {
MLC_THROW(RuntimeError) << "Node is already in the graph: " << to_insert->stmt;
}
// Step 1. Link the node into the graph
if (this->stmt_to_node.count(to_insert->stmt)) {
MLC_THROW(RuntimeError) << "Stmt already in the graph: " << to_insert->stmt;
}
this->stmt_to_node[to_insert->stmt] = to_insert;
to_insert->prev = prev;
to_insert->next = next;
if (prev != nullptr) {
prev->next = to_insert;
} else {
this->head = to_insert;
}
if (next != nullptr) {
next->prev = to_insert;
}
// Step 2. For each variable produced by the node
// 1) check if it doesn't have a producer yet
// 2) record its producer as this node
for (const Any &var : to_insert->output_vars) {
if (auto it = this->var_to_producer.find(var); it != this->var_to_producer.end()) {
MLC_THROW(RuntimeError) << "Variable already has a producer by another node: "
<< it->second.operator DepNode()->stmt;
} else {
this->var_to_producer[var] = to_insert;
this->var_to_consumers[var] = UList{};
}
}
// Step 3. For each variable consumed by the node
// 1) check if the var is in the graph
// 1) add a new consumer of this var
for (const Any &var : to_insert->input_vars) {
if (!this->var_to_producer.count(var)) {
MLC_THROW(RuntimeError) << "Variable is not produced by any node in the graph: " << var;
}
this->var_to_consumers.at(var).operator UListObj *()->push_back(to_insert);
}
}
} // namespace
} // namespace mlc
Loading
Loading