|
1 |
| -#include "./dep_graph.h" |
2 | 1 | #include "./registry.h"
|
3 | 2 | #include <mlc/core/all.h>
|
4 | 3 |
|
@@ -582,5 +581,365 @@ MLC_REGISTER_FUNC("mlc.testing.FieldSet").set_body([](ObjectRef root, const char
|
582 | 581 | MLC_UNREACHABLE();
|
583 | 582 | });
|
584 | 583 |
|
| 584 | +struct DepNodeObj { |
| 585 | + MLCAny _mlc_header; |
| 586 | + Any stmt; |
| 587 | + UList input_vars; |
| 588 | + UList output_vars; |
| 589 | + DepNodeObj *prev; |
| 590 | + DepNodeObj *next; |
| 591 | + |
| 592 | + MLC_DEF_DYN_TYPE(MLC_EXPORTS, DepNodeObj, Object, "mlc.core.DepNode"); |
| 593 | + |
| 594 | + explicit DepNodeObj(Any stmt, UList input_vars, UList output_vars, DepNodeObj *prev, DepNodeObj *next) |
| 595 | + : stmt(stmt), input_vars(input_vars), output_vars(output_vars), prev(prev), next(next) {} |
| 596 | + |
| 597 | + void Clear(); |
| 598 | + Str __str__() const { return this->stmt.str(); } |
| 599 | +}; |
| 600 | + |
| 601 | +struct DepNode : public ObjectRef { |
| 602 | + explicit DepNode(Any stmt, UList input_vars, UList output_vars) |
| 603 | + : DepNode(DepNode::New(stmt, input_vars, output_vars, nullptr, nullptr)) {} |
| 604 | + |
| 605 | + MLC_DEF_OBJ_REF(MLC_EXPORTS, DepNode, DepNodeObj, ObjectRef) |
| 606 | + .Field("stmt", &DepNodeObj::stmt, /*frozen=*/true) |
| 607 | + .Field("input_vars", &DepNodeObj::input_vars, /*frozen=*/true) |
| 608 | + .Field("output_vars", &DepNodeObj::output_vars, /*frozen=*/true) |
| 609 | + ._Field("_prev", offsetof(DepNodeObj, prev), sizeof(DepNodeObj *), true, ::mlc::core::ParseType<ObjectRef>()) |
| 610 | + ._Field("_next", offsetof(DepNodeObj, next), sizeof(DepNodeObj *), true, ::mlc::core::ParseType<ObjectRef>()) |
| 611 | + .StaticFn("__init__", InitOf<DepNodeObj, Any, UList, UList, DepNodeObj *, DepNodeObj *>) |
| 612 | + .MemFn("__str__", &DepNodeObj::__str__); |
| 613 | +}; |
| 614 | + |
| 615 | +struct DepGraphObj { |
| 616 | + MLCAny _mlc_header; |
| 617 | + Func stmt_to_inputs; |
| 618 | + Func stmt_to_outputs; |
| 619 | + UDict stmt_to_node; |
| 620 | + UDict var_to_producer; |
| 621 | + UDict var_to_consumers; |
| 622 | + DepNode head; |
| 623 | + |
| 624 | + MLC_DEF_DYN_TYPE(MLC_EXPORTS, DepGraphObj, Object, "mlc.core.DepGraph"); |
| 625 | + |
| 626 | + explicit DepGraphObj(Func stmt_to_inputs, Func stmt_to_outputs, UDict stmt_to_node, UDict var_to_producer, |
| 627 | + UDict var_to_consumers, DepNode head) |
| 628 | + : stmt_to_inputs(stmt_to_inputs), stmt_to_outputs(stmt_to_outputs), stmt_to_node(stmt_to_node), |
| 629 | + var_to_producer(var_to_producer), var_to_consumers(var_to_consumers), head(head) {} |
| 630 | + |
| 631 | + explicit DepGraphObj(UList input_vars, UList stmts, Func stmt_to_inputs, Func stmt_to_outputs) |
| 632 | + : stmt_to_inputs(stmt_to_inputs), stmt_to_outputs(stmt_to_outputs), stmt_to_node(), var_to_producer(), |
| 633 | + var_to_consumers(), head(DepNode(Any(), UList{}, input_vars)) { |
| 634 | + this->stmt_to_node[Any()] = this->head; |
| 635 | + for (const Any &var : input_vars) { |
| 636 | + this->var_to_producer[var] = this->head; |
| 637 | + this->var_to_consumers[var] = UList{}; |
| 638 | + } |
| 639 | + DepNodeObj *prev = this->head.get(); |
| 640 | + for (const Any &stmt : stmts) { |
| 641 | + DepNode node = this->CreateNode(stmt); |
| 642 | + this->InsertAfter(prev, node.get()); |
| 643 | + prev = node.get(); |
| 644 | + } |
| 645 | + } |
| 646 | + ~DepGraphObj() { this->Clear(); } |
| 647 | + void Clear(); |
| 648 | + DepNode CreateNode(Any stmt) { return DepNode(stmt, this->stmt_to_inputs(stmt), this->stmt_to_outputs(stmt)); } |
| 649 | + DepNode GetNodeFromStmt(Any stmt); |
| 650 | + void InsertBefore(DepNodeObj *anchor, DepNodeObj *to_insert); |
| 651 | + void InsertAfter(DepNodeObj *anchor, DepNodeObj *to_insert); |
| 652 | + void EraseNode(DepNodeObj *to_erase); |
| 653 | + void Replace(DepNodeObj *old_node, DepNodeObj *new_node); |
| 654 | + UList GetNodeProducers(DepNodeObj *node); |
| 655 | + UList GetNodeConsumers(DepNodeObj *node); |
| 656 | + DepNode GetVarProducer(Any var); |
| 657 | + UList GetVarConsumers(Any var); |
| 658 | + void _Insert(DepNodeObj *prev, DepNodeObj *next, DepNodeObj *to_insert); |
| 659 | +}; |
| 660 | + |
| 661 | +struct DepGraph : public ObjectRef { |
| 662 | + MLC_DEF_OBJ_REF(MLC_EXPORTS, DepGraph, DepGraphObj, ObjectRef) |
| 663 | + .Field("_stmt_to_inputs", &DepGraphObj::stmt_to_inputs, /*frozen=*/true) |
| 664 | + .Field("_stmt_to_outputs", &DepGraphObj::stmt_to_outputs, /*frozen=*/true) |
| 665 | + .Field("_stmt_to_node", &DepGraphObj::stmt_to_node, /*frozen=*/true) |
| 666 | + .Field("_var_to_producer", &DepGraphObj::var_to_producer, /*frozen=*/true) |
| 667 | + .Field("_var_to_consumers", &DepGraphObj::var_to_consumers, /*frozen=*/true) |
| 668 | + .Field("_head", &DepGraphObj::head, /*frozen=*/true) |
| 669 | + .StaticFn("__init__", InitOf<DepGraphObj, Func, Func, UDict, UDict, UDict, DepNode>) |
| 670 | + .StaticFn("_init_from_stmts", InitOf<DepGraphObj, UList, UList, Func, Func>) |
| 671 | + .MemFn("clear", &DepGraphObj::Clear) |
| 672 | + .MemFn("create_node", &DepGraphObj::CreateNode) |
| 673 | + .MemFn("get_node_from_stmt", &DepGraphObj::GetNodeFromStmt) |
| 674 | + .MemFn("insert_before", &DepGraphObj::InsertBefore) |
| 675 | + .MemFn("insert_after", &DepGraphObj::InsertAfter) |
| 676 | + .MemFn("erase_node", &DepGraphObj::EraseNode) |
| 677 | + .MemFn("replace", &DepGraphObj::Replace) |
| 678 | + .MemFn("get_node_producers", &DepGraphObj::GetNodeProducers) |
| 679 | + .MemFn("get_node_consumers", &DepGraphObj::GetNodeConsumers) |
| 680 | + .MemFn("get_var_producer", &DepGraphObj::GetVarProducer) |
| 681 | + .MemFn("get_var_consumers", &DepGraphObj::GetVarConsumers) |
| 682 | + .MemFn("__str__", ::mlc::core::StringifyOpaque); |
| 683 | + |
| 684 | + explicit DepGraph(Func stmt_to_inputs, Func stmt_to_outputs, UDict stmt_to_node, UDict var_to_producer, |
| 685 | + UDict var_to_consumers, DepNode head) |
| 686 | + : DepGraph( |
| 687 | + DepGraph::New(stmt_to_inputs, stmt_to_outputs, stmt_to_node, var_to_producer, var_to_consumers, head)) {} |
| 688 | + |
| 689 | + explicit DepGraph(UList input_vars, UList stmts, Func stmt_to_inputs, Func stmt_to_outputs) |
| 690 | + : DepGraph(DepGraph::New(input_vars, stmts, stmt_to_inputs, stmt_to_outputs)) {} |
| 691 | +}; |
| 692 | + |
| 693 | +inline void DepNodeObj::Clear() { |
| 694 | + this->stmt = Any(); |
| 695 | + this->prev = nullptr; |
| 696 | + this->next = nullptr; |
| 697 | +} |
| 698 | + |
| 699 | +inline void DepGraphObj::Clear() { |
| 700 | + for (DepNodeObj *node = this->head.get(); node;) { |
| 701 | + DepNodeObj *next = node->next; |
| 702 | + node->Clear(); |
| 703 | + node = next; |
| 704 | + } |
| 705 | + this->var_to_producer.clear(); |
| 706 | + this->var_to_consumers.clear(); |
| 707 | + this->stmt_to_node.clear(); |
| 708 | +} |
| 709 | + |
| 710 | +inline DepNode DepGraphObj::GetNodeFromStmt(Any stmt) { |
| 711 | + if (auto it = this->stmt_to_node.find(stmt); it != this->stmt_to_node.end()) { |
| 712 | + return it->second; |
| 713 | + } |
| 714 | + MLC_THROW(RuntimeError) << "Stmt not in graph: " << stmt; |
| 715 | + MLC_UNREACHABLE(); |
| 716 | +} |
| 717 | + |
| 718 | +inline void DepGraphObj::InsertBefore(DepNodeObj *anchor, DepNodeObj *to_insert) { |
| 719 | + if (anchor->prev == nullptr) { |
| 720 | + MLC_THROW(RuntimeError) << "Can't input before the input node: " << anchor->stmt; |
| 721 | + } |
| 722 | + if (!stmt_to_node.count(anchor->stmt)) { |
| 723 | + MLC_THROW(RuntimeError) << "Anchor node not in graph: " << anchor->stmt; |
| 724 | + } |
| 725 | + DepNodeObj *prev = anchor->prev; |
| 726 | + DepNodeObj *next = anchor; |
| 727 | + return _Insert(prev, next, to_insert); |
| 728 | +} |
| 729 | + |
| 730 | +inline void DepGraphObj::InsertAfter(DepNodeObj *anchor, DepNodeObj *to_insert) { |
| 731 | + if (!stmt_to_node.count(anchor->stmt)) { |
| 732 | + MLC_THROW(RuntimeError) << "Anchor node not in graph: " << anchor->stmt; |
| 733 | + } |
| 734 | + DepNodeObj *prev = anchor; |
| 735 | + DepNodeObj *next = anchor->next; |
| 736 | + return _Insert(prev, next, to_insert); |
| 737 | +} |
| 738 | + |
| 739 | +inline void DepGraphObj::EraseNode(DepNodeObj *to_erase) { |
| 740 | + // Step 1. Unlink the node from the graph |
| 741 | + if (to_erase->prev == nullptr) { |
| 742 | + MLC_THROW(RuntimeError) << "Can't erase the input node: " << to_erase->stmt; |
| 743 | + } |
| 744 | + if (!this->stmt_to_node.count(to_erase->stmt)) { |
| 745 | + MLC_THROW(RuntimeError) << "Node not in graph: " << to_erase->stmt; |
| 746 | + } |
| 747 | + this->stmt_to_node.erase(to_erase->stmt); |
| 748 | + if (to_erase->prev != nullptr) { |
| 749 | + to_erase->prev->next = to_erase->next; |
| 750 | + } else { |
| 751 | + this->head = to_erase->next; |
| 752 | + } |
| 753 | + if (to_erase->next != nullptr) { |
| 754 | + to_erase->next->prev = to_erase->prev; |
| 755 | + } |
| 756 | + // Step 2. For each variable produced by the node |
| 757 | + // 1) check all its consumers are gone |
| 758 | + // 2) remove the producer |
| 759 | + for (const Any &var : to_erase->output_vars) { |
| 760 | + UListObj *consumers = this->var_to_consumers.at(var); |
| 761 | + if (!consumers->empty()) { |
| 762 | + MLC_THROW(RuntimeError) << "Removing a node which produces a variable that still has consumers in graph: " << var; |
| 763 | + } |
| 764 | + this->var_to_producer.erase(var); |
| 765 | + this->var_to_consumers.erase(var); |
| 766 | + } |
| 767 | + // Step 3. For each varibale consumed by the node |
| 768 | + // 1) check if the var is in the graph |
| 769 | + // 2) remove the node from its consumer list |
| 770 | + for (const Any &var : to_erase->input_vars) { |
| 771 | + if (!this->var_to_producer.count(var)) { |
| 772 | + MLC_THROW(RuntimeError) << "Variable is not produced by any node in the graph: " << var; |
| 773 | + } |
| 774 | + UListObj *consumers = this->var_to_consumers.at(var); |
| 775 | + auto it = std::find_if(consumers->begin(), consumers->end(), |
| 776 | + [to_erase](const Any &v) -> bool { return v.operator DepNodeObj *() == to_erase; }); |
| 777 | + if (it == consumers->end()) { |
| 778 | + MLC_THROW(RuntimeError) << "Node is not a consumer of the variable: " << var; |
| 779 | + } |
| 780 | + consumers->erase(it); |
| 781 | + } |
| 782 | + // Step 4. Clear the node |
| 783 | + to_erase->Clear(); |
| 784 | +} |
| 785 | + |
| 786 | +inline void DepGraphObj::Replace(DepNodeObj *old_node, DepNodeObj *new_node) { |
| 787 | + if (old_node == new_node) { |
| 788 | + return; |
| 789 | + } |
| 790 | + if (old_node->prev == nullptr) { |
| 791 | + MLC_THROW(RuntimeError) << "Can't replace the input node: " << old_node->stmt; |
| 792 | + } |
| 793 | + if (!this->stmt_to_node.count(old_node->stmt)) { |
| 794 | + MLC_THROW(RuntimeError) << "Node not in graph: " << old_node->stmt; |
| 795 | + } |
| 796 | + if (new_node->prev != nullptr || new_node->next != nullptr) { |
| 797 | + MLC_THROW(RuntimeError) << "Node is already in the graph: " << new_node->stmt; |
| 798 | + } |
| 799 | + int64_t num_output_vars = old_node->output_vars.size(); |
| 800 | + if (num_output_vars != new_node->output_vars.size()) { |
| 801 | + MLC_THROW(RuntimeError) << "Mismatched number of output_vars: " << num_output_vars << " vs " |
| 802 | + << new_node->output_vars.size(); |
| 803 | + } |
| 804 | + // Step 1. Replace each variable produced by the old node |
| 805 | + for (int64_t i = 0; i < num_output_vars; ++i) { |
| 806 | + Any old_var = old_node->output_vars[i]; |
| 807 | + Any new_var = new_node->output_vars[i]; |
| 808 | + Ref<UListObj> old_var_consumers = var_to_consumers.at(old_var); |
| 809 | + // Search through its consumers |
| 810 | + for (DepNodeObj *consumer : *old_var_consumers) { |
| 811 | + // Replace the input vars of each consumer |
| 812 | + for (Any &v : consumer->input_vars) { |
| 813 | + if (v.operator Object *() == old_var.operator Object *()) { |
| 814 | + v = new_var; |
| 815 | + } |
| 816 | + } |
| 817 | + } |
| 818 | + this->var_to_producer.erase(old_var); |
| 819 | + this->var_to_consumers.erase(old_var); |
| 820 | + this->var_to_producer[new_var] = new_node; |
| 821 | + this->var_to_consumers[new_var] = old_var_consumers; |
| 822 | + } |
| 823 | + // Step 2. Delete each variable consumed by the old node |
| 824 | + for (const Any &var : old_node->input_vars) { |
| 825 | + UListObj *consumers = this->var_to_consumers.at(var); |
| 826 | + if (auto it = std::find_if(consumers->begin(), consumers->end(), |
| 827 | + [old_node](const Any &v) -> bool { return v.operator DepNodeObj *() == old_node; }); |
| 828 | + it != consumers->end()) { |
| 829 | + consumers->erase(it); |
| 830 | + } else { |
| 831 | + MLC_THROW(RuntimeError) << "Node is not a consumer of the variable: " << var; |
| 832 | + } |
| 833 | + } |
| 834 | + // Step 3. Add variables consumed by the new node |
| 835 | + for (const Any &var : new_node->input_vars) { |
| 836 | + if (!this->var_to_producer.count(var)) { |
| 837 | + MLC_THROW(RuntimeError) << "Variable is not produced by any node in the graph: " << var; |
| 838 | + } |
| 839 | + this->var_to_consumers.at(var).operator UListObj *()->push_back(new_node); |
| 840 | + } |
| 841 | + // Step 4. Link the new node into the graph |
| 842 | + new_node->prev = old_node->prev; |
| 843 | + new_node->next = old_node->next; |
| 844 | + if (old_node->prev != nullptr) { |
| 845 | + old_node->prev->next = new_node; |
| 846 | + } else { |
| 847 | + this->head = new_node; |
| 848 | + } |
| 849 | + if (old_node->next != nullptr) { |
| 850 | + old_node->next->prev = new_node; |
| 851 | + } |
| 852 | + this->stmt_to_node.erase(old_node->stmt); |
| 853 | + if (this->stmt_to_node.count(new_node->stmt)) { |
| 854 | + MLC_THROW(RuntimeError) << "Stmt already in the graph: " << new_node->stmt; |
| 855 | + } else { |
| 856 | + this->stmt_to_node[new_node->stmt] = new_node; |
| 857 | + } |
| 858 | + // Step 5. Clear the old node |
| 859 | + old_node->Clear(); |
| 860 | +} |
| 861 | + |
| 862 | +inline UList DepGraphObj::GetNodeProducers(DepNodeObj *node) { |
| 863 | + UList ret; |
| 864 | + for (const Any &var : node->input_vars) { |
| 865 | + if (auto it = this->var_to_producer.find(var); it != this->var_to_producer.end()) { |
| 866 | + ret.push_back(it->second); |
| 867 | + } else { |
| 868 | + MLC_THROW(RuntimeError) << "Variable is not produced by any node in the graph: " << var; |
| 869 | + } |
| 870 | + } |
| 871 | + return ret; |
| 872 | +} |
| 873 | + |
| 874 | +inline UList DepGraphObj::GetNodeConsumers(DepNodeObj *node) { |
| 875 | + UList ret; |
| 876 | + for (const Any &var : node->output_vars) { |
| 877 | + if (auto it = this->var_to_consumers.find(var); it != this->var_to_consumers.end()) { |
| 878 | + UListObj *consumers = it->second; |
| 879 | + ret.insert(ret.end(), consumers->begin(), consumers->end()); |
| 880 | + } else { |
| 881 | + MLC_THROW(RuntimeError) << "Variable is not consumed by any node in the graph: " << var; |
| 882 | + } |
| 883 | + } |
| 884 | + return ret; |
| 885 | +} |
| 886 | + |
| 887 | +inline DepNode DepGraphObj::GetVarProducer(Any var) { |
| 888 | + if (auto it = this->var_to_producer.find(var); it != this->var_to_producer.end()) { |
| 889 | + return it->second; |
| 890 | + } |
| 891 | + MLC_THROW(RuntimeError) << "Variable is not produced by any node in the graph: " << var; |
| 892 | + MLC_UNREACHABLE(); |
| 893 | +} |
| 894 | + |
| 895 | +inline UList DepGraphObj::GetVarConsumers(Any var) { |
| 896 | + if (auto it = this->var_to_consumers.find(var); it != this->var_to_consumers.end()) { |
| 897 | + return it->second; |
| 898 | + } |
| 899 | + MLC_THROW(RuntimeError) << "Variable is not consumed by any node in the graph: " << var; |
| 900 | + MLC_UNREACHABLE(); |
| 901 | +} |
| 902 | + |
| 903 | +inline void DepGraphObj::_Insert(DepNodeObj *prev, DepNodeObj *next, DepNodeObj *to_insert) { |
| 904 | + if (to_insert->prev != nullptr || to_insert->next != nullptr) { |
| 905 | + MLC_THROW(RuntimeError) << "Node is already in the graph: " << to_insert->stmt; |
| 906 | + } |
| 907 | + // Step 1. Link the node into the graph |
| 908 | + if (this->stmt_to_node.count(to_insert->stmt)) { |
| 909 | + MLC_THROW(RuntimeError) << "Stmt already in the graph: " << to_insert->stmt; |
| 910 | + } |
| 911 | + this->stmt_to_node[to_insert->stmt] = to_insert; |
| 912 | + to_insert->prev = prev; |
| 913 | + to_insert->next = next; |
| 914 | + if (prev != nullptr) { |
| 915 | + prev->next = to_insert; |
| 916 | + } else { |
| 917 | + this->head = to_insert; |
| 918 | + } |
| 919 | + if (next != nullptr) { |
| 920 | + next->prev = to_insert; |
| 921 | + } |
| 922 | + // Step 2. For each variable produced by the node |
| 923 | + // 1) check if it doesn't have a producer yet |
| 924 | + // 2) record its producer as this node |
| 925 | + for (const Any &var : to_insert->output_vars) { |
| 926 | + if (auto it = this->var_to_producer.find(var); it != this->var_to_producer.end()) { |
| 927 | + MLC_THROW(RuntimeError) << "Variable already has a producer by another node: " |
| 928 | + << it->second.operator DepNode()->stmt; |
| 929 | + } else { |
| 930 | + this->var_to_producer[var] = to_insert; |
| 931 | + this->var_to_consumers[var] = UList{}; |
| 932 | + } |
| 933 | + } |
| 934 | + // Step 3. For each variable consumed by the node |
| 935 | + // 1) check if the var is in the graph |
| 936 | + // 1) add a new consumer of this var |
| 937 | + for (const Any &var : to_insert->input_vars) { |
| 938 | + if (!this->var_to_producer.count(var)) { |
| 939 | + MLC_THROW(RuntimeError) << "Variable is not produced by any node in the graph: " << var; |
| 940 | + } |
| 941 | + this->var_to_consumers.at(var).operator UListObj *()->push_back(to_insert); |
| 942 | + } |
| 943 | +} |
585 | 944 | } // namespace
|
586 | 945 | } // namespace mlc
|
0 commit comments