11
11
12
12
#include " torch/csrc/jit/frontend/function_schema_parser.h"
13
13
#include " torch/csrc/jit/ir/ir.h"
14
- #include " torch/csrc/jit/ir/ir_views.h"
15
14
#include " torch/csrc/jit/passes/graph_fuser.h"
16
15
#include " torch/csrc/jit/passes/loop_unrolling.h"
17
16
#include " torch/csrc/jit/passes/lower_graph.h"
@@ -128,179 +127,54 @@ bool CheckMethodOperatorSupport(const torch::jit::script::Module& mod, std::stri
128
127
return conversion::VerifyConverterSupportForBlock (g->block ());
129
128
}
130
129
131
- void AddSegmentedBlockToGraph (
132
- std::shared_ptr<torch::jit::Graph>& g,
133
- partitioning::SegmentedBlock& seg,
134
- std::unordered_map<torch::jit::Value*, torch::jit::Value*>& old_to_new_g) {
135
- // old_to_new_g contains: original global graph value => new global graph value,
136
- // mini_to_new_g: mini graph value -> new graph value
137
- std::unordered_map<torch::jit::Value*, torch::jit::Value*> mini_to_new_g;
138
- size_t input_idx = 0 ;
139
- if (seg.target () == partitioning::SegmentedBlock::kTensorRT && g->inputs ().size () > 0 ) {
140
- if (g->inputs ()[0 ]->type ()->str ().find (" __torch__" ) == std::string::npos) {
141
- auto self = g->insertInput (0 , " self_1" );
142
- self->setType (seg.inputs ()[0 ]->type ());
143
- }
144
- mini_to_new_g[seg.inputs ()[input_idx++]] = g->inputs ()[0 ];
145
- }
146
-
147
- for (auto & raw_input : seg.raw_inputs ()) {
148
- if (old_to_new_g.count (raw_input)) {
149
- mini_to_new_g[seg.inputs ()[input_idx++]] = old_to_new_g[raw_input];
150
- }
151
- }
152
-
153
- for (const auto n : seg.nodes ()) {
154
- util::cloneNode (n, g, mini_to_new_g);
155
- }
156
-
157
- // original graph value => new global graph value
158
- for (size_t i = 0 ; i < seg.raw_outputs ().size (); ++i) {
159
- old_to_new_g[seg.raw_outputs ()[i]] = mini_to_new_g[seg.outputs ()[i]];
160
- }
161
- size_t offset = seg.target () == partitioning::SegmentedBlock::kTensorRT ? 1 : 0 ;
162
- for (size_t i = 0 ; i < seg.raw_inputs ().size (); ++i) {
163
- if (!old_to_new_g.count (seg.raw_inputs ()[i])) {
164
- old_to_new_g[seg.raw_inputs ()[i]] = mini_to_new_g[seg.inputs ()[i + offset]];
165
- }
166
- }
167
-
168
- return ;
169
- }
170
-
171
- typedef std::pair<std::shared_ptr<torch::jit::Graph>, std::unordered_map<torch::jit::Value*, torch::jit::Value*>>
172
- GraphAndMapping;
173
-
174
- void AddIfBlockToGraph (
175
- std::shared_ptr<torch::jit::Graph>& new_g,
176
- torch::jit::Node* if_node,
177
- const std::vector<GraphAndMapping>& graph_and_mappings,
178
- std::unordered_map<torch::jit::Value*, torch::jit::Value*>& old_to_new_g) {
179
- torch::jit::IfView if_view (if_node);
180
-
181
- // create a new if node in new_g and add corresponding inputs
182
- auto new_if = new_g->insertNode (new_g->create (torch::jit::prim::If, {}, 0 ));
183
- new_if->addInput (util::getOrAddInputForValue (if_view.cond (), new_g, old_to_new_g));
184
-
185
- // iterate over all blocks and add them to new created prim::If
186
- for (auto graph_and_mapping : graph_and_mappings) {
187
- auto new_if_block = new_if->addBlock ();
188
- auto cur_block_graph = graph_and_mapping.first ;
189
- auto cur_block_mapping = graph_and_mapping.second ;
190
- std::unordered_map<torch::jit::Value*, torch::jit::Value*> block_graph_to_new_g;
191
- for (auto & i : cur_block_mapping) {
192
- // for every pair in then_mapping, old_value => mini graph value, if old_value also appears in old_to_new_g, then
193
- // it's mini graph's input
194
- if (old_to_new_g.count (i.first )) {
195
- block_graph_to_new_g[i.second ] = old_to_new_g[i.first ];
196
- }
197
- }
198
-
199
- auto env = [&](torch::jit::Value* v) { return util::getOrAddInputForValue (v, new_g, block_graph_to_new_g); };
200
- new_if_block->cloneFrom (cur_block_graph->block (), env);
201
- if (cur_block_graph->inputs ().size () &&
202
- cur_block_graph->inputs ()[0 ]->type ()->str ().find (" __torch__" ) != std::string::npos) {
203
- if (new_g->inputs ()[0 ]->type ()->str ().find (" __torch__" ) == std::string::npos) {
204
- auto self = new_g->insertInput (0 , " self_1" );
205
- self->setType (cur_block_graph->inputs ()[0 ]->type ());
206
- }
207
- block_graph_to_new_g[cur_block_graph->inputs ()[0 ]] = new_g->inputs ()[0 ];
208
- }
209
- for (int i = cur_block_graph->inputs ().size () - 1 ; i >= 0 ; --i) {
210
- new_if_block->inputs ()[i]->replaceAllUsesWith (block_graph_to_new_g[cur_block_graph->inputs ()[i]]);
211
- new_if_block->eraseInput (i);
212
- }
213
- }
214
- for (auto ov : if_view.outputs ()) {
215
- auto no = new_if->addOutput ();
216
- old_to_new_g[ov] = no;
217
- no->copyMetadata (ov);
218
- }
219
- return ;
220
- }
221
-
222
- GraphAndMapping ConstructFallbackGraph (
130
+ partitioning::GraphAndMapping BuildHybridGraph (
223
131
torch::jit::script::Module& new_mod,
224
132
torch::jit::Block* block,
225
- std::unordered_map<const torch::jit::Value*, torch::jit::IValue> example_tensor_map,
226
133
CompileSpec cfg,
227
134
ir::StaticParams static_params,
228
- std::unordered_map<torch::jit::Node*, int >& fallback_nodes) {
229
- auto convert_cfg = cfg.convert_info ;
230
- auto partition_info = cfg.partition_info ;
231
-
232
- auto new_g = std::make_shared<torch::jit::Graph>();
233
-
234
- auto segmented_blocks = partitioning::Partition (block, example_tensor_map, partition_info, fallback_nodes);
235
-
236
- // the mapping from lowering graph => fallback global graph
237
- std::unordered_map<torch::jit::Value*, torch::jit::Value*> old_to_new_g;
238
- for (auto input : block->inputs ()) {
239
- util::getOrAddInputForValue (input, new_g, old_to_new_g);
240
- }
241
-
242
- for (auto & seg_block : segmented_blocks) {
243
- LOG_INFO (seg_block << " (GraphInSegmentedBlock)\n " );
244
- std::ostringstream trt_engine_id;
245
- trt_engine_id << reinterpret_cast <const int *>(&seg_block);
246
-
247
- if (seg_block.target () == partitioning::SegmentedBlock::kTensorRT ) {
248
- auto shapes = seg_block.in_shapes ();
249
- auto types = seg_block.in_types ();
250
- std::vector<ir::Input> inputs;
251
- for (size_t i = 0 ; i < shapes.size (); i++) {
252
- auto in = ir::Input (shapes[i]);
253
- in.dtype = util::ScalarTypeToTRTDataType (types[i]);
254
- inputs.push_back (in);
255
- }
256
- // update the input ranges for each segments
257
- convert_cfg.inputs = ir::associate_specs_with_inputs (seg_block.g (), inputs, static_params);
258
-
259
- // TODO mapping Inputs Ivalue to flatten one here
260
- auto engine = conversion::ConvertBlockToEngine (seg_block.block (), convert_cfg, static_params);
261
- auto temp_g = std::make_shared<torch::jit::Graph>();
262
- auto device_spec = convert_cfg.engine_settings .device ;
263
- auto cuda_device = runtime::CudaDevice (device_spec.gpu_id , device_spec.device_type );
264
- AddEngineToGraph (new_mod, temp_g, engine, cuda_device, trt_engine_id.str (), true );
265
-
266
- seg_block.update_graph (temp_g);
267
- AddSegmentedBlockToGraph (new_g, seg_block, old_to_new_g);
268
- } else {
269
- if (seg_block.raw_nodes ()[0 ]->kind () == torch::jit::prim::If) {
270
- auto if_node = seg_block.raw_nodes ()[0 ];
271
-
272
- // convert the 2 blocks in prim::if and get the converted graph with mappings
273
- std::vector<GraphAndMapping> graph_and_mappings;
274
- for (auto cur_block : if_node->blocks ()) {
275
- graph_and_mappings.push_back (
276
- ConstructFallbackGraph (new_mod, cur_block, example_tensor_map, cfg, static_params, fallback_nodes));
135
+ ir::CollectionTypeMap first_use_types) {
136
+ auto convert_info = cfg.convert_info ;
137
+ auto partitioning_info = cfg.partitioning_info ;
138
+
139
+ auto partitioning_ctx = partitioning::PartitioningCtx (block, partitioning_info);
140
+ auto collection_input_ivalues_map =
141
+ partitioning::generateRandomInputs (partitioning_info.collection_input_spec_map , first_use_types);
142
+
143
+ partitioning::partition (&partitioning_ctx, collection_input_ivalues_map);
144
+
145
+ for (auto & partitioned_block : partitioning_ctx.partitioned_blocks ) {
146
+ partitioning::PartitionedGraph& segmented_blocks = partitioned_block.second ;
147
+
148
+ for (auto & seg_block : segmented_blocks) {
149
+ LOG_INFO (" Block segment:" << seg_block);
150
+ std::ostringstream trt_engine_id;
151
+ trt_engine_id << reinterpret_cast <const int *>(&seg_block);
152
+
153
+ if (seg_block.target () == partitioning::SegmentedBlock::kTensorRT ) {
154
+ auto shapes = seg_block.in_shapes ();
155
+ auto types = seg_block.in_types ();
156
+ std::vector<ir::Input> inputs;
157
+ for (size_t i = 0 ; i < shapes.size (); i++) {
158
+ auto in = ir::Input (shapes[i]);
159
+ in.dtype = util::ScalarTypeToTRTDataType (types[i]);
160
+ inputs.push_back (in);
277
161
}
278
- AddIfBlockToGraph (new_g, if_node, graph_and_mappings, old_to_new_g);
162
+ // update the input ranges for each segments
163
+ convert_info.inputs = ir::associate_specs_with_inputs (seg_block.g (), inputs, static_params);
279
164
280
- } else {
281
- AddSegmentedBlockToGraph (new_g, seg_block, old_to_new_g);
282
- }
283
- }
284
- }
165
+ // TODO mapping Inputs Ivalue to flatten one here
166
+ auto engine = conversion::ConvertBlockToEngine (seg_block.block (), convert_info, static_params);
167
+ auto temp_g = std::make_shared<torch::jit::Graph>();
168
+ auto device_spec = convert_info.engine_settings .device ;
169
+ auto cuda_device = runtime::CudaDevice (device_spec.gpu_id , device_spec.device_type );
170
+ AddEngineToGraph (new_mod, temp_g, engine, cuda_device, trt_engine_id.str (), true );
285
171
286
- if (block->outputs ().size () > 1 ) {
287
- std::vector<torch::jit::Value*> fallback_graph_vector;
288
- for (auto & output : block->outputs ()) {
289
- if (old_to_new_g.count (output)) {
290
- fallback_graph_vector.push_back (old_to_new_g[output]);
172
+ seg_block.update_graph (temp_g);
291
173
}
292
174
}
293
- torch::jit::ArrayRef<torch::jit::Value*> fallback_graph_outputs (fallback_graph_vector);
294
- auto return_tuple_node = new_g->createTuple (fallback_graph_outputs);
295
- new_g->block ()->appendNode (return_tuple_node);
296
- // Set the output as the produced tuple
297
- new_g->registerOutput (return_tuple_node->outputs ()[0 ]);
298
- } else {
299
- if (block->outputs ().size () && old_to_new_g.count (block->outputs ()[0 ])) {
300
- new_g->registerOutput (old_to_new_g[block->outputs ()[0 ]]);
301
- }
302
175
}
303
- return {new_g, old_to_new_g};
176
+
177
+ return partitioning::stitch (&partitioning_ctx, block);
304
178
}
305
179
306
180
void MapInputsAndDetermineDTypes (
@@ -310,6 +184,8 @@ void MapInputsAndDetermineDTypes(
310
184
ir::CollectionTypeMap& first_use_type_map) {
311
185
cfg.convert_info .collection_input_spec_map =
312
186
std::move (ir::associate_specs_with_collection_inputs (g, cfg.graph_inputs , static_params));
187
+ cfg.partitioning_info .collection_input_spec_map =
188
+ ir::CollectionInputSpecMap (cfg.convert_info .collection_input_spec_map );
313
189
314
190
auto collection_inputs = ir::get_collection_inputs (g, static_params);
315
191
LOG_DEBUG (
@@ -339,7 +215,7 @@ void MapInputsAndDetermineDTypes(
339
215
" Cannot infer input type from calcuations in graph for input "
340
216
<< in->debugName () << " . Assuming it is Float32. If not, specify input type explicity" );
341
217
spec[i].dtype = nvinfer1::DataType::kFLOAT ;
342
- } else if (spec[i].dtype_is_user_defined && cfg.partition_info .enabled ) {
218
+ } else if (spec[i].dtype_is_user_defined && cfg.partitioning_info .enabled ) {
343
219
if (!est_type_opt[i]) {
344
220
LOG_INFO (" Cannot infer input tensor dtype in graph, compiler is going to use the user setting" );
345
221
std::stringstream ss;
@@ -424,22 +300,18 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg)
424
300
MapInputsAndDetermineDTypes (cfg, g, static_params, first_use_types);
425
301
auto isBlockConvertible = conversion::VerifyConverterSupportForBlock (g->block (), true );
426
302
auto outputIsCollection = conversion::OutputIsCollection (g->block ());
427
- if (cfg.partition_info .enabled &&
303
+ if (cfg.partitioning_info .enabled &&
428
304
(cfg.lower_info .forced_fallback_modules .size () == 0 &&
429
- cfg.partition_info .forced_fallback_operators .size () == 0 && isBlockConvertible) &&
305
+ cfg.partitioning_info .forced_fallback_operators .size () == 0 && isBlockConvertible) &&
430
306
!outputIsCollection) {
431
307
LOG_INFO (" Skipping partitioning since model is fully supported" );
432
308
}
433
309
434
- if (cfg.partition_info .enabled &&
310
+ if (cfg.partitioning_info .enabled &&
435
311
(!(cfg.lower_info .forced_fallback_modules .size () == 0 &&
436
- cfg.partition_info .forced_fallback_operators .size () == 0 && isBlockConvertible) ||
312
+ cfg.partitioning_info .forced_fallback_operators .size () == 0 && isBlockConvertible) ||
437
313
outputIsCollection)) {
438
- std::unordered_map<torch::jit::Node*, int > fallback_nodes;
439
- auto collection_input_ivalues_map =
440
- partitioning::generateRandomInputs (cfg.convert_info .collection_input_spec_map , first_use_types);
441
- auto graph_and_mapping = ConstructFallbackGraph (
442
- new_mod, g->block (), collection_input_ivalues_map, cfg, static_params, fallback_nodes);
314
+ auto graph_and_mapping = BuildHybridGraph (new_mod, g->block (), cfg, static_params, first_use_types);
443
315
new_g = graph_and_mapping.first ;
444
316
// renaming the input name of graph after fallback to ensure pytorch deserialize it correctly
445
317
for (size_t i = 0 ; i < new_g->inputs ().size (); ++i) {
0 commit comments