Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/Detector/tensorrt_yolo/DFINE_bb.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,8 @@ class DFINE_bb_onnx : public YoloONNX
float classConf = scores[i];
int64_t classId = labels[i];

if (classId > 0)
--classId;
//if (classId > 0)
// --classId;

if (classConf >= m_params.confThreshold)
{
Expand Down
4 changes: 4 additions & 0 deletions thirdparty/ruclip/RuCLIP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -207,9 +207,13 @@ torch::Tensor CLIPImpl :: EncodeText(torch::Tensor input_ids)

torch::Tensor CLIPImpl :: forward(torch::Tensor input_ids, torch::Tensor pixel_values)
{
//std::cout << "pixel_values: " << pixel_values.sizes() << ", input_ids: " << input_ids.sizes() << std::endl;

auto image_features = EncodeImage(pixel_values);
auto text_features = EncodeText(input_ids);

//std::cout << "image_features: " << image_features.sizes() << ", text_features: " << text_features.sizes() << std::endl;

//normalize features
image_features = image_features / image_features.norm(2/*L2*/, -1, true);
text_features = text_features / text_features.norm(2/*L2*/, -1, true);
Expand Down
1 change: 1 addition & 0 deletions thirdparty/ruclip/RuCLIPProcessor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ inline cv::Mat TorchTensorToCVMat(const torch::Tensor tensor_image, const bool p
if (perm)
t = t.permute({ 1, 2, 0 });
t = t.mul(255).clamp(0, 255).to(torch::kU8);
t = t.contiguous();
cv::Mat result_img;
cv::Mat(static_cast<int>(t.size(0)), static_cast<int>(t.size(1)), CV_MAKETYPE(CV_8U, t.sizes().size() >= 3 ? static_cast<int>(t.size(2)) : 1), t.data_ptr()).copyTo(result_img);
return result_img;
Expand Down
34 changes: 25 additions & 9 deletions thirdparty/ruclip/RuCLIPProcessor.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,26 +108,42 @@ class RuCLIPProcessor
///image_features = image_features / image_features.norm(2/*L2*/, -1, true);
///torch::Tensor rel = Relevancy(image_features, text_features, canon_features);
///float lv = rel.index({0,0}).item<float>();

inline torch::Tensor Relevancy(torch::Tensor embeds, torch::Tensor positives, torch::Tensor negatives)
{
#if 0
std::cout << "Relevancy: 0" << std::endl;
auto embeds2 = torch::cat({positives, negatives});
auto embeds2 = torch::cat({ positives, negatives });
std::cout << "Relevancy: 1" << std::endl;
auto logits = /*scale * */torch::mm(embeds, embeds2.t()); //[batch_size x phrases]
std::cout << "Relevancy: 2" << std::endl;
auto positive_vals = logits.index({"...", torch::indexing::Slice(0, positives.sizes()[0])}); // [batch_size x 1]
std::cout << "Relevancy: 2" << std::endl;
auto positive_vals = logits.index({ "...", torch::indexing::Slice(0, positives.sizes()[0]) }); // [batch_size x 1]
std::cout << "Relevancy: 3" << std::endl;
auto negative_vals = logits.index({"...", torch::indexing::Slice(positives.sizes()[0], torch::indexing::None)}); // [batch_size x negative_phrase_n]
auto negative_vals = logits.index({ "...", torch::indexing::Slice(positives.sizes()[0], torch::indexing::None) }); // [batch_size x negative_phrase_n]
std::cout << "Relevancy: 4" << std::endl;
auto repeated_pos = positive_vals.repeat({1, negatives.sizes()[0]}); //[batch_size x negative_phrase_n]
auto repeated_pos = positive_vals.repeat({ 1, negatives.sizes()[0] }); //[batch_size x negative_phrase_n]
std::cout << "Relevancy: 5: repeated_pos: " << repeated_pos.sizes() << ", negative_vals: " << negative_vals.sizes() << std::endl;
auto sims = torch::stack({repeated_pos, negative_vals}, -1); //[batch_size x negative_phrase_n x 2]
auto sims = torch::stack({ repeated_pos, negative_vals }, -1); //[batch_size x negative_phrase_n x 2]
std::cout << "Relevancy: 6" << std::endl;
auto smx = torch::softmax(10 * sims, -1); // [batch_size x negative_phrase_n x 2]
std::cout << "Relevancy: 7" << std::endl;
auto best_id = smx.index({"...", 0}).argmin(1); // [batch_size x 2]
auto best_id = smx.index({ "...", 0 }).argmin(1); // [batch_size x 2]
std::cout << "Relevancy: 8" << std::endl;
auto result = torch::gather(smx, 1, best_id.index({"...", torch::indexing::None, torch::indexing::None}).expand({best_id.sizes()[0], negatives.sizes()[0], 2})
).index({torch::indexing::Slice(), 0, torch::indexing::Slice()});// [batch_size x 2]
auto result = torch::gather(smx, 1, best_id.index({ "...", torch::indexing::None, torch::indexing::None }).expand({ best_id.sizes()[0], negatives.sizes()[0], 2 })
).index({ torch::indexing::Slice(), 0, torch::indexing::Slice() });// [batch_size x 2]
return result;
#else
auto embeds2 = torch::cat({ positives, negatives }, 0);
auto logits = torch::mm(embeds, embeds2.t()); // [batch_size, 1 + negatives_len]
auto positive_vals = logits.index({ "...", torch::indexing::Slice(0, 1) }); // [batch_size, 1]
auto negative_vals = logits.index({ "...", torch::indexing::Slice(1, torch::indexing::None) }); // [batch_size, negatives_len]
auto repeated_pos = positive_vals.repeat({ 1, negatives.sizes()[0] }); // [batch_size, negatives_len]
auto sims = torch::stack({ repeated_pos, negative_vals }, -1); // [batch_size, negatives_len, 2]
auto smx = torch::softmax(10 * sims, -1); // [batch_size, negatives_len, 2]
//Находим индекс самого сложного негатива (с минимальной вероятностью позитивного класса)
auto best_id = smx.index({ "...", 0 }).argmin(1, /*keepdim=*/true); // [batch_size, 1]
//Собираем результаты для выбранных негативов
auto result = torch::gather(smx, 1, best_id.unsqueeze(-1).expand({ -1, -1, 2 }));
return result.squeeze(1); // [batch_size, 2]
#endif
}
Loading