Skip to content

Commit 28bce22

Browse files
authored
Merge pull request #1029 from blchu/bitwise_not
feat (//core/conversion) : Add converter for torch.bitwise_not
2 parents 7d84caf + e699800 commit 28bce22

File tree

5 files changed

+105
-1
lines changed

5 files changed

+105
-1
lines changed

core/conversion/converters/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ cc_library(
5454
"NodeConverterRegistry.cpp",
5555
"impl/activation.cpp",
5656
"impl/batch_norm.cpp",
57+
"impl/bitwise.cpp",
5758
"impl/cast.cpp",
5859
"impl/concat.cpp",
5960
"impl/constant.cpp",
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
#include "core/conversion/converters/converters.h"
2+
#include "core/util/prelude.h"
3+
4+
#include <torch/torch.h>
5+
6+
namespace torch_tensorrt {
7+
namespace core {
8+
namespace conversion {
9+
namespace converters {
10+
namespace impl {
11+
12+
auto bitwise_not_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns().pattern(
13+
{"aten::bitwise_not(Tensor self) -> Tensor", [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
14+
auto in = args[0].ITensorOrFreeze(ctx);
15+
nvinfer1::ILayer* out;
16+
17+
if (in->getType() == nvinfer1::DataType::kINT32) {
18+
// Integer case, using ~x = -x - 1
19+
auto neg_one = torch::tensor({-1}, util::TRTDataTypeToScalarType(in->getType()));
20+
auto neg_one_const = tensor_to_const(ctx, neg_one);
21+
auto neg = add_elementwise(
22+
ctx,
23+
nvinfer1::ElementWiseOperation::kPROD,
24+
in,
25+
neg_one_const,
26+
util::node_info(n) + std::string("_Negation"));
27+
TORCHTRT_CHECK(neg, "Unable to create prod layer from node: " << *n);
28+
out = add_elementwise(
29+
ctx,
30+
nvinfer1::ElementWiseOperation::kSUM,
31+
neg->getOutput(0),
32+
neg_one_const,
33+
util::node_info(n) + std::string("_SubOne"));
34+
TORCHTRT_CHECK(out, "Unable to create sum layer from node: " << *n);
35+
} else if (in->getType() == nvinfer1::DataType::kBOOL) {
36+
// Boolean case
37+
out = ctx->net->addUnary(*in, nvinfer1::UnaryOperation::kNOT);
38+
TORCHTRT_CHECK(out, "Unable to create logical not layer from node: " << *n);
39+
} else {
40+
LOG_ERROR("Input tensor must be 32 bit integer or boolean");
41+
return false;
42+
}
43+
44+
out->setName(util::node_info(n).c_str());
45+
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], out->getOutput(0));
46+
LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions());
47+
48+
return true;
49+
}});
50+
51+
} // namespace impl
52+
} // namespace converters
53+
} // namespace conversion
54+
} // namespace core
55+
} // namespace torch_tensorrt

tests/core/conversion/converters/BUILD

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@ converter_test(
1515
name = "test_batch_norm",
1616
)
1717

18+
converter_test(
19+
name = "test_bitwise",
20+
)
21+
1822
converter_test(
1923
name = "test_instance_norm",
2024
)
@@ -136,6 +140,7 @@ test_suite(
136140
tests = [
137141
":test_activation",
138142
":test_batch_norm",
143+
":test_bitwise",
139144
":test_instance_norm",
140145
":test_cast",
141146
":test_clone",
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
#include <string>
2+
#include "core/compiler.h"
3+
#include "gtest/gtest.h"
4+
#include "tests/util/util.h"
5+
#include "torch/csrc/jit/ir/irparser.h"
6+
7+
std::string gen_test_graph() {
8+
return R"IR(
9+
graph(%0: Tensor):
10+
%3 : Tensor = aten::bitwise_not(%0)
11+
return (%3))IR";
12+
}
13+
14+
#define test_bitwise_not(dtype) \
15+
TEST(Converters, ATenBitwiseNot##dtype##ConvertsCorrectly) { \
16+
const auto graph = gen_test_graph(); \
17+
\
18+
auto g = std::make_shared<torch::jit::Graph>(); \
19+
torch::jit::parseIR(graph, g.get()); \
20+
\
21+
at::Tensor in; \
22+
if (strcmp(#dtype, "Integer") == 0) \
23+
in = at::randint(-128, 128, {10}, {at::kCUDA}).toType(at::kInt); \
24+
if (strcmp(#dtype, "Boolean") == 0) \
25+
in = at::randint(0, 1, {10}, {at::kCUDA}).toType(at::kBool); \
26+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); \
27+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in}); \
28+
\
29+
in = at::clone(in); \
30+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); \
31+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in}); \
32+
\
33+
auto jit_int = jit_results[0].toType(at::kInt); \
34+
auto trt_int = trt_results[0].toType(at::kInt); \
35+
\
36+
ASSERT_TRUE(torch_tensorrt::tests::util::exactlyEqual(jit_int, trt_int)); \
37+
}
38+
39+
test_bitwise_not(Integer);
40+
test_bitwise_not(Boolean);
41+
42+
#undef test_bitwise_not

tests/util/run_graph_engine.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include "core/ir/ir.h"
55
#include "core/runtime/runtime.h"
66
#include "core/util/prelude.h"
7+
#include "core/util/trt_util.h"
78
#include "cuda_runtime_api.h"
89
#include "torch/csrc/jit/ir/ir.h"
910
#include "torch/csrc/jit/ir/irparser.h"
@@ -19,7 +20,7 @@ namespace util {
1920
std::vector<core::ir::Input> toInputs(std::vector<at::Tensor> ten) {
2021
std::vector<core::ir::Input> a;
2122
for (auto i : ten) {
22-
a.push_back(core::ir::Input(core::util::toVec(i.sizes())));
23+
a.push_back(core::ir::Input(core::util::toVec(i.sizes()), core::util::ScalarTypeToTRTDataType(i.scalar_type())));
2324
}
2425
return a;
2526
}

0 commit comments

Comments
 (0)