Skip to content

feat(//core/conversion): Add support for aten::size with dynamic shaped models for Torchscript backend. #1647

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Apr 6, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion core/conversion/conversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ c10::optional<torch::jit::IValue> EvaluateNode(ConversionCtx* ctx, const torch::
return {};
}
}
auto eval = evaluators::EvalNode(n, eval_args);
auto eval = evaluators::EvalNode(ctx, n, eval_args);
return eval;
}

Expand Down
1 change: 1 addition & 0 deletions core/conversion/converters/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ cc_library(
"impl/constant_pad.cpp",
"impl/conv_deconv.cpp",
"impl/cumsum.cpp",
"impl/dual_ops.cpp",
"impl/einsum.cpp",
"impl/element_wise.cpp",
"impl/expand.cpp",
Expand Down
54 changes: 27 additions & 27 deletions core/conversion/converters/impl/shuffle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,33 +67,33 @@ static auto shuffle_registrations TORCHTRT_UNUSED =
.pattern(
{"aten::reshape(Tensor self, int[] shape) -> (Tensor)",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
auto in = args[0].ITensorOrFreeze(ctx);
auto in_shape = util::toVec(in->getDimensions());
std::vector<int64_t> new_shape;
if (ctx->input_is_dynamic) {
new_shape = util::toVec(args[1].unwrapToIntList().vec());
int nbDynamicDims = 0;
for (size_t i = 0; i < new_shape.size(); i++) {
if (in_shape[i] == -1)
nbDynamicDims++;
}
if (nbDynamicDims > 1) {
TORCHTRT_THROW_ERROR(
"Resize is currently not supported when target shape contains more than one dynamic dimension");
}
} else {
new_shape = torch::reshape(torch::rand(in_shape), args[1].unwrapToIntList().vec()).sizes().vec();
}

auto shuffle = ctx->net->addShuffle(*in);
TORCHTRT_CHECK(shuffle, "Unable to create shuffle layer from node: " << *n);
shuffle->setReshapeDimensions(util::toDims(new_shape));
shuffle->setName(util::node_info(n).c_str());

auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], shuffle->getOutput(0));
LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions());

return true;
auto in = args[0].ITensorOrFreeze(ctx);
auto in_shape = util::toVec(in->getDimensions());
std::vector<int64_t> new_shape;
nvinfer1::ITensor* shape_tensor;
if (ctx->input_is_dynamic) {
auto new_shape = args[1].unwrapToITensorList();
auto concat_layer = ctx->net->addConcatenation(new_shape.data(), new_shape.size());
TORCHTRT_CHECK(concat_layer, "Unable to create concatenation layer from node: " << *n);
concat_layer->setAxis(static_cast<int32_t>(0));
shape_tensor = concat_layer->getOutput(0);
} else {
auto new_shape = torch::reshape(torch::rand(in_shape), args[1].unwrapToIntList().vec()).sizes().vec();
}
auto shuffle = ctx->net->addShuffle(*in);
shuffle->setName(util::node_info(n).c_str());
TORCHTRT_CHECK(shuffle, "Unable to create shuffle layer from node: " << *n);

if (ctx->input_is_dynamic){
shuffle->setInput(1, *shape_tensor);
} else {
shuffle->setReshapeDimensions(util::toDims(new_shape));
}

auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], shuffle->getOutput(0));
LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions());

