@@ -150,6 +150,97 @@ TEST(Partitioning, ResolveNonTensorInputsCorrectly) {
150
150
ASSERT_TRUE (trt_block_cnt == 1 && torch_block_cnt == 1 );
151
151
}
152
152
153
+ TEST (Partitioning, ResolveMultipleNonTensorInputsCorrectly) {
154
+ const auto graph = R"IR(
155
+ graph(%x.1 : Tensor):
156
+ # TensorRT-intended Block
157
+ %16 : int = prim::Constant[value=8]()
158
+ %15 : int = prim::Constant[value=64]()
159
+ %13 : int = prim::Constant[value=0]()
160
+ %10 : int = prim::Constant[value=1]()
161
+ %self.linear.bias : Float(4096, strides=[1], requires_grad=0, device=cuda:0) = prim::Constant[value=<Tensor>]()
162
+ %self.linear.weight : Float(4096, 64, strides=[64, 1], requires_grad=0, device=cuda:0) = prim::Constant[value=<Tensor>]()
163
+ %3 : int = prim::Constant[value=-1]()
164
+ %2 : int = prim::Constant[value=1]()
165
+ %x.5 : Tensor = aten::flatten(%x.1, %2, %3)
166
+ %4 : Tensor = aten::t(%self.linear.weight)
167
+ %6 : Tensor = aten::matmul(%x.5, %4)
168
+ %7 : Tensor = trt::const(%self.linear.bias)
169
+ %9 : Tensor = aten::add(%7, %6, %10)
170
+ %11 : int[] = aten::size(%9) # <string>:13:9
171
+ %12 : int = aten::__getitem__(%11, %13)
172
+ %shape.3 : int[] = prim::ListConstruct(%12, %15, %16, %16)
173
+ %x.13 : Tensor = aten::reshape(%9, %shape.3)
174
+
175
+ # Torch-intended Block
176
+ %num_spatial_dims.2 : int = prim::Constant[value=2]()
177
+ %11 : int[] = prim::Constant[value=[0, 0]]()
178
+ %10 : bool = prim::Constant[value=0]()
179
+ %conv1_bias : Float(32, strides=[1], requires_grad=0, device=cuda:0) = prim::Constant[value=<Tensor>]()
180
+ %conv1_weight : Float(32, 32, 3, 3, strides=[288, 9, 3, 1], requires_grad=0, device=cuda:0) = prim::Constant[value=<Tensor>]()
181
+ %6 : int = prim::Constant[value=1]()
182
+ %5 : int[] = prim::Constant[value=[1, 1]]()
183
+ %4 : int[] = prim::Constant[value=[2, 2]]()
184
+ %conv_bias : Float(32, strides=[1], requires_grad=0, device=cuda:0) = prim::Constant[value=<Tensor>]()
185
+ %conv_weight : Float(64, 32, 3, 3, strides=[288, 9, 3, 1], requires_grad=0, device=cuda:0) = prim::Constant[value=<Tensor>]()
186
+ %input.16 : Tensor = aten::conv_transpose2d(%x.13, %conv_weight, %conv_bias, %4, %5, %5, %6, %5)
187
+ %7 : Tensor = aten::_convolution(%input.16, %conv1_weight, %conv1_bias, %5, %5, %5, %10, %11, %6, %10, %10, %10, %10)
188
+ %12 : int[] = aten::size(%7)
189
+ %96 : int = aten::len(%12)
190
+ %14 : int = aten::__range_length(%num_spatial_dims.2, %96, %6)
191
+
192
+ # TensorRT-intended Block
193
+ %15 : float = prim::Constant[value=1e-05]()
194
+ %14 : float = prim::Constant[value=0.1]()
195
+ %13 : NoneType = prim::Constant()
196
+ %num_spatial_dims.2 : int = prim::Constant[value=2]()
197
+ %300 : int = prim::Constant[value=3]()
198
+ %345 : int = aten::sub(%300, %96)
199
+ %3 : int = aten::add(%345, %6)
200
+ %2 : bool = prim::Constant[value=1]()
201
+ %size_prods.2 : int = prim::Loop(%3, %2, %6)
202
+ block0(%loop : int, %size_prods.13 : int):
203
+ %i.3 : int = aten::__derive_index(%loop, %num_spatial_dims.2, %3)
204
+ %8 : int = aten::__getitem__(%12, %i.3)
205
+ %size_prods.15 : int = aten::mul(%size_prods.13, %8)
206
+ -> (%2, %size_prods.15)
207
+ %11 : Tensor = aten::instance_norm(%7, %13, %13, %13, %13, %2, %14, %15, %2)
208
+ return (%11))IR" ;
209
+
210
+ auto g = std::make_shared<torch::jit::Graph>();
211
+ torch::jit::parseIR (graph, g.get (), true );
212
+
213
+ torch_tensorrt::core::partitioning::PartitioningInfo partitioning_info;
214
+ partitioning_info.enabled = true ;
215
+ std::vector<torch_tensorrt::core::ir::Input> inputs;
216
+ inputs.push_back (torch_tensorrt::core::ir::Input ({1 , 64 }));
217
+
218
+ torch_tensorrt::core::ir::CollectionInputSpecMap inputs_map;
219
+ std::unordered_map<const torch::jit::Value*, std::vector<c10::optional<at::ScalarType>>> input_types;
220
+ for (size_t i = 0 ; i < g->inputs ().size (); ++i) {
221
+ inputs_map.insert ({g->inputs ()[i], {inputs[i]}});
222
+ input_types.insert ({g->inputs ()[i], {{at::kFloat }}});
223
+ }
224
+
225
+ partitioning_info.collection_input_spec_map = inputs_map;
226
+ partitioning_info.forced_fallback_operators = {" aten::_convolution" };
227
+ torch_tensorrt::core::partitioning::PartitioningCtx ctx (g->block (), partitioning_info);
228
+ ctx.input_types_map = input_types;
229
+
230
+ torch_tensorrt::core::partitioning::populateInputIValues (&ctx);
231
+ torch_tensorrt::core::partitioning::partition (&ctx);
232
+ std::vector<torch_tensorrt::core::partitioning::SegmentedBlock> segmented_blocks =
233
+ ctx.partitioned_blocks .begin ()->second ;
234
+
235
+ // For each TensorRT segmented block, verify that all inputs are of Tensor type
236
+ for (auto block : segmented_blocks) {
237
+ if (block.target () == torch_tensorrt::core::partitioning::SegmentedBlock::SegmentedBlockTarget::kTensorRT ) {
238
+ for (auto input : block.raw_inputs ())
239
+ ASSERT_TRUE (input->type ()->isSubtypeOf (c10::TensorType::get ()));
240
+ }
241
+ }
242
+ }
243
+
153
244
TEST (Partitioning, ResolveTensorListInputsInTrtCorrectly) {
154
245
const auto graph = R"IR(
155
246
graph(%0 : Float(1, 3, 16, 16, strides=[768, 256, 16, 1]),
0 commit comments