Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ include(torch)
include(k2)
include(kaldifeat)
include(kaldi_native_io)
include(hclust-cpp)
if(SHERPA_ENABLE_PORTAUDIO)
include(portaudio)
endif()
Expand Down
47 changes: 47 additions & 0 deletions cmake/hclust-cpp.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
function(download_hclust_cpp)
include(FetchContent)

# The latest commit as of 2024.09.29
set(hclust_cpp_URL "https://github.com/csukuangfj/hclust-cpp/archive/refs/tags/2024-09-29.tar.gz")
set(hclust_cpp_URL2 "https://hf-mirror.com/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/hclust-cpp-2024-09-29.tar.gz")
set(hclust_cpp_HASH "SHA256=abab51448a3cb54272aae07522970306e0b2cc6479d59d7b19e7aee4d6cedd33")

# If you don't have access to the Internet,
# please pre-download hclust-cpp
set(possible_file_locations
$ENV{HOME}/Downloads/hclust-cpp-2024-09-29.tar.gz
${CMAKE_SOURCE_DIR}/hclust-cpp-2024-09-29.tar.gz
${CMAKE_BINARY_DIR}/hclust-cpp-2024-09-29.tar.gz
/tmp/hclust-cpp-2024-09-29.tar.gz
/star-fj/fangjun/download/github/hclust-cpp-2024-09-29.tar.gz
)

foreach(f IN LISTS possible_file_locations)
if(EXISTS ${f})
set(hclust_cpp_URL "${f}")
file(TO_CMAKE_PATH "${hclust_cpp_URL}" hclust_cpp_URL)
message(STATUS "Found local downloaded hclust_cpp: ${hclust_cpp_URL}")
set(hclust_cpp_URL2)
break()
endif()
endforeach()

FetchContent_Declare(hclust_cpp
URL
${hclust_cpp_URL}
${hclust_cpp_URL2}
URL_HASH ${hclust_cpp_HASH}
)

FetchContent_GetProperties(hclust_cpp)
if(NOT hclust_cpp_POPULATED)
message(STATUS "Downloading hclust_cpp from ${hclust_cpp_URL}")
FetchContent_Populate(hclust_cpp)
endif()

message(STATUS "hclust_cpp is downloaded to ${hclust_cpp_SOURCE_DIR}")
message(STATUS "hclust_cpp's binary dir is ${hclust_cpp_BINARY_DIR}")
include_directories(${hclust_cpp_SOURCE_DIR})
endfunction()

download_hclust_cpp()
13 changes: 13 additions & 0 deletions sherpa/csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,15 @@ set(sherpa_srcs
speaker-embedding-extractor-model.cc
speaker-embedding-extractor.cc
speaker-embedding-extractor-impl.cc
#
offline-speaker-diarization-result.cc
offline-speaker-segmentation-model-config.cc
offline-speaker-segmentation-pyannote-model-config.cc
fast-clustering-config.cc
fast-clustering.cc
offline-speaker-segmentation-pyannote-model.cc
./offline-speaker-diarization-impl.cc
./offline-speaker-diarization.cc
)

add_library(sherpa_core ${sherpa_srcs})
Expand Down Expand Up @@ -136,6 +145,9 @@ target_link_libraries(sherpa-vad sherpa_core)
add_executable(sherpa-compute-speaker-similarity sherpa-compute-speaker-similarity.cc)
target_link_libraries(sherpa-compute-speaker-similarity sherpa_core)

add_executable(sherpa-speaker-diarization sherpa-speaker-diarization.cc)
target_link_libraries(sherpa-speaker-diarization sherpa_core)

