@@ -138,10 +138,16 @@ partitioning::GraphAndMapping BuildHybridGraph(
138
138
torch::jit::Block* block,
139
139
CompileSpec cfg,
140
140
ir::StaticParams static_params,
141
- ir::CollectionTypeMap first_use_types) {
141
+ ir::CollectionTypeMap first_use_types,
142
+ bool expect_full_compilation = false ) {
142
143
auto convert_info = cfg.convert_info ;
143
144
auto partitioning_info = cfg.partitioning_info ;
144
145
146
+ // Any nonzero block size is valid if full compilation to TRT is desired
147
+ if (expect_full_compilation) {
148
+ partitioning_info.min_block_size = 1 ;
149
+ }
150
+
145
151
auto partitioning_ctx = partitioning::PartitioningCtx (block, partitioning_info);
146
152
partitioning_ctx.input_types_map = first_use_types;
147
153
@@ -153,13 +159,16 @@ partitioning::GraphAndMapping BuildHybridGraph(
153
159
154
160
for (auto & partitioned_block : partitioning_ctx.partitioned_blocks ) {
155
161
partitioning::PartitionedGraph& segmented_blocks = partitioned_block.second ;
162
+ int num_torch_segments = 0 ;
163
+ int num_trt_segments = 0 ;
156
164
157
165
for (auto & seg_block : segmented_blocks) {
158
166
LOG_INFO (" Block segment:" << seg_block);
159
167
std::ostringstream trt_engine_id;
160
168
trt_engine_id << reinterpret_cast <const int *>(&seg_block);
161
169
162
170
if (seg_block.target () == partitioning::SegmentedBlock::kTensorRT ) {
171
+ num_trt_segments++;
163
172
auto inputs = seg_block.construct_inputs_spec ();
164
173
// update the input ranges for each segments
165
174
convert_info.inputs = ir::associate_specs_with_inputs (seg_block.g (), inputs, static_params);
@@ -180,8 +189,32 @@ partitioning::GraphAndMapping BuildHybridGraph(
180
189
true );
181
190
182
191
seg_block.update_graph (temp_g);
192
+ } else {
193
+ num_torch_segments++;
194
+
195
+ // If full compilation is expected, ensure that all operators in Torch blocks are
196
+ // for collections processing
197
+ if (expect_full_compilation) {
198
+ for (auto torch_node : seg_block.block ()->nodes ()) {
199
+ if (partitioning::CollectionSchemas.find (torch_node->kind ().toQualString ()) ==
200
+ partitioning::CollectionSchemas.end ()) {
201
+ LOG_WARNING (
202
+ " Full compilation specified but node " << torch_node->kind ().toQualString ()
203
+ << " was executed in Torch." );
204
+ }
205
+ }
206
+ }
183
207
}
184
208
}
209
+
210
+ // If full compilation is expected, cannot have more than 2 Torch segments
211
+ // (one for preprocessing inputs, one for post-processing outputs) and 1 TRT segment
212
+ if (expect_full_compilation && !(num_torch_segments <= 2 && num_trt_segments == 1 )) {
213
+ LOG_WARNING (
214
+ " Full compilation specified but number of torch segments was "
215
+ << num_torch_segments << " and number of trt segments was " << num_trt_segments
216
+ << " . Was expecting at most 2 Torch segments and 1 TRT segment." );
217
+ }
185
218
}
186
219
187
220
return partitioning::stitch (&partitioning_ctx, block);
@@ -191,7 +224,8 @@ ir::TypeMap MapInputsAndDetermineDTypes(
191
224
CompileSpec& cfg,
192
225
std::shared_ptr<torch::jit::Graph>& g,
193
226
ir::StaticParams& static_params,
194
- ir::CollectionTypeMap& first_use_type_map) {
227
+ ir::CollectionTypeMap& first_use_type_map,
228
+ bool expect_full_compilation = false ) {
195
229
cfg.convert_info .collection_input_spec_map =
196
230
std::move (ir::associate_specs_with_collection_inputs (g, cfg.graph_inputs , static_params));
197
231
cfg.partitioning_info .collection_input_spec_map =
@@ -226,7 +260,7 @@ ir::TypeMap MapInputsAndDetermineDTypes(
226
260
" Cannot infer input type from calcuations in graph for input "
227
261
<< in->debugName () << " . Assuming it is Float32. If not, specify input type explicity" );
228
262
spec[i].dtype = at::kFloat ;
229
- } else if (spec[i].dtype_is_user_defined && cfg.partitioning_info .enabled ) {
263
+ } else if (spec[i].dtype_is_user_defined && ( cfg.partitioning_info .enabled || expect_full_compilation) ) {
230
264
if (!est_type_opt[i]) {
231
265
LOG_INFO (" Cannot infer input tensor dtype in graph, compiler is going to use the user setting" );
232
266
std::stringstream ss;
@@ -315,8 +349,14 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg)
315
349
// Infer the type of an input from the weights of the calculation
316
350
auto first_use_types = ir::get_block_first_calc_dtypes_opt_collection (g->block ());
317
351
352
+ // Determine if the block is convertible/has collection output, and based on the result,
353
+ // whether full compilation can be expected
354
+ auto isBlockConvertible = conversion::VerifyConverterSupportForBlock (g->block (), true );
355
+ auto outputIsCollection = conversion::OutputIsCollection (g->block ());
356
+ auto nearly_full_compilation = (isBlockConvertible && outputIsCollection);
357
+
318
358
// Extract map of IValue to DType
319
- auto type_map = MapInputsAndDetermineDTypes (cfg, g, static_params, first_use_types);
359
+ auto type_map = MapInputsAndDetermineDTypes (cfg, g, static_params, first_use_types, nearly_full_compilation );
320
360
321
361
// Check whether any of the input types are Long
322
362
bool user_requested_long = false ;
@@ -330,20 +370,23 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg)
330
370
user_requested_long &= (casts_inserted > 0 );
331
371
}
332
372
333
- auto isBlockConvertible = conversion::VerifyConverterSupportForBlock (g->block (), true );
334
- auto outputIsCollection = conversion::OutputIsCollection (g->block ());
335
373
if (cfg.partitioning_info .enabled && !user_requested_long &&
336
374
(cfg.lower_info .forced_fallback_modules .size () == 0 &&
337
375
cfg.partitioning_info .forced_fallback_operators .size () == 0 && isBlockConvertible) &&
338
376
!outputIsCollection) {
339
377
LOG_INFO (" Skipping partitioning since model is fully supported" );
340
378
}
341
379
342
- if (cfg.partitioning_info .enabled &&
343
- (!(cfg.lower_info .forced_fallback_modules .size () == 0 &&
344
- cfg.partitioning_info .forced_fallback_operators .size () == 0 && isBlockConvertible) ||
345
- outputIsCollection || user_requested_long)) {
346
- auto graph_and_mapping = BuildHybridGraph (new_mod, g->block (), cfg, static_params, first_use_types);
380
+ if ((cfg.partitioning_info .enabled &&
381
+ (!(cfg.lower_info .forced_fallback_modules .size () == 0 &&
382
+ cfg.partitioning_info .forced_fallback_operators .size () == 0 && isBlockConvertible) ||
383
+ outputIsCollection || user_requested_long)) ||
384
+ nearly_full_compilation) {
385
+ // If the model is fully-compilable and the user has specified full compilation, run partitioning
386
+ // to generate collection-processing code in Torch
387
+ auto expect_full_compilation = (nearly_full_compilation && !cfg.partitioning_info .enabled );
388
+ auto graph_and_mapping =
389
+ BuildHybridGraph (new_mod, g->block (), cfg, static_params, first_use_types, expect_full_compilation);
347
390
new_g = graph_and_mapping.first ;
348
391
// renaming the input name of graph after fallback to ensure pytorch deserialize it correctly
349
392
for (size_t i = 0 ; i < new_g->inputs ().size (); ++i) {
0 commit comments