Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -351,20 +351,10 @@ bool check_condition_true_pattern(const std::shared_ptr<op::v0::Result>& cond_re
const auto& condition_map = condition_matcher.get_pattern_value_map();
const auto& cond_const =
ov::as_type_ptr<op::v0::Constant>(condition_map.at(cond_const_label).get_node_shared_ptr());
if (!cond_const) {
bool cond_value = false;
if (!ov::op::util::get_constant_value(cond_const, cond_value) || !cond_value) {
return false;
}
if (ov::shape_size(cond_const->get_shape()) != 1)
return false;
const auto& type = cond_const->get_output_element_type(0);
if (type != ov::element::boolean) {
return false;
}
bool cond_value = cond_const->cast_vector<bool>()[0];
if (!cond_value) {
return false;
}

// number of iteration is retrieve from the first input port
num_iters_output = loop->input_value(0);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ uint64_t get_new_param_idx(const std::vector<uint64_t>& remove_parameter_idxs, u
for (auto remove_idx : remove_parameter_idxs) {
FRONT_END_GENERAL_CHECK(old_idx != remove_idx,
"[TensorFlow Frontend] internal error: incorrect old_idx for "
"TensorListSliceInputAndConcatOutputReplacer transformation");
"TensorListInLoopOptimization transformation");
if (remove_idx < old_idx) {
++num_removed;
}
Expand All @@ -181,7 +181,7 @@ uint64_t get_new_param_idx(const std::vector<uint64_t>& remove_parameter_idxs, u
// compute shifted index
FRONT_END_GENERAL_CHECK(num_removed <= old_idx,
"[TensorFlow Frontend] internal error: incorrect new parameter index computation "
"TensorListSliceInputAndConcatOutputReplacer transformation");
"TensorListInLoopOptimization transformation");
return old_idx - num_removed;
}
} // namespace
Expand Down Expand Up @@ -478,7 +478,7 @@ ov::frontend::tensorflow::pass::TensorListInLoopOptimization::TensorListInLoopOp
std::dynamic_pointer_cast<TensorListSetItem>(body_result->get_input_node_shared_ptr(0));
FRONT_END_GENERAL_CHECK(tensor_list_set_item,
"[TensorFlow Frontend] internal error: tensor_list_set_item is nullptr in "
"TensorListSliceInputAndConcatOutputReplacer");
"TensorListInLoopOptimization");
// unsqueeze newly generated data at this iteration
// that will be concatenated
auto new_data = tensor_list_set_item->input_value(2);
Expand All @@ -501,13 +501,13 @@ ov::frontend::tensorflow::pass::TensorListInLoopOptimization::TensorListInLoopOp
const auto& body_param = body_params[param_idx];
FRONT_END_GENERAL_CHECK(body_param->get_output_target_inputs(0).size() == 1,
"[TensorFlow Frontend] internal error: tensor list must have only consumer "
"TensorListGetItem operation in TensorListSliceInputAndConcatOutputReplacer");
"TensorListGetItem operation in TensorListInLoopOptimization");
auto target_input = *(body_param->get_output_target_inputs(0).begin());
auto tensor_list_get_item =
std::dynamic_pointer_cast<TensorListGetItem>(target_input.get_node()->shared_from_this());
FRONT_END_GENERAL_CHECK(tensor_list_get_item,
"[TensorFlow Frontend] internal error: tensor list must have only consumer "
"TensorListGetItem operation in TensorListSliceInputAndConcatOutputReplacer");
"TensorListGetItem operation in TensorListInLoopOptimization");

auto new_shape = body_param->get_output_partial_shape(0);
if (new_shape.rank().is_static() && new_shape.rank().get_length() > 0) {
Expand Down