install(TARGETS
sherpa_core
DESTINATION lib
Expand All @@ -146,5 +158,6 @@ install(
sherpa-version
sherpa-vad
sherpa-compute-speaker-similarity
sherpa-speaker-diarization
DESTINATION bin
)
46 changes: 46 additions & 0 deletions sherpa/csrc/fast-clustering-config.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
// sherpa/csrc/fast-clustering-config.cc
//
// Copyright (c) 2025 Xiaomi Corporation

#include "sherpa/csrc/fast-clustering-config.h"

#include <sstream>
#include <string>

#include "sherpa/csrc/macros.h"

namespace sherpa {

std::string FastClusteringConfig::ToString() const {
std::ostringstream os;

os << "FastClusteringConfig(";
os << "num_clusters=" << num_clusters << ", ";
os << "threshold=" << threshold << ")";

return os.str();
}

void FastClusteringConfig::Register(ParseOptions *po) {
po->Register(
"num-clusters", &num_clusters,
"Number of cluster. If greater than 0, then cluster threshold is "
"ignored. Please provide it if you know the actual number of "
"clusters in advance.");

po->Register("cluster-threshold", &threshold,
"If num_clusters is not specified, then it specifies the "
"distance threshold for clustering. smaller value -> more "
"clusters. larger value -> fewer clusters");
}

bool FastClusteringConfig::Validate() const {
if (num_clusters < 1 && threshold < 0) {
SHERPA_LOGE("Please provide either num_clusters or threshold");
return false;
}

return true;
}

} // namespace sherpa
39 changes: 39 additions & 0 deletions sherpa/csrc/fast-clustering-config.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// sherpa/csrc/fast-clustering-config.h
//
// Copyright (c) 2025 Xiaomi Corporation

#ifndef SHERPA_CSRC_FAST_CLUSTERING_CONFIG_H_
#define SHERPA_CSRC_FAST_CLUSTERING_CONFIG_H_

#include <string>

#include "sherpa/cpp_api/parse-options.h"

namespace sherpa {

struct FastClusteringConfig {
// If greater than 0, then threshold is ignored.
//
// We strongly recommend that you set it if you know the number of clusters
// in advance
int32_t num_clusters = -1;

// distance threshold.
//
// The smaller, the more clusters it will generate.
// The larger, the fewer clusters it will generate.
float threshold = 0.5;

FastClusteringConfig() = default;

FastClusteringConfig(int32_t num_clusters, float threshold)
: num_clusters(num_clusters), threshold(threshold) {}

std::string ToString() const;

void Register(ParseOptions *po);
bool Validate() const;
};

} // namespace sherpa
#endif // SHERPA_CSRC_FAST_CLUSTERING_CONFIG_H_
86 changes: 86 additions & 0 deletions sherpa/csrc/fast-clustering.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
// sherpa/csrc/fast-clustering.cc
//
// Copyright (c) 2024 Xiaomi Corporation

#include "sherpa/csrc/fast-clustering.h"

#include <vector>

#include "torch/torch.h"
//
#include "fastcluster-all-in-one.h" // NOLINT

namespace sherpa {

class FastClustering::Impl {
public:
explicit Impl(const FastClusteringConfig &config) : config_(config) {}

std::vector<int32_t> Cluster(float *features, int32_t num_rows,
int32_t num_cols) const {
if (num_rows <= 0) {
return {};
}

if (num_rows == 1) {
return {0};
}

torch::Tensor t =
torch::from_blob(features, {num_rows, num_cols}, torch::kFloat);

// torch::nn::functional::normalize(
// t, torch::nn::functional::NormalizeFuncOptions().p(2).dim(1));
t.div_(t.norm(2, 1, true));

std::vector<double> distance((num_rows * (num_rows - 1)) / 2);

int32_t k = 0;
for (int32_t i = 0; i != num_rows; ++i) {
auto v = t.index({i});
for (int32_t j = i + 1; j != num_rows; ++j) {
double cosine_similarity = v.dot(t.index({j})).item().toDouble();
double consine_dissimilarity = 1 - cosine_similarity;

if (consine_dissimilarity < 0) {
consine_dissimilarity = 0;
}

distance[k] = consine_dissimilarity;
++k;
}
}

std::vector<int32_t> merge(2 * (num_rows - 1));
std::vector<double> height(num_rows - 1);

fastclustercpp::hclust_fast(num_rows, distance.data(),
fastclustercpp::HCLUST_METHOD_COMPLETE,
merge.data(), height.data());

std::vector<int32_t> labels(num_rows);
if (config_.num_clusters > 0) {
fastclustercpp::cutree_k(num_rows, merge.data(), config_.num_clusters,
labels.data());
} else {
fastclustercpp::cutree_cdist(num_rows, merge.data(), height.data(),
config_.threshold, labels.data());
}

return labels;
}

private:
FastClusteringConfig config_;
};

FastClustering::FastClustering(const FastClusteringConfig &config)
: impl_(std::make_unique<Impl>(config)) {}

FastClustering::~FastClustering() = default;

std::vector<int32_t> FastClustering::Cluster(float *features, int32_t num_rows,
int32_t num_cols) const {
return impl_->Cluster(features, num_rows, num_cols);
}
} // namespace sherpa
44 changes: 44 additions & 0 deletions sherpa/csrc/fast-clustering.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
// sherpa/csrc/fast-clustering.h
//
// Copyright (c) 2025 Xiaomi Corporation

#ifndef SHERPA_CSRC_FAST_CLUSTERING_H_
#define SHERPA_CSRC_FAST_CLUSTERING_H_

#include <memory>
#include <vector>

#include "sherpa/cpp_api/parse-options.h"
#include "sherpa/csrc/fast-clustering-config.h"

namespace sherpa {

class FastClustering {
public:
explicit FastClustering(const FastClusteringConfig &config);
~FastClustering();

/**
* @param features Pointer to a 2-D feature matrix in row major. Each row
* is a feature frame. It is changed in-place. We will
* convert each feature frame to a normalized vector.
* That is, the L2-norm of each vector will be equal to 1.
* It uses cosine dissimilarity,
* which is 1 - (cosine similarity)
* @param num_rows Number of feature frames
* @param num-cols The feature dimension.
*
* @return Return a vector of size num_rows. ans[i] contains the label
* for the i-th feature frame, i.e., the i-th row of the feature
* matrix.
*/
std::vector<int32_t> Cluster(float *features, int32_t num_rows,
int32_t num_cols) const;

private:
class Impl;
std::unique_ptr<Impl> impl_;
};

} // namespace sherpa
#endif // SHERPA_CSRC_FAST_CLUSTERING_H_
26 changes: 26 additions & 0 deletions sherpa/csrc/offline-speaker-diarization-impl.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// sherpa/csrc/offline-speaker-diarization-impl.cc
//
// Copyright (c) 2025 Xiaomi Corporation

#include "sherpa/csrc/offline-speaker-diarization-impl.h"

#include <memory>

#include "sherpa/csrc/macros.h"
#include "sherpa/csrc/offline-speaker-diarization-pyannote-impl.h"

namespace sherpa {

std::unique_ptr<OfflineSpeakerDiarizationImpl>
OfflineSpeakerDiarizationImpl::Create(
const OfflineSpeakerDiarizationConfig &config) {
if (!config.segmentation.pyannote.model.empty()) {
return std::make_unique<OfflineSpeakerDiarizationPyannoteImpl>(config);
}

SHERPA_LOGE("Please specify a speaker segmentation model.");

return nullptr;
}

} // namespace sherpa
35 changes: 35 additions & 0 deletions sherpa/csrc/offline-speaker-diarization-impl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// sherpa/csrc/offline-speaker-diarization-impl.h
//
// Copyright (c) 2025 Xiaomi Corporation

#ifndef SHERPA_CSRC_OFFLINE_SPEAKER_DIARIZATION_IMPL_H_
#define SHERPA_CSRC_OFFLINE_SPEAKER_DIARIZATION_IMPL_H_

#include <functional>
#include <memory>

#include "sherpa/csrc/offline-speaker-diarization.h"
namespace sherpa {

class OfflineSpeakerDiarizationImpl {
public:
static std::unique_ptr<OfflineSpeakerDiarizationImpl> Create(
const OfflineSpeakerDiarizationConfig &config);

virtual ~OfflineSpeakerDiarizationImpl() = default;

virtual int32_t SampleRate() const = 0;

// Note: Only config.clustering is used. All other fields in config are
// ignored
virtual void SetConfig(const OfflineSpeakerDiarizationConfig &config) = 0;

virtual OfflineSpeakerDiarizationResult Process(
torch::Tensor samples,
OfflineSpeakerDiarizationProgressCallback callback = nullptr,
void *callback_arg = nullptr) const = 0;
};

} // namespace sherpa

#endif // SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_DIARIZATION_IMPL_H_
Loading