@@ -143,19 +143,14 @@ partitioning::GraphAndMapping BuildHybridGraph(
143
143
auto convert_info = cfg.convert_info ;
144
144
auto partitioning_info = cfg.partitioning_info ;
145
145
146
- // Any nonzero block size is valid if full compilation to TRT is desired
147
- if (expect_full_compilation) {
148
- partitioning_info.min_block_size = 1 ;
149
- }
150
-
151
146
auto partitioning_ctx = partitioning::PartitioningCtx (block, partitioning_info);
152
147
partitioning_ctx.input_types_map = first_use_types;
153
148
154
149
// Generate a dictionary of input torch::jit::Value's to their min, opt, max tensors and store in ctx
155
150
// TODO: Combine this within partition call
156
151
partitioning::populateInputIValues (&partitioning_ctx);
157
152
158
- partitioning::partition (&partitioning_ctx);
153
+ partitioning::partition (&partitioning_ctx, expect_full_compilation );
159
154
160
155
for (auto & partitioned_block : partitioning_ctx.partitioned_blocks ) {
161
156
partitioning::PartitionedGraph& segmented_blocks = partitioned_block.second ;
@@ -197,9 +192,11 @@ partitioning::GraphAndMapping BuildHybridGraph(
197
192
if (expect_full_compilation) {
198
193
for (auto torch_node : seg_block.block ()->nodes ()) {
199
194
if (partitioning::CollectionNodeKinds.find (torch_node->kind ()) == partitioning::CollectionNodeKinds.end ()) {
200
- LOG_ERROR (
201
- " Full compilation specified but node " << torch_node->kind ().toQualString ()
202
- << " was executed in Torch." );
195
+ TORCHTRT_THROW_ERROR (
196
+ " Full compilation specified but node "
197
+ << *torch_node
198
+ << " is set to run in PyTorch due to either lack of support in TensorRT or graph partitioning rules."
199
+ << " Try recompiling with require_full_compilation=False." );
203
200
}
204
201
}
205
202
}
@@ -209,10 +206,9 @@ partitioning::GraphAndMapping BuildHybridGraph(
209
206
// If full compilation is expected, cannot have more than 2 Torch segments
210
207
// (one for preprocessing inputs, one for post-processing outputs) and 1 TRT segment
211
208
if (expect_full_compilation && !(num_torch_segments <= 2 && num_trt_segments == 1 )) {
212
- LOG_ERROR (
213
- " Full compilation specified but number of torch segments was "
214
- << num_torch_segments << " and number of trt segments was " << num_trt_segments
215
- << " . Was expecting at most 2 Torch segments and 1 TRT segment." );
209
+ TORCHTRT_THROW_ERROR (
210
+ " Full compilation was requested but unable to convert all operations to TensorRT."
211
+ << " Try recompiling with require_full_compilation=False." );
216
212
}
217
213
}
218
214
@@ -224,7 +220,7 @@ ir::TypeMap MapInputsAndDetermineDTypes(
224
220
std::shared_ptr<torch::jit::Graph>& g,
225
221
ir::StaticParams& static_params,
226
222
ir::CollectionTypeMap& first_use_type_map,
227
- bool expect_full_compilation = false ) {
223
+ bool requires_collection_handling = false ) {
228
224
cfg.convert_info .collection_input_spec_map =
229
225
std::move (ir::associate_specs_with_collection_inputs (g, cfg.graph_inputs , static_params));
230
226
cfg.partitioning_info .collection_input_spec_map =
@@ -259,7 +255,7 @@ ir::TypeMap MapInputsAndDetermineDTypes(
259
255
" Cannot infer input type from calcuations in graph for input "
260
256
<< in->debugName () << " . Assuming it is Float32. If not, specify input type explicity" );
261
257
spec[i].dtype = at::kFloat ;
262
- } else if (spec[i].dtype_is_user_defined && (cfg.partitioning_info .enabled || expect_full_compilation )) {
258
+ } else if (spec[i].dtype_is_user_defined && (cfg.partitioning_info .enabled || requires_collection_handling )) {
263
259
if (!est_type_opt[i]) {
264
260
LOG_INFO (" Cannot infer input tensor dtype in graph, compiler is going to use the user setting" );
265
261
std::stringstream ss;
@@ -330,6 +326,11 @@ std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, std::
330
326
return engine;
331
327
}
332
328
329
+ bool userRequestedFallback (CompileSpec& cfg) {
330
+ return cfg.lower_info .forced_fallback_modules .size () != 0 ||
331
+ cfg.partitioning_info .forced_fallback_operators .size () != 0 ;
332
+ }
333
+
333
334
torch::jit::Module CompileGraph (const torch::jit::Module& mod, CompileSpec cfg) {
334
335
torch::jit::Module new_mod (mod._ivalue ()->name () + " _trt" );
335
336
@@ -352,10 +353,13 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg)
352
353
// whether full compilation can be expected
353
354
auto isBlockConvertible = conversion::VerifyConverterSupportForBlock (g->block (), true );
354
355
auto outputIsCollection = conversion::OutputIsCollection (g->block ());
355
- auto nearly_full_compilation = (isBlockConvertible && outputIsCollection);
356
+ auto requires_collection_handling = (isBlockConvertible && outputIsCollection);
357
+
358
+ // Determine whether user specifications necessitate partitioning
359
+ auto isFallbackRequested = userRequestedFallback (cfg);
356
360
357
361
// Extract map of IValue to DType
358
- auto type_map = MapInputsAndDetermineDTypes (cfg, g, static_params, first_use_types, nearly_full_compilation );
362
+ auto type_map = MapInputsAndDetermineDTypes (cfg, g, static_params, first_use_types, requires_collection_handling );
359
363
360
364
// Check whether any of the input types are Long
361
365
bool user_requested_long = false ;
@@ -369,21 +373,26 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg)
369
373
user_requested_long &= (casts_inserted > 0 );
370
374
}
371
375
372
- if (cfg.partitioning_info .enabled && !user_requested_long &&
373
- (cfg.lower_info .forced_fallback_modules .size () == 0 &&
374
- cfg.partitioning_info .forced_fallback_operators .size () == 0 && isBlockConvertible) &&
375
- !outputIsCollection) {
376
+ // Partitioning is required if:
377
+ // 1. User requested some modules/operators fallback
378
+ // 2. The block (graph) cannot be converted due to operator coverage
379
+ // 3. The output of the graph is a collection
380
+ // 4. The user requested a non-TRT data type input
381
+ auto isPartitioningRequired =
382
+ (isFallbackRequested || !isBlockConvertible || outputIsCollection || user_requested_long);
383
+
384
+ // The user did not require full compilation, but the model can be fully compiled
385
+ if (cfg.partitioning_info .enabled && !isPartitioningRequired) {
376
386
LOG_INFO (" Skipping partitioning since model is fully supported" );
377
387
}
378
388
379
- if ((cfg.partitioning_info .enabled &&
380
- (!(cfg.lower_info .forced_fallback_modules .size () == 0 &&
381
- cfg.partitioning_info .forced_fallback_operators .size () == 0 && isBlockConvertible) ||
382
- outputIsCollection || user_requested_long)) ||
383
- nearly_full_compilation) {
389
+ // The user did not require full compilation, and the model can be fully compiled
390
+ // or, the user required full compilation but the I/O of the graph use collections
391
+ if ((cfg.partitioning_info .enabled && isPartitioningRequired) || requires_collection_handling) {
384
392
// If the model is fully-compilable and the user has specified full compilation, run partitioning
385
393
// to generate collection-processing code in Torch
386
- auto expect_full_compilation = (nearly_full_compilation && !cfg.partitioning_info .enabled );
394
+ auto expect_full_compilation = (requires_collection_handling && !cfg.partitioning_info .enabled );
395
+
387
396
auto graph_and_mapping =
388
397
BuildHybridGraph (new_mod, g->block (), cfg, static_params, first_use_types, expect_full_compilation);
389
398
new_g = graph_and_mapping.first ;
0 commit comments