8
8
9
9
import copy
10
10
import operator
11
+ from collections import defaultdict
11
12
from typing import Any , Dict , List , Optional , Set , Tuple , Union
12
13
13
14
import torch
@@ -488,8 +489,12 @@ def _get_new_signature( # noqa: C901
488
489
else {}
489
490
)
490
491
492
+ toplevel_output_node_to_sig : Dict [str , List [OutputSpec ]] = defaultdict (list )
493
+ if not is_submodule :
494
+ for output_spec in old_signature .output_specs :
495
+ toplevel_output_node_to_sig [output_spec .arg .name ].append (output_spec )
496
+
491
497
for node in gm .graph .nodes :
492
- is_tagged = tag is None or node .meta .get ("delegation_tag" , None ) == tag
493
498
if node .op == "placeholder" :
494
499
495
500
if node .name not in input_node_to_sig :
@@ -507,7 +512,7 @@ def _get_new_signature( # noqa: C901
507
512
if not isinstance (orig_input_spec .arg , TensorArgument ):
508
513
input_specs .append (orig_input_spec )
509
514
510
- elif is_tagged :
515
+ elif node . meta . get ( "delegation_tag" , None ) == tag :
511
516
input_specs .append (orig_input_spec )
512
517
513
518
if orig_input_spec .kind == InputKind .USER_INPUT :
@@ -551,11 +556,67 @@ def _get_new_signature( # noqa: C901
551
556
)
552
557
553
558
if node .op == "output" :
554
- output_nodes = pytree .tree_leaves ((node .args , node .kwargs ))
555
-
556
- for output_node in output_nodes :
559
+ buffer_mutation_idxs : Dict [int , List [OutputSpec ]] = defaultdict (list )
560
+ for user in call_module_node .users .keys ():
561
+ if user .name in toplevel_output_node_to_sig :
562
+ assert (
563
+ user .op == "call_function" and user .target == operator .getitem
564
+ ), f"Invalid user { user } , node.op is { user .op } and node.target is { user .target } "
565
+ getitem_idx = user .args [1 ]
566
+ assert isinstance (
567
+ getitem_idx , int
568
+ ), f"Invalid getitem type: { type (getitem_idx )} "
569
+ buffer_mutation_idxs [getitem_idx ].extend (
570
+ toplevel_output_node_to_sig [user .name ]
571
+ )
557
572
558
- if not isinstance (output_node , torch .fx .Node ):
573
+ for i , output_node in enumerate (node .args [0 ]):
574
+ if i in buffer_mutation_idxs :
575
+ assert isinstance (output_node , torch .fx .Node )
576
+ orig_output_specs = buffer_mutation_idxs [i ]
577
+
578
+ for orig_output_spec in orig_output_specs :
579
+
580
+ if (
581
+ orig_output_spec .kind == OutputKind .BUFFER_MUTATION
582
+ and orig_output_spec .target in new_state_dict
583
+ ):
584
+ # If the delegate wants to consume the buffer, then
585
+ # the delegate should also consume the buffer
586
+ # mutation (output spec would be a BUFFER_MUTATION).
587
+ # Otherwise the delegate will just return the result
588
+ # of the mutation as a USER_OUTPUT.
589
+
590
+ assert len (orig_output_specs ) == 1 , (
591
+ f"Constant { orig_output_spec .target } was tagged to be "
592
+ "consumed by the buffer, and was found to also contain "
593
+ "a buffer mutation. However this buffer mutation node "
594
+ "was found to also be used as other types of outputs "
595
+ "which is currently not supported. Please file an "
596
+ "issue on Github. \n \n "
597
+ f"The toplevel program: { original_program } \n "
598
+ )
599
+ output_specs .append (
600
+ OutputSpec (
601
+ kind = OutputKind .BUFFER_MUTATION ,
602
+ arg = TensorArgument (name = output_node .name ),
603
+ target = orig_output_spec .target ,
604
+ )
605
+ )
606
+ output_specs_to_delete [orig_output_spec .arg .name ] = (
607
+ orig_output_spec
608
+ )
609
+
610
+ else :
611
+ output_specs .append (
612
+ OutputSpec (
613
+ kind = OutputKind .USER_OUTPUT ,
614
+ arg = TensorArgument (name = output_node .name ),
615
+ target = None ,
616
+ )
617
+ )
618
+
619
+ elif not isinstance (output_node , torch .fx .Node ):
559
620
output_specs .append (
560
621
OutputSpec (
561
622
kind = OutputKind .USER_OUTPUT ,
@@ -774,7 +835,7 @@ def get_lowered_backend_modules(
774
835
return lowered_programs
775
836
776
837
777
- def _unsafe_adjust_original_program (
838
+ def _unsafe_adjust_original_program ( # noqa: C901
778
839
original_program : ExportedProgram ,
779
840
call_delegate_node : torch .fx .Node ,
780
841
input_specs_to_delete : Dict [str , InputSpec ],
@@ -830,3 +891,50 @@ def _unsafe_adjust_original_program(
830
891
del original_program ._constants [input_spec .target ]
831
892
else :
832
893
raise RuntimeError (f"Invalid input spec { input_spec } received" )
894
+
895
+ # Delete buffer mutations from the output which were consumed by the delegate
896
+ toplevel_output_node = None
897
+ for node in reversed (original_program .graph .nodes ):
898
+ if node .op == "output" :
899
+ toplevel_output_node = node
900
+ break
901
+
902
+ assert toplevel_output_node is not None
903
+ assert (
904
+ len (toplevel_output_node .args ) == 1
905
+ ), f"Invalid output node: { toplevel_output_node } with args { toplevel_output_node .args } "
906
+
907
+ new_output_args = [
908
+ arg
909
+ for arg in toplevel_output_node .args [0 ]
910
+ if not isinstance (arg , torch .fx .Node ) or arg .name not in output_specs_to_delete
911
+ ]
912
+ toplevel_output_node .args = (tuple (new_output_args ),)
913
+
914
+ # Delete the buffer mutation getitem nodes
915
+ getitem_idxs : List [int ] = []
916
+ user_nodes = list (call_delegate_node .users .keys ())
917
+ for user in user_nodes :
918
+ if user .name in output_specs_to_delete :
919
+ assert (
920
+ user .op == "call_function" and user .target == operator .getitem
921
+ ), f"Invalid user { user } , node.op is { node .op } and node.target is { node .target } "
922
+ user_idx = user .args [1 ]
923
+ assert isinstance (user_idx , int ), f"Invalid getitem type: { type (user_idx )} "
924
+ getitem_idxs .append (user_idx )
925
+ original_program .graph .erase_node (user )
926
+
927
+ getitem_idxs .sort (reverse = True )
928
+
929
+ # Adjust all the getitem indices after the deleted getitems
930
+ user_nodes = list (call_delegate_node .users .keys ())
931
+ for user in user_nodes :
932
+ assert user .op == "call_function" and user .target == operator .getitem
933
+ user_idx = user .args [1 ]
934
+ assert isinstance (user_idx , int )
935
+ for i , idx in enumerate (getitem_idxs ):
936
+ if user_idx > idx :
937
+ user .args = (user .args [0 ], user_idx - (len (getitem_idxs ) - i ))
938
+ break
939
+
940
+ original_program ._validate ()
0 commit comments