@@ -59,10 +59,10 @@ def match_and_rewrite(self, op: riscv.ParallelMovOp, rewriter: PatternRewriter):
5959 ):
6060 raise PassFailedException ("All registers must be allocated" )
6161
62- srcs = cast (SSAValues [SSAValue [riscv .IntRegisterType ]], op .inputs )
63- dsts = cast (SSAValues [SSAValue [riscv .IntRegisterType ]], op .outputs )
64- src_types = cast ( Sequence [ riscv . IntRegisterType ], input_types )
65- dst_types = cast ( Sequence [ riscv . IntRegisterType ], output_types )
62+ srcs = cast (SSAValues [SSAValue [riscv .RISCVRegisterType ]], op .inputs )
63+ dsts = cast (SSAValues [SSAValue [riscv .RISCVRegisterType ]], op .outputs )
64+ src_types = input_types
65+ dst_types = output_types
6666
6767 # make a list of free registers for each type so we can add to it later
6868 free_registers : dict [
@@ -95,7 +95,7 @@ def match_and_rewrite(self, op: riscv.ParallelMovOp, rewriter: PatternRewriter):
9595
9696 # store the back edges of the graph
9797 src_by_dst_type : dict [
98- riscv .IntRegisterType , SSAValue [riscv .IntRegisterType ]
98+ riscv .RISCVRegisterType , SSAValue [riscv .RISCVRegisterType ]
9999 ] = {}
100100 leaves = set (dst_types )
101101 unprocessed_children = Counter [SSAValue ]()
@@ -159,6 +159,9 @@ def match_and_rewrite(self, op: riscv.ParallelMovOp, rewriter: PatternRewriter):
159159 inp = src_by_dst_type [out .type ]
160160
161161 while inp .type != out .type :
162+ # we know these are ints since input and output are of the same type
163+ inp = cast (SSAValue [riscv .IntRegisterType ], inp )
164+ out = cast (SSAValue [riscv .IntRegisterType ], out )
162165 nw_out , nw_inp = _insert_swap_ops (rewriter , inp , out )
163166 # after the swap, the input is in the right place, the input's input
164167 # needs to be moved to the new output
0 commit comments