@@ -26,19 +26,44 @@ def config_cpp(self):
26
26
clone_function_template = 'nnet::clone_stream<{input_t}, {output_t}, {size}>({input}, {output1}, {output2});'
27
27
clone_include_list = ['nnet_utils/nnet_stream.h' ]
28
28
29
+ class CloneThree (Layer ):
30
+ ''' Inserted after the layer whose output is used more than twice.'''
31
+
32
+ def initialize (self ):
33
+ inp = self .get_input_variable ()
34
+ self .add_output_variable (inp .shape , inp .dim_names , out_name = self .outputs [0 ], var_name = 'layer{index}_cpy1' )
35
+ self .add_output_variable (inp .shape , inp .dim_names , out_name = self .outputs [1 ], var_name = 'layer{index}_cpy2' )
36
+ self .add_output_variable (inp .shape , inp .dim_names , out_name = self .outputs [2 ], var_name = 'layer{index}_cpy3' )
37
+
38
+ def function_cpp (self ):
39
+ params = self ._default_function_params ()
40
+ params ['size' ] = self .get_attr ('size' )
41
+ params ['output1' ] = self .variables [self .outputs [0 ]].name
42
+ params ['output2' ] = self .variables [self .outputs [1 ]].name
43
+ params ['output3' ] = self .variables [self .outputs [2 ]].name
44
+ return [self ._function_template .format (** params )]
45
+
46
+ def config_cpp (self ):
47
+ return None
48
+
49
+ clonethree_function_template = 'nnet::clone_stream<{input_t}, {output_t}, {size}>({input}, {output1}, {output2}, {output3});'
50
+ clonethree_include_list = ['nnet_utils/nnet_stream.h' ]
51
+
29
52
# Register the layer types to the layer map
30
53
register_layer ('Clone' , Clone )
54
+ register_layer ('CloneThree' , CloneThree )
31
55
32
56
# Register the templates for config and function
33
57
for backend in ['Vivado' , 'VivadoAccelerator' ]:
34
58
templates .get_backend (backend ).register_templates ('Clone' , clone_function_template , None , clone_include_list )
59
+ templates .get_backend (backend ).register_templates ('CloneThree' , clonethree_function_template , None , clonethree_include_list )
35
60
36
61
37
62
class CloneOutput (OptimizerPass ):
38
63
''' Clones streams that are used multiple times '''
39
64
def match (self , node ):
40
65
# We may have already inserted the Clone layer
41
- if node .__class__ .__name__ == 'Clone' :
66
+ if node .__class__ .__name__ in [ 'Clone' , 'CloneThree' ] :
42
67
return False
43
68
44
69
return True
@@ -59,8 +84,8 @@ def transform(self, model, node):
59
84
transformed = False
60
85
for output in node .outputs :
61
86
if len (output_map [output ]) > 1 :
62
- if len (output_map [output ]) > 2 :
63
- print ('WARN: Cannot clone output {} of {} ({})' .format (output , node .__class__ .__name__ , node .name ))
87
+ if len (output_map [output ]) > 3 :
88
+ print ('WARNING: Cloning output {} of {} ({}) more than 3 times not currently supported ' .format (output , node .__class__ .__name__ , node .name ))
64
89
return False
65
90
out_var = node .get_output_variable (output )
66
91
for i , layer in enumerate (output_map [output ], 1 ):
@@ -69,7 +94,10 @@ def transform(self, model, node):
69
94
}
70
95
idx = layer .inputs .index (output )
71
96
layer .inputs [idx ] = output + '_cpy' + str (i )
72
- clone_layer = model .make_node ('Clone' , 'clone_' + node .name , attrs , [output ], [output + '_cpy1' , output + '_cpy2' ])
97
+ if len (output_map [output ]) == 3 :
98
+ clone_layer = model .make_node ('CloneThree' , 'clone_' + node .name , attrs , [output ], [output + '_cpy1' , output + '_cpy2' , output + '_cpy3' ])
99
+ else :
100
+ clone_layer = model .make_node ('Clone' , 'clone_' + node .name , attrs , [output ], [output + '_cpy1' , output + '_cpy2' ])
73
101
model .insert_node (clone_layer )
74
102
transformed = True
75
103
0 commit comments