Skip to content

Commit 298c3a3

Browse files
authored
Merge pull request #1252 from inocsin/scatter
feat: support scatter.value and scatter.src
2 parents a64a3ac + 01e6541 commit 298c3a3

File tree

6 files changed

+158
-24
lines changed

6 files changed

+158
-24
lines changed

core/conversion/converters/converter_util.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,27 @@ nvinfer1::ITensor* get_slice_size(
363363
return size_itensor;
364364
}
365365

366+
nvinfer1::ITensor* scalar_to_tensor(ConversionCtx* ctx, at::Scalar s) {
367+
nvinfer1::ITensor* out;
368+
if (s.isIntegral(false)) {
369+
auto s_int = s.to<int64_t>();
370+
auto s_t = torch::tensor({s_int}).to(at::kInt);
371+
out = tensor_to_const(ctx, s_t);
372+
} else if (s.isBoolean()) {
373+
auto s_bool = s.to<bool>();
374+
auto s_t = torch::tensor({s_bool}).to(at::kBool);
375+
out = tensor_to_const(ctx, s_t);
376+
} else if (s.isFloatingPoint()) {
377+
auto other_float = s.to<float>();
378+
auto s_t = torch::tensor({other_float});
379+
out = tensor_to_const(ctx, s_t);
380+
} else {
381+
out = nullptr;
382+
TORCHTRT_THROW_ERROR("Unsupported data type for scalar. Found: (" << s.type() << ")");
383+
}
384+
return out;
385+
}
386+
366387
} // namespace converters
367388
} // namespace conversion
368389
} // namespace core

core/conversion/converters/converter_util.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,8 @@ nvinfer1::ITensor* get_slice_size(
8080
int nbdims,
8181
std::string const& name);
8282

83+
nvinfer1::ITensor* scalar_to_tensor(ConversionCtx* ctx, at::Scalar s);
84+
8385
} // namespace converters
8486
} // namespace conversion
8587
} // namespace core

core/conversion/converters/impl/element_wise.cpp

Lines changed: 1 addition & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -25,26 +25,7 @@ nvinfer1::ITensor* clamp_util(
2525
return clamp_layer_out;
2626
}
2727

28-
nvinfer1::ITensor* scalar_to_tensor(ConversionCtx* ctx, at::Scalar s) {
29-
nvinfer1::ITensor* out;
30-
if (s.isIntegral(false)) {
31-
auto s_int = s.to<int64_t>();
32-
auto s_t = torch::tensor({s_int}).to(at::kInt);
33-
out = tensor_to_const(ctx, s_t);
34-
} else if (s.isBoolean()) {
35-
auto s_bool = s.to<bool>();
36-
auto s_t = torch::tensor({s_bool}).to(at::kBool);
37-
out = tensor_to_const(ctx, s_t);
38-
} else if (s.isFloatingPoint()) {
39-
auto other_float = s.to<float>();
40-
auto s_t = torch::tensor({other_float});
41-
out = tensor_to_const(ctx, s_t);
42-
} else {
43-
out = nullptr;
44-
TORCHTRT_THROW_ERROR("Unsupported data type for scalar. Found: (" << s.type() << ")");
45-
}
46-
return out;
47-
}
28+
4829

4930
auto element_wise_registrations TORCHTRT_UNUSED =
5031
RegisterNodeConversionPatterns()

core/conversion/converters/impl/select.cpp

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -464,6 +464,53 @@ auto select_registrations TORCHTRT_UNUSED =
464464
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], new_layer->getOutput(0));
465465
LOG_DEBUG("Output shape: " << out_tensor->getDimensions());
466466
return true;
467+
}})
468+
.pattern(
469+
{"aten::scatter.value(Tensor self, int dim, Tensor index, Scalar value) -> (Tensor)",
470+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
471+
auto self = args[0].ITensorOrFreeze(ctx);
472+
int dim = args[1].unwrapToInt();
473+
auto index = args[2].ITensorOrFreeze(ctx);
474+
auto index_dim = index->getDimensions();
475+
std::vector<int64_t> dim_vec;
476+
for (int i = 0; i < index_dim.nbDims; i++) {
477+
dim_vec.push_back(index_dim.d[i]);
478+
}
479+
auto value = args[3].unwrapToScalar() * torch::ones(dim_vec);
480+
auto value_tensor = tensor_to_const(ctx, value, "");
481+
if (self->getType() != value_tensor->getType()) {
482+
value_tensor = castITensor(ctx, value_tensor, self->getType());
483+
}
484+
485+
auto layer = ctx->net->addScatter(*self, *index, *value_tensor, nvinfer1::ScatterMode::kELEMENT);
486+
layer->setAxis(dim);
487+
488+
TORCHTRT_CHECK(layer, "Unable to create layer for aten::scatter.value");
489+
490+
layer->setName(util::node_info(n).c_str());
491+
492+
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], layer->getOutput(0));
493+
LOG_DEBUG("Output shape: " << out_tensor->getDimensions());
494+
return true;
495+
}})
496+
.pattern(
497+
{"aten::scatter.src(Tensor self, int dim, Tensor index, Tensor src) -> (Tensor)",
498+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
499+
auto self = args[0].ITensorOrFreeze(ctx);
500+
int dim = args[1].unwrapToInt();
501+
auto index = args[2].ITensorOrFreeze(ctx);
502+
auto src = args[3].ITensorOrFreeze(ctx);
503+
504+
auto layer = ctx->net->addScatter(*self, *index, *src, nvinfer1::ScatterMode::kELEMENT);
505+
layer->setAxis(dim);
506+
507+
TORCHTRT_CHECK(layer, "Unable to create layer for aten::scatter.src");
508+
509+
layer->setName(util::node_info(n).c_str());
510+
511+
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], layer->getOutput(0));
512+
LOG_DEBUG("Output shape: " << out_tensor->getDimensions());
513+
return true;
467514
}});
468515

