Skip to content

Commit 655ce22

Browse files
committed
fix: Allow full model compilation with collection outputs
- Update graph-building in compiler to account for case where all operations are supported by Torch-TRT, but the output is a collection. - Enable 'psuedo-partitioning' for nearly-fully-compiled models for which the only non-supported aspect of the model is the format of the output (TRT cannot output complex collections) - Define a small subset of operation schemas which are allowed despite the flag `require_full_compilation`. These operations are packing and unpacking of Tuples/Lists, and some are already used in cases of `require_full_compilation` - Display warnings to users if any portion of the `pseudo-partitioning` is unexpected, for example the model being partitioned ends up in more than 3 segments (maximally - a Torch segment to preprocess collection inputs, a TRT segment to perform model logic, a Torch segment to post-process collection outputs) or if schemas falling outside of the collection subset are encountered in a Torch segment - Add end-to-end test case with minimal reproducing example of a failing model, repaired with the changes to the compiler - Add minor fix to lowering to remediate c++ compiler warning
1 parent b2a5da6 commit 655ce22

File tree

4 files changed

+100
-12
lines changed

4 files changed

+100
-12
lines changed

core/compiler.cpp

Lines changed: 54 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -138,10 +138,16 @@ partitioning::GraphAndMapping BuildHybridGraph(
138138
torch::jit::Block* block,
139139
CompileSpec cfg,
140140
ir::StaticParams static_params,
141-
ir::CollectionTypeMap first_use_types) {
141+
ir::CollectionTypeMap first_use_types,
142+
bool expect_full_compilation = false) {
142143
auto convert_info = cfg.convert_info;
143144
auto partitioning_info = cfg.partitioning_info;
144145

146+
// Any nonzero block size is valid if full compilation to TRT is desired
147+
if (expect_full_compilation) {
148+
partitioning_info.min_block_size = 1;
149+
}
150+
145151
auto partitioning_ctx = partitioning::PartitioningCtx(block, partitioning_info);
146152
partitioning_ctx.input_types_map = first_use_types;
147153

@@ -153,13 +159,16 @@ partitioning::GraphAndMapping BuildHybridGraph(
153159

154160
for (auto& partitioned_block : partitioning_ctx.partitioned_blocks) {
155161
partitioning::PartitionedGraph& segmented_blocks = partitioned_block.second;
162+
int num_torch_segments = 0;
163+
int num_trt_segments = 0;
156164

157165
for (auto& seg_block : segmented_blocks) {
158166
LOG_INFO("Block segment:" << seg_block);
159167
std::ostringstream trt_engine_id;
160168
trt_engine_id << reinterpret_cast<const int*>(&seg_block);
161169

162170
if (seg_block.target() == partitioning::SegmentedBlock::kTensorRT) {
171+
num_trt_segments++;
163172
auto inputs = seg_block.construct_inputs_spec();
164173
// update the input ranges for each segments
165174
convert_info.inputs = ir::associate_specs_with_inputs(seg_block.g(), inputs, static_params);
@@ -180,8 +189,32 @@ partitioning::GraphAndMapping BuildHybridGraph(
180189
true);
181190

182191
seg_block.update_graph(temp_g);
192+
} else {
193+
num_torch_segments++;
194+
195+
// If full compilation is expected, ensure that all operators in Torch blocks are
196+
// for collections processing
197+
if (expect_full_compilation) {
198+
for (auto torch_node : seg_block.block()->nodes()) {
199+
if (partitioning::CollectionSchemas.find(torch_node->kind().toQualString()) ==
200+
partitioning::CollectionSchemas.end()) {
201+
LOG_WARNING(
202+
"Full compilation specified but node " << torch_node->kind().toQualString()
203+
<< " was executed in Torch.");
204+
}
205+
}
206+
}
183207
}
184208
}
209+
210+
// If full compilation is expected, cannot have more than 2 Torch segments
211+
// (one for preprocessing inputs, one for post-processing outputs) and 1 TRT segment
212+
if (expect_full_compilation && !(num_torch_segments <= 2 && num_trt_segments == 1)) {
213+
LOG_WARNING(
214+
"Full compilation specified but number of torch segments was "
215+
<< num_torch_segments << " and number of trt segments was " << num_trt_segments
216+
<< ". Was expecting at most 2 Torch segments and 1 TRT segment.");
217+
}
185218
}
186219

187220
return partitioning::stitch(&partitioning_ctx, block);
@@ -191,7 +224,8 @@ ir::TypeMap MapInputsAndDetermineDTypes(
191224
CompileSpec& cfg,
192225
std::shared_ptr<torch::jit::Graph>& g,
193226
ir::StaticParams& static_params,
194-
ir::CollectionTypeMap& first_use_type_map) {
227+
ir::CollectionTypeMap& first_use_type_map,
228+
bool expect_full_compilation = false) {
195229
cfg.convert_info.collection_input_spec_map =
196230
std::move(ir::associate_specs_with_collection_inputs(g, cfg.graph_inputs, static_params));
197231
cfg.partitioning_info.collection_input_spec_map =
@@ -226,7 +260,7 @@ ir::TypeMap MapInputsAndDetermineDTypes(
226260
"Cannot infer input type from calcuations in graph for input "
227261
<< in->debugName() << ". Assuming it is Float32. If not, specify input type explicity");
228262
spec[i].dtype = at::kFloat;
229-
} else if (spec[i].dtype_is_user_defined && cfg.partitioning_info.enabled) {
263+
} else if (spec[i].dtype_is_user_defined && (cfg.partitioning_info.enabled || expect_full_compilation)) {
230264
if (!est_type_opt[i]) {
231265
LOG_INFO("Cannot infer input tensor dtype in graph, compiler is going to use the user setting");
232266
std::stringstream ss;
@@ -315,8 +349,14 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg)
315349
// Infer the type of an input from the weights of the calculation
316350
auto first_use_types = ir::get_block_first_calc_dtypes_opt_collection(g->block());
317351

352+
// Determine if the block is convertible/has collection output, and based on the result,
353+
// whether full compilation can be expected
354+
auto isBlockConvertible = conversion::VerifyConverterSupportForBlock(g->block(), true);
355+
auto outputIsCollection = conversion::OutputIsCollection(g->block());
356+
auto nearly_full_compilation = (isBlockConvertible && outputIsCollection);
357+
318358
// Extract map of IValue to DType
319-
auto type_map = MapInputsAndDetermineDTypes(cfg, g, static_params, first_use_types);
359+
auto type_map = MapInputsAndDetermineDTypes(cfg, g, static_params, first_use_types, nearly_full_compilation);
320360

321361
// Check whether any of the input types are Long
322362
bool user_requested_long = false;
@@ -330,20 +370,23 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg)
330370
user_requested_long &= (casts_inserted > 0);
331371
}
332372

333-
auto isBlockConvertible = conversion::VerifyConverterSupportForBlock(g->block(), true);
334-
auto outputIsCollection = conversion::OutputIsCollection(g->block());
335373
if (cfg.partitioning_info.enabled && !user_requested_long &&
336374
(cfg.lower_info.forced_fallback_modules.size() == 0 &&
337375
cfg.partitioning_info.forced_fallback_operators.size() == 0 && isBlockConvertible) &&
338376
!outputIsCollection) {
339377
LOG_INFO("Skipping partitioning since model is fully supported");
340378
}
341379

342-
if (cfg.partitioning_info.enabled &&
343-
(!(cfg.lower_info.forced_fallback_modules.size() == 0 &&
344-
cfg.partitioning_info.forced_fallback_operators.size() == 0 && isBlockConvertible) ||
345-
outputIsCollection || user_requested_long)) {
346-
auto graph_and_mapping = BuildHybridGraph(new_mod, g->block(), cfg, static_params, first_use_types);
380+
if ((cfg.partitioning_info.enabled &&
381+
(!(cfg.lower_info.forced_fallback_modules.size() == 0 &&
382+
cfg.partitioning_info.forced_fallback_operators.size() == 0 && isBlockConvertible) ||
383+
outputIsCollection || user_requested_long)) ||
384+
nearly_full_compilation) {
385+
// If the model is fully-compilable and the user has specified full compilation, run partitioning
386+
// to generate collection-processing code in Torch
387+
auto expect_full_compilation = (nearly_full_compilation && !cfg.partitioning_info.enabled);
388+
auto graph_and_mapping =
389+
BuildHybridGraph(new_mod, g->block(), cfg, static_params, first_use_types, expect_full_compilation);
347390
new_g = graph_and_mapping.first;
348391
// renaming the input name of graph after fallback to ensure pytorch deserialize it correctly
349392
for (size_t i = 0; i < new_g->inputs().size(); ++i) {

core/lowering/lowering.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ int AutocastLongInputs(
3232
std::string target_device_name) {
3333
int num_autocasts = 0;
3434
// For each graph input, determine if it can be autocasted
35-
for (int i = 0; i < g->inputs().size(); i++) {
35+
for (size_t i = 0; i < g->inputs().size(); i++) {
3636
auto input = g->inputs()[i];
3737

3838
// Autocasted inputs must be Tensor-type

core/partitioning/partitioning.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,19 @@ typedef std::unordered_map<const torch::jit::Value*, torch::jit::IValue> Example
1818
typedef std::pair<std::shared_ptr<torch::jit::Graph>, std::unordered_map<torch::jit::Value*, torch::jit::Value*>>
1919
GraphAndMapping;
2020

21+
// Set of schemas allowed to be executed in Torch, even with require_full_compilation=true,
22+
// as necessary for returning collections of Tensors or other complex constructs, and for
23+
// processing inputs to TRT engines
24+
const std::unordered_set<std::string> CollectionSchemas = {
25+
"prim::Constant",
26+
"aten::__getitem__",
27+
"prim::ListConstruct",
28+
"prim::ListUnpack",
29+
"prim::TupleIndex",
30+
"prim::TupleConstruct",
31+
"prim::TupleUnpack",
32+
};
33+
2134
ExampleIValues generateRandomInputs(
2235
ir::CollectionInputSpecMap& input_ranges,
2336
ir::CollectionTypeMap& input_types,

tests/py/api/test_e2e_behavior.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,38 @@ def test_input_respect_user_setting_fp16_weights_fp32_in_non_constuctor(self):
109109
)
110110
trt_mod(self.input)
111111

112+
def test_nested_tuple_output_with_full_compilation(self):
113+
class Sample(torch.nn.Module):
114+
def __init__(self):
115+
super(Sample, self).__init__()
116+
117+
def forward(self, x, y, z):
118+
c = 1.0
119+
b = x + 2.0 * z
120+
b = y + b
121+
a = b + c
122+
return (a, (b, c))
123+
124+
self.model = Sample().eval().to("cuda")
125+
self.input_1 = torch.zeros((5, 5), dtype=torch.float, device="cuda:0")
126+
self.input_2 = torch.ones((5, 5), dtype=torch.float, device="cuda:0")
127+
self.input_3 = torch.ones((5, 5), dtype=torch.float, device="cuda:0")
128+
scripted_mod = torch.jit.script(self.model)
129+
130+
inputs = [
131+
torchtrt.Input((5, 5), dtype=torch.float),
132+
torchtrt.Input((5, 5), dtype=torch.float),
133+
torchtrt.Input((5, 5), dtype=torch.float),
134+
]
135+
136+
trt_mod = torchtrt.ts.compile(
137+
scripted_mod,
138+
inputs=inputs,
139+
require_full_compilation=True,
140+
enabled_precisions={torch.float, torch.half},
141+
)
142+
trt_mod(self.input_1, self.input_2, self.input_3)
143+
112144

113145
if __name__ == "__main__":
114146
unittest.main()

0 commit comments

Comments
 (0)