diff --git a/src/Detector/tensorrt_yolo/DFINE_bb.hpp b/src/Detector/tensorrt_yolo/DFINE_bb.hpp index 6c773097..c0edc0b9 100644 --- a/src/Detector/tensorrt_yolo/DFINE_bb.hpp +++ b/src/Detector/tensorrt_yolo/DFINE_bb.hpp @@ -87,8 +87,8 @@ class DFINE_bb_onnx : public YoloONNX float classConf = scores[i]; int64_t classId = labels[i]; - if (classId > 0) - --classId; + //if (classId > 0) + // --classId; if (classConf >= m_params.confThreshold) { diff --git a/thirdparty/ruclip/RuCLIP.cpp b/thirdparty/ruclip/RuCLIP.cpp index 366e05ce..c64b840b 100644 --- a/thirdparty/ruclip/RuCLIP.cpp +++ b/thirdparty/ruclip/RuCLIP.cpp @@ -207,9 +207,13 @@ torch::Tensor CLIPImpl :: EncodeText(torch::Tensor input_ids) torch::Tensor CLIPImpl :: forward(torch::Tensor input_ids, torch::Tensor pixel_values) { + //std::cout << "pixel_values: " << pixel_values.sizes() << ", input_ids: " << input_ids.sizes() << std::endl; + auto image_features = EncodeImage(pixel_values); auto text_features = EncodeText(input_ids); + //std::cout << "image_features: " << image_features.sizes() << ", text_features: " << text_features.sizes() << std::endl; + //normalize features image_features = image_features / image_features.norm(2/*L2*/, -1, true); text_features = text_features / text_features.norm(2/*L2*/, -1, true); diff --git a/thirdparty/ruclip/RuCLIPProcessor.cpp b/thirdparty/ruclip/RuCLIPProcessor.cpp index 887d074d..3bb7242d 100644 --- a/thirdparty/ruclip/RuCLIPProcessor.cpp +++ b/thirdparty/ruclip/RuCLIPProcessor.cpp @@ -18,6 +18,7 @@ inline cv::Mat TorchTensorToCVMat(const torch::Tensor tensor_image, const bool p if (perm) t = t.permute({ 1, 2, 0 }); t = t.mul(255).clamp(0, 255).to(torch::kU8); + t = t.contiguous(); cv::Mat result_img; cv::Mat(static_cast(t.size(0)), static_cast(t.size(1)), CV_MAKETYPE(CV_8U, t.sizes().size() >= 3 ? static_cast(t.size(2)) : 1), t.data_ptr()).copyTo(result_img); return result_img; diff --git a/thirdparty/ruclip/RuCLIPProcessor.h b/thirdparty/ruclip/RuCLIPProcessor.h index a99eba4c..78023116 100644 --- a/thirdparty/ruclip/RuCLIPProcessor.h +++ b/thirdparty/ruclip/RuCLIPProcessor.h @@ -108,26 +108,42 @@ class RuCLIPProcessor ///image_features = image_features / image_features.norm(2/*L2*/, -1, true); ///torch::Tensor rel = Relevancy(image_features, text_features, canon_features); ///float lv = rel.index({0,0}).item(); + inline torch::Tensor Relevancy(torch::Tensor embeds, torch::Tensor positives, torch::Tensor negatives) { +#if 0 std::cout << "Relevancy: 0" << std::endl; - auto embeds2 = torch::cat({positives, negatives}); + auto embeds2 = torch::cat({ positives, negatives }); std::cout << "Relevancy: 1" << std::endl; auto logits = /*scale * */torch::mm(embeds, embeds2.t()); //[batch_size x phrases] - std::cout << "Relevancy: 2" << std::endl; - auto positive_vals = logits.index({"...", torch::indexing::Slice(0, positives.sizes()[0])}); // [batch_size x 1] + std::cout << "Relevancy: 2" << std::endl; + auto positive_vals = logits.index({ "...", torch::indexing::Slice(0, positives.sizes()[0]) }); // [batch_size x 1] std::cout << "Relevancy: 3" << std::endl; - auto negative_vals = logits.index({"...", torch::indexing::Slice(positives.sizes()[0], torch::indexing::None)}); // [batch_size x negative_phrase_n] + auto negative_vals = logits.index({ "...", torch::indexing::Slice(positives.sizes()[0], torch::indexing::None) }); // [batch_size x negative_phrase_n] std::cout << "Relevancy: 4" << std::endl; - auto repeated_pos = positive_vals.repeat({1, negatives.sizes()[0]}); //[batch_size x negative_phrase_n] + auto repeated_pos = positive_vals.repeat({ 1, negatives.sizes()[0] }); //[batch_size x negative_phrase_n] std::cout << "Relevancy: 5: repeated_pos: " << repeated_pos.sizes() << ", negative_vals: " << negative_vals.sizes() << std::endl; - auto sims = torch::stack({repeated_pos, negative_vals}, -1); //[batch_size x negative_phrase_n x 2] + auto sims = torch::stack({ repeated_pos, negative_vals }, -1); //[batch_size x negative_phrase_n x 2] std::cout << "Relevancy: 6" << std::endl; auto smx = torch::softmax(10 * sims, -1); // [batch_size x negative_phrase_n x 2] std::cout << "Relevancy: 7" << std::endl; - auto best_id = smx.index({"...", 0}).argmin(1); // [batch_size x 2] + auto best_id = smx.index({ "...", 0 }).argmin(1); // [batch_size x 2] std::cout << "Relevancy: 8" << std::endl; - auto result = torch::gather(smx, 1, best_id.index({"...", torch::indexing::None, torch::indexing::None}).expand({best_id.sizes()[0], negatives.sizes()[0], 2}) - ).index({torch::indexing::Slice(), 0, torch::indexing::Slice()});// [batch_size x 2] + auto result = torch::gather(smx, 1, best_id.index({ "...", torch::indexing::None, torch::indexing::None }).expand({ best_id.sizes()[0], negatives.sizes()[0], 2 }) + ).index({ torch::indexing::Slice(), 0, torch::indexing::Slice() });// [batch_size x 2] return result; +#else + auto embeds2 = torch::cat({ positives, negatives }, 0); + auto logits = torch::mm(embeds, embeds2.t()); // [batch_size, 1 + negatives_len] + auto positive_vals = logits.index({ "...", torch::indexing::Slice(0, 1) }); // [batch_size, 1] + auto negative_vals = logits.index({ "...", torch::indexing::Slice(1, torch::indexing::None) }); // [batch_size, negatives_len] + auto repeated_pos = positive_vals.repeat({ 1, negatives.sizes()[0] }); // [batch_size, negatives_len] + auto sims = torch::stack({ repeated_pos, negative_vals }, -1); // [batch_size, negatives_len, 2] + auto smx = torch::softmax(10 * sims, -1); // [batch_size, negatives_len, 2] + //Находим индекс самого сложного негатива (с минимальной вероятностью позитивного класса) + auto best_id = smx.index({ "...", 0 }).argmin(1, /*keepdim=*/true); // [batch_size, 1] + //Собираем результаты для выбранных негативов + auto result = torch::gather(smx, 1, best_id.unsqueeze(-1).expand({ -1, -1, 2 })); + return result.squeeze(1); // [batch_size, 2] +#endif }