diff --git a/core/compiler.cpp b/core/compiler.cpp index b684b808f5..72243835dd 100644 --- a/core/compiler.cpp +++ b/core/compiler.cpp @@ -254,6 +254,7 @@ GraphAndMapping ConstructFallbackGraph( // update the input ranges for each segments convert_cfg.inputs = ir::associate_specs_with_inputs(seg_block.g(), inputs, static_params); + // TODO mapping Inputs Ivalue to flatten one here auto engine = conversion::ConvertBlockToEngine(seg_block.block(), convert_cfg, static_params); auto temp_g = std::make_shared(); auto device_spec = convert_cfg.engine_settings.device; @@ -304,57 +305,72 @@ void MapInputsAndDetermineDTypes( CompileSpec& cfg, std::shared_ptr& g, ir::StaticParams& static_params, - ir::TypeMap& first_use_type_map) { - // Associate input specs with inputs - cfg.convert_info.inputs = std::move(ir::associate_specs_with_inputs(g, cfg.inputs, static_params)); - - for (auto& in : g->inputs()) { - if (static_params.find(in) == static_params.end()) { - ir::Input& spec = cfg.convert_info.inputs.find(in)->second; - auto est_type_opt = first_use_type_map.find(in)->second; - if (est_type_opt && !spec.dtype_is_user_defined) { - // If we can calculate the type from the graph and the type was not defined by the user then use the calculated - // type - LOG_INFO( - "Since input type is not explicitly defined, infering using first tensor calculation\n Found input " - << in->debugName() << " has type " << est_type_opt.value() - << ". If this is incorrect explicitly set dtype for input and file a bug"); - spec.dtype = util::ScalarTypeToTRTDataType(est_type_opt.value()); - } else if (!est_type_opt && !spec.dtype_is_user_defined) { - // If we cannot calculate the type and the user did not define the type, then default to FP32 - LOG_WARNING( - "Cannot infer input type from calcuations in graph for input " - << in->debugName() << ". Assuming it is Float32. If not, specify input type explicity"); - spec.dtype = nvinfer1::DataType::kFLOAT; - } else if (spec.dtype_is_user_defined && cfg.partition_info.enabled) { - if (!est_type_opt) { - LOG_INFO("Cannot infer input tensor dtype in graph. Using user provided input dtype settings"); - first_use_type_map[in] = {util::TRTDataTypeToScalarType(cfg.convert_info.inputs.find(in)->second.dtype)}; - } else { - if (util::TRTDataTypeToScalarType(cfg.convert_info.inputs.find(in)->second.dtype) != est_type_opt.value()) { + ir::CollectionTypeMap& first_use_type_map) { + cfg.convert_info.collection_input_spec_map = std::move(ir::associate_specs_with_collection_inputs(g, cfg.graph_inputs, static_params)); + + auto collection_inputs = ir::get_collection_inputs(g, static_params); + LOG_DEBUG("In MapInputsAndDetermineDTypes, the g->inputs() size is " << g->inputs().size() << ", CollectionInputSpecMap size is" << collection_inputs.size()); + + for (auto in : collection_inputs) { + std::vector& spec = cfg.convert_info.collection_input_spec_map.find(in)->second; + std::vector> est_type_opt; + + auto est_it = first_use_type_map.find(in); + if (est_it != first_use_type_map.end()) { + est_type_opt = first_use_type_map.find(in)->second; + } + // traverse elements in est_type_out and spec + for (int i = 0; i < est_type_opt.size(); i++) { + if (est_type_opt[i] && !spec[i].dtype_is_user_defined) { + // If we can calculate the type from the graph and the type was not defined by the user then use the calculated + // type + LOG_INFO( + "Since input type is not explicitly defined, infering using first tensor calculation\n Inferred input " + << in->debugName() << " has type " << est_type_opt[i].value()); + spec[i].dtype = util::ScalarTypeToTRTDataType(est_type_opt[i].value()); + } else if (!est_type_opt[i] && !spec[i].dtype_is_user_defined) { + // If we cannot calculate the type and the user did not define the type, then default to FP32 + LOG_WARNING( + "Cannot infer input type from calcuations in graph for input " + << in->debugName() << ". Assuming it is Float32. If not, specify input type explicity"); + spec[i].dtype = nvinfer1::DataType::kFLOAT; + } else if (spec[i].dtype_is_user_defined && cfg.partition_info.enabled) { + if (!est_type_opt[i]) { + LOG_INFO("Cannot infer input tensor dtype in graph, compiler is going to use the user setting"); std::stringstream ss; ss << "For input " << in->debugName() << ", found user specified input dtype as "; - ss << cfg.convert_info.inputs.find(in)->second.dtype; - ss << ", however when inspecting the graph, the input type expected was inferred to be "; - ss << est_type_opt.value() << std::endl; - ss << "The compiler is going to use the user setting " << cfg.convert_info.inputs.find(in)->second.dtype; - ss << "\nThis conflict may cause an error at runtime due to partial compilation being enabled and therefore\n"; - ss << "compatibility with PyTorch's data type convention is required.\n"; - ss << "If you do indeed see errors at runtime either:\n"; - ss << "- Remove the dtype spec for " << in->debugName() << std::endl; - ss << "- Disable partial compilation by setting require_full_compilation to True"; + ss << cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype; + ss << ". The compiler is going to use the user setting " << cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype; auto warn_str = ss.str(); LOG_WARNING(warn_str); + // Overwrite type map with user settings + first_use_type_map[in][i] = {util::TRTDataTypeToScalarType(cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype)}; + + } else { + if (util::TRTDataTypeToScalarType(cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype) != est_type_opt[i].value()) { + std::stringstream ss; + ss << "For input " << in->debugName() << ", found user specified input dtype as "; + ss << cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype; + ss << ", however when inspecting the graph, the input type expected was inferred to be "; + ss << est_type_opt[i].value() << std::endl; + ss << "The compiler is going to use the user setting " << cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype; + ss << "\nThis conflict may cause an error at runtime due to partial compilation being enabled and therefore\n"; + ss << "compatibility with PyTorch's data type convention is required.\n"; + ss << "If you do indeed see errors at runtime either:\n"; + ss << "- Remove the dtype spec for " << in->debugName() << std::endl; + ss << "- Disable partial compilation by setting require_full_compilation to True"; + auto warn_str = ss.str(); + LOG_WARNING(warn_str); + // Overwrite type map with user settings + first_use_type_map[in][i] = {util::TRTDataTypeToScalarType(cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype)}; + } } - // Overwrite type map with user settings - // We use this map for partitiioning since we need c10::ScalarTypes not nvinfer::DataTypes - first_use_type_map[in] = {util::TRTDataTypeToScalarType(cfg.convert_info.inputs.find(in)->second.dtype)}; + } else { + // The user defined the type so no changes are necessary } - } else { - // The user defined the type so no changes are necessary } } - } + // } } uint64_t GetRecommendedWorkspaceSize(const runtime::CudaDevice& device) { @@ -376,7 +392,8 @@ std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, std:: auto params = graph_and_parameters.second; auto static_params = ir::get_static_params(g->inputs(), params); // Infer the type of an input from the weights of the calculation - auto first_use_types = ir::get_block_first_calc_dtypes_opt(g->block()); + // auto first_use_types = ir::get_block_first_calc_dtypes_opt(g->block()); + auto first_use_types = ir::get_block_first_calc_dtypes_opt_collection(g->block()); // GPU default WS size : 1 GB // Set WS = 256 Mb for Jetson nano/TX1 like platforms whose compute capability is 5.X. @@ -416,10 +433,11 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg) auto params = graph_and_parameters.second; auto static_params = ir::get_static_params(g->inputs(), params); // Infer the type of an input from the weights of the calculation - auto first_use_types = ir::get_block_first_calc_dtypes_opt(g->block()); + auto first_use_types = ir::get_block_first_calc_dtypes_opt_collection(g->block()); MapInputsAndDetermineDTypes(cfg, g, static_params, first_use_types); auto isBlockConvertible = conversion::VerifyConverterSupportForBlock(g->block(), true); + auto outputIsCollection = conversion::OutputIsCollection(g->block()); if (cfg.partition_info.enabled && (cfg.lower_info.forced_fallback_modules.size() == 0 && cfg.partition_info.forced_fallback_operators.size() == 0 && isBlockConvertible)) { @@ -427,10 +445,12 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg) } if (cfg.partition_info.enabled && - !(cfg.lower_info.forced_fallback_modules.size() == 0 && - cfg.partition_info.forced_fallback_operators.size() == 0 && isBlockConvertible)) { - auto input_ivalues_map = partitioning::generateRandomInputs(cfg.convert_info.inputs, first_use_types); - auto graph_and_mapping = ConstructFallbackGraph(new_mod, g->block(), input_ivalues_map, cfg, static_params); + (!(cfg.lower_info.forced_fallback_modules.size() == 0 && + cfg.partition_info.forced_fallback_operators.size() == 0 && isBlockConvertible) + || outputIsCollection)) { + + auto collection_input_ivalues_map = partitioning::generateRandomInputs(cfg.convert_info.collection_input_spec_map, first_use_types); + auto graph_and_mapping = ConstructFallbackGraph(new_mod, g->block(), collection_input_ivalues_map, cfg, static_params); new_g = graph_and_mapping.first; LOG_INFO("Segmented Graph: " << *new_g); @@ -444,6 +464,7 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg) TORCHTRT_CHECK( conversion::VerifyConverterSupportForBlock(g->block()), "Not all operations in graph are supported by the compiler"); + // TODO find the right auto engine = conversion::ConvertBlockToEngine(g->block(), cfg.convert_info, static_params); AddEngineToGraph(new_mod, new_g, engine, cuda_device); } diff --git a/core/compiler.h b/core/compiler.h index c1bb85aa3b..c8dc85020b 100644 --- a/core/compiler.h +++ b/core/compiler.h @@ -8,13 +8,15 @@ #include "core/partitioning/partitioning.h" #include "core/runtime/runtime.h" #include "torch/csrc/jit/api/module.h" +#include "torch/csrc/jit/ir/ir.h" namespace torch_tensorrt { namespace core { struct CompileSpec { - CompileSpec(std::vector inputs) : inputs(inputs) {} - std::vector inputs; + CompileSpec(std::vector inputs) : graph_inputs(inputs) {} + CompileSpec(torch::jit::IValue& input_signature) : graph_inputs(input_signature) {} + ir::GraphInputs graph_inputs; conversion::ConversionInfo convert_info; lowering::LowerInfo lower_info; partitioning::PartitionInfo partition_info; diff --git a/core/conversion/conversion.cpp b/core/conversion/conversion.cpp index 8da79b13a3..bddd8fd835 100644 --- a/core/conversion/conversion.cpp +++ b/core/conversion/conversion.cpp @@ -134,7 +134,10 @@ void AddLayer(ConversionCtx* ctx, const torch::jit::Node* n) { void AddInputs( ConversionCtx* ctx, c10::ArrayRef inputs, - std::unordered_map& input_specs) { + ConversionInfo& conversion_info) { + std::unordered_map& input_specs = conversion_info.inputs; + std::unordered_map> collection_input_spec = conversion_info.collection_input_spec_map; + std::vector input_tensors; for (auto in : inputs) { // Disregarding inputs that are not tensors @@ -162,9 +165,15 @@ void AddInputs( for (auto input : input_tensors) { const torch::jit::Value* in = input; TORCHTRT_CHECK( - input_specs.find(in) != input_specs.end(), + input_specs.find(in) != input_specs.end() || collection_input_spec.find(in) != collection_input_spec.end(), "Cannot find an input spec associated with input: " << in->debugName()); - ir::Input& spec = input_specs.find(in)->second; + ir::Input spec; + if (input_specs.find(in) != input_specs.end()) { + spec = input_specs.find(in)->second; + } else { + spec = collection_input_spec.find(in)->second[0]; // assume input is tensor + } + // ir::Input& spec = input_specs.find(in)->second; std::string name = std::string("input_") + std::to_string(ctx->num_inputs); LOG_INFO( @@ -184,6 +193,7 @@ void AddInputs( ctx->input_is_dynamic = true; } + // mapping torch Value to tensorrt iTensor ctx->value_tensor_map[in] = trt_in; ctx->num_inputs += 1; } @@ -404,7 +414,7 @@ void ConvertBlockToNetDef( auto inputs = b->inputs(); AddParamsToCtxValueMap(ctx, static_params); - AddInputs(ctx, inputs, build_info.inputs); + AddInputs(ctx, inputs, build_info); auto nodes = b->nodes(); @@ -545,6 +555,15 @@ std::set ConvertableOpsInBlock(const torch::jit::Block* b) { return convertable_ops; } +bool OutputIsCollection(const torch::jit::Block* b) { + for (auto out: b->outputs()) { + if(out->type()->kind() == torch::jit::TypeKind::TupleType || out->type()->kind() == torch::jit::TypeKind::ListType) { + return true; + } + } + return false; +} + bool VerifyConverterSupportForBlock(const torch::jit::Block* b, bool suppress_errors) { auto unsupported_ops = GetUnsupportedOpsInBlock(b); if (unsupported_ops.size() != 0) { diff --git a/core/conversion/conversion.h b/core/conversion/conversion.h index 58c06b42a3..a578c4288e 100644 --- a/core/conversion/conversion.h +++ b/core/conversion/conversion.h @@ -13,6 +13,7 @@ namespace conversion { struct ConversionInfo { ir::InputSpecMap inputs; + ir::CollectionInputSpecMap collection_input_spec_map; BuilderSettings engine_settings; }; @@ -25,6 +26,8 @@ std::string ConvertBlockToEngine( bool OpSupported(const torch::jit::Node* n); +bool OutputIsCollection(const torch::jit::Block* b); + bool VerifyConverterSupportForBlock(const torch::jit::Block* b, bool suppress_errors = false); c10::optional EvaluateNode( diff --git a/core/conversion/evaluators/aten.cpp b/core/conversion/evaluators/aten.cpp index 30cdeaa46a..fde9e71e66 100644 --- a/core/conversion/evaluators/aten.cpp +++ b/core/conversion/evaluators/aten.cpp @@ -264,21 +264,6 @@ auto aten_registrations TORCHTRT_UNUSED = }, EvalOptions().validSchemas( {"aten::size(Tensor self) -> (int[])", "aten::size.int(Tensor self, int dim) -> (int)"})}) - .evaluator({c10::Symbol::fromQualString("aten::__getitem__"), - [](const torch::jit::Node* n, kwargs& args) -> c10::optional { - auto list = args.at(n->input(0)).IValue()->to>(); - auto idx = args.at(n->input(1)).unwrapToInt(); - - const int64_t list_size = list.size(); - const int64_t normalized_idx = normalizeIndex(idx, list_size); - TORCHTRT_CHECK( - normalized_idx >= 0 || normalized_idx < list_size, - "List index out of range (aten::__getitem__)"); - return list.get(normalized_idx); - }, - EvalOptions().validSchemas({ - "aten::__getitem__.t(t[](a) list, int idx) -> (t(*))", - })}) .evaluator({c10::Symbol::fromQualString("aten::append"), [](const torch::jit::Node* n, kwargs& args) -> c10::optional { auto list = args.at(n->input(0)).IValue()->to>(); diff --git a/core/ir/BUILD b/core/ir/BUILD index a613aaf489..2e9ef7e6a8 100644 --- a/core/ir/BUILD +++ b/core/ir/BUILD @@ -15,7 +15,8 @@ cc_library( srcs = [ "ir.cpp", "Input.cpp", - "StaticParams.cpp" + "StaticParams.cpp", + "GraphInputs.cpp" ], deps = [ "@tensorrt//:nvinfer", diff --git a/core/ir/GraphInputs.cpp b/core/ir/GraphInputs.cpp new file mode 100644 index 0000000000..645624f2f1 --- /dev/null +++ b/core/ir/GraphInputs.cpp @@ -0,0 +1,75 @@ +#include "core/ir/ir.h" +#include "core/util/prelude.h" + +namespace torch_tensorrt { +namespace core { +namespace ir { + +void flatten_dfs(std::vector& flattened_inputs, std::vector>& collection_inputs, + torch::jit::IValue input_ivalue, int level, int index) { + if (input_ivalue.isTuple()) { + auto input_tuple = input_ivalue.toTuple(); + int idx = 0; + if (level == 0) { + collection_inputs.resize(input_tuple->elements().size()); + } + for (auto item: input_tuple->elements()) { + torch::jit::IValue converted_item; + int cur_idx = level < 1 ? idx: index; + flatten_dfs(flattened_inputs, collection_inputs, item, level+1, cur_idx); + idx++; + } + } else if(input_ivalue.isList()) { + auto input_list = input_ivalue.toList().vec(); + if (level == 0) { + collection_inputs.resize(input_list.size()); + } + c10::TypePtr type = input_list[0].type(); + auto converted_elements = c10::impl::GenericList(type); + int idx = 0; + for (auto item: input_list) { + int cur_idx = level < 1 ? idx: index; + flatten_dfs(flattened_inputs, collection_inputs, item, level+1, cur_idx); + idx++; + } + } else if(input_ivalue.isCustomClass()) { + torch_tensorrt::core::ir::Input cur_input = *(input_ivalue.toCustomClass()); + flattened_inputs.push_back(cur_input); + if (level == 0) { // a single value like A + collection_inputs.resize(1); + collection_inputs[0].push_back(cur_input); + } else if (level == 1) { // like A in [A, A] or [(B, B), A] + collection_inputs[index].push_back(cur_input); + } else if (level == 2) { // like A in [(A, A), C] + collection_inputs[index].push_back(cur_input); + } else {// only support 2 level + LOG_ERROR("Input nesting depth exceeds currently supported depth (3), use 1 level: [A, B], or 2 level: [A, (B, C)]"); + } + } +} + + +GraphInputs::GraphInputs(std::vector inputs_) { + LOG_DEBUG("Construct GraphInput with ir::Input"); + inputs = inputs_; + collection_inputs.resize(inputs_.size()); + for (int i = 0; i < inputs_.size(); i++) { + collection_inputs[i].push_back(inputs_[i]); + } +} + +GraphInputs::GraphInputs(torch::jit::IValue& input_signature_) { + LOG_DEBUG("Construct GraphInput with IValue"); + + std::vector flattened_inputs; + std::vector> collection_inputs_; + + flatten_dfs(flattened_inputs, collection_inputs_, input_signature_, 0, 0); + inputs = flattened_inputs; + input_signature = input_signature_; + collection_inputs = collection_inputs_; +} + +} // namespace ir +} // namespace core +} // namespace torch_tensorrt \ No newline at end of file diff --git a/core/ir/StaticParams.cpp b/core/ir/StaticParams.cpp index ac16c72d9f..0073ad2888 100644 --- a/core/ir/StaticParams.cpp +++ b/core/ir/StaticParams.cpp @@ -11,7 +11,10 @@ StaticParams get_static_params(c10::ArrayRef inputs, std::ve StaticParams static_params; auto param_it = params.begin(); for (auto in : inputs) { - if (in->type() != c10::TensorType::get() && param_it != params.end()) { + // handle TensorType, TupleType and ListType + if (in->type() != c10::TensorType::get() && + in->type()->kind() != torch::jit::TypeKind::TupleType && + in->type()->kind() != torch::jit::TypeKind::ListType && param_it != params.end()) { static_params[in] = *param_it; ++param_it; } diff --git a/core/ir/ir.cpp b/core/ir/ir.cpp index 1c1813ea5f..061327c6bc 100644 --- a/core/ir/ir.cpp +++ b/core/ir/ir.cpp @@ -13,6 +13,14 @@ InputSpecMap associate_specs_with_inputs( return pair_input_vals_with_specs(tensor_inputs, specs); } +CollectionInputSpecMap associate_specs_with_collection_inputs( + std::shared_ptr& g, + ir::GraphInputs graph_inputs, + StaticParams& static_params) { + auto tensor_inputs = get_collection_inputs(g, static_params); + return pair_input_vals_with_specs_collection(tensor_inputs, graph_inputs.collection_inputs); +} + InputSpecMap pair_input_vals_with_specs(std::vector vals, std::vector specs) { TORCHTRT_CHECK( vals.size() == specs.size(), @@ -27,12 +35,28 @@ InputSpecMap pair_input_vals_with_specs(std::vector va return a; } +CollectionInputSpecMap pair_input_vals_with_specs_collection(std::vector vals, std::vector>& specs) { + TORCHTRT_CHECK( + vals.size() == specs.size(), + "Expected dimension specifications for all input tensors" + << ", but found " << vals.size() << " input tensors and " << specs.size() << " dimension specs"); + + CollectionInputSpecMap a; + for (size_t i = 0; i < vals.size(); i++) { + LOG_DEBUG("Paring " << i << ": " << vals[i]->debugName() << " : " << specs[i]); + a.insert({vals[i], specs[i]}); + } + return a; +} + std::vector get_tensor_inputs( std::shared_ptr& g, StaticParams& static_params) { std::vector input_tensors; auto inputs = g->inputs(); + LOG_DEBUG("Raw inputs size of get_tensor_inputs: " << inputs.size()); for (auto in : inputs) { + LOG_DEBUG("Handle input of debug name: " << in->debugName()); // Disregarding inputs that are not tensors or are static // // Ex. @@ -40,6 +64,29 @@ std::vector get_tensor_inputs( // input.1:Tensor -> used if (in->type()->isSubtypeOf(c10::TensorType::get()) && static_params.find(in) == static_params.end()) { input_tensors.push_back(in); + } + } + return input_tensors; +} + +std::vector get_collection_inputs( + std::shared_ptr& g, + StaticParams& static_params) { + std::vector input_tensors; + auto inputs = g->inputs(); + LOG_DEBUG("Raw inputs size of get_collection_inputs: " << inputs.size()); + for (auto in : inputs) { + LOG_DEBUG("Handle input of debug name: " << in->debugName()); + if (in->type()->isSubtypeOf(c10::TensorType::get()) && static_params.find(in) == static_params.end()) { + input_tensors.push_back(in); + } else if (in->type()->kind() == torch::jit::TypeKind::TupleType && static_params.find(in) == static_params.end()) { + // } else if (in->type()->isSubtypeOf(c10::TupleType::create()) && static_params.find(in) == static_params.end()) { + input_tensors.push_back(in); // push original tuple + at::ArrayRef unpack_tuple = torch::jit::createTupleUnpack(in); + LOG_DEBUG("get_collection_inputs, tuple size " << unpack_tuple.size()); + } else if (in->type()->kind() == torch::jit::TypeKind::ListType && static_params.find(in) == static_params.end()) { + LOG_DEBUG("get_collection_inputs, list use size " << in->uses().size()); + input_tensors.push_back(in); // push original list } } return input_tensors; @@ -52,9 +99,6 @@ c10::optional get_value_first_calc_dtype_opt(torch::jit::Block* auto b_ins = b->inputs(); std::unordered_set b_in_set(b_ins.begin(), b_ins.end()); - TORCHTRT_ASSERT( - in->type() == c10::TensorType::get(), "Input is not a tensor, cannot check for dtype based on calculation"); - auto consumers = in->uses(); auto search_list = std::vector(consumers.begin(), consumers.end()); @@ -142,16 +186,57 @@ c10::optional get_value_first_calc_dtype_opt(torch::jit::Block* TypeMap get_block_first_calc_dtypes_opt(torch::jit::Block* b) { TypeMap types; - for (auto i : b->inputs()) { if (i->type() == c10::TensorType::get()) { torch::jit::Value* in = i; types.insert({in, get_value_first_calc_dtype_opt(b, i)}); + } else if(i->type()->cast()) { + // make sure very time get the same ptr + at::ArrayRef unpack_tuple = torch::jit::createTupleUnpack(i); + LOG_DEBUG("Tuple size " << unpack_tuple.size()); + for (auto item: unpack_tuple) { + torch::jit::Value* in = item; + types.insert({in, get_value_first_calc_dtype_opt(b, i)}); + } + } else if(i->type()->isSubtypeOf(c10::ListType::ofTensors())) { + LOG_INFO("Unsupported type of c10::ListType::ofTensors()"); } } return types; } +CollectionTypeMap get_block_first_calc_dtypes_opt_collection(torch::jit::Block* b) { + CollectionTypeMap types; + for (auto i : b->inputs()) { + if (i->type() == c10::TensorType::get()) { + torch::jit::Value* in = i; + types.insert({in, {get_value_first_calc_dtype_opt(b, i)}}); + + } else if(i->type()->kind() == torch::jit::TypeKind::TupleType) { + // TODO: to evaluate the data type of tuple element + // make sure very time get the same ptr + // c10::optional tp = get_value_first_calc_dtype_opt(b, i); + at::ArrayRef unpack_tuple = torch::jit::createTupleUnpack(i); + // TODO: calculate the tuple element type, currently we use {} as default datatype + // std::vector> dytpes(unpack_tuple.size(), tp); + std::vector> dytpes(unpack_tuple.size()); + types.insert({i, dytpes}); // insert an empty + + } else if(i->type()->kind() == torch::jit::TypeKind::ListType) { + // TODO: to decide the size of list and type of list element + LOG_DEBUG("get_block_first_calc_dtypes_opt ListType: use size " << i->uses().size()); + c10::optional tp = get_value_first_calc_dtype_opt(b, i); + // std::vector> dytpes(i->uses().size()); + std::vector> dytpes(i->uses().size(), tp); + types.insert({i, dytpes}); // insert an empty + } + } + return types; +} + +static auto core_input_container = + torch::class_("_torch_tensorrt_core_ir", "Input").def(torch::init<>()); + } // namespace ir } // namespace core } // namespace torch_tensorrt diff --git a/core/ir/ir.h b/core/ir/ir.h index 2d9acccc69..966c747176 100644 --- a/core/ir/ir.h +++ b/core/ir/ir.h @@ -11,9 +11,8 @@ namespace torch_tensorrt { namespace core { namespace ir { -struct Input { - // Input(std::vector shape); - // Input(std::vector min_shape, std::vector opt_shape, std::vector max_shape); +struct Input : torch::CustomClassHolder { + Input() {}; Input( std::vector shape, nvinfer1::DataType dtype = nvinfer1::DataType::kFLOAT, @@ -36,27 +35,52 @@ struct Input { nvinfer1::Dims opt; nvinfer1::DataType dtype; nvinfer1::TensorFormat format; + int id; }; +// Add to spec +struct GraphInputs { + GraphInputs(std::vector inputs); + GraphInputs(torch::jit::IValue& input_signature); + torch::jit::IValue input_signature; // nested Input, full input spec + std::vector inputs; // flattend Input + std::vector> collection_inputs; // only support two layer nesting, e.g. ((a, b), [c, d], e) +}; + +typedef std::pair GraphIO; // Graph input output mapping + using StaticParams = std::map; StaticParams get_static_params(c10::ArrayRef inputs, std::vector params); using InputSpecMap = std::unordered_map; +using CollectionInputSpecMap = std::unordered_map>; +std::vector get_tensor_inputs( + std::shared_ptr& g, + StaticParams& static_params); InputSpecMap associate_specs_with_inputs( std::shared_ptr& g, std::vector specs, StaticParams& static_params); +CollectionInputSpecMap associate_specs_with_collection_inputs( + std::shared_ptr& g, + ir::GraphInputs graph_inputs, + StaticParams& static_params); InputSpecMap pair_input_vals_with_specs(std::vector vals, std::vector specs); +CollectionInputSpecMap pair_input_vals_with_specs_collection(std::vector vals, std::vector>& specs); std::vector get_tensor_inputs( std::shared_ptr& g, StaticParams& static_params); +std::vector get_collection_inputs( + std::shared_ptr& g, + StaticParams& static_params); using TypeMap = std::unordered_map>; +using CollectionTypeMap = std::unordered_map>>; c10::optional get_value_first_calc_dtype_opt(torch::jit::Block* b, torch::jit::Value* in); ir::TypeMap get_block_first_calc_dtypes_opt(torch::jit::Block* b); - +ir::CollectionTypeMap get_block_first_calc_dtypes_opt_collection(torch::jit::Block* b); } // namespace ir } // namespace core } // namespace torch_tensorrt diff --git a/core/lowering/lowering.cpp b/core/lowering/lowering.cpp index d3296c347c..8bbae296c3 100644 --- a/core/lowering/lowering.cpp +++ b/core/lowering/lowering.cpp @@ -33,7 +33,6 @@ void LowerGraph(std::shared_ptr& g, LowerInfo lower_info) { torch::jit::InlineFunctionalGraphs(g); torch::jit::PeepholeOptimize(g, false); torch::jit::FuseLinear(g); - torch::jit::LowerAllTuples(g); if (!lower_info.disable_cse) { torch::jit::EliminateCommonSubexpression(g); } diff --git a/core/partitioning/partitioning.cpp b/core/partitioning/partitioning.cpp index d171ae15c0..93ee4ab2a6 100644 --- a/core/partitioning/partitioning.cpp +++ b/core/partitioning/partitioning.cpp @@ -419,6 +419,15 @@ bool checkLoopEvaluatable(torch::jit::Node* n) { return compile_to_trt; } +bool is_collection(torch::jit::Node* n) { + for (auto out: n->outputs()) { + if(out->type()->kind() == torch::jit::TypeKind::TupleType || out->type()->kind() == torch::jit::TypeKind::ListType) { + return true; + } + } + return false; +} + bool should_run_in_trt(torch::jit::Node* n, const std::unordered_set& torch_ops) { // If the op is not supported by the conversion phase it should run in PyTorch if (!conversion::OpSupported(n)) { @@ -459,18 +468,19 @@ PartitionedGraph segment_graph(torch::jit::Block* block, const PartitionInfo& pa partition_info.forced_fallback_operators.begin(), partition_info.forced_fallback_operators.end()); auto nodes = block->nodes(); + auto reverse_nodes = nodes.reverse(); // merge from output side to input side PartitionedGraph segmented_blocks; // segment the nodes std::vector in_prog_trt_blk_nodes, in_prog_pyt_blk_nodes; - for (const auto n : nodes) { + for (const auto n : reverse_nodes) { // Skip constant nodes as they are resources for both kinds of modules if (n->kind() == torch::jit::prim::Constant) { continue; } - - if (should_run_in_trt(n, forced_fallback_ops)) { - in_prog_trt_blk_nodes.push_back(n); + // the outputs of trt subgraph shouldn't be collections + if (should_run_in_trt(n, forced_fallback_ops) && !(in_prog_trt_blk_nodes.size() == 0 && is_collection(n))) { + in_prog_trt_blk_nodes.insert(in_prog_trt_blk_nodes.begin(), n); // If there is an active PyTorch block and we have passed the threshold for a valid TRT // block then segment and reset the active PyTorch block @@ -505,14 +515,14 @@ PartitionedGraph segment_graph(torch::jit::Block* block, const PartitionInfo& pa finalize_block(segmented_blocks, SegmentedBlock::kTorch, in_prog_pyt_blk_nodes); } if (checkLoopEvaluatable(n)) { - in_prog_trt_blk_nodes.push_back(n); + in_prog_trt_blk_nodes.insert(in_prog_trt_blk_nodes.begin(), n); } else { auto loop_node = std::vector{n}; finalize_block(segmented_blocks, SegmentedBlock::kTorch, loop_node); } continue; } - in_prog_pyt_blk_nodes.push_back(n); + in_prog_pyt_blk_nodes.insert(in_prog_pyt_blk_nodes.begin(), n); } } @@ -527,7 +537,7 @@ PartitionedGraph segment_graph(torch::jit::Block* block, const PartitionInfo& pa in_prog_pyt_blk_nodes.end(), in_prog_trt_blk_nodes.begin(), in_prog_trt_blk_nodes.end()); finalize_block(segmented_blocks, SegmentedBlock::kTorch, in_prog_pyt_blk_nodes); } - + std::reverse(segmented_blocks.begin(), segmented_blocks.end()); return segmented_blocks; } diff --git a/core/partitioning/shape_analysis.cpp b/core/partitioning/shape_analysis.cpp index 96b1312062..961831cb47 100644 --- a/core/partitioning/shape_analysis.cpp +++ b/core/partitioning/shape_analysis.cpp @@ -8,27 +8,56 @@ namespace torch_tensorrt { namespace core { namespace partitioning { +at::Tensor generateSingleInput(ir::Input& input, c10::optional& type_opt) { + auto cur_shape = input.input_shape; + std::vector shape; + shape.insert(shape.begin(), std::begin(cur_shape.d), std::begin(cur_shape.d) + cur_shape.nbDims); + // auto type_opt = types[input.first][i]; + auto type = at::kFloat; + if (type_opt) { + type = type_opt.value(); + } else { + LOG_WARNING("Input type for doing shape analysis could not be determined, defaulting to F32"); + } + auto in = at::randint(5, shape, {at::kCUDA}).to(type); + // ivalue_map[input.first] = in.clone(); + return in; +} + std::unordered_map generateRandomInputs( - std::unordered_map& inputs, - std::unordered_map>& types) { + std::unordered_map>& inputs, + std::unordered_map>>& types) { + // generate random inputs for running pytorch segments std::unordered_map ivalue_map; - uint64_t in_i = 0; + for (auto& input : inputs) { - auto cur_shape = input.second.input_shape; - std::vector shape; - shape.insert(shape.begin(), std::begin(cur_shape.d), std::begin(cur_shape.d) + cur_shape.nbDims); - auto type_opt = types[input.first]; - auto type = at::kFloat; - if (type_opt) { - type = type_opt.value(); + + if (input.first->type()->kind() == torch::jit::TypeKind::ListType) { + // create list + std::vector list; + c10::TypePtr elementType = c10::TensorType::get(); + auto generic_list = c10::impl::GenericList(elementType); + for (int i = 0; i < input.second.size(); i++) { + auto in = generateSingleInput(input.second[i], types[input.first][i]); + generic_list.push_back(in.clone()); + } + ivalue_map[input.first] = c10::IValue(generic_list); + } else if (input.first->type()->kind() == torch::jit::TypeKind::TupleType) { + // create tuple + std::vector list; + for (int i = 0; i < input.second.size(); i++) { + auto in = generateSingleInput(input.second[i], types[input.first][i]); + list.push_back(in.clone()); + } + auto tuple = c10::ivalue::Tuple::create(list); // create tuple ptr + ivalue_map[input.first] = c10::IValue(tuple); } else { - LOG_WARNING("Input type for doing shape analysis could not be determined, defaulting to F32"); + auto in = generateSingleInput(input.second[0], types[input.first][0]); + ivalue_map[input.first] = in.clone(); + } - auto in = at::randint(5, shape, {at::kCUDA}).to(type); - ivalue_map[input.first] = in.clone(); - in_i++; } return ivalue_map; } @@ -79,8 +108,10 @@ void getSegmentsOutputByRunning( } else if (input->type()->isSubtypeOf(torch::jit::BoolType::get())) { jit_inputs_ivalues.push_back(ivalues_maps[input].toBool()); } else if (input->type()->kind() == torch::jit::TypeKind::ListType) { - jit_inputs_ivalues.push_back(ivalues_maps[input].toList()); + // create list + jit_inputs_ivalues.push_back(ivalues_maps[input].toList());; } else if (input->type()->kind() == torch::jit::TypeKind::TupleType) { + // create tuple jit_inputs_ivalues.push_back(ivalues_maps[input].toTuple()); } else if (input->type()->kind() == torch::jit::TypeKind::NumberType) { jit_inputs_ivalues.push_back(ivalues_maps[input].toScalar()); @@ -141,6 +172,7 @@ void getSegmentsOutputByRunning( } input_types.push_back(cur_ivalue.toTensor().scalar_type()); } + // TODO: tuple and list inputs in subgraph } seg_block.register_inshapes(input_shapes); diff --git a/core/partitioning/shape_analysis.h b/core/partitioning/shape_analysis.h index 0626490222..2654699a1d 100644 --- a/core/partitioning/shape_analysis.h +++ b/core/partitioning/shape_analysis.h @@ -6,9 +6,10 @@ namespace torch_tensorrt { namespace core { namespace partitioning { + std::unordered_map generateRandomInputs( - std::unordered_map& input_ranges, - std::unordered_map>& input_types); + std::unordered_map>& input_ranges, + std::unordered_map>>& input_types); void runShapeAnalysis( std::vector& segmented_blocks, diff --git a/cpp/include/torch_tensorrt/torch_tensorrt.h b/cpp/include/torch_tensorrt/torch_tensorrt.h index ace05d33f5..e19b9f1408 100644 --- a/cpp/include/torch_tensorrt/torch_tensorrt.h +++ b/cpp/include/torch_tensorrt/torch_tensorrt.h @@ -14,6 +14,7 @@ #include #include #include +#include "torch/custom_class.h" // Just include the .h? #ifndef DOXYGEN_SHOULD_SKIP_THIS @@ -363,7 +364,7 @@ class TORCHTRT_API TensorFormat { * signifying a static input shape or a set of three input shapes representing * the min, optiminal and max input shapes allowed for the engine. */ -struct TORCHTRT_API Input { +struct TORCHTRT_API Input : torch::CustomClassHolder{ /// Minimum acceptable input size into the engine std::vector min_shape; /// Optimal input size into the engine (size optimized for given kernels accept any size in min max range) @@ -378,6 +379,7 @@ struct TORCHTRT_API Input { /// Expected tensor format for the input TensorFormat format; + Input() {} /** * @brief Construct a new Input spec object for static input size from * vector, optional arguments allow the user to configure expected input shape @@ -512,6 +514,16 @@ struct TORCHTRT_API Input { bool input_is_dynamic; }; +/** + * @brief A struct to hold complex inputs + * + * This struct can either hold a complex inputs of shape or a flattened one, + */ +struct TORCHTRT_API GraphInputs { + torch::jit::IValue input_signature; // nested Input, full input spec + std::vector inputs; // flatten input spec +}; + /** * @brief Get the build information for the library including the dependency * versions @@ -579,18 +591,22 @@ struct TORCHTRT_API CompileSpec { * * @param inputs */ - CompileSpec(std::vector inputs) : inputs(std::move(inputs)) {} - - // Defaults should reflect TensorRT defaults for BuilderConfig + CompileSpec(std::vector inputs); /** - * @brief Specifications for inputs to the engine, can either be a single size or a range defined by min, opt and max - * sizes Users can also specify expected input type as well as tensor memory format + * @brief Construct a new Extra Info object from IValue. + * The IValue store a complex Input * - * Order in vector should match call order for the function + * @param input_signature */ - std::vector inputs; + CompileSpec(torch::jit::IValue input_signature); + // Defaults should reflect TensorRT defaults for BuilderConfig + /** + * @brief Specifications for inputs to the engine, can store a IValue which has stored complex Input + * or a flatened Input + */ + GraphInputs graph_inputs; /** * @brief The set of precisions TensorRT is allowed to use for kernels during compilation * diff --git a/cpp/src/compile_spec.cpp b/cpp/src/compile_spec.cpp index 3058b23ce0..9447def7e0 100644 --- a/cpp/src/compile_spec.cpp +++ b/cpp/src/compile_spec.cpp @@ -18,18 +18,67 @@ torchtrt::core::runtime::CudaDevice to_internal_cuda_device(Device device); namespace torchscript { CompileSpec::CompileSpec(std::vector> fixed_sizes) { for (auto in : fixed_sizes) { - inputs.push_back(Input(in)); + graph_inputs.inputs.push_back(Input(in)); } } CompileSpec::CompileSpec(std::vector> fixed_sizes) { for (auto in : fixed_sizes) { - inputs.push_back(Input(in)); + graph_inputs.inputs.push_back(Input(in)); + } +} + +CompileSpec::CompileSpec(std::vector inputs) { + graph_inputs.inputs = std::move(inputs); +} + +CompileSpec::CompileSpec(torch::jit::IValue input_signature) { + graph_inputs.input_signature = input_signature; +} + + + +void to_internal_input_signature(torch::jit::IValue input_ivalue, torch::jit::IValue& converted_ivalue) { + if (input_ivalue.isTuple()) { + auto input_tuple = input_ivalue.toTuple(); + std::vector converted_elements; + for (auto item: input_tuple->elements()) { + torch::jit::IValue converted_item; + to_internal_input_signature(item, converted_item); + converted_elements.push_back(converted_item); + auto tuple_ptr = c10::ivalue::Tuple::create(converted_elements); + converted_ivalue = torch::jit::IValue(tuple_ptr); + } + } else if(input_ivalue.isList()) { + auto input_list = input_ivalue.toList().vec(); + c10::TypePtr type = input_list[0].type(); + auto converted_elements = c10::impl::GenericList(type); + for (auto item: input_list) { + torch::jit::IValue converted_item; + to_internal_input_signature(item, converted_item); + converted_elements.push_back(converted_item); + } + converted_ivalue = torch::jit::IValue(converted_elements); + } else if(input_ivalue.isCustomClass()) { + torchtrt::core::ir::Input cur_input = to_internal_input(*(input_ivalue.toCustomClass())); + converted_ivalue = torch::jit::IValue(std::move(c10::make_intrusive(cur_input))); + } +} + +torchtrt::core::CompileSpec init_compile_spec(CompileSpec external) { + if (external.graph_inputs.inputs.size() > 0) { + torchtrt::core::CompileSpec internal(to_vec_internal_inputs(external.graph_inputs.inputs)); + return internal; + } else { + torch::jit::IValue converted_input_signature; + to_internal_input_signature(external.graph_inputs.input_signature, converted_input_signature); + torchtrt::core::CompileSpec internal(converted_input_signature); + return internal; } } torchtrt::core::CompileSpec to_internal_compile_spec(CompileSpec external) { - torchtrt::core::CompileSpec internal(to_vec_internal_inputs(external.inputs)); + torchtrt::core::CompileSpec internal = init_compile_spec(external); for (auto p : external.enabled_precisions) { internal.convert_info.engine_settings.enabled_precisions.insert(toTRTDataType(p)); diff --git a/cpp/src/torch_tensorrt.cpp b/cpp/src/torch_tensorrt.cpp index 42b44833de..93813190ab 100644 --- a/cpp/src/torch_tensorrt.cpp +++ b/cpp/src/torch_tensorrt.cpp @@ -52,4 +52,7 @@ void set_device(const int gpu_id) { // Want to export a much simpler (non CUDA header dependent) API torch_tensorrt::core::set_device(gpu_id); } + +static auto tensorrt_input_container = + torch::class_("_torch_tensorrt", "Input").def(torch::init<>()); } // namespace torch_tensorrt diff --git a/py/torch_tensorrt/csrc/register_tensorrt_classes.cpp b/py/torch_tensorrt/csrc/register_tensorrt_classes.cpp index 53b9fc2cdb..0a9f357c47 100644 --- a/py/torch_tensorrt/csrc/register_tensorrt_classes.cpp +++ b/py/torch_tensorrt/csrc/register_tensorrt_classes.cpp @@ -23,6 +23,13 @@ void RegisterTRTCompileSpec() { ADD_FIELD_GET_SET_REGISTRATION(TRTInputRangeTSRegistration, torch_tensorrt::pyapi::Input, input_is_dynamic); ADD_FIELD_GET_SET_REGISTRATION(TRTInputRangeTSRegistration, torch_tensorrt::pyapi::Input, explicit_set_dtype); + static auto TORCHTRT_UNUSED TRTGraphInpuTSRegistration = + torch::class_("tensorrt", "_GraphInputs") + .def(torch::init<>()) + .def("__str__", &torch_tensorrt::pyapi::GraphInputs::to_str); + + ADD_FIELD_GET_SET_REGISTRATION(TRTInputRangeTSRegistration, torch_tensorrt::pyapi::GraphInputs, input_signature); + static auto TORCHTRT_UNUSED TRTDeviceTSRegistration = torch::class_("tensorrt", "_Device") .def(torch::init<>()) diff --git a/py/torch_tensorrt/csrc/tensorrt_classes.cpp b/py/torch_tensorrt/csrc/tensorrt_classes.cpp index a89fe692bd..9d2761ba95 100644 --- a/py/torch_tensorrt/csrc/tensorrt_classes.cpp +++ b/py/torch_tensorrt/csrc/tensorrt_classes.cpp @@ -104,6 +104,11 @@ std::string Input::to_str() { return ss.str(); } +std::string GraphInputs::to_str() { + std::stringstream ss; + return ss.str(); +} + std::string to_str(DeviceType value) { switch (value) { case DeviceType::kDLA: @@ -184,13 +189,51 @@ std::string TorchFallback::to_str() { return ss.str(); } -core::CompileSpec CompileSpec::toInternalCompileSpec() { - std::vector internal_inputs; - for (auto i : inputs) { - internal_inputs.push_back(i.toInternalInput()); +void to_internal_input_signature(torch::jit::IValue input_ivalue, torch::jit::IValue& converted_ivalue) { + if (input_ivalue.isTuple()) { + auto input_tuple = input_ivalue.toTuple(); + std::vector converted_elements; + for (auto item: input_tuple->elements()) { + torch::jit::IValue converted_item; + to_internal_input_signature(item, converted_item); + converted_elements.push_back(converted_item); + auto tuple_ptr = c10::ivalue::Tuple::create(converted_elements); + converted_ivalue = torch::jit::IValue(tuple_ptr); + } + } else if(input_ivalue.isList()) { + auto input_list = input_ivalue.toList().vec(); + c10::TypePtr type = input_list[0].type(); + auto converted_elements = c10::impl::GenericList(type); + for (auto item: input_list) { + torch::jit::IValue converted_item; + to_internal_input_signature(item, converted_item); + converted_elements.push_back(converted_item); + } + converted_ivalue = torch::jit::IValue(converted_elements); + } else if(input_ivalue.isCustomClass()) { + core::ir::Input cur_input = (*(input_ivalue.toCustomClass())).toInternalInput(); + converted_ivalue = torch::jit::IValue(std::move(c10::make_intrusive(cur_input))); + } +} + +core::CompileSpec init_compile_spec(CompileSpec external) { + if (external.graph_inputs.inputs.size() > 0) { + std::vector internal_inputs; + for (auto i : external.graph_inputs.inputs) { + internal_inputs.push_back(i.toInternalInput()); + } + core::CompileSpec internal(internal_inputs); + return internal; + } else { + torch::jit::IValue converted_input_signature; + to_internal_input_signature(external.graph_inputs.input_signature, converted_input_signature); + core::CompileSpec internal(converted_input_signature); + return internal; } +} - auto info = core::CompileSpec(internal_inputs); +core::CompileSpec CompileSpec::toInternalCompileSpec() { + core::CompileSpec info = init_compile_spec(*this); for (auto p : enabled_precisions) { info.convert_info.engine_settings.enabled_precisions.insert(toTRTDataType(p)); diff --git a/py/torch_tensorrt/csrc/tensorrt_classes.h b/py/torch_tensorrt/csrc/tensorrt_classes.h index 0c80641005..7231efa0fa 100644 --- a/py/torch_tensorrt/csrc/tensorrt_classes.h +++ b/py/torch_tensorrt/csrc/tensorrt_classes.h @@ -57,6 +57,13 @@ struct Input : torch::CustomClassHolder { std::string to_str(); }; +struct GraphInputs : torch::CustomClassHolder { + torch::jit::IValue input_signature; // nested Input, full input spec + std::vector inputs; // flatten input spec + ADD_FIELD_GET_SET(input_signature, torch::jit::IValue); + std::string to_str(); +}; + enum DeviceType : int8_t { kGPU, kDLA, @@ -156,6 +163,7 @@ struct CompileSpec : torch::CustomClassHolder { ADD_FIELD_GET_SET(ptq_calibrator, nvinfer1::IInt8Calibrator*); std::vector inputs; + GraphInputs graph_inputs; nvinfer1::IInt8Calibrator* ptq_calibrator = nullptr; std::set enabled_precisions = {}; bool sparse_weights = false; diff --git a/py/torch_tensorrt/csrc/torch_tensorrt_py.cpp b/py/torch_tensorrt/csrc/torch_tensorrt_py.cpp index 6e5f333f78..8e89441f56 100644 --- a/py/torch_tensorrt/csrc/torch_tensorrt_py.cpp +++ b/py/torch_tensorrt/csrc/torch_tensorrt_py.cpp @@ -178,6 +178,12 @@ PYBIND11_MODULE(_C, m) { .def_readwrite("dtype", &Input::dtype) .def_readwrite("format", &Input::format); + py::class_(m, "GraphInputs") + .def(py::init<>()) + .def("__str__", &torch_tensorrt::pyapi::GraphInputs::to_str) + .def_readwrite("input_signature", &GraphInputs::input_signature) + .def_readwrite("inputs", &GraphInputs::inputs); + py::enum_(m, "dtype", "Enum to specifiy operating precision for engine execution") .value("float", DataType::kFloat, "32 bit floating point number") .value("float32", DataType::kFloat, "32 bit floating point number") @@ -292,6 +298,7 @@ PYBIND11_MODULE(_C, m) { .def("__str__", &torch_tensorrt::pyapi::CompileSpec::stringify) .def("_get_calibrator_handle", &CompileSpec::getPTQCalibratorHandle, "[Internal] gets a handle from a calibrator") .def_readwrite("inputs", &CompileSpec::inputs) + .def_readwrite("graph_inputs", &CompileSpec::graph_inputs) .def_readwrite("enabled_precisions", &CompileSpec::enabled_precisions) .def_readwrite("ptq_calibrator", &CompileSpec::ptq_calibrator) .def_readwrite("refit", &CompileSpec::refit) diff --git a/py/torch_tensorrt/ts/_compile_spec.py b/py/torch_tensorrt/ts/_compile_spec.py index e406096677..5c046a7d1d 100644 --- a/py/torch_tensorrt/ts/_compile_spec.py +++ b/py/torch_tensorrt/ts/_compile_spec.py @@ -5,7 +5,7 @@ from torch_tensorrt import _enums from torch_tensorrt._Input import Input from torch_tensorrt._Device import Device - +from typing import Tuple, List, Dict import warnings @@ -156,6 +156,24 @@ def _parse_torch_fallback(fallback_info: Dict[str, Any]) -> _ts_C.TorchFallback: return info +def _parse_collection_input(input_signature: Any) -> _C.GraphInputs.input_signature: + if isinstance(input_signature, tuple): + input_list = [] + for item in input_signature: + input = _parse_collection_input(item) + input_list.append(input) + return tuple(input_list) + elif isinstance(input_signature, list): + input_list = [] + for item in input_signature: + input = _parse_collection_input(item) + input_list.append(input) + return input_list + elif isinstance(input_signature, Input) or isinstance(input_signature, torch.Tensor): + input = Input._from_tensor(input_signature) if isinstance(input_signature, torch.Tensor) else input_signature + return input._to_internal() + else: + raise KeyError("Invalid Input spec") def _parse_compile_spec(compile_spec: Dict[str, Any]) -> _ts_C.CompileSpec: info = _ts_C.CompileSpec() @@ -165,14 +183,19 @@ def _parse_compile_spec(compile_spec: Dict[str, Any]) -> _ts_C.CompileSpec: ) if "inputs" in compile_spec: - if not all([isinstance(i, torch.Tensor) or isinstance(i, Input) for i in compile_spec["inputs"]]): - raise KeyError("Input specs should be either torch_tensorrt.Input or torch.Tensor, found types: {}".format( - [type(i) for i in compile_spec["inputs"]])) - - inputs = [Input._from_tensor(i) if isinstance(i, torch.Tensor) else i for i in compile_spec["inputs"]] - info.inputs = [i._to_internal() for i in inputs] + # if not all([isinstance(i, torch.Tensor) or isinstance(i, Input) for i in compile_spec["inputs"]]): + # raise KeyError("Input specs should be either torch_tensorrt.Input or torch.Tensor, found types: {}".format( + # [type(i) for i in compile_spec["inputs"]])) + + if isinstance(compile_spec["inputs"], list) and all([isinstance(i, torch.Tensor) or isinstance(i, Input) for i in compile_spec["inputs"]]): + inputs = [Input._from_tensor(i) if isinstance(i, torch.Tensor) else i for i in compile_spec["inputs"]] + # from python Input to torch_tensorrt::pyapi::Input + # info.inputs = [i._to_internal() for i in inputs] + info.graph_inputs.inputs = [i._to_internal() for i in inputs] + else: + info.graph_inputs.input_signature = _parse_collection_input(compile_spec["inputs"]) - assert (len(info.inputs) > 0), "Require at least one input definition to compile model" + assert (len(info.graph_inputs.inputs) > 0), "Require at least one input definition to compile model" if "enabled_precisions" in compile_spec: info.enabled_precisions = _parse_enabled_precisions(compile_spec["enabled_precisions"]) diff --git a/tests/core/partitioning/test_resolve_nontensor_inputs.cpp b/tests/core/partitioning/test_resolve_nontensor_inputs.cpp index a83d2330e4..e70d5d2b5d 100644 --- a/tests/core/partitioning/test_resolve_nontensor_inputs.cpp +++ b/tests/core/partitioning/test_resolve_nontensor_inputs.cpp @@ -116,11 +116,11 @@ TEST(Partitioning, ResolveNonTensorInputsCorrectly) { inputs.push_back(torch_tensorrt::core::ir::Input({16, 3, 3, 3})); inputs.push_back(torch_tensorrt::core::ir::Input({16})); - std::unordered_map inputs_map; - std::unordered_map> input_types; + std::unordered_map> inputs_map; + std::unordered_map>> input_types; for (size_t i = 0; i < g->inputs().size(); ++i) { - inputs_map.insert({g->inputs()[i], inputs[i]}); - input_types.insert({g->inputs()[i], {at::kFloat}}); + inputs_map.insert({g->inputs()[i], {inputs[i]}}); + input_types.insert({g->inputs()[i], {{at::kFloat}}}); } auto input_ivalues_map = torch_tensorrt::core::partitioning::generateRandomInputs(inputs_map, input_types); std::vector segmented_blocks = @@ -174,11 +174,11 @@ TEST(Partitioning, ResolveTensorListInputsInTrtCorrectly) { inputs.push_back(torch_tensorrt::core::ir::Input({16, 6, 3, 3})); inputs.push_back(torch_tensorrt::core::ir::Input({16})); - std::unordered_map inputs_map; - std::unordered_map> input_types; + std::unordered_map> inputs_map; + std::unordered_map>> input_types; for (size_t i = 0; i < g->inputs().size(); ++i) { - inputs_map.insert({g->inputs()[i], inputs[i]}); - input_types.insert({g->inputs()[i], {at::kFloat}}); + inputs_map.insert({g->inputs()[i], {inputs[i]}}); + input_types.insert({g->inputs()[i], {{at::kFloat}}}); } auto input_ivalues_map = torch_tensorrt::core::partitioning::generateRandomInputs(inputs_map, input_types); std::vector segmented_blocks = @@ -255,5 +255,5 @@ TEST(Partitioning, ConvertForTensorListInputsInFallbackCorrectly) { torch::jit::script::Module new_mod = torch_tensorrt::core::CompileGraph(mod, cfg); auto fallback_g = new_mod.get_method("forward").graph(); int count = count_trt_engines(fallback_g); - ASSERT_TRUE(count == 2); + ASSERT_TRUE(count == 1); } diff --git a/tests/core/partitioning/test_shape_analysis.cpp b/tests/core/partitioning/test_shape_analysis.cpp index 8effa821ae..d05f10c163 100644 --- a/tests/core/partitioning/test_shape_analysis.cpp +++ b/tests/core/partitioning/test_shape_analysis.cpp @@ -59,11 +59,11 @@ TEST(Partitioning, InferSequentialModelSegmentedBlockShapeCorrectly) { inputs.push_back(torch_tensorrt::core::ir::Input({8, 16, 3, 3})); inputs.push_back(torch_tensorrt::core::ir::Input({8})); - std::unordered_map inputs_map; - std::unordered_map> input_types; + std::unordered_map> inputs_map; + std::unordered_map>> input_types; for (size_t i = 0; i < g->inputs().size(); ++i) { - inputs_map.insert({g->inputs()[i], inputs[i]}); - input_types.insert({g->inputs()[i], {at::kFloat}}); + inputs_map.insert({g->inputs()[i], {inputs[i]}}); + input_types.insert({g->inputs()[i], {{at::kFloat}}}); } auto input_ivalues_map = torch_tensorrt::core::partitioning::generateRandomInputs(inputs_map, input_types); std::vector segmented_blocks = @@ -109,11 +109,11 @@ TEST(Partitioning, InferBranchModelSegmentedBlockShapeCorrectly) { inputs.push_back(torch_tensorrt::core::ir::Input({16, 32, 3, 3})); inputs.push_back(torch_tensorrt::core::ir::Input({16})); - std::unordered_map inputs_map; - std::unordered_map> input_types; + std::unordered_map> inputs_map; + std::unordered_map>> input_types; for (size_t i = 0; i < g->inputs().size(); ++i) { - inputs_map.insert({g->inputs()[i], inputs[i]}); - input_types.insert({g->inputs()[i], {at::kFloat}}); + inputs_map.insert({g->inputs()[i], {inputs[i]}}); + input_types.insert({g->inputs()[i], {{at::kFloat}}}); } auto input_ivalues_map = torch_tensorrt::core::partitioning::generateRandomInputs(inputs_map, input_types); std::vector segmented_blocks = diff --git a/tests/cpp/BUILD b/tests/cpp/BUILD index 3d69afba95..2d545dc8f1 100644 --- a/tests/cpp/BUILD +++ b/tests/cpp/BUILD @@ -18,7 +18,8 @@ test_suite( ":test_multiple_registered_engines", ":test_serialization", ":test_module_fallback", - ":test_example_tensors" + ":test_example_tensors", + ":test_collection" ], ) @@ -32,7 +33,8 @@ test_suite( ":test_multiple_registered_engines", ":test_serialization", ":test_module_fallback", - ":test_example_tensors" + ":test_example_tensors", + ":test_collection" ], ) @@ -122,6 +124,20 @@ cc_test( }) ) +cc_test( + name = "test_collection", + srcs = ["test_collection.cpp"], + data = [ + "//tests/modules:jit_models", + ], + deps = [ + "//tests/util", + "@googletest//:gtest_main", + ] + select({ + ":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"], + "//conditions:default": ["@libtorch//:libtorch"], + }) +) cc_test( name = "test_compiled_modules", srcs = ["test_compiled_modules.cpp"], diff --git a/tests/cpp/test_collection.cpp b/tests/cpp/test_collection.cpp new file mode 100644 index 0000000000..9308d951f4 --- /dev/null +++ b/tests/cpp/test_collection.cpp @@ -0,0 +1,363 @@ +#include +#include +#include "gtest/gtest.h" +#include "tests/util/util.h" +#include "torch/script.h" +#include "torch_tensorrt/torch_tensorrt.h" + + +TEST(CppAPITests, TestCollectionNormalInput) { + + std::string path = "tests/modules/normal_model.jit.pt"; + torch::Tensor in0 = torch::randn({1, 3, 512, 512}, torch::kCUDA).to(torch::kHalf); + std::vector inputs; + inputs.push_back(in0); + inputs.push_back(in0); + + torch::jit::Module mod; + try { + // Deserialize the ScriptModule from a file using torch::jit::load(). + mod = torch::jit::load(path); + } catch (const c10::Error& e) { + std::cerr << "error loading the model\n"; + } + mod.eval(); + mod.to(torch::kCUDA); + + + std::vector inputs_; + + for (auto in : inputs) { + inputs_.push_back(torch::jit::IValue(in.clone())); + } + + auto out = mod.forward(inputs_); + LOG_DEBUG("Finish torchscirpt forward"); + + std::vector input_range; + input_range.push_back({in0.sizes(), torch::kF16}); + input_range.push_back({in0.sizes(), torch::kF16}); + torch_tensorrt::ts::CompileSpec compile_settings(input_range); + compile_settings.require_full_compilation = true; + compile_settings.min_block_size = 1; + + // // FP16 execution + compile_settings.enabled_precisions = {torch::kHalf}; + // // Compile module + auto trt_mod = torch_tensorrt::torchscript::compile(mod, compile_settings); + LOG_DEBUG("Finish compile"); + auto trt_out = trt_mod.forward(inputs_); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(out.toTensor(), trt_out.toTensor(), 1e-5)); +} + +TEST(CppAPITests, TestCollectionTupleInput) { + + std::string path = "tests/modules/tuple_input.jit.pt"; + torch::Tensor in0 = torch::randn({1, 3, 512, 512}, torch::kCUDA).to(torch::kHalf); + + torch::jit::Module mod; + try { + // Deserialize the ScriptModule from a file using torch::jit::load(). + mod = torch::jit::load(path); + } catch (const c10::Error& e) { + std::cerr << "error loading the model\n"; + } + mod.eval(); + mod.to(torch::kCUDA); + + std::vector complex_inputs, complex_inputs_list; + std::tuple input_tuple(in0, in0); + + complex_inputs.push_back(input_tuple); + + auto out = mod.forward(complex_inputs); + LOG_DEBUG("Finish torchscirpt forward"); + + auto input_shape = torch_tensorrt::Input(in0.sizes(), torch_tensorrt::DataType::kHalf); + + auto input_shape_ivalue = torch::jit::IValue(std::move(c10::make_intrusive(input_shape))); + + + std::tuple input_shape_tuple(input_shape_ivalue, input_shape_ivalue); + + torch::jit::IValue complex_input_shape(input_shape_tuple); + std::tuple input_tuple2(complex_input_shape); + torch::jit::IValue complex_input_shape2(input_tuple2); + + + auto compile_settings = torch_tensorrt::ts::CompileSpec(complex_input_shape2); + compile_settings.require_full_compilation = false; + compile_settings.min_block_size = 1; + + // // FP16 execution + compile_settings.enabled_precisions = {torch::kHalf}; + // // Compile module + auto trt_mod = torch_tensorrt::torchscript::compile(mod, compile_settings); + LOG_DEBUG("Finish compile"); + auto trt_out = trt_mod.forward(complex_inputs); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(out.toTensor(), trt_out.toTensor(), 1e-5)); +} + + +TEST(CppAPITests, TestCollectionListInput) { + + std::string path = "tests/modules/list_input.jit.pt"; + torch::Tensor in0 = torch::randn({1, 3, 512, 512}, torch::kCUDA).to(torch::kHalf); + std::vector inputs; + inputs.push_back(in0); + + torch::jit::Module mod; + try { + // Deserialize the ScriptModule from a file using torch::jit::load(). + mod = torch::jit::load(path); + } catch (const c10::Error& e) { + std::cerr << "error loading the model\n"; + } + mod.eval(); + mod.to(torch::kCUDA); + + + std::vector inputs_; + + for (auto in : inputs) { + inputs_.push_back(torch::jit::IValue(in.clone())); + } + + std::vector complex_inputs; + auto input_list = c10::impl::GenericList(c10::TensorType::get()); + input_list.push_back(inputs_[0]); + input_list.push_back(inputs_[0]); + + torch::jit::IValue input_list_ivalue = torch::jit::IValue(input_list); + + complex_inputs.push_back(input_list_ivalue); + + + auto out = mod.forward(complex_inputs); + LOG_DEBUG("Finish torchscirpt forward"); + + auto input_shape = torch_tensorrt::Input(in0.sizes(), torch_tensorrt::DataType::kHalf); + auto input_shape_ivalue = torch::jit::IValue(std::move(c10::make_intrusive(input_shape))); + + c10::TypePtr elementType = input_shape_ivalue.type(); + auto list = c10::impl::GenericList(elementType); + list.push_back(input_shape_ivalue); + list.push_back(input_shape_ivalue); + + + torch::jit::IValue complex_input_shape(list); + std::tuple input_tuple2(complex_input_shape); + torch::jit::IValue complex_input_shape2(input_tuple2); + + auto compile_settings = torch_tensorrt::ts::CompileSpec(complex_input_shape2); + compile_settings.require_full_compilation = false; + compile_settings.min_block_size = 1; + compile_settings.torch_executed_ops.push_back("aten::__getitem__"); + + // // FP16 execution + compile_settings.enabled_precisions = {torch::kHalf}; + // // Compile module + auto trt_mod = torch_tensorrt::torchscript::compile(mod, compile_settings); + LOG_DEBUG("Finish compile"); + auto trt_out = trt_mod.forward(complex_inputs); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(out.toTensor(), trt_out.toTensor(), 1e-5)); +} + + +TEST(CppAPITests, TestCollectionTupleInputOutput) { + + std::string path = "tests/modules/tuple_input_output.jit.pt"; + + torch::Tensor in0 = torch::randn({1, 3, 512, 512}, torch::kCUDA).to(torch::kHalf); + + torch::jit::Module mod; + try { + // Deserialize the ScriptModule from a file using torch::jit::load(). + mod = torch::jit::load(path); + } catch (const c10::Error& e) { + std::cerr << "error loading the model\n"; + } + mod.eval(); + mod.to(torch::kCUDA); + + + std::vector complex_inputs, complex_inputs_list; + std::tuple input_tuple(in0, in0); + + complex_inputs.push_back(input_tuple); + + auto out = mod.forward(complex_inputs); + LOG_DEBUG("Finish torchscirpt forward"); + + auto input_shape = torch_tensorrt::Input(in0.sizes(), torch_tensorrt::DataType::kHalf); + + auto input_shape_ivalue = torch::jit::IValue(std::move(c10::make_intrusive(input_shape))); + + + std::tuple input_shape_tuple(input_shape_ivalue, input_shape_ivalue); + + torch::jit::IValue complex_input_shape(input_shape_tuple); + std::tuple input_tuple2(complex_input_shape); + torch::jit::IValue complex_input_shape2(input_tuple2); + // torch::jit::IValue complex_input_shape(list); + + auto compile_settings = torch_tensorrt::ts::CompileSpec(complex_input_shape2); + compile_settings.require_full_compilation = false; + compile_settings.min_block_size = 1; + + // compile_settings.torch_executed_ops.push_back("prim::TupleConstruct"); + + // // FP16 execution + compile_settings.enabled_precisions = {torch::kHalf}; + // // Compile module + auto trt_mod = torch_tensorrt::torchscript::compile(mod, compile_settings); + LOG_DEBUG("Finish compile"); + auto trt_out = trt_mod.forward(complex_inputs); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(out.toTuple()->elements()[0].toTensor(), trt_out.toTuple()->elements()[0].toTensor(), 1e-5)); + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(out.toTuple()->elements()[1].toTensor(), trt_out.toTuple()->elements()[1].toTensor(), 1e-5)); +} + + +TEST(CppAPITests, TestCollectionListInputOutput) { + + std::string path = "tests/modules/list_input_output.jit.pt"; + torch::Tensor in0 = torch::randn({1, 3, 512, 512}, torch::kCUDA).to(torch::kHalf); + std::vector inputs; + inputs.push_back(in0); + + torch::jit::Module mod; + try { + // Deserialize the ScriptModule from a file using torch::jit::load(). + mod = torch::jit::load(path); + } catch (const c10::Error& e) { + std::cerr << "error loading the model\n"; + } + mod.eval(); + mod.to(torch::kCUDA); + + + std::vector inputs_; + + for (auto in : inputs) { + inputs_.push_back(torch::jit::IValue(in.clone())); + } + + std::vector complex_inputs; + auto input_list = c10::impl::GenericList(c10::TensorType::get()); + input_list.push_back(inputs_[0]); + input_list.push_back(inputs_[0]); + + torch::jit::IValue input_list_ivalue = torch::jit::IValue(input_list); + + complex_inputs.push_back(input_list_ivalue); + + + auto out = mod.forward(complex_inputs); + LOG_DEBUG("Finish torchscirpt forward"); + + auto input_shape = torch_tensorrt::Input(in0.sizes(), torch_tensorrt::DataType::kHalf); + + auto input_shape_ivalue = torch::jit::IValue(std::move(c10::make_intrusive(input_shape))); + + + c10::TypePtr elementType = input_shape_ivalue.type(); + auto list = c10::impl::GenericList(elementType); + list.push_back(input_shape_ivalue); + list.push_back(input_shape_ivalue); + + + torch::jit::IValue complex_input_shape(list); + std::tuple input_tuple2(complex_input_shape); + torch::jit::IValue complex_input_shape2(input_tuple2); + + auto compile_settings = torch_tensorrt::ts::CompileSpec(complex_input_shape2); + compile_settings.require_full_compilation = false; + compile_settings.min_block_size = 1; + + // Need to skip the conversion of __getitem__ and ListConstruct + compile_settings.torch_executed_ops.push_back("aten::__getitem__"); + + // // FP16 execution + compile_settings.enabled_precisions = {torch::kHalf}; + // // Compile module + auto trt_mod = torch_tensorrt::torchscript::compile(mod, compile_settings); + LOG_DEBUG("Finish compile"); + auto trt_out = trt_mod.forward(complex_inputs); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(out.toList().vec()[0].toTensor(), trt_out.toList().vec()[0].toTensor(), 1e-5)); + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(out.toList().vec()[1].toTensor(), trt_out.toList().vec()[1].toTensor(), 1e-5)); +} + + +TEST(CppAPITests, TestCollectionComplexModel) { + + std::string path = "tests/modules/complex_model.jit.pt"; + torch::Tensor in0 = torch::randn({1, 3, 512, 512}, torch::kCUDA).to(torch::kHalf); + std::vector inputs; + inputs.push_back(in0); + + torch::jit::Module mod; + try { + // Deserialize the ScriptModule from a file using torch::jit::load(). + mod = torch::jit::load(path); + } catch (const c10::Error& e) { + std::cerr << "error loading the model\n"; + } + mod.eval(); + mod.to(torch::kCUDA); + + + std::vector inputs_; + + for (auto in : inputs) { + inputs_.push_back(torch::jit::IValue(in.clone())); + } + + std::vector complex_inputs; + auto input_list = c10::impl::GenericList(c10::TensorType::get()); + input_list.push_back(inputs_[0]); + input_list.push_back(inputs_[0]); + + torch::jit::IValue input_list_ivalue = torch::jit::IValue(input_list); + + complex_inputs.push_back(input_list_ivalue); + + + auto out = mod.forward(complex_inputs); + LOG_DEBUG("Finish torchscirpt forward"); + + auto input_shape = torch_tensorrt::Input(in0.sizes(), torch_tensorrt::DataType::kHalf); + + auto input_shape_ivalue = torch::jit::IValue(std::move(c10::make_intrusive(input_shape))); + + c10::TypePtr elementType = input_shape_ivalue.type(); + auto list = c10::impl::GenericList(elementType); + list.push_back(input_shape_ivalue); + list.push_back(input_shape_ivalue); + + + torch::jit::IValue complex_input_shape(list); + std::tuple input_tuple2(complex_input_shape); + torch::jit::IValue complex_input_shape2(input_tuple2); + + auto compile_settings = torch_tensorrt::ts::CompileSpec(complex_input_shape2); + compile_settings.require_full_compilation = false; + compile_settings.min_block_size = 1; + + // Need to skip the conversion of __getitem__ and ListConstruct + compile_settings.torch_executed_ops.push_back("aten::__getitem__"); + + // // FP16 execution + compile_settings.enabled_precisions = {torch::kHalf}; + // // Compile module + auto trt_mod = torch_tensorrt::torchscript::compile(mod, compile_settings); + LOG_DEBUG("Finish compile"); + auto trt_out = trt_mod.forward(complex_inputs); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(out.toTuple()->elements()[0].toTensor(), trt_out.toTuple()->elements()[0].toTensor(), 1e-5)); + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(out.toTuple()->elements()[1].toTensor(), trt_out.toTuple()->elements()[1].toTensor(), 1e-5)); +} \ No newline at end of file diff --git a/tests/cpp/test_example_tensors.cpp b/tests/cpp/test_example_tensors.cpp index 6561cd16a0..7e16f47f70 100644 --- a/tests/cpp/test_example_tensors.cpp +++ b/tests/cpp/test_example_tensors.cpp @@ -8,8 +8,8 @@ TEST_P(CppAPITests, InputsFromTensors) { jit_inputs_ivalues.push_back(in.clone()); trt_inputs_ivalues.push_back(in.clone()); } - - auto spec = torch_tensorrt::ts::CompileSpec({trt_inputs_ivalues[0].toTensor()}); + std::vector inputs = {trt_inputs_ivalues[0].toTensor()}; + auto spec = torch_tensorrt::ts::CompileSpec(inputs); auto trt_mod = torch_tensorrt::ts::compile(mod, spec); torch::jit::IValue trt_results_ivalues = torch_tensorrt::tests::util::RunModuleForward(trt_mod, trt_inputs_ivalues); diff --git a/tests/modules/hub.py b/tests/modules/hub.py index 7b707f5785..a2adc3ab4b 100644 --- a/tests/modules/hub.py +++ b/tests/modules/hub.py @@ -4,6 +4,7 @@ import torchvision.models as models import timm from transformers import BertModel, BertTokenizer, BertConfig +from typing import Tuple, List, Dict torch.hub._validate_not_a_forked_repo = lambda a, b, c: True @@ -217,3 +218,94 @@ def forward(self, x): traced_model = torch.jit.trace(model, [tokens_tensor, segments_tensors]) torch.jit.save(traced_model, "bert_base_uncased_traced.jit.pt") + +# Collection input/output models +class Normal(nn.Module): + def __init__(self): + super(Normal, self).__init__() + + def forward(self, x, y): + r = x + y + return r + +class TupleInput(nn.Module): + def __init__(self): + super(TupleInput, self).__init__() + + def forward(self, z: Tuple[torch.Tensor, torch.Tensor]): + r = z[0] + z[1] + return r + +class ListInput(nn.Module): + def __init__(self): + super(ListInput, self).__init__() + + def forward(self, z: List[torch.Tensor]): + r = z[0] + z[1] + return r + +class TupleInputOutput(nn.Module): + def __init__(self): + super(TupleInputOutput, self).__init__() + + def forward(self, z: Tuple[torch.Tensor, torch.Tensor]): + r1 = z[0] + z[1] + r2 = z[0] - z[1] + r = (r1, r2) + return r + +class ListInputOutput(nn.Module): + def __init__(self): + super(ListInputOutput, self).__init__() + + def forward(self, z: List[torch.Tensor]): + r1 = z[0] + z[1] + r2 = z[0] - z[1] + r = [r1, r2] + return r + +class ComplexModel(nn.Module): + def __init__(self): + super(ComplexModel, self).__init__() + self.list_model = ListInputOutput() + self.tuple_model = TupleInputOutput() + + def forward(self, z: List[torch.Tensor]): + r1 = z[0] + z[1] + r2 = z[0] - z[1] + r3 = (r1, r2) + r4 = [r2, r1] + tuple_out = self.tuple_model(r3) + list_out = self.list_model(r4) + r = (tuple_out[1], list_out[0]) + return r + +normal_model = Normal() +normal_model_ts = torch.jit.script(normal_model) +normal_model_ts.to("cuda").eval() +torch.jit.save(normal_model_ts, "normal_model.jit.pt") + +tuple_input = TupleInput() +tuple_input_ts = torch.jit.script(tuple_input) +tuple_input_ts.to("cuda").eval() +torch.jit.save(tuple_input_ts, "tuple_input.jit.pt") + +list_input = ListInput() +list_input_ts = torch.jit.script(list_input) +list_input_ts.to("cuda").eval() +torch.jit.save(list_input_ts, "list_input.jit.pt") + +tuple_input = TupleInputOutput() +tuple_input_ts = torch.jit.script(tuple_input) +tuple_input_ts.to("cuda").eval() +torch.jit.save(tuple_input_ts, "tuple_input_output.jit.pt") + +list_input = ListInputOutput() +list_input_ts = torch.jit.script(list_input) +list_input_ts.to("cuda").eval() +torch.jit.save(list_input_ts, "list_input_output.jit.pt") + +complex_model = ComplexModel() +complex_model_ts = torch.jit.script(complex_model) +complex_model_ts.to("cuda").eval() +torch.jit.save(complex_model_ts, "complex_model.jit.pt") diff --git a/tests/py/test_collection.py b/tests/py/test_collection.py new file mode 100644 index 0000000000..23e15c99b3 --- /dev/null +++ b/tests/py/test_collection.py @@ -0,0 +1,60 @@ +import torch +import torch.nn as nn +import torch_tensorrt as torchtrt +from typing import Tuple, List, Dict + +class Normal(nn.Module): + def __init__(self): + super(Normal, self).__init__() + + def forward(self, x, y): + r = x + y + return r + +class TupleInputOutput(nn.Module): + def __init__(self): + super(TupleInputOutput, self).__init__() + + def forward(self, z: Tuple[torch.Tensor, torch.Tensor]): + r1 = z[0] + z[1] + r2 = z[0] - z[1] + r = (r1, r2) + return r + +input = torch.randn((1, 3, 224, 224)).to("cuda") +normal_model = Normal() +scripted_model = torch.jit.script(normal_model) + +compile_spec = { + "inputs": [torchtrt.Input(input.shape, dtype=torch.float, format=torch.contiguous_format), + torchtrt.Input(input.shape, dtype=torch.float, format=torch.contiguous_format)], + "device": { + "device_type": torchtrt.DeviceType.GPU, + "gpu_id": 0, + }, + "enabled_precisions": {torch.float} +} + +trt_mod = torchtrt.ts.compile(scripted_model, **compile_spec) +same = (trt_mod(input, input) - scripted_model(input, input)).abs().max() +print(same.cpu()) + +# input = torch.randn((1, 3, 224, 224)).to("cuda") +# tuple_model = TupleInputOutput() +# scripted_model = torch.jit.script(tuple_model) + +# compile_spec = { +# "inputs": (torchtrt.Input(input.shape, dtype=torch.float, format=torch.contiguous_format), +# torchtrt.Input(input.shape, dtype=torch.float, format=torch.contiguous_format)), +# "device": { +# "device_type": torchtrt.DeviceType.GPU, +# "gpu_id": 0, +# }, +# "enabled_precisions": {torch.float} +# } + +# trt_mod = torchtrt.ts.compile(scripted_model, **compile_spec) +# same = (trt_mod((input, input))[0] - scripted_model((input, input))[0]).abs().max() +# print(same.cpu()) + +