Skip to content

Commit 6f61c6f

Browse files
authored
Merge pull request #1259 from pytorch/assorted_small_fixes
Assorted small fixes
2 parents 679ea21 + 460fc9b commit 6f61c6f

File tree

5 files changed

+99
-34
lines changed

5 files changed

+99
-34
lines changed

core/conversion/converters/impl/element_wise.cpp

100755100644
File mode changed.

core/conversion/converters/impl/unary.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@ namespace impl {
1111
namespace {
1212

1313
auto abs_registration TORCHTRT_UNUSED = RegisterNodeConversionPatterns().pattern(
14-
{"aten::abs (Tensor self) -> Tensor", [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
15-
auto in = args[0].ITensor();
14+
{"aten::abs(Tensor self) -> Tensor", [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
15+
auto in = args[0].ITensorOrFreeze(ctx);
1616
bool unary_supported_input = in->getType() == nvinfer1::DataType::kFLOAT ||
1717
in->getType() == nvinfer1::DataType::kHALF || in->getType() == nvinfer1::DataType::kINT8;
1818
if (unary_supported_input) {
@@ -23,6 +23,9 @@ auto abs_registration TORCHTRT_UNUSED = RegisterNodeConversionPatterns().pattern
2323
LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions());
2424
return true;
2525
} else {
26+
LOG_GRAPH(
27+
"Tensor is of unsupported type "
28+
<< in->getType() << " for IUnaryLayer::kABS. Using backup implementation via IElementWise (max(x, -x)");
2629
// For types not supported by kABS, use an elementwise implementation abs(x) = max(x, -1 * x)
2730
at::Tensor neg_one = torch::full({1}, -1).to(util::TRTDataTypeToScalarType(in->getType()));
2831
auto neg_one_const = tensor_to_const(ctx, neg_one);
@@ -50,7 +53,7 @@ auto abs_registration TORCHTRT_UNUSED = RegisterNodeConversionPatterns().pattern
5053
auto unary##_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns().pattern( \
5154
{"aten::" #unary "(Tensor self) -> Tensor", \
5255
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { \
53-
auto in = args[0].ITensor(); \
56+
auto in = args[0].ITensorOrFreeze(ctx); \
5457
auto unary = ctx->net->addUnary(*in, nvinfer1::UnaryOperation::trt_type); \
5558
\
5659
TORCHTRT_CHECK(unary, "Unable to create " #unary " layer from node: " << *n); \

core/conversion/evaluators/aten.cpp

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -516,7 +516,7 @@ auto aten_registrations TORCHTRT_UNUSED =
516516
auto self = args.at(n->input(0)).IValue();
517517
auto obj = args.at(n->input(1)).IValue();
518518

519-
return self->isSameIdentity(*obj);
519+
return self->is(*obj);
520520
},
521521
EvalOptions().validSchemas({
522522
"aten::__is__(t1 self, t2 obj) -> bool",
@@ -527,7 +527,7 @@ auto aten_registrations TORCHTRT_UNUSED =
527527
auto self = args.at(n->input(0)).IValue();
528528
auto obj = args.at(n->input(1)).IValue();
529529

530-
return !self->isSameIdentity(*obj);
530+
return !self->is(*obj);
531531
},
532532
EvalOptions().validSchemas({
533533
"aten::__isnot__(t1 self, t2 obj) -> bool",
@@ -806,9 +806,19 @@ auto aten_registrations TORCHTRT_UNUSED =
806806
return 0;
807807
}
808808
},
809-
EvalOptions().validSchemas({"aten::__range_length(int lo, int hi, int step) -> int"})});
809+
EvalOptions().validSchemas({"aten::__range_length(int lo, int hi, int step) -> int"})})
810+
.evaluator(
811+
{c10::Symbol::fromQualString("aten::__derive_index"),
812+
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
813+
auto idx = args.at(n->input(0)).unwrapToInt();
814+
auto start = args.at(n->input(1)).unwrapToInt();
815+
auto step = args.at(n->input(2)).unwrapToInt();
816+
return start + idx * step;
817+
},
818+
EvalOptions().validSchemas({"aten::__derive_index(int idx, int start, int step) -> int"})});
819+
810820
} // namespace
811821
} // namespace evaluators
812822
} // namespace conversion
813823
} // namespace core
814-
} // namespace torch_tensorrt
824+
} // namespace torch_tensorrt

core/conversion/evaluators/prim.cpp

