diff --git a/core/lowering/passes/remove_dropout.cpp b/core/lowering/passes/remove_dropout.cpp index c2f90d1737..1b9db30260 100644 --- a/core/lowering/passes/remove_dropout.cpp +++ b/core/lowering/passes/remove_dropout.cpp @@ -1,4 +1,4 @@ -#include +#include "torch/csrc/jit/passes/dead_code_elimination.h" #include "core/util/prelude.h" @@ -7,86 +7,52 @@ namespace core { namespace lowering { namespace passes { -void RemoveDropout(std::shared_ptr& graph) { - std::string dropout_pattern = R"IR( - graph(%input, %4, %5): - %6 = aten::dropout(%input, %4, %5) - return (%6))IR"; - std::string no_dropout_pattern = R"IR( - graph(%input, %4, %5): - return (%input))IR"; - - torch::jit::SubgraphRewriter remove_dropout; - remove_dropout.RegisterRewritePattern(dropout_pattern, no_dropout_pattern); - remove_dropout.runOnGraph(graph); - - std::string dropout_inplace_pattern = R"IR( - graph(%input, %4, %5): - %6 = aten::dropout_(%input, %4, %5) - return (%6))IR"; - std::string no_dropout_inplace_pattern = R"IR( - graph(%input, %4, %5): - return (%input))IR"; - - torch::jit::SubgraphRewriter remove_dropout_inplace_pattern; - remove_dropout_inplace_pattern.RegisterRewritePattern(dropout_inplace_pattern, no_dropout_inplace_pattern); - remove_dropout_inplace_pattern.runOnGraph(graph); - - // remove feature_dropout - std::string feature_dropout_pattern = R"IR( - graph(%input, %4, %5): - %6 = aten::feature_dropout(%input, %4, %5) - return (%6))IR"; - std::string no_feature_dropout_pattern = R"IR( - graph(%input, %4, %5): - return (%input))IR"; - - torch::jit::SubgraphRewriter remove_feature_dropout_pattern; - remove_feature_dropout_pattern.RegisterRewritePattern(feature_dropout_pattern, no_feature_dropout_pattern); - remove_feature_dropout_pattern.runOnGraph(graph); - - // remove feature_dropout inplace - std::string feature_dropout_inplace_pattern = R"IR( - graph(%input, %4, %5): - %6 = aten::feature_dropout_(%input, %4, %5) - return (%6))IR"; - std::string no_feature_dropout_inplace_pattern = R"IR( - graph(%input, %4, %5): - return (%input))IR"; - - torch::jit::SubgraphRewriter remove_feature_dropout_inplace_pattern; - remove_feature_dropout_inplace_pattern.RegisterRewritePattern( - feature_dropout_inplace_pattern, no_feature_dropout_inplace_pattern); - remove_feature_dropout_inplace_pattern.runOnGraph(graph); - - // remove feature_alpha_dropout - std::string feature_alpha_dropout_pattern = R"IR( - graph(%input, %4, %5): - %6 = aten::feature_alpha_dropout(%input, %4, %5) - return (%6))IR"; - std::string no_feature_alpha_dropout_pattern = R"IR( - graph(%input, %4, %5): - return (%input))IR"; - - torch::jit::SubgraphRewriter remove_feature_alpha_dropout_pattern; - remove_feature_alpha_dropout_pattern.RegisterRewritePattern( - feature_alpha_dropout_pattern, no_feature_alpha_dropout_pattern); - remove_feature_alpha_dropout_pattern.runOnGraph(graph); - - // remove feature_alpha_dropout inplace - std::string feature_alpha_dropout_inplace_pattern = R"IR( - graph(%input, %4, %5): - %6 = aten::feature_alpha_dropout_(%input, %4, %5) - return (%6))IR"; - std::string no_feature_alpha_dropout_inplace_pattern = R"IR( - graph(%input, %4, %5): - return (%input))IR"; - - torch::jit::SubgraphRewriter remove_feature_alpha_dropout_inplace_pattern; - remove_feature_alpha_dropout_inplace_pattern.RegisterRewritePattern( - feature_alpha_dropout_inplace_pattern, no_feature_alpha_dropout_inplace_pattern); - remove_feature_alpha_dropout_inplace_pattern.runOnGraph(graph); +// Schemas for dropout variants +const std::unordered_set DropoutNodeKinds = { + c10::Symbol::fromQualString("aten::dropout"), + c10::Symbol::fromQualString("aten::dropout_"), + c10::Symbol::fromQualString("aten::feature_dropout"), + c10::Symbol::fromQualString("aten::feature_dropout_"), + c10::Symbol::fromQualString("aten::feature_alpha_dropout"), + c10::Symbol::fromQualString("aten::feature_alpha_dropout_"), +}; + +void removeDropoutInBlock(torch::jit::Block* block) { + /* + Function adapted from: + torch/csrc/jit/passes/remove_dropout.cpp + + Modified for conciseness, documentation, and allowing new variants of dropout operators to be quickly added + */ + std::vector dropout_nodes_to_remove; + + for (auto node : block->nodes()) { + // Remove dropout for each member block within a node + for (auto block : node->blocks()) { + removeDropoutInBlock(block); + } + + // For each node having a dropout-variant Schema, remove the node + if (DropoutNodeKinds.find(node->kind()) != DropoutNodeKinds.end()) { + // Extract input and output tensors of dropout operator + auto input_value = node->inputs()[0]; + auto output_value = node->outputs()[0]; + + output_value->replaceAllUsesWith(input_value); + dropout_nodes_to_remove.push_back(node); + } + } + + // Delete dropout nodes + for (auto del_node : dropout_nodes_to_remove) { + del_node->destroy(); + } +} +void RemoveDropout(std::shared_ptr& graph) { + // Remove all instances of dropout variants from graph + removeDropoutInBlock(graph->block()); + torch::jit::EliminateDeadCode(graph); LOG_GRAPH("Post remove dropout: " << *graph); } diff --git a/tests/core/lowering/test_remove_dropout_pass.cpp b/tests/core/lowering/test_remove_dropout_pass.cpp index baeba192d0..76e85d661a 100644 --- a/tests/core/lowering/test_remove_dropout_pass.cpp +++ b/tests/core/lowering/test_remove_dropout_pass.cpp @@ -32,6 +32,32 @@ TEST(LoweringPasses, RemoveDropoutLowersCorrectly) { ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty()); } +TEST(LoweringPasses, RemoveDropoutNestedLowersCorrectly) { + std::string source_graph = R"IR( + graph(%x.1): + %3 : float = prim::Constant[value=0.5]() + %4 : bool = prim::Constant[value=0]() + %y.1 : Tensor = aten::dropout(%x.1, %3, %4) + %z.1 : Tensor = aten::dropout(%y.1, %3, %4) + %12 : Tensor = aten::relu(%z.1) + return (%12))IR"; + std::string target_graph = R"IR( + graph(%x.1): + %11 : Tensor = aten::relu(%x.1) + return (%11))IR"; + + torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level( + torch_tensorrt::core::util::logging::LogLevel::kGRAPH); + auto sg = std::make_shared(); + torch::jit::parseIR(source_graph, sg.get()); + torch_tensorrt::core::lowering::passes::RemoveDropout(sg); + + auto tg = std::make_shared(); + torch::jit::parseIR(target_graph, tg.get()); + + ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty()); +} + TEST(LoweringPasses, RemoveDropoutInplaceLowersCorrectly) { std::string source_graph = R"IR( graph(%x.1): @@ -132,6 +158,32 @@ TEST(LoweringPasses, RemoveFeatureAlphaDropoutLowersCorrectly) { ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty()); } +TEST(LoweringPasses, RemoveFeatureAlphaDropoutNestedLowersCorrectly) { + std::string source_graph = R"IR( + graph(%x.1): + %3 : float = prim::Constant[value=0.5]() + %4 : bool = prim::Constant[value=0]() + %y.1 : Tensor = aten::feature_alpha_dropout(%x.1, %3, %4) + %z.1 : Tensor = aten::feature_alpha_dropout(%y.1, %3, %4) + %12 : Tensor = aten::relu(%z.1) + return (%12))IR"; + std::string target_graph = R"IR( + graph(%x.1): + %11 : Tensor = aten::relu(%x.1) + return (%11))IR"; + + torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level( + torch_tensorrt::core::util::logging::LogLevel::kGRAPH); + auto sg = std::make_shared(); + torch::jit::parseIR(source_graph, sg.get()); + torch_tensorrt::core::lowering::passes::RemoveDropout(sg); + + auto tg = std::make_shared(); + torch::jit::parseIR(target_graph, tg.get()); + + ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty()); +} + TEST(LoweringPasses, RemoveFeatureAlphaDropoutInplaceLowersCorrectly) { std::string source_graph = R"IR( graph(%x.1):