|
38 | 38 | #include "openvino/pass/graph_rewrite.hpp" |
39 | 39 | #include "openvino/pass/manager.hpp" |
40 | 40 | #include "openvino/pass/pattern/matcher.hpp" |
41 | | -#include "openvino/pass/pattern/op/branch.hpp" |
42 | 41 | #include "openvino/pass/pattern/op/label.hpp" |
43 | 42 | #include "openvino/pass/pattern/op/optional.hpp" |
44 | 43 | #include "openvino/pass/pattern/op/or.hpp" |
@@ -310,16 +309,10 @@ TEST(pattern, matcher) { |
310 | 309 | ASSERT_EQ(n.get_matched_nodes(), (NodeVector{a})); |
311 | 310 |
|
312 | 311 | auto abs = make_shared<op::v0::Abs>(a); |
313 | | - auto any = std::make_shared<pattern::op::Skip>(a); |
314 | | - ASSERT_TRUE(n.match(any, abs)); |
315 | | - ASSERT_EQ(n.get_matched_nodes(), (NodeVector{abs, a})); |
316 | 312 |
|
317 | 313 | auto false_pred = [](std::shared_ptr<Node> /* no */) { |
318 | 314 | return false; |
319 | 315 | }; |
320 | | - auto any_false = std::make_shared<pattern::op::Skip>(a, false_pred); |
321 | | - ASSERT_TRUE(n.match(any_false, a)); |
322 | | - ASSERT_EQ(n.get_matched_nodes(), (NodeVector{a, a})); |
323 | 316 |
|
324 | 317 | auto pattern = std::make_shared<pattern::op::Label>(a); |
325 | 318 | ASSERT_TRUE(n.match(pattern, a)); |
@@ -371,39 +364,6 @@ TEST(pattern, matcher) { |
371 | 364 | ASSERT_FALSE(n.match(std::make_shared<op::v1::Add>(abs, b), std::make_shared<op::v1::Add>(b, b))); |
372 | 365 | ASSERT_EQ(n.get_matched_nodes(), (NodeVector{})); |
373 | 366 |
|
374 | | - auto add_absb = std::make_shared<op::v1::Add>(abs, b); |
375 | | - ASSERT_TRUE(n.match(std::make_shared<op::v1::Add>(any, b), add_absb)); |
376 | | - ASSERT_EQ(n.get_matched_nodes(), (NodeVector{add_absb, abs, a, b})); |
377 | | - |
378 | | - ASSERT_TRUE(n.match(std::make_shared<op::v1::Add>(pattern, b), add_absb)); |
379 | | - ASSERT_EQ(n.get_pattern_map()[pattern], abs); |
380 | | - ASSERT_EQ(n.get_matched_nodes(), (NodeVector{add_absb, abs, b})); |
381 | | - |
382 | | - ASSERT_TRUE(n.match(std::make_shared<op::v1::Add>(b, pattern), add_absb)); |
383 | | - ASSERT_EQ(n.get_pattern_map()[pattern], abs); |
384 | | - ASSERT_EQ(n.get_matched_nodes(), (NodeVector{add_absb, abs, b})); |
385 | | - |
386 | | - auto c = make_shared<op::v0::Parameter>(element::i32, shape); |
387 | | - auto mul_add_absb = std::make_shared<op::v1::Multiply>(c, add_absb); |
388 | | - ASSERT_TRUE( |
389 | | - n.match(std::make_shared<op::v1::Multiply>(c, std::make_shared<op::v1::Add>(b, pattern)), mul_add_absb)); |
390 | | - ASSERT_EQ(n.get_pattern_map()[pattern], abs); |
391 | | - ASSERT_EQ(n.get_matched_nodes(), (NodeVector{mul_add_absb, c, add_absb, abs, b})); |
392 | | - |
393 | | - ASSERT_TRUE(n.match(std::make_shared<op::v1::Multiply>(c, std::make_shared<op::v1::Add>(any, b)), |
394 | | - mul_add_absb)); // nested any |
395 | | - ASSERT_EQ(n.get_matched_nodes(), (NodeVector{mul_add_absb, c, add_absb, abs, a, b})); |
396 | | - ASSERT_TRUE(n.match(std::make_shared<op::v1::Multiply>(c, std::make_shared<op::v1::Add>(any, b)), |
397 | | - std::make_shared<op::v1::Multiply>(std::make_shared<op::v1::Add>(b, abs), |
398 | | - c))); // permutations w/ any |
399 | | - auto mul_c_add_ab = make_shared<op::v1::Multiply>(c, add_ab); |
400 | | - ASSERT_TRUE(n.match(std::make_shared<op::v1::Multiply>(c, std::make_shared<op::v1::Add>(any_false, b)), |
401 | | - std::make_shared<op::v1::Multiply>(c, std::make_shared<op::v1::Add>(a, b)))); // |
402 | | - // nested any |
403 | | - ASSERT_TRUE(n.match(std::make_shared<op::v1::Multiply>(c, std::make_shared<op::v1::Add>(any_false, b)), |
404 | | - mul_c_add_ab)); // permutations w/ any_false |
405 | | - ASSERT_EQ(n.get_matched_nodes(), (NodeVector{mul_c_add_ab, c, add_ab, a, a, b})); |
406 | | - |
407 | 367 | auto iconst1_0 = construct_constant_node(1); |
408 | 368 | auto iconst1_1 = construct_constant_node(1); |
409 | 369 | ASSERT_TRUE(n.match(make_shared<op::v1::Multiply>(pattern, iconst1_0), |
@@ -462,18 +422,6 @@ TEST(pattern, matcher) { |
462 | 422 | std::make_shared<op::v1::Subtract>(a, b)}), |
463 | 423 | std::make_shared<op::v1::Subtract>(a, b))); |
464 | 424 |
|
465 | | - // Branch |
466 | | - { |
467 | | - auto branch = std::make_shared<pattern::op::Branch>(); |
468 | | - auto star = std::make_shared<pattern::op::Or>(OutputVector{branch, std::make_shared<pattern::op::True>()}); |
469 | | - auto pattern = std::make_shared<op::v1::Add>(star, star); |
470 | | - branch->set_destination(pattern); |
471 | | - auto arg = |
472 | | - std::make_shared<op::v1::Add>(std::make_shared<op::v1::Add>(a, b), std::make_shared<op::v1::Add>(b, a)); |
473 | | - ASSERT_TRUE(n.match(pattern, std::make_shared<op::v1::Add>(arg, a))); |
474 | | - ASSERT_EQ(n.get_matched_nodes().size(), 4); |
475 | | - } |
476 | | - |
477 | 425 | // strict mode |
478 | 426 | { |
479 | 427 | TestMatcher sm(Output<Node>{}, "TestMatcher", true); |
@@ -959,47 +907,6 @@ TEST(pattern, test_sort) { |
959 | 907 | } |
960 | 908 | } |
961 | 909 |
|
962 | | -TEST(pattern, label_on_skip) { |
963 | | - const auto zero = std::string{"0"}; |
964 | | - const auto is_zero = [&zero](const Output<Node>& node) { |
965 | | - if (const auto c = as_type_ptr<op::v0::Constant>(node.get_node_shared_ptr())) { |
966 | | - return (c->get_all_data_elements_bitwise_identical() && c->convert_value_to_string(0) == zero); |
967 | | - } else { |
968 | | - return false; |
969 | | - } |
970 | | - }; |
971 | | - |
972 | | - Shape shape{2, 2}; |
973 | | - auto a = make_shared<op::v0::Parameter>(element::i32, shape); |
974 | | - auto b = make_shared<op::v0::Parameter>(element::i32, Shape{}); |
975 | | - auto iconst = op::v0::Constant::create(element::i32, Shape{}, {0.0f}); |
976 | | - auto label = std::make_shared<pattern::op::Label>(iconst); |
977 | | - auto const_label = std::make_shared<pattern::op::Label>(iconst, is_zero, NodeVector{iconst}); |
978 | | - |
979 | | - auto bcst_pred = [](std::shared_ptr<Node> n) { |
980 | | - return ov::as_type_ptr<op::v1::Broadcast>(n) != nullptr; |
981 | | - }; |
982 | | - |
983 | | - auto shape_const = ov::op::v0::Constant::create(element::u64, Shape{shape.size()}, shape); |
984 | | - auto axes_const = ov::op::v0::Constant::create(element::u8, Shape{}, {0}); |
985 | | - auto bcst = std::make_shared<pattern::op::Skip>(OutputVector{const_label, shape_const, axes_const}, bcst_pred); |
986 | | - auto bcst_label = std::make_shared<pattern::op::Label>(bcst, nullptr, NodeVector{bcst}); |
987 | | - auto matcher = |
988 | | - std::make_shared<pattern::Matcher>(std::make_shared<op::v1::Multiply>(label, bcst_label), "label_on_skip"); |
989 | | - |
990 | | - auto const_broadcast = make_shared<op::v1::Broadcast>(iconst, shape_const); |
991 | | - std::shared_ptr<Node> mul = std::make_shared<op::v1::Multiply>(a, const_broadcast); |
992 | | - std::shared_ptr<Node> mul_scalar = std::make_shared<op::v1::Multiply>(b, iconst); |
993 | | - ASSERT_TRUE(matcher->match(mul)); |
994 | | - ASSERT_EQ(matcher->get_pattern_map()[bcst_label], const_broadcast); |
995 | | - ASSERT_EQ(matcher->get_pattern_map()[const_label], iconst); |
996 | | - ASSERT_EQ(matcher->get_pattern_map()[label], a); |
997 | | - ASSERT_TRUE(matcher->match(mul_scalar)); |
998 | | - ASSERT_EQ(matcher->get_pattern_map()[bcst_label], iconst); |
999 | | - ASSERT_EQ(matcher->get_pattern_map()[const_label], iconst); |
1000 | | - ASSERT_EQ(matcher->get_pattern_map()[label], b); |
1001 | | -} |
1002 | | - |
1003 | 910 | TEST(pattern, is_contained_match) { |
1004 | 911 | Shape shape{}; |
1005 | 912 | auto a = make_shared<op::v0::Parameter>(element::i32, shape); |
|
0 commit comments