return true;
}})
.pattern(
{"aten::view(Tensor(a) self, int[] size) -> (Tensor(a))",
Expand Down
4 changes: 2 additions & 2 deletions core/conversion/evaluators/NodeEvaluatorRegistry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,9 @@ std::vector<std::string> getEvaluatorList() {
return get_evaluator_registry().GetRegisteredEvaluatorList();
}

c10::optional<torch::jit::IValue> EvalNode(const torch::jit::Node* n, kwargs& args) {
c10::optional<torch::jit::IValue> EvalNode(ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) {
auto evaluator = get_evaluator_registry().GetEvaluator(n);
return evaluator(n, args);
return evaluator(ctx, n, args);
}

void register_node_evaluator(torch::jit::NodeKind node_kind, EvalRegistration eval_reg) {
Expand Down
134 changes: 91 additions & 43 deletions core/conversion/evaluators/aten.cpp

Large diffs are not rendered by default.

6 changes: 3 additions & 3 deletions core/conversion/evaluators/eval_macros.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#define DEFINE_GENERIC_TWO_INPUT_EVALUATOR(name, node_kind, operation, schemas) \
auto name##_registrations TORCHTRT_UNUSED = RegisterNodeEvaluators().evaluator( \
{c10::Symbol::fromQualString(node_kind), \
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> { \
[](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> { \
if (args.at(n->input(0)).IValue()->isInt()) { \
auto a = args.at(n->input(0)).unwrapToInt(); \
if (args.at(n->input(1)).IValue()->isInt()) { \
Expand Down Expand Up @@ -80,7 +80,7 @@
#define DEFINE_ARITHMATIC_TWO_INPUT_EVALUATOR(name, node_kind, operation, schemas) \
auto name##_registrations TORCHTRT_UNUSED = RegisterNodeEvaluators().evaluator( \
{c10::Symbol::fromQualString(node_kind), \
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> { \
[](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> { \
if (args.at(n->input(0)).IValue()->isInt()) { \
auto a = args.at(n->input(0)).unwrapToInt(); \
if (args.at(n->input(1)).IValue()->isInt()) { \
Expand Down Expand Up @@ -127,7 +127,7 @@
#define DEFINE_TWO_INPUT_SIMPLE_EVALUATOR(node_kind, node_name, operation, type, schemas) \
auto node_kind##_registrations TORCHTRT_UNUSED = RegisterNodeEvaluators().evaluator( \
{c10::Symbol::fromQualString(node_name), \
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> { \
[](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> { \
auto a = args.at(n->input(0)).unwrapTo<type>(); \
auto b = args.at(n->input(1)).unwrapTo<type>(); \
return operation; \
Expand Down
6 changes: 4 additions & 2 deletions core/conversion/evaluators/evaluators.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
#include "torch/csrc/jit/ir/ir.h"

#include "core/conversion/tensorcontainer/TensorContainer.h"
#include "core/conversion/conversionctx/ConversionCtx.h"
#include "core/conversion/converters/converter_util.h"
#include "core/conversion/var/Var.h"

namespace torch_tensorrt {
Expand All @@ -33,7 +35,7 @@ inline bool constTypesOnly(kwargs& args) {
// to use the node itself to pull out arguments.
// This means that you should iterate over node inputs vs. the args
// when writing evaluators
typedef std::function<c10::optional<torch::jit::IValue>(const torch::jit::Node*, kwargs&)> NodeEvaluator;
typedef std::function<c10::optional<torch::jit::IValue>(ConversionCtx*, const torch::jit::Node*, kwargs&)> NodeEvaluator;

struct EvalOptions {
std::set<c10::TypePtr> blacklisted_output_types;
Expand Down Expand Up @@ -72,7 +74,7 @@ struct EvalRegistration {
: kind(_kind), evaluator(_evaluator), options(_options){};
};

c10::optional<torch::jit::IValue> EvalNode(const torch::jit::Node* n, kwargs& args);
c10::optional<torch::jit::IValue> EvalNode(ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args);
bool shouldEvalAtConversionTime(const torch::jit::Node* n);
std::vector<std::string> getEvaluatorList();
void register_node_evaluator(torch::jit::NodeKind node_kind, NodeEvaluator evaluator);
Expand Down
41 changes: 25 additions & 16 deletions core/conversion/evaluators/prim.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,30 +24,31 @@ auto prim_registrations =
RegisterNodeEvaluators()
.evaluator(
{torch::jit::prim::Constant,
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
[](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
if (n->output()->type()->kind() == at::FunctionType::Kind) {
return {};
}
return evaluators::toIValue(n->output());
}})
.evaluator(
{torch::jit::prim::NumToTensor,
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
[](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
return evaluators::scalar_to_tensor(args.at(n->input(0)).IValue()->toScalar());
}})
.evaluator(
{torch::jit::prim::ListUnpack,
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
[](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
// Outputs is an IValue which has list of tensors which can be found in ctx->evaluated_value_map
const torch::jit::IValue* outputs = args.at(n->input()).IValue();
auto outputVec = outputs->toList().vec();
return std::move(c10::ivalue::Tuple::create(outputVec));
}})
.evaluator(
{torch::jit::prim::ListConstruct,
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
[](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
const auto num_inputs = n->inputs().size();
if (constTypesOnly(args)) {
LOG_DEBUG("==== CONST TYPES ARGS ==== ");
c10::ListTypePtr lt = n->output()->type()->expect<c10::ListType>();
if (torch::jit::IntType::get() == lt->getElementType()) {
c10::List<int64_t> list;
Expand Down Expand Up @@ -89,6 +90,7 @@ auto prim_registrations =
return c10::optional<torch::jit::IValue>(std::move(torch::jit::IValue(list)));
}
} else {
LOG_DEBUG("==== NON CONST TYPES ==== ");
c10::ListTypePtr lt = n->output()->type()->expect<c10::ListType>();
c10::TypePtr elementType = lt->getElementType();
auto list = c10::impl::GenericList(elementType);
Expand All @@ -103,8 +105,15 @@ auto prim_registrations =
if (args.at(in).IValue()->isNone()) {
auto ival = torch::jit::IValue();
list.emplace_back(std::move(ival));
} else if (args.at(in).IValue()->isInt()) {
LOG_DEBUG("==== INT TYPE ITENSOR ==== ");
auto itensor = torch_tensorrt::core::conversion::converters::tensor_to_const(ctx, torch::tensor({args.at(in).unwrapToInt()}));
auto tensor_holder = TensorContainer();
tensor_holder.hold_tensor(itensor);
auto ival = c10::IValue(std::move(c10::make_intrusive<TensorContainer>(tensor_holder)));
list.emplace_back(std::move(ival));
} else {
list.emplace_back(std::move(args.at(in).unwrapToTensor()));
list.emplace_back(std::move(args.at(in).unwrapToTensor()));
}
}
}
Expand All @@ -113,7 +122,7 @@ auto prim_registrations =
}})
.evaluator(
{c10::Symbol::fromQualString("prim::dtype"),
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
[](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
auto input = args.at(n->input(0));
if (input.isITensor()) {
auto trt_dtype = input.ITensor()->getType();
Expand All @@ -136,7 +145,7 @@ auto prim_registrations =
})})
.evaluator(
{c10::Symbol::fromQualString("prim::min"),
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
[](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
if (n->inputs().size() == 1) {
auto a = args.at(n->input(0)).unwrapToIntList();
int64_t min = std::numeric_limits<int64_t>::max();
Expand Down Expand Up @@ -198,7 +207,7 @@ auto prim_registrations =
})})
.evaluator(
{c10::Symbol::fromQualString("prim::max"),
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
[](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
if (n->inputs().size() == 1) {
auto a = args.at(n->input(0)).unwrapToIntList();
int64_t max = std::numeric_limits<int64_t>::min();
Expand Down Expand Up @@ -260,7 +269,7 @@ auto prim_registrations =
})})
.evaluator(
{c10::Symbol::fromQualString("prim::shape"),
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
[](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
LOG_WARNING("There may be undefined behavior using dynamic shape and prim::shape");
auto tensor_var = args.at(n->input(0));
if (tensor_var.isITensor()) {
Expand All @@ -274,7 +283,7 @@ auto prim_registrations =
EvalOptions().validSchemas({"prim::shape(Tensor a) -> (int[])"})})
.evaluator(
{torch::jit::prim::TupleConstruct,
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
[](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
c10::IValue tuple = c10::ivalue::Tuple::create();
std::vector<c10::IValue> elems;
for (auto in : n->inputs()) {
Expand All @@ -292,7 +301,7 @@ auto prim_registrations =
}})
.evaluator(
{torch::jit::prim::TupleIndex,
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
[](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
// Outputs is an IValue which has list of tensors which can be found in ctx->evaluated_value_map
auto tuple = args.at(n->input(0)).IValue()->toTuple();
int64_t idx = args.at(n->input(1)).IValue()->toInt();
Expand All @@ -302,24 +311,24 @@ auto prim_registrations =
EvalOptions().validSchemas({"prim::TupleIndex(Any tup, int i) -> (Any)"})})
.evaluator(
{torch::jit::prim::TupleUnpack,
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
[](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
// Outputs is an IValue which has list of tensors which can be found in ctx->evaluated_value_map
auto output = args.at(n->input()).IValue()->toTuple();
return c10::optional<torch::jit::IValue>(std::move(output));
}})
.evaluator(
{c10::Symbol::fromQualString("prim::unchecked_cast"),
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
[](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
return *(args.at(n->input(0)).IValue());
}})
.evaluator(
{c10::Symbol::fromQualString("prim::Uninitialized"),
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
[](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
return c10::IValue::uninitialized();
}})
.evaluator(
{c10::Symbol::fromQualString("prim::RaiseException"),
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
[](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
auto exception = args.at(n->input(0)).IValue();
TORCHTRT_THROW_ERROR("Error from TorchScript: " << *exception);
return {};
Expand All @@ -328,4 +337,4 @@ auto prim_registrations =
} // namespace evaluators
} // namespace conversion
} // namespace core
} // namespace torch_tensorrt
} // namespace torch_tensorrt
20 changes: 20 additions & 0 deletions core/conversion/var/Var.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,26 @@ bool Var::isITensor() const {
}
}

bool Var::isITensorList() const {
if (type_ == Type::kITensor) {
return true;
} else {
return false;
}
}

std::vector<nvinfer1::ITensor*> Var::unwrapToITensorList() {
TORCHTRT_CHECK(
isIValue(), "Requested unwrapping of arg assuming it was an IValue, however arg type is " << type_name());
auto ivalue_list = ptr_.ivalue->toList();
std::vector<nvinfer1::ITensor*> outputs;
for (int i=0; i < ivalue_list.size(); i++){
auto element = ivalue_list.get(i).toCustomClass<TensorContainer>()->tensor();
outputs.push_back(std::move(element));
}
return outputs;
}

bool Var::isIValue() const {
if (type_ == Type::kIValue) {
return true;
Expand Down
2 changes: 2 additions & 0 deletions core/conversion/var/Var.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class Var : torch::CustomClassHolder {
c10::Scalar unwrapToScalar();
c10::List<int64_t> unwrapToIntList(c10::List<int64_t> default_val);
c10::List<int64_t> unwrapToIntList();
std::vector<nvinfer1::ITensor*> unwrapToITensorList();
c10::List<double> unwrapToDoubleList(c10::List<double> default_val);
c10::List<double> unwrapToDoubleList();
c10::List<bool> unwrapToBoolList(c10::List<bool> default_val);
Expand All @@ -58,6 +59,7 @@ class Var : torch::CustomClassHolder {

bool isIValue() const;
bool isITensor() const;
bool isITensorList() const;
bool isNone() const;
Var::Type type() const;
std::string type_name() const;
Expand Down
5 changes: 2 additions & 3 deletions py/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
numpy
pybind11==2.6.2
--extra-index-url https://download.pytorch.org/whl/nightly/cu117
torch==2.0.0.dev20230103+cu117
torchvision==0.15.0.dev20230103+cu117
torch==1.13.0
torchvision==0.14.0
--extra-index-url https://pypi.ngc.nvidia.com
tensorrt==8.5.1.7