Lines changed: 10 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -270,36 +270,19 @@ auto prim_registrations =
270270
.evaluator(
271271
{torch::jit::prim::TupleConstruct,
272272
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
273-
auto num_inputs = n->inputs().size();
274273
c10::IValue tuple = c10::ivalue::Tuple::create();
275-
switch (num_inputs) {
276-
case 0:
277-
tuple = c10::ivalue::Tuple::create();
278-
break;
279-
case 1:
280-
tuple = c10::ivalue::Tuple::create(std::move((*args.at(n->input(0)).IValue())));
281-
break;
282-
case 2: {
283-
tuple = c10::ivalue::Tuple::create(
284-
std::move(*(args.at(n->input(0)).IValue())), std::move(*(args.at(n->input(1)).IValue())));
285-
break;
286-
}
287-
case 3: {
288-
tuple = c10::ivalue::Tuple::create(
289-
std::move(*(args.at(n->input(0)).IValue())),
290-
std::move(*(args.at(n->input(1)).IValue())),
291-
std::move(*(args.at(n->input(2)).IValue())));
292-
break;
293-
}
294-
default: {
295-
std::vector<c10::IValue> elems;
296-
for (size_t i = 0; i < num_inputs; i++) {
297-
elems.push_back(*(args.at(n->input(i)).IValue()));
298-
}
299-
tuple = c10::ivalue::Tuple::create(std::move(elems));
300-
break;
274+
std::vector<c10::IValue> elems;
275+
for (auto in : n->inputs()) {
276+
if (args.at(in).isITensor()) {
277+
auto tensor_holder = TensorContainer();
278+
tensor_holder.hold_tensor(args.at(in).ITensor());
279+
auto ival = c10::IValue(std::move(c10::make_intrusive<TensorContainer>(tensor_holder)));
280+
elems.push_back(std::move(ival));
281+
} else {
282+
elems.push_back(*(args.at(in).IValue()));
301283
}
302284
}
285+
tuple = c10::ivalue::Tuple::create(std::move(elems));
303286
return c10::optional<torch::jit::IValue>(std::move(tuple));
304287
}})
305288
.evaluator(

tests/core/conversion/evaluators/test_aten_evaluators.cpp

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -797,3 +797,72 @@ TEST(Evaluators, PowFloatIntEvaluatesCorrectly) {
797797

798798
ASSERT_TRUE(jit_results[0] == trt_results[0]);
799799
}
800+
801+
TEST(Evaluators, DeriveIndexEvaluatesCorrectly) {
802+
const auto graph = R"IR(
803+
graph():
804+
%1 : int = prim::Constant[value=9]()
805+
%2 : int = prim::Constant[value=4]()
806+
%3 : int = prim::Constant[value=2]()
807+
%4 : int = aten::__derive_index(%1, %2, %3)
808+
return (%4))IR";
809+
810+
auto g = std::make_shared<torch::jit::Graph>();
811+
torch::jit::parseIR(graph, g.get());
812+
813+
auto jit_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {});
814+
auto trt_results = torch_tensorrt::tests::util::EvaluateGraph(g->block(), {});
815+
816+
ASSERT_TRUE(jit_results[0] == trt_results[0]);
817+
}
818+
819+
TEST(Evaluators, IsTrueEvaluatesCorrectly) {
820+
const auto graph = R"IR(
821+
graph():
822+
%1 : int = prim::Constant[value=1]()
823+
%2 : int = prim::Constant[value=1]()
824+
%4 : bool = aten::__is__(%1, %2)
825+
return (%4))IR";
826+
827+
auto g = std::make_shared<torch::jit::Graph>();
828+
torch::jit::parseIR(graph, g.get());
829+
830+
auto jit_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {});
831+
auto trt_results = torch_tensorrt::tests::util::EvaluateGraph(g->block(), {});
832+
833+
ASSERT_TRUE(jit_results[0] == trt_results[0]);
834+
}
835+
836+
TEST(Evaluators, IsFalseEvaluatesCorrectly) {
837+
const auto graph = R"IR(
838+
graph():
839+
%1 : int = prim::Constant[value=9]()
840+
%2 : None = prim::Constant()
841+
%4 : bool = aten::__is__(%1, %2)
842+
return (%4))IR";
843+
844+
auto g = std::make_shared<torch::jit::Graph>();
845+
torch::jit::parseIR(graph, g.get());
846+
847+
auto jit_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {});
848+
auto trt_results = torch_tensorrt::tests::util::EvaluateGraph(g->block(), {});
849+
850+
ASSERT_TRUE(jit_results[0] == trt_results[0]);
851+
}
852+
853+
TEST(Evaluators, IsNotTrueEvaluatesCorrectly) {
854+
const auto graph = R"IR(
855+
graph():
856+
%1 : int = prim::Constant[value=1]()
857+
%2 : None = prim::Constant()
858+
%4 : bool = aten::__isnot__(%1, %2)
859+
return (%4))IR";
860+
861+
auto g = std::make_shared<torch::jit::Graph>();
862+
torch::jit::parseIR(graph, g.get());
863+
864+
auto jit_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {});
865+
auto trt_results = torch_tensorrt::tests::util::EvaluateGraph(g->block(), {});
866+
867+
ASSERT_TRUE(jit_results[0] == trt_results[0]);
868+
}

0 commit comments

Comments
 (0)