diff --git a/core/compiler.cpp b/core/compiler.cpp index 118ca7aa1c..58af1e6cd8 100644 --- a/core/compiler.cpp +++ b/core/compiler.cpp @@ -137,10 +137,13 @@ partitioning::GraphAndMapping BuildHybridGraph( auto partitioning_info = cfg.partitioning_info; auto partitioning_ctx = partitioning::PartitioningCtx(block, partitioning_info); - auto collection_input_ivalues_map = - partitioning::generateRandomInputs(partitioning_info.collection_input_spec_map, first_use_types); + partitioning_ctx.input_types_map = first_use_types; - partitioning::partition(&partitioning_ctx, collection_input_ivalues_map); + // Generate a dictionary of input torch::jit::Value's to their min, opt, max tensors and store in ctx + // TODO: Combine this within partition call + partitioning::populateInputIValues(&partitioning_ctx); + + partitioning::partition(&partitioning_ctx); for (auto& partitioned_block : partitioning_ctx.partitioned_blocks) { partitioning::PartitionedGraph& segmented_blocks = partitioned_block.second; @@ -151,14 +154,7 @@ partitioning::GraphAndMapping BuildHybridGraph( trt_engine_id << reinterpret_cast(&seg_block); if (seg_block.target() == partitioning::SegmentedBlock::kTensorRT) { - auto shapes = seg_block.in_shapes(); - auto types = seg_block.in_types(); - std::vector inputs; - for (size_t i = 0; i < shapes.size(); i++) { - auto in = ir::Input(shapes[i]); - in.dtype = util::ScalarTypeToTRTDataType(types[i]); - inputs.push_back(in); - } + auto inputs = seg_block.construct_inputs_spec(); // update the input ranges for each segments convert_info.inputs = ir::associate_specs_with_inputs(seg_block.g(), inputs, static_params); diff --git a/core/ir/ir.h b/core/ir/ir.h index 6c78908d5b..141dd24aa0 100644 --- a/core/ir/ir.h +++ b/core/ir/ir.h @@ -11,6 +11,12 @@ namespace torch_tensorrt { namespace core { namespace ir { +enum class ShapeMode { + kMIN, + kOPT, + kMAX, +}; + struct Device { nvinfer1::DeviceType device_type; int64_t gpu_id; diff --git a/core/partitioning/partitioning.cpp b/core/partitioning/partitioning.cpp index 21da0c9f0f..4d74461454 100644 --- a/core/partitioning/partitioning.cpp +++ b/core/partitioning/partitioning.cpp @@ -536,7 +536,35 @@ void segmentGraph(PartitioningCtx* ctx, torch::jit::Block* block) { return; } -void partition(PartitioningCtx* ctx, ExampleIValues& example_tensor_map) { +bool isInputDynamic(PartitioningCtx* ctx) { + // Check if inputs have dynamic shapes + bool input_is_dynamic = true; + auto inputs_map = ctx->settings.collection_input_spec_map; + for (auto inputs : inputs_map) { + for (auto input : inputs.second) { + if (!input.input_is_dynamic) { + input_is_dynamic = false; + } + } + } + return input_is_dynamic; +} + +void populateInputIValues(PartitioningCtx* ctx) { + if (isInputDynamic(ctx)) { + ctx->min_input_ivalues_map = partitioning::generateRandomInputs( + ctx->settings.collection_input_spec_map, ctx->input_types_map, ir::ShapeMode::kMIN); + ctx->opt_input_ivalues_map = partitioning::generateRandomInputs( + ctx->settings.collection_input_spec_map, ctx->input_types_map, ir::ShapeMode::kOPT); + ctx->max_input_ivalues_map = partitioning::generateRandomInputs( + ctx->settings.collection_input_spec_map, ctx->input_types_map, ir::ShapeMode::kMAX); + } else { + ctx->opt_input_ivalues_map = partitioning::generateRandomInputs( + ctx->settings.collection_input_spec_map, ctx->input_types_map, ir::ShapeMode::kOPT); + } +} + +void partition(PartitioningCtx* ctx) { LOG_DEBUG(ctx->settings); // Go through all the blocks to do the partitioning @@ -546,15 +574,24 @@ void partition(PartitioningCtx* ctx, ExampleIValues& example_tensor_map) { // It's possible that some TensorRT blocks have nonTensor inputs/output because they are interleaved by Torch blocks // resolve nonTensor inputs/outputs + LOG_DEBUG("Resolving non-tensor inputs for segmented blocks"); resolveTRTNonTensorInputs(ctx, block); // register input/output torch::jit::Value for segmented graphs LOG_DEBUG("Registering input/output torch::jit::Value for segmented graphs"); registerSegmentsOutputs(ctx, block); - // run shape analysis on each segmented block - LOG_DEBUG("Running shape analysis for segmented graphs"); - runShapeAnalysis(ctx, block, example_tensor_map); + // Incase of dynamic shape inputs, run shape analysis on each segmented block for min/opt/max ranges and register + // output shapes for each block accordingly + if (isInputDynamic(ctx)) { + LOG_DEBUG("Performing shape analysis for segmented blocks using min/opt/max shapes for inputs"); + runShapeAnalysis(ctx, block, ctx->min_input_ivalues_map, ir::ShapeMode::kMIN); + runShapeAnalysis(ctx, block, ctx->opt_input_ivalues_map, ir::ShapeMode::kOPT); + runShapeAnalysis(ctx, block, ctx->max_input_ivalues_map, ir::ShapeMode::kMAX); + } else { + LOG_DEBUG("Performing shape analysis for segmented blocks using static shapes for inputs"); + runShapeAnalysis(ctx, block, ctx->opt_input_ivalues_map, ir::ShapeMode::kOPT); + } } } diff --git a/core/partitioning/partitioning.h b/core/partitioning/partitioning.h index 3038f6c52f..7c72d091b6 100644 --- a/core/partitioning/partitioning.h +++ b/core/partitioning/partitioning.h @@ -18,15 +18,24 @@ typedef std::unordered_map Example typedef std::pair, std::unordered_map> GraphAndMapping; -ExampleIValues generateRandomInputs(ir::CollectionInputSpecMap& input_ranges, ir::CollectionTypeMap& input_types); +ExampleIValues generateRandomInputs( + ir::CollectionInputSpecMap& input_ranges, + ir::CollectionTypeMap& input_types, + const ir::ShapeMode& shape_mode = ir::ShapeMode::kOPT); -void runShapeAnalysis(PartitioningCtx* ctx, torch::jit::Block* block, ExampleIValues& ivalues_maps); +void populateInputIValues(PartitioningCtx* ctx); + +void runShapeAnalysis( + PartitioningCtx* ctx, + torch::jit::Block* block, + ExampleIValues& ivalues_maps, + const ir::ShapeMode& shape_mode); void segmentGraph(PartitioningCtx* ctx, torch::jit::Block* block); GraphAndMapping stitch(PartitioningCtx* ctx, torch::jit::Block* block); -void partition(PartitioningCtx* ctx, ExampleIValues& example_tensor_map); +void partition(PartitioningCtx* ctx); } // namespace partitioning } // namespace core diff --git a/core/partitioning/partitioningctx/PartitioningCtx.h b/core/partitioning/partitioningctx/PartitioningCtx.h index ed8e705be5..91e376eab3 100644 --- a/core/partitioning/partitioningctx/PartitioningCtx.h +++ b/core/partitioning/partitioningctx/PartitioningCtx.h @@ -47,6 +47,9 @@ struct UsageInfo { struct PartitioningCtx { // TODO: Make the set a part of settings not stand alone PartitioningInfo settings; + std::unordered_map min_input_ivalues_map; + std::unordered_map opt_input_ivalues_map; + std::unordered_map max_input_ivalues_map; // records all the original blocks topologically in the module std::vector original_blocks; // mapping: node=> execution status @@ -60,6 +63,7 @@ struct PartitioningCtx { bool shouldNodeRunInTorch(torch::jit::Node* n); bool shouldNodeRunInTensorRT(torch::jit::Node* n); std::vector getNodesRunInTorch(); + std::unordered_map>> input_types_map; private: void _load_nodes_into_decision_map(torch::jit::Block* b); diff --git a/core/partitioning/segmentedblock/SegmentedBlock.cpp b/core/partitioning/segmentedblock/SegmentedBlock.cpp index 6a370c83ad..583e67ca4d 100644 --- a/core/partitioning/segmentedblock/SegmentedBlock.cpp +++ b/core/partitioning/segmentedblock/SegmentedBlock.cpp @@ -1,4 +1,5 @@ #include "SegmentedBlock.h" +#include "core/util/prelude.h" namespace torch_tensorrt { namespace core { @@ -56,6 +57,24 @@ torch::jit::Value* SegmentedBlock::getOrAddInputForValue(torch::jit::Value* old_ } } +std::vector SegmentedBlock::construct_inputs_spec() const { + std::vector inputs; + if (min_shapes_.size() == opt_shapes_.size() && opt_shapes_.size() == max_shapes_.size()) { + for (uint64_t i = 0; i < opt_shapes_.size(); i++) { + auto in = ir::Input(min_shapes_[i], opt_shapes_[i], max_shapes_[i]); + in.dtype = util::ScalarTypeToTRTDataType(in_types_[i]); + inputs.push_back(in); + } + } else { + for (uint64_t i = 0; i < opt_shapes_.size(); i++) { + auto in = ir::Input(opt_shapes_[i]); + in.dtype = util::ScalarTypeToTRTDataType(in_types_[i]); + inputs.push_back(in); + } + } + return inputs; +} + torch::jit::Node* SegmentedBlock::cloneNode(torch::jit::Node* node) { auto* block = g_->block(); auto env = [&](torch::jit::Value* v) { return getOrAddInputForValue(v); }; diff --git a/core/partitioning/segmentedblock/SegmentedBlock.h b/core/partitioning/segmentedblock/SegmentedBlock.h index 0cea11e99d..db5e8fedd9 100644 --- a/core/partitioning/segmentedblock/SegmentedBlock.h +++ b/core/partitioning/segmentedblock/SegmentedBlock.h @@ -35,6 +35,7 @@ struct SegmentedBlock { SegmentedBlock(BlockID id, SegmentedBlockTarget blk_target, const std::vector& nodes); torch::jit::Value* getOrAddInputForValue(torch::jit::Value* v); + std::vector construct_inputs_spec() const; torch::jit::Node* cloneNode(torch::jit::Node* node); void appendNode(torch::jit::Node* n) { cloneNode(n); @@ -72,11 +73,23 @@ struct SegmentedBlock { bool contain_raw_value(torch::jit::Value* input) const { return old_to_new_.count(input); } - void register_inshapes(std::vector& in_shapes) { - in_shapes_ = in_shapes; + void register_inshapes(std::vector>& in_shapes, const ir::ShapeMode& shape_mode) { + if (shape_mode == ir::ShapeMode::kMIN) { + min_shapes_ = in_shapes; + } else if (shape_mode == ir::ShapeMode::kOPT) { + opt_shapes_ = in_shapes; + } else { + max_shapes_ = in_shapes; + } + } + const std::vector> in_opt_shapes() const { + return opt_shapes_; } - const std::vector& in_shapes() const { - return in_shapes_; + const std::vector> in_min_shapes() const { + return min_shapes_; + } + const std::vector> in_max_shapes() const { + return max_shapes_; } void register_intypes(std::vector& in_types) { in_types_ = in_types; @@ -84,6 +97,7 @@ struct SegmentedBlock { const std::vector& in_types() const { return in_types_; } + void update_id(BlockID new_id) { id_ = new_id; } @@ -107,7 +121,9 @@ struct SegmentedBlock { private: BlockID id_; SegmentedBlockTarget target_; - std::vector in_shapes_; + std::vector> min_shapes_; + std::vector> opt_shapes_; + std::vector> max_shapes_; std::vector in_types_; std::vector inputs_; std::vector outputs_; diff --git a/core/partitioning/shape_analysis.cpp b/core/partitioning/shape_analysis.cpp index beab08aa90..81220e3af8 100644 --- a/core/partitioning/shape_analysis.cpp +++ b/core/partitioning/shape_analysis.cpp @@ -10,16 +10,25 @@ namespace torch_tensorrt { namespace core { namespace partitioning { -at::Tensor generateSingleInput(ir::Input& input, c10::optional& type_opt) { - auto cur_shape = input.input_shape; - std::vector shape; +at::Tensor generateSingleInput( + ir::Input& input, + c10::optional& type_opt, + const ir::ShapeMode& shape_mode) { + nvinfer1::Dims input_shape = input.input_shape; + if (input.input_is_dynamic) { + if (shape_mode == ir::ShapeMode::kMIN) { + input_shape = input.min; + } else if (shape_mode == ir::ShapeMode::kOPT) { + input_shape = input.opt; + } else { + input_shape = input.max; + } + } // Initialize min and max ranges for random number selection int LoValIncl = 0; int HiValExcl = 2; - shape.insert(shape.begin(), std::begin(cur_shape.d), std::begin(cur_shape.d) + cur_shape.nbDims); - auto type = at::kFloat; if (type_opt) { type = type_opt.value(); @@ -29,14 +38,15 @@ at::Tensor generateSingleInput(ir::Input& input, c10::optional& // Make the value range for input tensor a uniform (float) distribution // over [LoValIncl, HiValExcl), then cast to the desired dtype - auto in = ((HiValExcl - LoValIncl) * at::rand(shape, {at::kCUDA}) + LoValIncl).to(type); + auto in = ((HiValExcl - LoValIncl) * at::rand(util::toVec(input_shape), {at::kCUDA}) + LoValIncl).to(type); return in; } std::unordered_map generateRandomInputs( std::unordered_map>& inputs, - std::unordered_map>>& types) { + std::unordered_map>>& types, + const ir::ShapeMode& shape_mode) { // generate random inputs for running pytorch segments std::unordered_map ivalue_map; @@ -45,7 +55,7 @@ std::unordered_map generateRandomI c10::TypePtr elementType = c10::TensorType::get(); auto generic_list = c10::impl::GenericList(elementType); for (size_t i = 0; i < input.second.size(); i++) { - auto in = generateSingleInput(input.second[i], types[input.first][i]); + auto in = generateSingleInput(input.second[i], types[input.first][i], shape_mode); generic_list.push_back(in.clone()); } ivalue_map[input.first] = c10::IValue(generic_list); @@ -53,13 +63,13 @@ std::unordered_map generateRandomI // create tuple std::vector list; for (size_t i = 0; i < input.second.size(); i++) { - auto in = generateSingleInput(input.second[i], types[input.first][i]); + auto in = generateSingleInput(input.second[i], types[input.first][i], shape_mode); list.push_back(in.clone()); } auto tuple = c10::ivalue::Tuple::create(list); // create tuple ptr ivalue_map[input.first] = c10::IValue(tuple); } else { - auto in = generateSingleInput(input.second[0], types[input.first][0]); + auto in = generateSingleInput(input.second[0], types[input.first][0], shape_mode); ivalue_map[input.first] = in.clone(); } } @@ -124,7 +134,8 @@ torch::jit::Node* createCastNode(SegmentedBlock& seg_block, size_t index, bool i void getSegmentsOutputByRunning( SegmentedBlock& seg_block, std::unordered_map& ivalues_maps, - const PartitioningInfo& partitioning_info) { + const PartitioningInfo& partitioning_info, + const ir::ShapeMode& shape_mode) { // create a module to run the graph auto g = seg_block.g(); auto copy_g = g->copy(); @@ -235,7 +246,7 @@ void getSegmentsOutputByRunning( } // set input shape for each segmented block so we wil use it in conversion process - std::vector input_shapes; + std::vector> input_shapes; std::vector input_types; for (size_t i = 0; i < seg_block.inputs().size(); ++i) { if (ivalues_maps[seg_block.raw_inputs()[i]].isTensor()) { @@ -270,15 +281,19 @@ void getSegmentsOutputByRunning( // TODO: tuple and list inputs in subgraph } - seg_block.register_inshapes(input_shapes); + seg_block.register_inshapes(input_shapes, shape_mode); seg_block.register_intypes(input_types); } -void runShapeAnalysis(PartitioningCtx* ctx, torch::jit::Block* block, ExampleIValues& example_tensor_map) { +void runShapeAnalysis( + PartitioningCtx* ctx, + torch::jit::Block* block, + ExampleIValues& example_tensor_map, + const ir::ShapeMode& shape_mode) { // register every segment's input shape, and it's running output IValues for (auto& seg_block : ctx->partitioned_blocks[block]) { torch::jit::ConstantPooling(seg_block.g()); - getSegmentsOutputByRunning(seg_block, example_tensor_map, ctx->settings); + getSegmentsOutputByRunning(seg_block, example_tensor_map, ctx->settings, shape_mode); } return; } diff --git a/tests/core/partitioning/test_resolve_nontensor_inputs.cpp b/tests/core/partitioning/test_resolve_nontensor_inputs.cpp index 44ac0a06e7..5f590fa5ab 100644 --- a/tests/core/partitioning/test_resolve_nontensor_inputs.cpp +++ b/tests/core/partitioning/test_resolve_nontensor_inputs.cpp @@ -122,9 +122,13 @@ TEST(Partitioning, ResolveNonTensorInputsCorrectly) { inputs_map.insert({g->inputs()[i], {inputs[i]}}); input_types.insert({g->inputs()[i], {{at::kFloat}}}); } - auto input_ivalues_map = torch_tensorrt::core::partitioning::generateRandomInputs(inputs_map, input_types); + + partitioning_info.collection_input_spec_map = inputs_map; torch_tensorrt::core::partitioning::PartitioningCtx ctx(g->block(), partitioning_info); - torch_tensorrt::core::partitioning::partition(&ctx, input_ivalues_map); + ctx.input_types_map = input_types; + + torch_tensorrt::core::partitioning::populateInputIValues(&ctx); + torch_tensorrt::core::partitioning::partition(&ctx); std::vector segmented_blocks = ctx.partitioned_blocks.begin()->second; @@ -182,10 +186,12 @@ TEST(Partitioning, ResolveTensorListInputsInTrtCorrectly) { inputs_map.insert({g->inputs()[i], {inputs[i]}}); input_types.insert({g->inputs()[i], {{at::kFloat}}}); } - auto input_ivalues_map = torch_tensorrt::core::partitioning::generateRandomInputs(inputs_map, input_types); - torch_tensorrt::core::partitioning::PartitioningCtx ctx(g->block(), partitioning_info); - torch_tensorrt::core::partitioning::partition(&ctx, input_ivalues_map); + partitioning_info.collection_input_spec_map = inputs_map; + torch_tensorrt::core::partitioning::PartitioningCtx ctx(g->block(), partitioning_info); + ctx.input_types_map = input_types; + torch_tensorrt::core::partitioning::populateInputIValues(&ctx); + torch_tensorrt::core::partitioning::partition(&ctx); std::vector segmented_blocks = ctx.partitioned_blocks.begin()->second; @@ -262,7 +268,7 @@ TEST(Partitioning, ConvertForTensorListInputsInFallbackCorrectly) { int count = count_trt_engines(fallback_g); ASSERT_TRUE(count == 1); } - +// TEST(Partitioning, ResolveOnlyNeccessaryNonTensorInputs) { /* parseIR does not support "= aten::_set_item" so we will build this graph manually const auto graph = R"IR( @@ -376,9 +382,11 @@ TEST(Partitioning, ResolveOnlyNeccessaryNonTensorInputs) { inputs_map.insert({g->inputs()[i], {inputs[i]}}); input_types.insert({g->inputs()[i], {{at::kFloat}}}); } - auto input_ivalues_map = torch_tensorrt::core::partitioning::generateRandomInputs(inputs_map, input_types); + partitioning_info.collection_input_spec_map = inputs_map; torch_tensorrt::core::partitioning::PartitioningCtx ctx(g->block(), partitioning_info); - torch_tensorrt::core::partitioning::partition(&ctx, input_ivalues_map); + ctx.input_types_map = input_types; + torch_tensorrt::core::partitioning::populateInputIValues(&ctx); + torch_tensorrt::core::partitioning::partition(&ctx); auto segmented_blocks = ctx.partitioned_blocks.begin()->second; int torch_block_cnt = 0, trt_block_cnt = 0; diff --git a/tests/core/partitioning/test_shape_analysis.cpp b/tests/core/partitioning/test_shape_analysis.cpp index 87c42c0e47..7558fbec69 100644 --- a/tests/core/partitioning/test_shape_analysis.cpp +++ b/tests/core/partitioning/test_shape_analysis.cpp @@ -15,7 +15,7 @@ bool checkSegmentedBlockInputShape( if (cur_block_in_shapes.size() != in_shape[i].size()) return false; for (size_t j = 0; j < cur_block_in_shapes.size(); ++j) { - auto cur_input_shape = torch_tensorrt::core::util::toVec(cur_block_in_shapes[j].input_shape); + auto cur_input_shape = cur_block_in_shapes[j]; for (size_t k = 0; k < cur_input_shape.size(); ++k) { if (cur_input_shape[k] != in_shape[i][j][k]) return false; @@ -61,14 +61,18 @@ TEST(Partitioning, InferSequentialModelSegmentedBlockShapeCorrectly) { std::unordered_map> inputs_map; std::unordered_map>> input_types; + for (size_t i = 0; i < g->inputs().size(); ++i) { inputs_map.insert({g->inputs()[i], {inputs[i]}}); input_types.insert({g->inputs()[i], {{at::kFloat}}}); } - auto input_ivalues_map = torch_tensorrt::core::partitioning::generateRandomInputs(inputs_map, input_types); - + // Store a map of torch::jit::Value to ir::Input + partitioning_info.collection_input_spec_map = inputs_map; torch_tensorrt::core::partitioning::PartitioningCtx ctx(g->block(), partitioning_info); - torch_tensorrt::core::partitioning::partition(&ctx, input_ivalues_map); + ctx.input_types_map = input_types; + + torch_tensorrt::core::partitioning::populateInputIValues(&ctx); + torch_tensorrt::core::partitioning::partition(&ctx); auto segmented_blocks = ctx.partitioned_blocks.begin()->second; ASSERT_TRUE(checkSegmentedBlockInputShape( @@ -117,10 +121,14 @@ TEST(Partitioning, InferBranchModelSegmentedBlockShapeCorrectly) { inputs_map.insert({g->inputs()[i], {inputs[i]}}); input_types.insert({g->inputs()[i], {{at::kFloat}}}); } - auto input_ivalues_map = torch_tensorrt::core::partitioning::generateRandomInputs(inputs_map, input_types); + // Store a map of torch::jit::Value to ir::Input + partitioning_info.collection_input_spec_map = inputs_map; torch_tensorrt::core::partitioning::PartitioningCtx ctx(g->block(), partitioning_info); - torch_tensorrt::core::partitioning::partition(&ctx, input_ivalues_map); + ctx.input_types_map = input_types; + + torch_tensorrt::core::partitioning::populateInputIValues(&ctx); + torch_tensorrt::core::partitioning::partition(&ctx); auto segmented_blocks = ctx.partitioned_blocks.begin()->second; ASSERT_TRUE(checkSegmentedBlockInputShape( diff --git a/tests/cpp/BUILD b/tests/cpp/BUILD index 3d56682189..41d31f5275 100644 --- a/tests/cpp/BUILD +++ b/tests/cpp/BUILD @@ -15,6 +15,7 @@ test_suite( ":test_collections", ":test_compiled_modules", ":test_default_input_types", + ":test_dynamic_fallback", ":test_example_tensors", ":test_module_fallback", ":test_modules_as_engines", @@ -30,6 +31,7 @@ test_suite( ":test_collections", ":test_compiled_modules", ":test_default_input_types", + ":test_dynamic_fallback", ":test_example_tensors", ":test_module_fallback", ":test_modules_as_engines", @@ -125,6 +127,21 @@ cc_test( }), ) +cc_test( + name = "test_dynamic_fallback", + srcs = ["test_dynamic_fallback.cpp"], + data = [ + "//tests/modules:jit_models", + ], + deps = [ + "//tests/util", + "@googletest//:gtest_main", + ] + select({ + ":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"], + "//conditions:default": ["@libtorch//:libtorch"], + }), +) + cc_test( name = "test_collections", srcs = ["test_collections.cpp"], diff --git a/tests/cpp/test_dynamic_fallback.cpp b/tests/cpp/test_dynamic_fallback.cpp new file mode 100644 index 0000000000..42ffbba897 --- /dev/null +++ b/tests/cpp/test_dynamic_fallback.cpp @@ -0,0 +1,105 @@ +#include +#include "gtest/gtest.h" +#include "tests/util/util.h" +#include "torch/script.h" +#include "torch_tensorrt/torch_tensorrt.h" + +TEST(CppAPITest, ResNet18DynamicBatchFallbackCorrectly) { + torch::jit::script::Module mod; + try { + mod = torch::jit::load("tests/modules/resnet18_scripted.jit.pt"); + } catch (const c10::Error& e) { + std::cerr << "error loading the model\n"; + ASSERT_TRUE(false); + } + + const std::vector> input_shapes = {{1, 3, 224, 224}, {4, 3, 224, 224}, {8, 3, 224, 224}}; + std::vector jit_inputs_ivalues; + std::vector trt_inputs_ivalues; + auto in_bs1 = at::randint(5, input_shapes[0], {at::kCUDA}); + jit_inputs_ivalues.push_back(in_bs1.clone()); + trt_inputs_ivalues.push_back(in_bs1.clone()); + + std::vector inputs; + inputs.push_back(torch_tensorrt::Input(input_shapes[0], input_shapes[1], input_shapes[2])); + torch_tensorrt::ts::CompileSpec cfg(inputs); + cfg.torch_executed_ops.push_back("aten::add"); + + auto jit_results_bs1 = mod.forward(jit_inputs_ivalues).toTensor(); + // Compile and build the hybrid graph with dynamic shapes + auto trt_mod = torch_tensorrt::ts::compile(mod, cfg); + auto trt_results_bs1 = trt_mod.forward(trt_inputs_ivalues).toTensor(); + ASSERT_TRUE(torch_tensorrt::tests::util::cosineSimEqual(jit_results_bs1, trt_results_bs1)); + jit_inputs_ivalues.clear(); + trt_inputs_ivalues.clear(); + + // Run with batch size of 4 + auto in_bs4 = at::randint(5, input_shapes[1], {at::kCUDA}); + jit_inputs_ivalues.push_back(in_bs4.clone()); + trt_inputs_ivalues.push_back(in_bs4.clone()); + + auto jit_results_bs4 = mod.forward(jit_inputs_ivalues).toTensor(); + auto trt_results_bs4 = trt_mod.forward(trt_inputs_ivalues).toTensor(); + ASSERT_TRUE(torch_tensorrt::tests::util::cosineSimEqual(jit_results_bs4, trt_results_bs4)); + jit_inputs_ivalues.clear(); + trt_inputs_ivalues.clear(); + + // Run with batch size of 8 + auto in_bs8 = at::randint(5, input_shapes[2], {at::kCUDA}); + jit_inputs_ivalues.push_back(in_bs8.clone()); + trt_inputs_ivalues.push_back(in_bs8.clone()); + + auto jit_results_bs8 = mod.forward(jit_inputs_ivalues).toTensor(); + auto trt_results_bs8 = trt_mod.forward(trt_inputs_ivalues).toTensor(); + ASSERT_TRUE(torch_tensorrt::tests::util::cosineSimEqual(jit_results_bs8, trt_results_bs8)); +} + +TEST(CppAPITest, ResNet18DynamicShapeFallbackCorrectly) { + torch::jit::script::Module mod; + try { + mod = torch::jit::load("tests/modules/resnet18_scripted.jit.pt"); + } catch (const c10::Error& e) { + std::cerr << "error loading the model\n"; + ASSERT_TRUE(false); + } + + const std::vector> input_shapes = {{1, 3, 64, 64}, {1, 3, 128, 128}, {1, 3, 224, 224}}; + std::vector jit_inputs_ivalues; + std::vector trt_inputs_ivalues; + auto in_64 = at::randint(5, input_shapes[0], {at::kCUDA}); + jit_inputs_ivalues.push_back(in_64.clone()); + trt_inputs_ivalues.push_back(in_64.clone()); + + std::vector inputs; + inputs.push_back(torch_tensorrt::Input(input_shapes[0], input_shapes[1], input_shapes[2])); + torch_tensorrt::ts::CompileSpec cfg(inputs); + cfg.torch_executed_ops.push_back("aten::add"); + + auto jit_results_64 = mod.forward(jit_inputs_ivalues).toTensor(); + // Compile and build the hybrid graph with dynamic shapes + auto trt_mod = torch_tensorrt::ts::compile(mod, cfg); + auto trt_results_64 = trt_mod.forward(trt_inputs_ivalues).toTensor(); + ASSERT_TRUE(torch_tensorrt::tests::util::cosineSimEqual(jit_results_64, trt_results_64)); + jit_inputs_ivalues.clear(); + trt_inputs_ivalues.clear(); + + // Run with input resolution of (1, 3, 128, 128) + auto in_128 = at::randint(5, input_shapes[1], {at::kCUDA}); + jit_inputs_ivalues.push_back(in_128.clone()); + trt_inputs_ivalues.push_back(in_128.clone()); + + auto jit_results_128 = mod.forward(jit_inputs_ivalues).toTensor(); + auto trt_results_128 = trt_mod.forward(trt_inputs_ivalues).toTensor(); + ASSERT_TRUE(torch_tensorrt::tests::util::cosineSimEqual(jit_results_128, trt_results_128)); + jit_inputs_ivalues.clear(); + trt_inputs_ivalues.clear(); + + // Run with input resolution of (1, 3, 256, 256) + auto in_256 = at::randint(5, input_shapes[2], {at::kCUDA}); + jit_inputs_ivalues.push_back(in_256.clone()); + trt_inputs_ivalues.push_back(in_256.clone()); + + auto jit_results_256 = mod.forward(jit_inputs_ivalues).toTensor(); + auto trt_results_256 = trt_mod.forward(trt_inputs_ivalues).toTensor(); + ASSERT_TRUE(torch_tensorrt::tests::util::cosineSimEqual(jit_results_256, trt_results_256)); +}