469516
} // namespace

core/lowering/passes/op_aliasing.cpp

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,25 @@ void AliasOperators(std::shared_ptr<torch::jit::Graph>& graph) {
1616
graph(%s, %o):
1717
%1 : Tensor = aten::div(%s, %o)
1818
return (%1))IR";
19-
;
20-
21-
// TODO
22-
// complete other element wise pass
2319

2420
torch::jit::SubgraphRewriter true_divide_to_div;
2521
true_divide_to_div.RegisterRewritePattern(true_divide_pattern, div_pattern);
2622
true_divide_to_div.runOnGraph(graph);
2723
LOG_GRAPH("Post map true_divide -> div: " << *graph);
24+
25+
std::string scatter_sub_pattern = R"IR(
26+
graph(%data, %dim, %index, %value):
27+
%o : Tensor = aten::scatter_(%data, %dim, %index, %value)
28+
return (%o))IR";
29+
std::string scatter_pattern = R"IR(
30+
graph(%data, %dim, %index, %value):
31+
%o : Tensor = aten::scatter(%data, %dim, %index, %value)
32+
return (%o))IR";
33+
34+
torch::jit::SubgraphRewriter rewrite_scatter;
35+
rewrite_scatter.RegisterRewritePattern(scatter_sub_pattern, scatter_pattern);
36+
rewrite_scatter.runOnGraph(graph);
37+
LOG_GRAPH("Post map scatter_ -> scatter: " << *graph);
2838
}
2939

3040
} // namespace passes

tests/core/conversion/converters/test_select.cpp

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -855,3 +855,76 @@ TEST(Converters, ATenUnbindNegativeAxisConvertsCorrectly) {
855855
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[i], trt, 2e-6));
856856
}
857857
}
858+
859+
TEST(Converters, ScatterValueConvertsCorrectly) {
860+
const auto graph = R"IR(
861+
graph(%data : Tensor,
862+
%index.1 : Tensor):
863+
%value : int = prim::Constant[value=100]()
864+
%dim : int = prim::Constant[value=1]()
865+
%5 : NoneType = prim::Constant()
866+
%6 : bool = prim::Constant[value=0]()
867+
%7 : int = prim::Constant[value=4]()
868+
%index : Tensor = aten::to(%index.1, %7, %6, %6, %5)
869+
%10 : Tensor = aten::scatter(%data, %dim, %index, %value)
870+
return (%10))IR";
871+
872+
auto g = std::make_shared<torch::jit::Graph>();
873+
874+
torch::jit::parseIR(graph, g.get());
875+
876+
auto index = at::randint(0, 5, {2, 2}, {at::kCUDA});
877+
auto data = at::randn({5, 5}, {at::kCUDA});
878+
879+
auto jit_index = at::clone(index);
880+
auto jit_data = at::clone(data);
881+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
882+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_data, jit_index});
883+
884+
auto trt_index = at::clone(index);
885+
auto trt_data = at::clone(data);
886+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_data, trt_index});
887+
888+
for (size_t i = 0; i < jit_results.size(); i++) {
889+
auto trt = trt_results[i].reshape(jit_results[i].sizes());
890+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[i], trt, 2e-6));
891+
}
892+
}
893+
894+
TEST(Converters, ScatterSrcConvertsCorrectly) {
895+
const auto graph = R"IR(
896+
graph(%data : Tensor,
897+
%src : Tensor,
898+
%index.1 : Tensor):
899+
%dim : int = prim::Constant[value=1]()
900+
%5 : NoneType = prim::Constant()
901+
%6 : bool = prim::Constant[value=0]()
902+
%7 : int = prim::Constant[value=4]()
903+
%index : Tensor = aten::to(%index.1, %7, %6, %6, %5)
904+
%10 : Tensor = aten::scatter(%data, %dim, %index, %src)
905+
return (%10))IR";
906+
907+
auto g = std::make_shared<torch::jit::Graph>();
908+
909+
torch::jit::parseIR(graph, g.get());
910+
911+
auto index = at::randint(0, 4, {2, 2}, {at::kCUDA});
912+
auto data = at::randn({5, 5}, {at::kCUDA});
913+
auto src = at::randn({2, 2}, {at::kCUDA});
914+
915+
auto jit_index = at::clone(index);
916+
auto jit_data = at::clone(data);
917+
auto jit_src = at::clone(src);
918+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
919+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_data, jit_src, jit_index});
920+
921+
auto trt_index = at::clone(index);
922+
auto trt_data = at::clone(data);
923+
auto trt_src = at::clone(src);
924+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_data, trt_src, trt_index});
925+
926+
for (size_t i = 0; i < jit_results.size(); i++) {
927+
auto trt = trt_results[i].reshape(jit_results[i].sizes());
928+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[i], trt, 2e-6));
929+
}
930+
}

0 commit comments

Comments
 (0)