diff --git a/CMakeLists.txt b/CMakeLists.txt index 4f3fc5f4e..5485c13cd 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -146,6 +146,7 @@ include(torch) include(k2) include(kaldifeat) include(kaldi_native_io) +include(hclust-cpp) if(SHERPA_ENABLE_PORTAUDIO) include(portaudio) endif() diff --git a/cmake/hclust-cpp.cmake b/cmake/hclust-cpp.cmake new file mode 100644 index 000000000..c84ccafc8 --- /dev/null +++ b/cmake/hclust-cpp.cmake @@ -0,0 +1,47 @@ +function(download_hclust_cpp) + include(FetchContent) + + # The latest commit as of 2024.09.29 + set(hclust_cpp_URL "https://github.com/csukuangfj/hclust-cpp/archive/refs/tags/2024-09-29.tar.gz") + set(hclust_cpp_URL2 "https://hf-mirror.com/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/hclust-cpp-2024-09-29.tar.gz") + set(hclust_cpp_HASH "SHA256=abab51448a3cb54272aae07522970306e0b2cc6479d59d7b19e7aee4d6cedd33") + + # If you don't have access to the Internet, + # please pre-download hclust-cpp + set(possible_file_locations + $ENV{HOME}/Downloads/hclust-cpp-2024-09-29.tar.gz + ${CMAKE_SOURCE_DIR}/hclust-cpp-2024-09-29.tar.gz + ${CMAKE_BINARY_DIR}/hclust-cpp-2024-09-29.tar.gz + /tmp/hclust-cpp-2024-09-29.tar.gz + /star-fj/fangjun/download/github/hclust-cpp-2024-09-29.tar.gz + ) + + foreach(f IN LISTS possible_file_locations) + if(EXISTS ${f}) + set(hclust_cpp_URL "${f}") + file(TO_CMAKE_PATH "${hclust_cpp_URL}" hclust_cpp_URL) + message(STATUS "Found local downloaded hclust_cpp: ${hclust_cpp_URL}") + set(hclust_cpp_URL2) + break() + endif() + endforeach() + + FetchContent_Declare(hclust_cpp + URL + ${hclust_cpp_URL} + ${hclust_cpp_URL2} + URL_HASH ${hclust_cpp_HASH} + ) + + FetchContent_GetProperties(hclust_cpp) + if(NOT hclust_cpp_POPULATED) + message(STATUS "Downloading hclust_cpp from ${hclust_cpp_URL}") + FetchContent_Populate(hclust_cpp) + endif() + + message(STATUS "hclust_cpp is downloaded to ${hclust_cpp_SOURCE_DIR}") + message(STATUS "hclust_cpp's binary dir is ${hclust_cpp_BINARY_DIR}") + include_directories(${hclust_cpp_SOURCE_DIR}) +endfunction() + +download_hclust_cpp() diff --git a/sherpa/csrc/CMakeLists.txt b/sherpa/csrc/CMakeLists.txt index 66573a422..342e8f447 100644 --- a/sherpa/csrc/CMakeLists.txt +++ b/sherpa/csrc/CMakeLists.txt @@ -48,6 +48,15 @@ set(sherpa_srcs speaker-embedding-extractor-model.cc speaker-embedding-extractor.cc speaker-embedding-extractor-impl.cc + # + offline-speaker-diarization-result.cc + offline-speaker-segmentation-model-config.cc + offline-speaker-segmentation-pyannote-model-config.cc + fast-clustering-config.cc + fast-clustering.cc + offline-speaker-segmentation-pyannote-model.cc + ./offline-speaker-diarization-impl.cc + ./offline-speaker-diarization.cc ) add_library(sherpa_core ${sherpa_srcs}) @@ -136,6 +145,9 @@ target_link_libraries(sherpa-vad sherpa_core) add_executable(sherpa-compute-speaker-similarity sherpa-compute-speaker-similarity.cc) target_link_libraries(sherpa-compute-speaker-similarity sherpa_core) +add_executable(sherpa-speaker-diarization sherpa-speaker-diarization.cc) +target_link_libraries(sherpa-speaker-diarization sherpa_core) + install(TARGETS sherpa_core DESTINATION lib @@ -146,5 +158,6 @@ install( sherpa-version sherpa-vad sherpa-compute-speaker-similarity + sherpa-speaker-diarization DESTINATION bin ) diff --git a/sherpa/csrc/fast-clustering-config.cc b/sherpa/csrc/fast-clustering-config.cc new file mode 100644 index 000000000..3f3bcf894 --- /dev/null +++ b/sherpa/csrc/fast-clustering-config.cc @@ -0,0 +1,46 @@ +// sherpa/csrc/fast-clustering-config.cc +// +// Copyright (c) 2025 Xiaomi Corporation + +#include "sherpa/csrc/fast-clustering-config.h" + +#include +#include + +#include "sherpa/csrc/macros.h" + +namespace sherpa { + +std::string FastClusteringConfig::ToString() const { + std::ostringstream os; + + os << "FastClusteringConfig("; + os << "num_clusters=" << num_clusters << ", "; + os << "threshold=" << threshold << ")"; + + return os.str(); +} + +void FastClusteringConfig::Register(ParseOptions *po) { + po->Register( + "num-clusters", &num_clusters, + "Number of cluster. If greater than 0, then cluster threshold is " + "ignored. Please provide it if you know the actual number of " + "clusters in advance."); + + po->Register("cluster-threshold", &threshold, + "If num_clusters is not specified, then it specifies the " + "distance threshold for clustering. smaller value -> more " + "clusters. larger value -> fewer clusters"); +} + +bool FastClusteringConfig::Validate() const { + if (num_clusters < 1 && threshold < 0) { + SHERPA_LOGE("Please provide either num_clusters or threshold"); + return false; + } + + return true; +} + +} // namespace sherpa diff --git a/sherpa/csrc/fast-clustering-config.h b/sherpa/csrc/fast-clustering-config.h new file mode 100644 index 000000000..1d8817d2b --- /dev/null +++ b/sherpa/csrc/fast-clustering-config.h @@ -0,0 +1,39 @@ +// sherpa/csrc/fast-clustering-config.h +// +// Copyright (c) 2025 Xiaomi Corporation + +#ifndef SHERPA_CSRC_FAST_CLUSTERING_CONFIG_H_ +#define SHERPA_CSRC_FAST_CLUSTERING_CONFIG_H_ + +#include + +#include "sherpa/cpp_api/parse-options.h" + +namespace sherpa { + +struct FastClusteringConfig { + // If greater than 0, then threshold is ignored. + // + // We strongly recommend that you set it if you know the number of clusters + // in advance + int32_t num_clusters = -1; + + // distance threshold. + // + // The smaller, the more clusters it will generate. + // The larger, the fewer clusters it will generate. + float threshold = 0.5; + + FastClusteringConfig() = default; + + FastClusteringConfig(int32_t num_clusters, float threshold) + : num_clusters(num_clusters), threshold(threshold) {} + + std::string ToString() const; + + void Register(ParseOptions *po); + bool Validate() const; +}; + +} // namespace sherpa +#endif // SHERPA_CSRC_FAST_CLUSTERING_CONFIG_H_ diff --git a/sherpa/csrc/fast-clustering.cc b/sherpa/csrc/fast-clustering.cc new file mode 100644 index 000000000..c037d7b22 --- /dev/null +++ b/sherpa/csrc/fast-clustering.cc @@ -0,0 +1,86 @@ +// sherpa/csrc/fast-clustering.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa/csrc/fast-clustering.h" + +#include + +#include "torch/torch.h" +// +#include "fastcluster-all-in-one.h" // NOLINT + +namespace sherpa { + +class FastClustering::Impl { + public: + explicit Impl(const FastClusteringConfig &config) : config_(config) {} + + std::vector Cluster(float *features, int32_t num_rows, + int32_t num_cols) const { + if (num_rows <= 0) { + return {}; + } + + if (num_rows == 1) { + return {0}; + } + + torch::Tensor t = + torch::from_blob(features, {num_rows, num_cols}, torch::kFloat); + + // torch::nn::functional::normalize( + // t, torch::nn::functional::NormalizeFuncOptions().p(2).dim(1)); + t.div_(t.norm(2, 1, true)); + + std::vector distance((num_rows * (num_rows - 1)) / 2); + + int32_t k = 0; + for (int32_t i = 0; i != num_rows; ++i) { + auto v = t.index({i}); + for (int32_t j = i + 1; j != num_rows; ++j) { + double cosine_similarity = v.dot(t.index({j})).item().toDouble(); + double consine_dissimilarity = 1 - cosine_similarity; + + if (consine_dissimilarity < 0) { + consine_dissimilarity = 0; + } + + distance[k] = consine_dissimilarity; + ++k; + } + } + + std::vector merge(2 * (num_rows - 1)); + std::vector height(num_rows - 1); + + fastclustercpp::hclust_fast(num_rows, distance.data(), + fastclustercpp::HCLUST_METHOD_COMPLETE, + merge.data(), height.data()); + + std::vector labels(num_rows); + if (config_.num_clusters > 0) { + fastclustercpp::cutree_k(num_rows, merge.data(), config_.num_clusters, + labels.data()); + } else { + fastclustercpp::cutree_cdist(num_rows, merge.data(), height.data(), + config_.threshold, labels.data()); + } + + return labels; + } + + private: + FastClusteringConfig config_; +}; + +FastClustering::FastClustering(const FastClusteringConfig &config) + : impl_(std::make_unique(config)) {} + +FastClustering::~FastClustering() = default; + +std::vector FastClustering::Cluster(float *features, int32_t num_rows, + int32_t num_cols) const { + return impl_->Cluster(features, num_rows, num_cols); +} +} // namespace sherpa diff --git a/sherpa/csrc/fast-clustering.h b/sherpa/csrc/fast-clustering.h new file mode 100644 index 000000000..7738558bf --- /dev/null +++ b/sherpa/csrc/fast-clustering.h @@ -0,0 +1,44 @@ +// sherpa/csrc/fast-clustering.h +// +// Copyright (c) 2025 Xiaomi Corporation + +#ifndef SHERPA_CSRC_FAST_CLUSTERING_H_ +#define SHERPA_CSRC_FAST_CLUSTERING_H_ + +#include +#include + +#include "sherpa/cpp_api/parse-options.h" +#include "sherpa/csrc/fast-clustering-config.h" + +namespace sherpa { + +class FastClustering { + public: + explicit FastClustering(const FastClusteringConfig &config); + ~FastClustering(); + + /** + * @param features Pointer to a 2-D feature matrix in row major. Each row + * is a feature frame. It is changed in-place. We will + * convert each feature frame to a normalized vector. + * That is, the L2-norm of each vector will be equal to 1. + * It uses cosine dissimilarity, + * which is 1 - (cosine similarity) + * @param num_rows Number of feature frames + * @param num-cols The feature dimension. + * + * @return Return a vector of size num_rows. ans[i] contains the label + * for the i-th feature frame, i.e., the i-th row of the feature + * matrix. + */ + std::vector Cluster(float *features, int32_t num_rows, + int32_t num_cols) const; + + private: + class Impl; + std::unique_ptr impl_; +}; + +} // namespace sherpa +#endif // SHERPA_CSRC_FAST_CLUSTERING_H_ diff --git a/sherpa/csrc/offline-speaker-diarization-impl.cc b/sherpa/csrc/offline-speaker-diarization-impl.cc new file mode 100644 index 000000000..2a1e6dbaf --- /dev/null +++ b/sherpa/csrc/offline-speaker-diarization-impl.cc @@ -0,0 +1,26 @@ +// sherpa/csrc/offline-speaker-diarization-impl.cc +// +// Copyright (c) 2025 Xiaomi Corporation + +#include "sherpa/csrc/offline-speaker-diarization-impl.h" + +#include + +#include "sherpa/csrc/macros.h" +#include "sherpa/csrc/offline-speaker-diarization-pyannote-impl.h" + +namespace sherpa { + +std::unique_ptr +OfflineSpeakerDiarizationImpl::Create( + const OfflineSpeakerDiarizationConfig &config) { + if (!config.segmentation.pyannote.model.empty()) { + return std::make_unique(config); + } + + SHERPA_LOGE("Please specify a speaker segmentation model."); + + return nullptr; +} + +} // namespace sherpa diff --git a/sherpa/csrc/offline-speaker-diarization-impl.h b/sherpa/csrc/offline-speaker-diarization-impl.h new file mode 100644 index 000000000..d4f090e00 --- /dev/null +++ b/sherpa/csrc/offline-speaker-diarization-impl.h @@ -0,0 +1,35 @@ +// sherpa/csrc/offline-speaker-diarization-impl.h +// +// Copyright (c) 2025 Xiaomi Corporation + +#ifndef SHERPA_CSRC_OFFLINE_SPEAKER_DIARIZATION_IMPL_H_ +#define SHERPA_CSRC_OFFLINE_SPEAKER_DIARIZATION_IMPL_H_ + +#include +#include + +#include "sherpa/csrc/offline-speaker-diarization.h" +namespace sherpa { + +class OfflineSpeakerDiarizationImpl { + public: + static std::unique_ptr Create( + const OfflineSpeakerDiarizationConfig &config); + + virtual ~OfflineSpeakerDiarizationImpl() = default; + + virtual int32_t SampleRate() const = 0; + + // Note: Only config.clustering is used. All other fields in config are + // ignored + virtual void SetConfig(const OfflineSpeakerDiarizationConfig &config) = 0; + + virtual OfflineSpeakerDiarizationResult Process( + torch::Tensor samples, + OfflineSpeakerDiarizationProgressCallback callback = nullptr, + void *callback_arg = nullptr) const = 0; +}; + +} // namespace sherpa + +#endif // SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_DIARIZATION_IMPL_H_ diff --git a/sherpa/csrc/offline-speaker-diarization-pyannote-impl.h b/sherpa/csrc/offline-speaker-diarization-pyannote-impl.h new file mode 100644 index 000000000..df6bcd979 --- /dev/null +++ b/sherpa/csrc/offline-speaker-diarization-pyannote-impl.h @@ -0,0 +1,645 @@ +// sherpa/csrc/offline-speaker-diarization-pyannote-impl.h +// +// Copyright (c) 2025 Xiaomi Corporation +#ifndef SHERPA_CSRC_OFFLINE_SPEAKER_DIARIZATION_PYANNOTE_IMPL_H_ +#define SHERPA_CSRC_OFFLINE_SPEAKER_DIARIZATION_PYANNOTE_IMPL_H_ + +#include +#include +#include +#include +#include +#include + +#include "sherpa/csrc/fast-clustering.h" +#include "sherpa/csrc/math.h" +#include "sherpa/csrc/offline-speaker-diarization-impl.h" +#include "sherpa/csrc/offline-speaker-segmentation-pyannote-model.h" +#include "sherpa/csrc/speaker-embedding-extractor.h" + +namespace sherpa { + +namespace { // NOLINT + +// copied from https://github.com/k2-fsa/k2/blob/master/k2/csrc/host/util.h#L41 +template +inline void hash_combine(std::size_t *seed, const T &v) { // NOLINT + std::hash hasher; + *seed ^= hasher(v) + 0x9e3779b9 + ((*seed) << 6) + ((*seed) >> 2); // NOLINT +} + +// copied from https://github.com/k2-fsa/k2/blob/master/k2/csrc/host/util.h#L47 +struct PairHash { + template + std::size_t operator()(const std::pair &pair) const { + std::size_t result = 0; + hash_combine(&result, pair.first); + hash_combine(&result, pair.second); + return result; + } +}; +} // namespace + +using Int32Pair = std::pair; + +class OfflineSpeakerDiarizationPyannoteImpl + : public OfflineSpeakerDiarizationImpl { + public: + ~OfflineSpeakerDiarizationPyannoteImpl() override = default; + + explicit OfflineSpeakerDiarizationPyannoteImpl( + const OfflineSpeakerDiarizationConfig &config) + : config_(config), + segmentation_model_(config_.segmentation), + embedding_extractor_(config_.embedding), + clustering_(std::make_unique(config_.clustering)) { + InitPowersetMapping(); + std::cout << "powerset_mapping: " << powerset_mapping_ << "\n"; + } + + int32_t SampleRate() const override { + const auto &meta_data = segmentation_model_.GetModelMetaData(); + + return meta_data.sample_rate; + } + + void SetConfig(const OfflineSpeakerDiarizationConfig &config) override { + if (!config.clustering.Validate()) { + SHERPA_LOGE("Invalid clustering config. Skip it"); + return; + } + clustering_ = std::make_unique(config.clustering); + config_.clustering = config.clustering; + } + + OfflineSpeakerDiarizationResult Process( + torch::Tensor samples, + OfflineSpeakerDiarizationProgressCallback callback = nullptr, + void *callback_arg = nullptr) const override { + std::cout << "samples: " << samples.sizes() << "\n"; + if (samples.dim() != 2) { + SHERPA_LOGE("Support only 2-d tensors. Given: %d", + static_cast(samples.dim())); + return {}; + } + + if (samples.size(0) != 1) { + SHERPA_LOGE("Support only batch size == 1. Given: %d", + static_cast(samples.size(0))); + return {}; + } + + std::cout << "samples.sizes: " << samples.sizes() << "\n"; + torch::Tensor log_probs = RunSpeakerSegmentationModel(samples); + std::cout << "log_probs.sizes: " << log_probs.sizes() << "\n"; + // A chunk is a window_size samples + // log_probs: (num_chunks, num_frames, 7) + // where 7 is the num_powerset_classes + + torch::Tensor labels = ToMultiLabel(log_probs); + std::cout << "labels.sizes: " << labels.sizes() << "\n"; + + // labels.sizes: (num_chunks, num_frames, 3) + // where 3 is num_speakers + + torch::Tensor speakers_per_frame = ComputeSpeakersPerFrame(labels); + if (speakers_per_frame.argmax().item().toInt() == 0) { + SHERPA_LOGE("No speakers found"); + return {}; + } + std::cout << "speakers_per_frame.sizes " << speakers_per_frame.sizes() + << "\n"; + + auto chunk_speaker_samples_list_pair = GetChunkSpeakerSampleIndexes(labels); + + torch::Tensor embeddings = + ComputeEmbeddings(samples, chunk_speaker_samples_list_pair.second, + std::move(callback), callback_arg); + std::cout << "embedding size: " << embeddings.sizes() << "\n"; + + std::vector cluster_labels = clustering_->Cluster( + embeddings.data_ptr(), embeddings.size(0), embeddings.size(1)); + + int32_t max_cluster_index = + *std::max_element(cluster_labels.begin(), cluster_labels.end()); + + auto chunk_speaker_to_cluster = ConvertChunkSpeakerToCluster( + chunk_speaker_samples_list_pair.first, cluster_labels); + + auto new_labels = + ReLabel(labels, max_cluster_index, chunk_speaker_to_cluster); + std::cout << "new_labels.sizes() " << new_labels.sizes() << "\n"; + + torch::Tensor speaker_count = + ComputeSpeakerCount(new_labels, samples.size(1)); + std::cout << "speaker_count.sizes() " << speaker_count.sizes() << "\n"; + + torch::Tensor final_labels = + FinalizeLabels(speaker_count, speakers_per_frame); + + auto result = ComputeResult(final_labels); + + return result; + } + + torch::Tensor RunSpeakerSegmentationModel(torch::Tensor samples) const { + const auto &meta_data = segmentation_model_.GetModelMetaData(); + int32_t window_size = meta_data.window_size; + int32_t window_shift = meta_data.window_shift; + + int32_t batch_size = samples.size(0); + int32_t num_samples = samples.size(1); + int32_t need_pad = (num_samples < window_size) || + ((num_samples - window_size) % window_shift); + std::cout << "num_samples < window_size: " << (num_samples - window_size) + << "\n"; + std::cout << "((num_samples - window_size) % window_shift): " + << ((num_samples - window_size) % window_shift) << "\n"; + std::cout << "need pad: " << need_pad << "\n"; + + if (need_pad) { + int32_t padding = 0; + if (num_samples < window_size) { + padding = window_size - num_samples; + } else { + padding = window_shift - ((num_samples - window_size) % window_shift); + } + std::cout << "padding size: " << padding << "\n"; + samples = torch::nn::functional::pad( + samples, torch::nn::functional::PadFuncOptions({0, padding, 0, 0}) + .mode(torch::kConstant) + .value(0)); + } + int32_t num_segments = (samples.size(1) - window_size) / window_shift + 1; + + if (need_pad || num_segments > 1) { + samples = samples.as_strided({batch_size, num_segments, window_size}, + {samples.size(1), window_shift, 1}); + } else { + samples = samples.reshape({1, 1, -1}); + } + + samples = samples.reshape({-1, 1, window_size}); + // e.g. samples.sizes: (264, 1, 160000) + + int32_t max_batch_size = 2; + torch::Tensor log_probs; + if (samples.size(0) < max_batch_size) { + log_probs = segmentation_model_.Forward(samples); + } else { + std::vector tmp; + int32_t n = samples.size(0) / max_batch_size; + for (int32_t i = 0; i < n; ++i) { + auto this_batch = + samples.slice(0, i * max_batch_size, (i + 1) * max_batch_size); + std::cout << i << "/" << n << " -> " << this_batch.sizes() << "\n"; + auto this_log_prob = segmentation_model_.Forward(this_batch); + std::cout << " " << this_log_prob.sizes() << "\n"; + tmp.push_back(std::move(this_log_prob)); + } + + if (samples.size(0) % max_batch_size) { + auto this_batch = samples.slice(0, n * max_batch_size); + std::cout << n << " -> " << this_batch.sizes() << "\n"; + auto this_log_prob = segmentation_model_.Forward(this_batch); + std::cout << " " << this_log_prob.sizes() << "\n"; + tmp.push_back(std::move(this_log_prob)); + } + + log_probs = torch::cat(tmp, 0); + } + // e.g. log_probs.sizes: (264, 589, 7) + std::cout << "log_probs.sizes: " << log_probs.sizes() << "\n"; + + return log_probs; + } + + // see + // https://github.com/pyannote/pyannote-audio/blob/develop/pyannote/audio/utils/powerset.py#L103 + torch::Tensor ToMultiLabel(torch::Tensor log_probs) const { + int32_t num_classes = powerset_mapping_.size(0); + auto powerset_probs = torch::nn::functional::one_hot( + torch::argmax(log_probs, -1), num_classes) + .to(torch::kFloat); + std::cout << "powerset_probs.sizes: " << powerset_probs.sizes() << "\n"; + auto labels = torch::matmul(powerset_probs, powerset_mapping_); + std::cout << "labels.sizes: " << labels.sizes() << "\n"; + // labels.size (num_chunks, num_frames, 3) + return labels; + } + + // Return a 1-D int32 tensor of shape (num_frames,) + torch::Tensor ComputeSpeakersPerFrame(torch::Tensor labels) const { + const auto &meta_data = segmentation_model_.GetModelMetaData(); + int32_t window_size = meta_data.window_size; + int32_t window_shift = meta_data.window_shift; + int32_t receptive_field_shift = meta_data.receptive_field_shift; + + int32_t num_chunks = labels.size(0); + + int32_t num_frames = (window_size + (num_chunks - 1) * window_shift) / + receptive_field_shift + + 1; + torch::Tensor count = torch::zeros({num_frames}, torch::kFloat); + torch::Tensor weight = torch::zeros({num_frames}, torch::kFloat); + + for (int32_t i = 0; i != num_chunks; ++i) { + int32_t start = + static_cast(i) * window_shift / receptive_field_shift + 0.5; + + int32_t end = start + labels.size(1); + + count.slice(0, start, end).add_(labels.index({i}).sum(1)); + weight.slice(0, start, end).add_(1); + } + + return (count / (weight + 1e-12f) + 0.5).to(torch::kInt); + } + + // ans.first: a list of (chunk_id, speaker_id) + // ans.second: a list of list of (start_sample_index, end_sample_index) + // + // ans.first[i] corresponds to ans.second[i] + std::pair, std::vector>> + GetChunkSpeakerSampleIndexes(torch::Tensor labels) const { + labels = ExcludeOverlap(labels); + // now labels.dtype is changed from float to int32 + + std::vector chunk_speaker_list; + std::vector> samples_index_list; + + const auto &meta_data = segmentation_model_.GetModelMetaData(); + int32_t window_size = meta_data.window_size; + int32_t window_shift = meta_data.window_shift; + int32_t receptive_field_shift = meta_data.receptive_field_shift; + int32_t num_speakers = meta_data.num_speakers; + + int32_t num_frames = labels.size(1); + int32_t num_chunks = labels.size(0); + for (int32_t chunk_index = 0; chunk_index < num_chunks; ++chunk_index) { + int32_t sample_offset = chunk_index * window_shift; + + torch::Tensor this_chunk = labels.index({chunk_index}).t(); + // this_chunk: (num_speakers, num_frames) + + for (int32_t speaker_index = 0; speaker_index != num_speakers; + ++speaker_index) { + torch::Tensor this_speaker = this_chunk.index({speaker_index}); + if (this_speaker.sum().item().toInt() < 10) { + // skip segments less than 10 frames + continue; + } + + Int32Pair this_chunk_speaker = {chunk_index, speaker_index}; + std::vector this_speaker_samples; + + bool is_active = false; + int32_t start_index = 0; + + auto acc = this_speaker.accessor(); + + for (int32_t k = 0; k != num_frames; ++k) { + if (acc[k] != 0) { + if (!is_active) { + is_active = true; + start_index = k; + } + } else if (is_active) { + is_active = false; + + int32_t start_samples = + static_cast(start_index) / num_frames * window_size + + sample_offset; + int32_t end_samples = + static_cast(k) / num_frames * window_size + + sample_offset; + + this_speaker_samples.emplace_back(start_samples, end_samples); + } + } + + if (is_active) { + int32_t start_samples = + static_cast(start_index) / num_frames * window_size + + sample_offset; + int32_t end_samples = + static_cast(num_frames - 1) / num_frames * window_size + + sample_offset; + this_speaker_samples.emplace_back(start_samples, end_samples); + } + + chunk_speaker_list.push_back(std::move(this_chunk_speaker)); + samples_index_list.push_back(std::move(this_speaker_samples)); + } // for (int32_t speaker_index = 0; + } // for (const auto &label : new_labels) + + return {chunk_speaker_list, samples_index_list}; + } + + private: + void InitPowersetMapping() { + const auto &meta_data = segmentation_model_.GetModelMetaData(); + int32_t num_classes = meta_data.num_classes; + int32_t powerset_max_classes = meta_data.powerset_max_classes; + int32_t num_speakers = meta_data.num_speakers; + + powerset_mapping_ = + torch::zeros({num_classes, num_speakers}, torch::kFloat); + auto acc = powerset_mapping_.accessor(); + + int32_t k = 1; + for (int32_t i = 1; i <= powerset_max_classes; ++i) { + if (i == 1) { + for (int32_t j = 0; j != num_speakers; ++j, ++k) { + acc[k][j] = 1; + } + } else if (i == 2) { + for (int32_t j = 0; j != num_speakers; ++j) { + for (int32_t m = j + 1; m < num_speakers; ++m, ++k) { + acc[k][j] = 1; + acc[k][m] = 1; + } + } + } else { + SHERPA_LOGE("powerset_max_classes = %d is currently not supported!", i); + SHERPA_EXIT(-1); + } + } + } + + // If there are multiple speakers at a frame, then this frame is excluded. + torch::Tensor ExcludeOverlap(torch::Tensor labels) const { + torch::Tensor labels_copy = labels.to(torch::kInt); + + torch::Tensor indexes = labels.sum(-1) > 1; + labels_copy.index_put_({indexes}, 0); + return labels_copy; + } + + torch::Tensor ComputeEmbeddings( + torch::Tensor samples, + const std::vector> &sample_indexes, + OfflineSpeakerDiarizationProgressCallback callback, + void *callback_arg) const { + const auto &meta_data = segmentation_model_.GetModelMetaData(); + int32_t sample_rate = meta_data.sample_rate; + torch::Tensor ans = + torch::empty({static_cast(sample_indexes.size()), + embedding_extractor_.Dim()}, + torch::kFloat); + + int32_t n = samples.size(1); + int32_t k = 0; + int32_t cur_row_index = 0; + const float *ptr = samples.data_ptr(); + for (const auto &v : sample_indexes) { + auto stream = embedding_extractor_.CreateStream(); + std::vector buffer; + + for (const auto &p : v) { + int32_t end = (p.second <= n) ? p.second : n; + buffer.insert(buffer.end(), ptr + p.first, ptr + end); + } + + stream->AcceptSamples(buffer.data(), buffer.size()); + + torch::Tensor embedding = embedding_extractor_.Compute(stream.get()); + ans.index_put_({k}, embedding); + + k += 1; + + if (callback) { + callback(k, ans.size(0), callback_arg); + } + + } // for (const auto &v : sample_indexes) + + return ans; + } + + std::unordered_map ConvertChunkSpeakerToCluster( + const std::vector &chunk_speaker_pair, + const std::vector &cluster_labels) const { + std::unordered_map ans; + + int32_t k = 0; + for (const auto &p : chunk_speaker_pair) { + ans[p] = cluster_labels[k]; + k += 1; + } + + return ans; + } + + torch::Tensor ReLabel(torch::Tensor labels, int32_t max_cluster_index, + const std::unordered_map + &chunk_speaker_to_cluster) const { + int32_t num_chunks = labels.size(0); + + torch::Tensor new_labels = torch::empty( + {num_chunks, labels.size(1), max_cluster_index + 1}, torch::kFloat); + + for (int32_t chunk_index = 0; chunk_index < num_chunks; ++chunk_index) { + auto this_chunk = labels.index({chunk_index}).t(); + // this_chunk: (num_speakers, num_frames) + + torch::Tensor new_label = torch::zeros( + {this_chunk.size(1), max_cluster_index + 1}, torch::kFloat); + + auto this_chunk_acc = this_chunk.accessor(); + auto new_label_acc = new_label.accessor(); + + for (int32_t speaker_index = 0; speaker_index != this_chunk.size(1); + ++speaker_index) { + if (chunk_speaker_to_cluster.count({chunk_index, speaker_index}) == 0) { + continue; + } + + int32_t new_speaker_index = + chunk_speaker_to_cluster.at({chunk_index, speaker_index}); + + for (int32_t k = 0; k != this_chunk.size(1); ++k) { + if (this_chunk_acc[speaker_index][k] == 1) { + new_label_acc[k][new_speaker_index] = 1; + } + } + } + + // TODO(fangjun): Optimize it. No need to create a new_label variable + new_labels.index_put_({chunk_index}, new_label); + + chunk_index += 1; + } + + return new_labels; + } + + torch::Tensor ComputeSpeakerCount(torch::Tensor labels, + int32_t num_samples) const { + const auto &meta_data = segmentation_model_.GetModelMetaData(); + int32_t window_size = meta_data.window_size; + int32_t window_shift = meta_data.window_shift; + int32_t receptive_field_shift = meta_data.receptive_field_shift; + + int32_t num_chunks = labels.size(0); + + int32_t num_frames = (window_size + (num_chunks - 1) * window_shift) / + receptive_field_shift + + 1; + + torch::Tensor count = + torch::zeros({num_frames, labels.size(2)}, torch::kFloat); + + for (int32_t i = 0; i != num_chunks; ++i) { + int32_t start = + static_cast(i) * window_shift / receptive_field_shift + 0.5; + int32_t end = start + labels.size(1); + + count.slice(0, start, end).add_(labels.index({i})); + } + + bool has_last_chunk = ((num_samples - window_size) % window_shift) > 0; + + if (!has_last_chunk) { + return count.to(torch::kInt); + } + + int32_t last_frame = num_samples / receptive_field_shift; + return count.slice(0, 0, last_frame).to(torch::kInt); + } + + // count: float, (num_frames, num_spakers) + // speakers_per_frame: int, (num_frames,) + torch::Tensor FinalizeLabels(torch::Tensor count, + torch::Tensor speakers_per_frame) const { + int32_t num_rows = count.size(0); + int32_t num_cols = count.size(1); + + torch::Tensor ans = torch::zeros({num_rows, num_cols}, torch::kInt); + + auto speaker_acc = speakers_per_frame.accessor(); + auto ans_acc = ans.accessor(); + + for (int32_t i = 0; i != num_rows; ++i) { + int32_t k = speaker_acc[i]; + if (k == 0) { + continue; + } + torch::Tensor values, indexes; + std::tie(values, indexes) = count.index({i}).topk(k, 0, true, true); + + auto indexes_acc = indexes.accessor(); + + for (int32_t m = 0; m < k; ++m) { + ans_acc[i][indexes_acc[m]] = 1; + } + } + + return ans; + } + + void MergeSegments( + std::vector *segments) const { + float min_duration_off = config_.min_duration_off; + bool changed = true; + while (changed) { + changed = false; + for (int32_t i = 0; i < static_cast(segments->size()) - 1; ++i) { + auto s = (*segments)[i].Merge((*segments)[i + 1], min_duration_off); + if (s) { + (*segments)[i] = s.value(); + segments->erase(segments->begin() + i + 1); + + changed = true; + break; + } + } + } + } + + OfflineSpeakerDiarizationResult ComputeResult( + torch::Tensor final_labels) const { + torch::Tensor final_labels_t = final_labels.t(); + + int32_t num_speakers = final_labels_t.size(0); + int32_t num_frames = final_labels_t.size(1); + auto acc = final_labels_t.accessor(); + + const auto &meta_data = segmentation_model_.GetModelMetaData(); + int32_t window_size = meta_data.window_size; + int32_t window_shift = meta_data.window_shift; + int32_t receptive_field_shift = meta_data.receptive_field_shift; + int32_t receptive_field_size = meta_data.receptive_field_size; + int32_t sample_rate = meta_data.sample_rate; + + float scale = static_cast(receptive_field_shift) / sample_rate; + float scale_offset = 0.5 * receptive_field_size / sample_rate; + + OfflineSpeakerDiarizationResult ans; + + for (int32_t speaker_index = 0; speaker_index != num_speakers; + ++speaker_index) { + std::vector this_speaker; + + bool is_active = acc[speaker_index][0] > 0; + int32_t start_index = is_active ? 0 : -1; + + for (int32_t frame_index = 1; frame_index != num_frames; ++frame_index) { + if (is_active) { + if (acc[speaker_index][frame_index] == 0) { + float start_time = start_index * scale + scale_offset; + float end_time = frame_index * scale + scale_offset; + + OfflineSpeakerDiarizationSegment segment(start_time, end_time, + speaker_index); + this_speaker.push_back(segment); + + is_active = false; + } + } else if (acc[speaker_index][frame_index] == 1) { + is_active = true; + start_index = frame_index; + } + } + + if (is_active) { + float start_time = start_index * scale + scale_offset; + float end_time = (num_frames - 1) * scale + scale_offset; + + OfflineSpeakerDiarizationSegment segment(start_time, end_time, + speaker_index); + this_speaker.push_back(segment); + } + + // merge segments if the gap between them is less than min_duration_off + MergeSegments(&this_speaker); + + for (const auto &seg : this_speaker) { + if (seg.Duration() > config_.min_duration_on) { + ans.Add(seg); + } + } + } // for (int32_t speaker_index = 0; speaker_index != num_speakers; + + return ans; + } + + private: + OfflineSpeakerDiarizationConfig config_; + OfflineSpeakerSegmentationPyannoteModel segmentation_model_; + SpeakerEmbeddingExtractor embedding_extractor_; + std::unique_ptr clustering_; + torch::Tensor powerset_mapping_; // 2-d float tensor + /* + 0 0 0 // 0 + 1 0 0 // 1 + 0 1 0 // 2 + 0 0 1 // 3 + 1 1 0 // 4 + 1 0 1 // 5 + 0 1 1 // 6 + */ +}; + +} // namespace sherpa +#endif // SHERPA_NNX_CSRC_OFFLINE_SPEAKER_DIARIZATION_PYANNOTE_IMPL_H_ diff --git a/sherpa/csrc/offline-speaker-diarization-result.cc b/sherpa/csrc/offline-speaker-diarization-result.cc new file mode 100644 index 000000000..5099fbdcc --- /dev/null +++ b/sherpa/csrc/offline-speaker-diarization-result.cc @@ -0,0 +1,113 @@ +// sherpa/csrc/offline-speaker-diarization-result.cc +// +// Copyright (c) 2025 Xiaomi Corporation + +#include "sherpa/csrc/offline-speaker-diarization-result.h" + +#include +#include +#include +#include +#include +#include + +#include "sherpa/csrc/macros.h" + +namespace sherpa { + +OfflineSpeakerDiarizationSegment::OfflineSpeakerDiarizationSegment( + float start, float end, int32_t speaker, const std::string &text /*= {}*/) { + if (start > end) { + SHERPA_LOGE("start %.3f should be less than end %.3f", start, end); + SHERPA_EXIT(-1); + } + + start_ = start; + end_ = end; + speaker_ = speaker; + text_ = text; +} + +std::optional +OfflineSpeakerDiarizationSegment::Merge( + const OfflineSpeakerDiarizationSegment &other, float gap) const { + if (other.speaker_ != speaker_) { + SHERPA_LOGE( + "The two segments should have the same speaker. this->speaker: %d, " + "other.speaker: %d", + speaker_, other.speaker_); + return std::nullopt; + } + + if (end_ < other.start_ && end_ + gap >= other.start_) { + return OfflineSpeakerDiarizationSegment(start_, other.end_, speaker_); + } else if (other.end_ < start_ && other.end_ + gap >= start_) { + return OfflineSpeakerDiarizationSegment(other.start_, end_, speaker_); + } else { + return std::nullopt; + } +} + +std::string OfflineSpeakerDiarizationSegment::ToString() const { + std::array s{}; + + snprintf(s.data(), s.size(), "%.3f -- %.3f speaker_%02d", start_, end_, + speaker_); + + std::ostringstream os; + os << s.data(); + + if (!text_.empty()) { + os << " " << text_; + } + + return os.str(); +} + +void OfflineSpeakerDiarizationResult::Add( + const OfflineSpeakerDiarizationSegment &segment) { + segments_.push_back(segment); +} + +int32_t OfflineSpeakerDiarizationResult::NumSpeakers() const { + std::unordered_set count; + for (const auto &s : segments_) { + count.insert(s.Speaker()); + } + + return count.size(); +} + +int32_t OfflineSpeakerDiarizationResult::NumSegments() const { + return segments_.size(); +} + +// Return a list of segments sorted by segment.start time +std::vector +OfflineSpeakerDiarizationResult::SortByStartTime() const { + auto ans = segments_; + std::sort(ans.begin(), ans.end(), [](const auto &a, const auto &b) { + return (a.Start() < b.Start()) || + ((a.Start() == b.Start()) && (a.Speaker() < b.Speaker())); + }); + + return ans; +} + +std::vector> +OfflineSpeakerDiarizationResult::SortBySpeaker() const { + auto tmp = segments_; + std::sort(tmp.begin(), tmp.end(), [](const auto &a, const auto &b) { + return (a.Speaker() < b.Speaker()) || + ((a.Speaker() == b.Speaker()) && (a.Start() < b.Start())); + }); + + std::vector> ans(NumSpeakers()); + for (auto &s : tmp) { + ans[s.Speaker()].push_back(std::move(s)); + } + + return ans; +} + +} // namespace sherpa diff --git a/sherpa/csrc/offline-speaker-diarization-result.h b/sherpa/csrc/offline-speaker-diarization-result.h new file mode 100644 index 000000000..ce5f711ac --- /dev/null +++ b/sherpa/csrc/offline-speaker-diarization-result.h @@ -0,0 +1,67 @@ +// sherpa/csrc/offline-speaker-diarization-result.h +// +// Copyright (c) 2025 Xiaomi Corporation + +#ifndef SHERPA_CSRC_OFFLINE_SPEAKER_DIARIZATION_RESULT_H_ +#define SHERPA_CSRC_OFFLINE_SPEAKER_DIARIZATION_RESULT_H_ + +#include +#include +#include +#include + +namespace sherpa { + +class OfflineSpeakerDiarizationSegment { + public: + OfflineSpeakerDiarizationSegment(float start, float end, int32_t speaker, + const std::string &text = {}); + + // If the gap between the two segments is less than the given gap, then we + // merge them and return a new segment. Otherwise, it returns null. + std::optional Merge( + const OfflineSpeakerDiarizationSegment &other, float gap) const; + + float Start() const { return start_; } + float End() const { return end_; } + int32_t Speaker() const { return speaker_; } + const std::string &Text() const { return text_; } + float Duration() const { return end_ - start_; } + + void SetText(const std::string &text) { text_ = text; } + + std::string ToString() const; + + private: + float start_; // in seconds + float end_; // in seconds + int32_t speaker_; // ID of the speaker, starting from 0 + std::string text_; // If not empty, it contains the speech recognition result + // of this segment +}; + +class OfflineSpeakerDiarizationResult { + public: + // Add a new segment + void Add(const OfflineSpeakerDiarizationSegment &segment); + + // Number of distinct speakers contained in this object at this point + int32_t NumSpeakers() const; + + int32_t NumSegments() const; + + // Return a list of segments sorted by segment.start time + std::vector SortByStartTime() const; + + // ans.size() == NumSpeakers(). + // ans[i] is for speaker_i and is sorted by start time + std::vector> SortBySpeaker() + const; + + private: + std::vector segments_; +}; + +} // namespace sherpa + +#endif // SHERPA_CSRC_OFFLINE_SPEAKER_DIARIZATION_RESULT_H_ diff --git a/sherpa/csrc/offline-speaker-diarization.cc b/sherpa/csrc/offline-speaker-diarization.cc new file mode 100644 index 000000000..b629cfb55 --- /dev/null +++ b/sherpa/csrc/offline-speaker-diarization.cc @@ -0,0 +1,96 @@ +// sherpa/csrc/offline-speaker-diarization.cc +// +// Copyright (c) 2025 Xiaomi Corporation + +#include "sherpa/csrc/offline-speaker-diarization.h" + +#include +#include + +#include "sherpa/csrc/macros.h" +#include "sherpa/csrc/offline-speaker-diarization-impl.h" + +namespace sherpa { + +void OfflineSpeakerDiarizationConfig::Register(ParseOptions *po) { + ParseOptions po_segmentation("segmentation", po); + segmentation.Register(&po_segmentation); + + ParseOptions po_embedding("embedding", po); + embedding.Register(&po_embedding); + + ParseOptions po_clustering("clustering", po); + clustering.Register(&po_clustering); + + po->Register("min-duration-on", &min_duration_on, + "if a segment is less than this value, then it is discarded. " + "Set it to 0 so that no segment is discarded"); + + po->Register("min-duration-off", &min_duration_off, + "if the gap between to segments of the same speaker is less " + "than this value, then these two segments are merged into a " + "single segment. We do it recursively."); +} + +bool OfflineSpeakerDiarizationConfig::Validate() const { + if (!segmentation.Validate()) { + return false; + } + + if (!embedding.Validate()) { + return false; + } + + if (!clustering.Validate()) { + return false; + } + + if (min_duration_on < 0) { + SHERPA_LOGE("min_duration_on %.3f is negative", min_duration_on); + return false; + } + + if (min_duration_off < 0) { + SHERPA_LOGE("min_duration_off %.3f is negative", min_duration_off); + return false; + } + + return true; +} + +std::string OfflineSpeakerDiarizationConfig::ToString() const { + std::ostringstream os; + + os << "OfflineSpeakerDiarizationConfig("; + os << "segmentation=" << segmentation.ToString() << ", "; + os << "embedding=" << embedding.ToString() << ", "; + os << "clustering=" << clustering.ToString() << ", "; + os << "min_duration_on=" << min_duration_on << ", "; + os << "min_duration_off=" << min_duration_off << ")"; + + return os.str(); +} + +OfflineSpeakerDiarization::OfflineSpeakerDiarization( + const OfflineSpeakerDiarizationConfig &config) + : impl_(OfflineSpeakerDiarizationImpl::Create(config)) {} + +OfflineSpeakerDiarization::~OfflineSpeakerDiarization() = default; + +int32_t OfflineSpeakerDiarization::SampleRate() const { + return impl_->SampleRate(); +} + +void OfflineSpeakerDiarization::SetConfig( + const OfflineSpeakerDiarizationConfig &config) { + impl_->SetConfig(config); +} + +OfflineSpeakerDiarizationResult OfflineSpeakerDiarization::Process( + torch::Tensor samples, + OfflineSpeakerDiarizationProgressCallback callback /*= nullptr*/, + void *callback_arg /*= nullptr*/) const { + return impl_->Process(samples, std::move(callback), callback_arg); +} + +} // namespace sherpa diff --git a/sherpa/csrc/offline-speaker-diarization.h b/sherpa/csrc/offline-speaker-diarization.h new file mode 100644 index 000000000..abc18bf8e --- /dev/null +++ b/sherpa/csrc/offline-speaker-diarization.h @@ -0,0 +1,82 @@ +// sherpa/csrc/offline-speaker-diarization.h +// +// Copyright (c) 2025 Xiaomi Corporation + +#ifndef SHERPA_CSRC_OFFLINE_SPEAKER_DIARIZATION_H_ +#define SHERPA_CSRC_OFFLINE_SPEAKER_DIARIZATION_H_ + +#include +#include +#include + +#include "sherpa/cpp_api/parse-options.h" +#include "sherpa/csrc/fast-clustering-config.h" +#include "sherpa/csrc/offline-speaker-diarization-result.h" +#include "sherpa/csrc/offline-speaker-segmentation-model-config.h" +#include "sherpa/csrc/speaker-embedding-extractor.h" + +namespace sherpa { + +struct OfflineSpeakerDiarizationConfig { + OfflineSpeakerSegmentationModelConfig segmentation; + SpeakerEmbeddingExtractorConfig embedding; + FastClusteringConfig clustering; + + // if a segment is less than this value, then it is discarded + float min_duration_on = 0.3; // in seconds + + // if the gap between to segments of the same speaker is less than this value, + // then these two segments are merged into a single segment. + // We do this recursively. + float min_duration_off = 0.5; // in seconds + + OfflineSpeakerDiarizationConfig() = default; + + OfflineSpeakerDiarizationConfig( + const OfflineSpeakerSegmentationModelConfig &segmentation, + const SpeakerEmbeddingExtractorConfig &embedding, + const FastClusteringConfig &clustering, float min_duration_on, + float min_duration_off) + : segmentation(segmentation), + embedding(embedding), + clustering(clustering), + min_duration_on(min_duration_on), + min_duration_off(min_duration_off) {} + + void Register(ParseOptions *po); + bool Validate() const; + std::string ToString() const; +}; + +class OfflineSpeakerDiarizationImpl; + +using OfflineSpeakerDiarizationProgressCallback = std::function; + +class OfflineSpeakerDiarization { + public: + explicit OfflineSpeakerDiarization( + const OfflineSpeakerDiarizationConfig &config); + + ~OfflineSpeakerDiarization(); + + // Expected sample rate of the input audio samples + int32_t SampleRate() const; + + // Note: Only config.clustering is used. All other fields in config are + // ignored + void SetConfig(const OfflineSpeakerDiarizationConfig &config); + + // @param samples 2-D tensor of shape (batch_size, num_samples) + OfflineSpeakerDiarizationResult Process( + torch::Tensor samples, + OfflineSpeakerDiarizationProgressCallback callback = nullptr, + void *callback_arg = nullptr) const; + + private: + std::unique_ptr impl_; +}; + +} // namespace sherpa + +#endif // SHERPA_CSRC_OFFLINE_SPEAKER_DIARIZATION_H_ diff --git a/sherpa/csrc/offline-speaker-segmentation-model-config.cc b/sherpa/csrc/offline-speaker-segmentation-model-config.cc new file mode 100644 index 000000000..9784681e2 --- /dev/null +++ b/sherpa/csrc/offline-speaker-segmentation-model-config.cc @@ -0,0 +1,46 @@ +// sherpa/csrc/offline-speaker-segmentation-model-config.cc +// +// Copyright (c) 2025 Xiaomi Corporation +#include "sherpa/csrc/offline-speaker-segmentation-model-config.h" + +#include +#include + +#include "sherpa/csrc/macros.h" + +namespace sherpa { + +void OfflineSpeakerSegmentationModelConfig::Register(ParseOptions *po) { + pyannote.Register(po); + + po->Register("use-gpu", &use_gpu, "true to use GPU."); + + po->Register("debug", &debug, + "true to print model information while loading it."); +} + +bool OfflineSpeakerSegmentationModelConfig::Validate() const { + if (!pyannote.model.empty()) { + return pyannote.Validate(); + } + + if (pyannote.model.empty()) { + SHERPA_LOGE("You have to provide at least one speaker segmentation model"); + return false; + } + + return true; +} + +std::string OfflineSpeakerSegmentationModelConfig::ToString() const { + std::ostringstream os; + + os << "OfflineSpeakerSegmentationModelConfig("; + os << "pyannote=" << pyannote.ToString() << ", "; + os << "use_gpu=" << (use_gpu ? "True" : "False") << ", "; + os << "debug=" << (debug ? "True" : "False") << ")"; + + return os.str(); +} + +} // namespace sherpa diff --git a/sherpa/csrc/offline-speaker-segmentation-model-config.h b/sherpa/csrc/offline-speaker-segmentation-model-config.h new file mode 100644 index 000000000..4563d63e1 --- /dev/null +++ b/sherpa/csrc/offline-speaker-segmentation-model-config.h @@ -0,0 +1,36 @@ +// sherpa/csrc/offline-speaker-segmentation-model-config.h +// +// Copyright (c) 2025 Xiaomi Corporation +#ifndef SHERPA_CSRC_OFFLINE_SPEAKER_SEGMENTATION_MODEL_CONFIG_H_ +#define SHERPA_CSRC_OFFLINE_SPEAKER_SEGMENTATION_MODEL_CONFIG_H_ + +#include + +#include "sherpa/cpp_api/parse-options.h" +#include "sherpa/csrc/offline-speaker-segmentation-pyannote-model-config.h" + +namespace sherpa { + +struct OfflineSpeakerSegmentationModelConfig { + OfflineSpeakerSegmentationPyannoteModelConfig pyannote; + + bool use_gpu = false; + bool debug = false; + + OfflineSpeakerSegmentationModelConfig() = default; + + explicit OfflineSpeakerSegmentationModelConfig( + const OfflineSpeakerSegmentationPyannoteModelConfig &pyannote, + bool use_gpu, bool debug) + : pyannote(pyannote), use_gpu(use_gpu), debug(debug) {} + + void Register(ParseOptions *po); + + bool Validate() const; + + std::string ToString() const; +}; + +} // namespace sherpa + +#endif // SHERPA_CSRC_OFFLINE_SPEAKER_SEGMENTATION_MODEL_CONFIG_H_ diff --git a/sherpa/csrc/offline-speaker-segmentation-pyannote-model-config.cc b/sherpa/csrc/offline-speaker-segmentation-pyannote-model-config.cc new file mode 100644 index 000000000..8dd9f39e3 --- /dev/null +++ b/sherpa/csrc/offline-speaker-segmentation-pyannote-model-config.cc @@ -0,0 +1,38 @@ +// sherpa/csrc/offline-speaker-segmentation-pyannote-model-config.cc +// +// Copyright (c) 2025 Xiaomi Corporation +#include "sherpa/csrc/offline-speaker-segmentation-pyannote-model-config.h" + +#include +#include + +#include "sherpa/csrc/file-utils.h" +#include "sherpa/csrc/macros.h" + +namespace sherpa { + +void OfflineSpeakerSegmentationPyannoteModelConfig::Register(ParseOptions *po) { + po->Register("pyannote-model", &model, + "Path to model.pt of the Pyannote segmentation model."); +} + +bool OfflineSpeakerSegmentationPyannoteModelConfig::Validate() const { + if (!FileExists(model)) { + SHERPA_LOGE("Pyannote segmentation model: '%s' does not exist", + model.c_str()); + return false; + } + + return true; +} + +std::string OfflineSpeakerSegmentationPyannoteModelConfig::ToString() const { + std::ostringstream os; + + os << "OfflineSpeakerSegmentationPyannoteModelConfig("; + os << "model=\"" << model << "\")"; + + return os.str(); +} + +} // namespace sherpa diff --git a/sherpa/csrc/offline-speaker-segmentation-pyannote-model-config.h b/sherpa/csrc/offline-speaker-segmentation-pyannote-model-config.h new file mode 100644 index 000000000..603dc2b96 --- /dev/null +++ b/sherpa/csrc/offline-speaker-segmentation-pyannote-model-config.h @@ -0,0 +1,30 @@ +// sherpa/csrc/offline-speaker-segmentation-pyannote-model-config.h +// +// Copyright (c) 2025 Xiaomi Corporation + +#ifndef SHERPA_CSRC_OFFLINE_SPEAKER_SEGMENTATION_PYANNOTE_MODEL_CONFIG_H_ +#define SHERPA_CSRC_OFFLINE_SPEAKER_SEGMENTATION_PYANNOTE_MODEL_CONFIG_H_ +#include + +#include "sherpa/cpp_api/parse-options.h" + +namespace sherpa { + +struct OfflineSpeakerSegmentationPyannoteModelConfig { + std::string model; + + OfflineSpeakerSegmentationPyannoteModelConfig() = default; + + explicit OfflineSpeakerSegmentationPyannoteModelConfig( + const std::string &model) + : model(model) {} + + void Register(ParseOptions *po); + bool Validate() const; + + std::string ToString() const; +}; + +} // namespace sherpa + +#endif // SHERPA_CSRC_OFFLINE_SPEAKER_SEGMENTATION_PYANNOTE_MODEL_CONFIG_H_ diff --git a/sherpa/csrc/offline-speaker-segmentation-pyannote-model-meta-data.h b/sherpa/csrc/offline-speaker-segmentation-pyannote-model-meta-data.h new file mode 100644 index 000000000..3f239aa1a --- /dev/null +++ b/sherpa/csrc/offline-speaker-segmentation-pyannote-model-meta-data.h @@ -0,0 +1,29 @@ +// sherpa/csrc/offline-speaker-segmentation-pyannote-model-meta-data.h +// +// Copyright (c) 2025 Xiaomi Corporation + +#ifndef SHERPA_CSRC_OFFLINE_SPEAKER_SEGMENTATION_PYANNOTE_MODEL_META_DATA_H_ +#define SHERPA_CSRC_OFFLINE_SPEAKER_SEGMENTATION_PYANNOTE_MODEL_META_DATA_H_ + +#include +#include + +namespace sherpa { + +// If you are not sure what each field means, please +// have a look of the Python file in the model directory that +// you have downloaded. +struct OfflineSpeakerSegmentationPyannoteModelMetaData { + int32_t sample_rate = 0; + int32_t window_size = 0; // in samples + int32_t window_shift = 0; // in samples + int32_t receptive_field_size = 0; // in samples + int32_t receptive_field_shift = 0; // in samples + int32_t num_speakers = 0; + int32_t powerset_max_classes = 0; + int32_t num_classes = 0; +}; + +} // namespace sherpa + +#endif // SHERPA_CSRC_OFFLINE_SPEAKER_SEGMENTATION_PYANNOTE_MODEL_META_DATA_H_ diff --git a/sherpa/csrc/offline-speaker-segmentation-pyannote-model.cc b/sherpa/csrc/offline-speaker-segmentation-pyannote-model.cc new file mode 100644 index 000000000..7a6efa537 --- /dev/null +++ b/sherpa/csrc/offline-speaker-segmentation-pyannote-model.cc @@ -0,0 +1,119 @@ +// sherpa/csrc/offline-speaker-segmentation-pyannote-model.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa/csrc/offline-speaker-segmentation-pyannote-model.h" + +#include +#include +#include +#include + +#include "sherpa/cpp_api/macros.h" +#include "sherpa/csrc/macros.h" +#include "torch/script.h" + +namespace sherpa { + +class OfflineSpeakerSegmentationPyannoteModel::Impl { + public: + explicit Impl(const OfflineSpeakerSegmentationModelConfig &config) + : config_(config) { + torch::jit::ExtraFilesMap meta_data{ + {"model_type", {}}, + {"num_speakers", {}}, + {"powerset_max_classes", {}}, + {"num_classes", {}}, + {"sample_rate", {}}, + {"window_size", {}}, + {"receptive_field_size", {}}, + {"receptive_field_shift", {}}, + {"version", {}}, + {"maintainer", {}}, + }; + + if (config.use_gpu) { + device_ = torch::Device{torch::kCUDA}; + } + + model_ = torch::jit::load(config.pyannote.model, device_, meta_data); + model_.eval(); + + if (meta_data.at("model_type") != "pyannote-segmentation-3.0") { + SHERPA_LOGE( + "Expected model_type 'pyannote-segmentation-3.0'. Given: '%s'", + meta_data.at("model_type").c_str()); + SHERPA_EXIT(-1); + } + InitMetaData(meta_data); + + if (config.debug) { + std::ostringstream os; + os << "----------meta_data for pyannote-segmentation-3.0------\n"; + os << "sample_rate: " << meta_data_.sample_rate << " s\n"; + os << "window_size: " << meta_data_.window_size << " samples\n"; + os << "window_shift: " << meta_data_.window_shift << " samples\n"; + os << "receptive_field_size: " << meta_data_.receptive_field_size + << " samples\n"; + os << "receptive_field_shift: " << meta_data_.receptive_field_shift + << " samples\n"; + os << "num_speakers: " << meta_data_.num_speakers << "\n"; + os << "powerset_max_classes: " << meta_data_.powerset_max_classes << "\n"; + os << "num_classes: " << meta_data_.num_classes << "\n"; + SHERPA_LOGE("%s", os.str().c_str()); + } + } + + const OfflineSpeakerSegmentationPyannoteModelMetaData &GetModelMetaData() + const { + return meta_data_; + } + + torch::Tensor Forward(torch::Tensor x) { + InferenceMode no_grad; + + return model_.run_method("forward", x).toTensor(); + } + + private: + void InitMetaData(const torch::jit::ExtraFilesMap &m) { + meta_data_.sample_rate = atoi(m.at("sample_rate").c_str()); + meta_data_.window_size = atoi(m.at("window_size").c_str()); + meta_data_.window_shift = + static_cast(0.1 * meta_data_.window_size); + meta_data_.receptive_field_size = + atoi(m.at("receptive_field_size").c_str()); + meta_data_.receptive_field_shift = + atoi(m.at("receptive_field_shift").c_str()); + meta_data_.num_speakers = atoi(m.at("num_speakers").c_str()); + meta_data_.powerset_max_classes = + atoi(m.at("powerset_max_classes").c_str()); + meta_data_.num_classes = atoi(m.at("num_classes").c_str()); + } + + private: + OfflineSpeakerSegmentationModelConfig config_; + OfflineSpeakerSegmentationPyannoteModelMetaData meta_data_; + torch::jit::Module model_; + torch::Device device_{torch::kCPU}; +}; + +OfflineSpeakerSegmentationPyannoteModel:: + OfflineSpeakerSegmentationPyannoteModel( + const OfflineSpeakerSegmentationModelConfig &config) + : impl_(std::make_unique(config)) {} + +OfflineSpeakerSegmentationPyannoteModel:: + ~OfflineSpeakerSegmentationPyannoteModel() = default; + +const OfflineSpeakerSegmentationPyannoteModelMetaData & +OfflineSpeakerSegmentationPyannoteModel::GetModelMetaData() const { + return impl_->GetModelMetaData(); +} + +torch::Tensor OfflineSpeakerSegmentationPyannoteModel::Forward( + torch::Tensor x) const { + return impl_->Forward(x); +} + +} // namespace sherpa diff --git a/sherpa/csrc/offline-speaker-segmentation-pyannote-model.h b/sherpa/csrc/offline-speaker-segmentation-pyannote-model.h new file mode 100644 index 000000000..2cf397025 --- /dev/null +++ b/sherpa/csrc/offline-speaker-segmentation-pyannote-model.h @@ -0,0 +1,40 @@ +// sherpa/csrc/offline-speaker-segmentation-pyannote-model.h +// +// Copyright (c) 2025 Xiaomi Corporation +#ifndef SHERPA_CSRC_OFFLINE_SPEAKER_SEGMENTATION_PYANNOTE_MODEL_H_ +#define SHERPA_CSRC_OFFLINE_SPEAKER_SEGMENTATION_PYANNOTE_MODEL_H_ + +#include + +#include "sherpa/csrc/offline-speaker-segmentation-model-config.h" +#include "sherpa/csrc/offline-speaker-segmentation-pyannote-model-meta-data.h" +#include "torch/script.h" + +namespace sherpa { + +class OfflineSpeakerSegmentationPyannoteModel { + public: + explicit OfflineSpeakerSegmentationPyannoteModel( + const OfflineSpeakerSegmentationModelConfig &config); + + ~OfflineSpeakerSegmentationPyannoteModel(); + + const OfflineSpeakerSegmentationPyannoteModelMetaData &GetModelMetaData() + const; + + /** + * @param x A 3-D float tensor of shape (batch_size, 1, num_samples) + * @return Return a float tensor of + * shape (batch_size, num_frames, num_speakers). Note that + * num_speakers here uses powerset encoding. + */ + torch::Tensor Forward(torch::Tensor x) const; + + private: + class Impl; + std::unique_ptr impl_; +}; + +} // namespace sherpa + +#endif // SHERPA_CSRC_OFFLINE_SPEAKER_SEGMENTATION_PYANNOTE_MODEL_H_ diff --git a/sherpa/csrc/sherpa-speaker-diarization.cc b/sherpa/csrc/sherpa-speaker-diarization.cc new file mode 100644 index 000000000..76c811344 --- /dev/null +++ b/sherpa/csrc/sherpa-speaker-diarization.cc @@ -0,0 +1,55 @@ +// sherpa/csrc/sherpa-vad.cc +// +// Copyright (c) 2025 Xiaomi Corporation + +#include // NOLINT +#include + +#include "sherpa/cpp_api/parse-options.h" +#include "sherpa/csrc/fbank-features.h" +#include "sherpa/csrc/offline-speaker-diarization.h" +#include "torch/torch.h" +int32_t main(int32_t argc, char *argv[]) { + const char *kUsageMessage = R"usage( +This program uses a speaker segmentation model and a speaker embedding extractor +model for speaker diarization. +Usage: + +sherpa-speaker-diarization \ + --vad-use-gpu=false \ + --num-threads=1 \ + --embedding.model=./3d_speaker-speech_eres2netv2_sv_zh-cn_16k-common.pt \ + --segmentation.pyannote-model=./sherpa-pyannote-segmentation-3-0/model.pt \ + ./foo.wav + +)usage"; + int32_t num_threads = 1; + sherpa::ParseOptions po(kUsageMessage); + sherpa::OfflineSpeakerDiarizationConfig config; + config.Register(&po); + po.Register("num-threads", &num_threads, "Number of threads for PyTorch"); + po.Read(argc, argv); + + if (po.NumArgs() != 1) { + std::cerr << "Please provide only 1 test wave\n"; + exit(-1); + } + + std::cerr << config.ToString() << "\n"; + if (!config.Validate()) { + std::cerr << "Please check your config\n"; + return -1; + } + + sherpa::OfflineSpeakerDiarization sd(config); + + int32_t sr = 16000; + torch::Tensor samples = sherpa::ReadWave(po.GetArg(1), sr).first; + auto result = sd.Process(samples.unsqueeze(0)).SortByStartTime(); + + for (const auto &r : result) { + std::cout << r.ToString() << "\n"; + } + + return 0; +} diff --git a/sherpa/csrc/sherpa-vad.cc b/sherpa/csrc/sherpa-vad.cc index 38b963f9d..c381a023a 100644 --- a/sherpa/csrc/sherpa-vad.cc +++ b/sherpa/csrc/sherpa-vad.cc @@ -12,7 +12,7 @@ int32_t main(int32_t argc, char *argv[]) { const char *kUsageMessage = R"usage( -This program uses a VAD models to add timestamps to a audio file +This program uses a VAD model to add timestamps to a audio file Usage: sherpa-vad \ diff --git a/sherpa/csrc/silero-vad-model.cc b/sherpa/csrc/silero-vad-model.cc index 0aaa8c3c6..efd0ac0bb 100644 --- a/sherpa/csrc/silero-vad-model.cc +++ b/sherpa/csrc/silero-vad-model.cc @@ -3,6 +3,7 @@ // Copyright (c) 2025 Xiaomi Corporation #include "sherpa/csrc/silero-vad-model.h" +#include "sherpa/cpp_api/macros.h" #include "sherpa/csrc/macros.h" namespace sherpa { @@ -38,6 +39,8 @@ class SileroVadModel::Impl { torch::Device Device() const { return device_; } torch::Tensor Run(torch::Tensor samples) { + InferenceMode no_grad; + torch::Tensor sample_rate = torch::tensor( {config_.sample_rate}, torch::dtype(torch::kInt).device(device_)); diff --git a/sherpa/csrc/speaker-embedding-extractor-model.cc b/sherpa/csrc/speaker-embedding-extractor-model.cc index 0793dd5e4..f562a8c7e 100644 --- a/sherpa/csrc/speaker-embedding-extractor-model.cc +++ b/sherpa/csrc/speaker-embedding-extractor-model.cc @@ -8,6 +8,7 @@ #include #include +#include "sherpa/cpp_api/macros.h" #include "sherpa/csrc/macros.h" #include "sherpa/csrc/speaker-embedding-extractor-model-meta-data.h" @@ -38,6 +39,7 @@ class SpeakerEmbeddingExtractorModel::Impl { } torch::Tensor Compute(torch::Tensor x) { + InferenceMode no_grad; return model_.run_method("forward", x).toTensor(); }