diff --git a/src/chunks/norms.rs b/src/chunks/norms.rs index 39d3983..2c40bbf 100644 --- a/src/chunks/norms.rs +++ b/src/chunks/norms.rs @@ -127,14 +127,32 @@ impl WriteChunk for NdNorms { } } +/// Prune the embedding norms. +pub trait PruneNorms { + /// Prune the embedding norms. Remap the norms of the words whose original vectors need to be + /// tossed to their nearest remaining vectors' norms. + fn prune_norms(&self, toss_indices: &[usize], most_similar_indices: &Array1) -> NdNorms; +} + +impl PruneNorms for NdNorms { + fn prune_norms(&self, toss_indices: &[usize], most_similar_indices: &Array1) -> NdNorms { + let mut pruned_norms = self.inner.clone(); + for (toss_idx, remapped_idx) in toss_indices.iter().zip(most_similar_indices) { + pruned_norms[*toss_idx] = pruned_norms[*remapped_idx]; + } + NdNorms::new(pruned_norms) + } +} + #[cfg(test)] mod tests { use std::io::{Cursor, Read, Seek, SeekFrom}; + use std::ops::Deref; use byteorder::{LittleEndian, ReadBytesExt}; - use ndarray::Array1; + use ndarray::{arr1, Array1}; - use super::NdNorms; + use super::{NdNorms, PruneNorms}; use crate::chunks::io::{ReadChunk, WriteChunk}; const LEN: usize = 100; @@ -174,4 +192,22 @@ mod tests { let arr = NdNorms::read_chunk(&mut cursor).unwrap(); assert_eq!(arr.view(), check_arr.view()); } + + #[test] + fn test_prune_norms() { + let original_norms = test_ndnorms(); + let toss_indices = &[1, 5, 7]; + let most_similar_indices = arr1(&[2, 6, 8]); + let test_ndnorms = original_norms.prune_norms(toss_indices, &most_similar_indices); + for (toss_idx, remap_idx) in toss_indices.iter().zip(most_similar_indices.iter()) { + assert_eq!( + test_ndnorms.deref()[*toss_idx], + test_ndnorms.deref()[*remap_idx] + ); + assert_eq!( + original_norms.deref()[*remap_idx], + test_ndnorms.deref()[*toss_idx] + ); + } + } } diff --git a/src/chunks/storage/array.rs b/src/chunks/storage/array.rs index 5bd1717..710d6ff 100644 --- a/src/chunks/storage/array.rs +++ b/src/chunks/storage/array.rs @@ -1,14 +1,17 @@ +use std::collections::HashSet; use std::fs::File; use std::io::{BufReader, Read, Seek, SeekFrom, Write}; +use std::iter::FromIterator; use std::mem::size_of; #[cfg(target_endian = "big")] use byteorder::ByteOrder; use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; use memmap::{Mmap, MmapOptions}; -use ndarray::{Array2, ArrayView2, ArrayViewMut2, CowArray, Dimension, Ix1, Ix2}; +use ndarray::{s, Array1, Array2, ArrayView2, ArrayViewMut2, Axis, CowArray, Dimension, Ix1, Ix2}; +use ordered_float::OrderedFloat; -use super::{Storage, StorageView, StorageViewMut}; +use super::{Storage, StoragePrune, StorageView, StorageViewMut, StorageWrap}; use crate::chunks::io::{ChunkIdentifier, MmapChunk, ReadChunk, TypeId, WriteChunk}; use crate::io::{Error, ErrorKind, Result}; use crate::util::padding; @@ -274,15 +277,60 @@ impl WriteChunk for NdArray { } } +impl StoragePrune for NdArray { + fn prune_storage(&self, toss_indices: &[usize]) -> StorageWrap { + let toss_indices: HashSet = HashSet::from_iter(toss_indices.iter().cloned()); + let mut keep_indices_all = Vec::new(); + for idx in 0..self.shape().1 - 1 { + if !toss_indices.contains(&idx) { + keep_indices_all.push(idx); + } + } + NdArray::new(self.inner.select(Axis(0), &keep_indices_all).to_owned()).into() + } + + fn most_similar( + &self, + keep_indices: &[usize], + toss_indices: &[usize], + batch_size: usize, + ) -> Array1 { + let toss = toss_indices.len(); + let mut remap_indices = Array1::zeros(toss); + let keep_embeds = self.inner.select(Axis(0), keep_indices); + let keep_embeds_t = &keep_embeds.t(); + let toss_embeds = self.inner.select(Axis(0), toss_indices); + for n in (0..toss).step_by(batch_size) { + let mut offset = n + batch_size; + if offset > toss { + offset = toss; + } + #[allow(clippy::deref_addrof)] + let batch = toss_embeds.slice(s![n..offset, ..]); + let similarity_scores = batch.dot(keep_embeds_t); + for (i, row) in similarity_scores.axis_iter(Axis(0)).enumerate() { + let dist = row + .iter() + .enumerate() + .max_by_key(|(_, &v)| OrderedFloat(v)) + .unwrap() + .0; + remap_indices[n + i] = dist; + } + } + remap_indices + } +} + #[cfg(test)] mod tests { use std::io::{Cursor, Read, Seek, SeekFrom}; use byteorder::{LittleEndian, ReadBytesExt}; - use ndarray::Array2; + use ndarray::{arr1, arr2, Array2, Axis}; use crate::chunks::io::{ReadChunk, WriteChunk}; - use crate::chunks::storage::{NdArray, StorageView}; + use crate::chunks::storage::{NdArray, Storage, StoragePrune, StorageView}; const N_ROWS: usize = 100; const N_COLS: usize = 100; @@ -326,4 +374,36 @@ mod tests { let arr = NdArray::read_chunk(&mut cursor).unwrap(); assert_eq!(arr.view(), check_arr.view()); } + + fn test_ndarray_for_pruning() -> NdArray { + let test_storage: Array2 = arr2(&[ + [1.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [0.6, 0.8, 0.0, 0.0, 0.0], + [0.8, 0.6, 0.0, 0.0, 0.0], + ]); + NdArray::new(test_storage) + } + + #[test] + fn test_most_similar() { + let storage = test_ndarray_for_pruning(); + let keep_indices = &[0, 1]; + let toss_indices = &[2, 3]; + let test_remap_indices = storage.most_similar(keep_indices, toss_indices, 1); + assert_eq!(arr1(&[1, 0]), test_remap_indices); + } + + #[test] + fn test_prune_storage() { + let pruned_storage: Array2 = + arr2(&[[1.0, 0.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0, 0.0]]); + let storage = test_ndarray_for_pruning(); + let toss_indices = &[2, 3]; + let test_pruned_storage = storage.prune_storage(toss_indices); + assert_eq!(pruned_storage.dim(), test_pruned_storage.shape()); + for (idx, row) in pruned_storage.axis_iter(Axis(0)).enumerate() { + assert_eq!(row, test_pruned_storage.embedding(idx)); + } + } } diff --git a/src/chunks/storage/mod.rs b/src/chunks/storage/mod.rs index 7667403..7552fe0 100644 --- a/src/chunks/storage/mod.rs +++ b/src/chunks/storage/mod.rs @@ -1,6 +1,6 @@ //! Embedding matrix representations. -use ndarray::{ArrayView2, ArrayViewMut2, CowArray, Ix1}; +use ndarray::{Array1, ArrayView2, ArrayViewMut2, CowArray, Ix1}; mod array; pub use self::array::{MmapArray, NdArray}; @@ -33,3 +33,17 @@ pub(crate) trait StorageViewMut: Storage { /// Get a view of the embedding matrix. fn view_mut(&mut self) -> ArrayViewMut2; } + +/// Storage that can be pruned. +pub trait StoragePrune: Storage { + /// Prune a storage. Discard the vectors which need to be pruned off based on their indices. + fn prune_storage(&self, toss_indices: &[usize]) -> StorageWrap; + + /// Find a nearest vector for each vector that need to be tossed. + fn most_similar( + &self, + keep_indices: &[usize], + toss_indices: &[usize], + batch_size: usize, + ) -> Array1; +} diff --git a/src/chunks/storage/quantized.rs b/src/chunks/storage/quantized.rs index f66fbda..ea27b9f 100644 --- a/src/chunks/storage/quantized.rs +++ b/src/chunks/storage/quantized.rs @@ -1,17 +1,20 @@ +use std::collections::HashSet; use std::fs::File; use std::io::{BufReader, Read, Seek, SeekFrom, Write}; +use std::iter::FromIterator; use std::mem::size_of; use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; use memmap::{Mmap, MmapOptions}; use ndarray::{ - Array, Array1, Array2, ArrayView1, ArrayView2, CowArray, Dimension, IntoDimension, Ix1, + Array, Array1, Array2, ArrayView1, ArrayView2, Axis, CowArray, Dimension, IntoDimension, Ix1, }; +use ordered_float::OrderedFloat; use rand::{RngCore, SeedableRng}; use rand_xorshift::XorShiftRng; use reductive::pq::{QuantizeVector, ReconstructVector, TrainPQ, PQ}; -use super::{Storage, StorageView}; +use super::{Storage, StoragePrune, StorageView, StorageWrap}; use crate::chunks::io::{ChunkIdentifier, MmapChunk, ReadChunk, TypeId, WriteChunk}; use crate::io::{Error, ErrorKind, Result}; use crate::util::padding; @@ -561,17 +564,88 @@ impl WriteChunk for MmapQuantizedArray { } } +impl StoragePrune for QuantizedArray { + fn prune_storage(&self, toss_indices: &[usize]) -> StorageWrap { + let mut keep_indices_all = Vec::new(); + let toss_indices: HashSet = HashSet::from_iter(toss_indices.iter().cloned()); + for idx in 0..self.quantized_embeddings.shape()[0] - 1 { + if !toss_indices.contains(&idx) { + keep_indices_all.push(idx); + } + } + let norms = if self.norms.is_some() { + Some( + self.norms + .as_ref() + .unwrap() + .select(Axis(0), &keep_indices_all), + ) + } else { + None + }; + let new_storage = QuantizedArray { + quantizer: self.quantizer.clone(), + quantized_embeddings: self + .quantized_embeddings + .select(Axis(0), &keep_indices_all) + .to_owned(), + norms, + }; + new_storage.into() + } + + fn most_similar( + &self, + keep_indices: &[usize], + toss_indices: &[usize], + _batch_size: usize, + ) -> Array1 { + let dists: Vec> = self + .quantizer + .subquantizers() + .axis_iter(Axis(0)) + .map(|quantizer| quantizer.dot(&quantizer.t())) + .collect(); + let keep_quantized_embeddings = self.quantized_embeddings.select(Axis(0), keep_indices); + let toss_quantized_embeddings = self.quantized_embeddings.select(Axis(0), toss_indices); + let mut remap_indices = Array1::zeros((toss_quantized_embeddings.shape()[0],)); + for (i, toss_row) in toss_quantized_embeddings.axis_iter(Axis(0)).enumerate() { + let mut row_dist = vec![0f32; keep_quantized_embeddings.shape()[0]]; + for (n, keep_row) in keep_quantized_embeddings.axis_iter(Axis(0)).enumerate() { + row_dist[n] = toss_row + .iter() + .zip(keep_row.iter()) + .enumerate() + .map(|(id, (&toss_id, &keep_id))| { + dists[id][(toss_id as usize, keep_id as usize)] + }) + .sum(); + } + + remap_indices[i] = row_dist + .iter() + .enumerate() + .max_by_key(|(_, &v)| OrderedFloat(v)) + .unwrap() + .0; + } + remap_indices + } +} + #[cfg(test)] mod tests { use std::fs::File; use std::io::{BufReader, Cursor, Read, Seek, SeekFrom}; use byteorder::{LittleEndian, ReadBytesExt}; - use ndarray::Array2; + use ndarray::{arr1, arr2, Array2}; use reductive::pq::PQ; use crate::chunks::io::{MmapChunk, ReadChunk, WriteChunk}; - use crate::chunks::storage::{MmapQuantizedArray, NdArray, Quantize, QuantizedArray, Storage}; + use crate::chunks::storage::{ + MmapQuantizedArray, NdArray, Quantize, QuantizedArray, Storage, StoragePrune, + }; const N_ROWS: usize = 100; const N_COLS: usize = 100; @@ -676,4 +750,45 @@ mod tests { // Check storage_eq(&arr, &check_arr); } + + fn test_ndarray_for_pruning() -> NdArray { + let test_storage: Array2 = arr2(&[ + [1.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.8, 0.6, 0.0, 0.0], + [0.8, 0.6, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.8, 0.6, 0.0, 0.0, 0.0], + ]); + NdArray::new(test_storage) + } + + #[test] + fn test_most_similar() { + let ndarray = test_ndarray_for_pruning(); + let quantized_storage = ndarray.quantize::>(3, 2, 5, 1, true); + let keep_indices = &[0, 1, 2]; + let toss_indices = &[3, 4, 5]; + let test_remap_indices = quantized_storage.most_similar(keep_indices, toss_indices, 1); + assert_eq!(arr1(&[2, 0, 1]), test_remap_indices); + } + + #[test] + fn test_prune_storage() { + let ndarray = test_ndarray_for_pruning(); + let quantized_storage = ndarray.quantize::>(3, 2, 5, 1, true); + let keep_indices = &[0, 1, 2]; + let toss_indices = &[3, 4, 5]; + let test_pruned_storage = quantized_storage.prune_storage(toss_indices); + assert_eq!( + (ndarray.shape().0 - toss_indices.len(), ndarray.shape().1), + test_pruned_storage.shape() + ); + for idx in keep_indices.iter() { + assert_eq!( + quantized_storage.embedding(*idx), + test_pruned_storage.embedding(*idx) + ); + } + } } diff --git a/src/chunks/storage/wrappers.rs b/src/chunks/storage/wrappers.rs index e27b321..ae8b87b 100644 --- a/src/chunks/storage/wrappers.rs +++ b/src/chunks/storage/wrappers.rs @@ -2,9 +2,11 @@ use std::fs::File; use std::io::{BufReader, Read, Seek, SeekFrom, Write}; use byteorder::{LittleEndian, ReadBytesExt}; -use ndarray::{ArrayView2, CowArray, Ix1}; +use ndarray::{Array1, ArrayView2, CowArray, Ix1}; -use super::{MmapArray, MmapQuantizedArray, NdArray, QuantizedArray, Storage, StorageView}; +use super::{ + MmapArray, MmapQuantizedArray, NdArray, QuantizedArray, Storage, StoragePrune, StorageView, +}; use crate::chunks::io::{ChunkIdentifier, MmapChunk, ReadChunk, WriteChunk}; use crate::io::{Error, ErrorKind, Result}; @@ -297,3 +299,114 @@ impl MmapChunk for StorageViewWrap { } } } + +pub enum StoragePruneWrap { + NdArray(NdArray), + QuantizedArray(QuantizedArray), +} + +impl Storage for StoragePruneWrap { + fn embedding(&self, idx: usize) -> CowArray { + match self { + StoragePruneWrap::QuantizedArray(inner) => inner.embedding(idx), + StoragePruneWrap::NdArray(inner) => inner.embedding(idx), + } + } + + fn shape(&self) -> (usize, usize) { + match self { + StoragePruneWrap::QuantizedArray(inner) => inner.shape(), + StoragePruneWrap::NdArray(inner) => inner.shape(), + } + } +} + +impl From for StoragePruneWrap { + fn from(s: NdArray) -> Self { + StoragePruneWrap::NdArray(s) + } +} + +impl From for StoragePruneWrap { + fn from(s: QuantizedArray) -> Self { + StoragePruneWrap::QuantizedArray(s) + } +} + +impl ReadChunk for StoragePruneWrap { + fn read_chunk(read: &mut R) -> Result + where + R: Read + Seek, + { + let chunk_start_pos = read + .seek(SeekFrom::Current(0)) + .map_err(|e| ErrorKind::io_error("Cannot get storage chunk start position", e))?; + + let chunk_id = read + .read_u32::() + .map_err(|e| ErrorKind::io_error("Cannot read storage chunk identifier", e))?; + let chunk_id = ChunkIdentifier::try_from(chunk_id) + .ok_or_else(|| ErrorKind::Format(format!("Unknown chunk identifier: {}", chunk_id))) + .map_err(Error::from)?; + + read.seek(SeekFrom::Start(chunk_start_pos)) + .map_err(|e| ErrorKind::io_error("Cannot seek to storage chunk start position", e))?; + + match chunk_id { + ChunkIdentifier::NdArray => NdArray::read_chunk(read).map(StoragePruneWrap::NdArray), + ChunkIdentifier::QuantizedArray => { + QuantizedArray::read_chunk(read).map(StoragePruneWrap::QuantizedArray) + } + _ => Err(ErrorKind::Format(format!( + "Invalid chunk identifier, expected one of: {} or {}, got: {}", + ChunkIdentifier::NdArray, + ChunkIdentifier::QuantizedArray, + chunk_id + )) + .into()), + } + } +} + +impl WriteChunk for StoragePruneWrap { + fn chunk_identifier(&self) -> ChunkIdentifier { + match self { + StoragePruneWrap::QuantizedArray(inner) => inner.chunk_identifier(), + StoragePruneWrap::NdArray(inner) => inner.chunk_identifier(), + } + } + + fn write_chunk(&self, write: &mut W) -> Result<()> + where + W: Write + Seek, + { + match self { + StoragePruneWrap::QuantizedArray(inner) => inner.write_chunk(write), + StoragePruneWrap::NdArray(inner) => inner.write_chunk(write), + } + } +} + +impl StoragePrune for StoragePruneWrap { + fn prune_storage(&self, toss_indices: &[usize]) -> StorageWrap { + match self { + StoragePruneWrap::NdArray(inner) => inner.prune_storage(toss_indices), + StoragePruneWrap::QuantizedArray(inner) => inner.prune_storage(toss_indices), + } + } + fn most_similar( + &self, + keep_indices: &[usize], + toss_indices: &[usize], + batch_size: usize, + ) -> Array1 { + match self { + StoragePruneWrap::NdArray(inner) => { + inner.most_similar(keep_indices, toss_indices, batch_size) + } + StoragePruneWrap::QuantizedArray(inner) => { + inner.most_similar(keep_indices, toss_indices, batch_size) + } + } + } +} diff --git a/src/chunks/vocab/mod.rs b/src/chunks/vocab/mod.rs index 61e0005..4e6bba2 100644 --- a/src/chunks/vocab/mod.rs +++ b/src/chunks/vocab/mod.rs @@ -4,6 +4,7 @@ use std::collections::HashMap; use std::io::{Read, Seek, Write}; use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; +use ndarray::Array1; use crate::io::{Error, ErrorKind, Result}; @@ -65,6 +66,26 @@ impl WordIndex { } } +/// Prune the embedding vocabularies. +pub trait VocabPrune: Vocab { + /// Prune the vocabulary and get a new one. + fn prune_vocab(&self, remapped_indices: HashMap) -> VocabWrap; +} + +/// Handle the indices changes during pruning. +pub trait VocabPruneIndices: Vocab { + /// Seperate the indices of the words whose original vectors need to be tossed from the + /// ones whose vectors need to be kept. + fn part_indices(&self, n_keep: usize) -> (Vec, Vec); + + /// Remap the indices of the words whose original vectors need to be tossed to their + /// closest remainings' indices. + fn create_remapped_indices( + &self, + most_similar_indices: &Array1, + ) -> HashMap; +} + pub(crate) fn create_indices(words: &[String]) -> HashMap { let mut indices = HashMap::new(); diff --git a/src/chunks/vocab/simple.rs b/src/chunks/vocab/simple.rs index 95baed3..83a2e85 100644 --- a/src/chunks/vocab/simple.rs +++ b/src/chunks/vocab/simple.rs @@ -3,9 +3,13 @@ use std::io::{Read, Seek, Write}; use std::mem::size_of; use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; +use ndarray::Array1; use crate::chunks::io::{ChunkIdentifier, ReadChunk, WriteChunk}; -use crate::chunks::vocab::{create_indices, read_vocab_items, write_vocab_items, Vocab, WordIndex}; +use crate::chunks::vocab::{ + create_indices, read_vocab_items, write_vocab_items, Vocab, VocabPrune, VocabPruneIndices, + VocabWrap, WordIndex, +}; use crate::io::{ErrorKind, Result}; /// Vocabulary without subword units. @@ -107,13 +111,57 @@ impl WriteChunk for SimpleVocab { } } +impl VocabPrune for SimpleVocab { + fn prune_vocab(&self, remapped_indices: HashMap) -> VocabWrap { + let new_vocab = SimpleVocab { + words: self.words.clone(), + indices: remapped_indices, + }; + new_vocab.into() + } +} + +impl VocabPruneIndices for SimpleVocab { + fn part_indices(&self, n_keep: usize) -> (Vec, Vec) { + let keep_indices = self + .words() + .iter() + .take(n_keep) + .map(|w| *self.indices.get(w).unwrap()) + .collect(); + let toss_indices = self.words()[n_keep..] + .iter() + .map(|w| *self.indices.get(w).unwrap()) + .collect(); + (keep_indices, toss_indices) + } + + fn create_remapped_indices( + &self, + most_similar_indices: &Array1, + ) -> HashMap { + let mut remapped_indices = self.indices.clone(); + for (toss_word, remapped_idx) in self.words() + [self.words_len() - most_similar_indices.len()..] + .iter() + .zip(most_similar_indices) + { + remapped_indices.insert(toss_word.to_owned(), *remapped_idx); + } + remapped_indices + } +} + #[cfg(test)] mod tests { + use std::collections::HashMap; use std::io::{Cursor, Read, Seek, SeekFrom}; + use ndarray::arr1; + use super::SimpleVocab; use crate::chunks::io::{ReadChunk, WriteChunk}; - use crate::chunks::vocab::read_chunk_size; + use crate::chunks::vocab::{read_chunk_size, VocabPruneIndices}; fn test_simple_vocab() -> SimpleVocab { let words = vec![ @@ -149,4 +197,28 @@ mod tests { chunk_size as usize ); } + + #[test] + fn test_part_indices() { + let vocab = test_simple_vocab(); + let (test_keep_indices, test_toss_indices) = vocab.part_indices(2); + assert_eq!(vec![0, 1], test_keep_indices); + assert_eq!(vec![2, 3], test_toss_indices); + } + + #[test] + fn test_create_remapped_indices() { + let vocab = test_simple_vocab(); + let test_remapped_indices = vocab.create_remapped_indices(&arr1(&[1, 0])); + let remapped_indices: HashMap = [ + ("this".to_owned(), 0), + ("is".to_owned(), 1), + ("a".to_owned(), 1), + ("test".to_owned(), 0), + ] + .iter() + .cloned() + .collect(); + assert_eq!(remapped_indices, test_remapped_indices); + } } diff --git a/src/chunks/vocab/subword.rs b/src/chunks/vocab/subword.rs index 1ba06cb..25826bd 100644 --- a/src/chunks/vocab/subword.rs +++ b/src/chunks/vocab/subword.rs @@ -4,9 +4,13 @@ use std::io::{Read, Seek, Write}; use std::mem::size_of; use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; +use ndarray::Array1; use crate::chunks::io::{ChunkIdentifier, ReadChunk, WriteChunk}; -use crate::chunks::vocab::{create_indices, read_vocab_items, write_vocab_items, Vocab, WordIndex}; +use crate::chunks::vocab::{ + create_indices, read_vocab_items, write_vocab_items, Vocab, VocabPrune, VocabPruneIndices, + VocabWrap, WordIndex, +}; use crate::compat::fasttext::FastTextIndexer; use crate::io::{Error, ErrorKind, Result}; use crate::subword::{ @@ -474,13 +478,89 @@ where Ok(()) } +impl VocabPrune for FastTextSubwordVocab { + fn prune_vocab(&self, remapped_indices: HashMap) -> VocabWrap { + let new_vocab = SubwordVocab { + indexer: self.indexer, + indices: remapped_indices, + words: self.words.clone(), + min_n: self.min_n, + max_n: self.max_n, + }; + new_vocab.into() + } +} + +impl VocabPrune for BucketSubwordVocab { + fn prune_vocab(&self, remapped_indices: HashMap) -> VocabWrap { + let new_vocab = SubwordVocab { + indexer: self.indexer, + indices: remapped_indices, + words: self.words.clone(), + min_n: self.min_n, + max_n: self.max_n, + }; + new_vocab.into() + } +} + +impl VocabPrune for ExplicitSubwordVocab { + fn prune_vocab(&self, remapped_indices: HashMap) -> VocabWrap { + let new_vocab = SubwordVocab { + indexer: self.indexer.clone(), + indices: remapped_indices, + words: self.words.clone(), + min_n: self.min_n, + max_n: self.max_n, + }; + new_vocab.into() + } +} + +impl VocabPruneIndices for SubwordVocab +where + I: Clone + Indexer, +{ + fn part_indices(&self, n_keep: usize) -> (Vec, Vec) { + let keep_indices = self + .words() + .iter() + .take(n_keep) + .map(|w| *self.indices.get(w).unwrap()) + .collect(); + let toss_indices = self.words()[n_keep..] + .iter() + .map(|w| *self.indices.get(w).unwrap()) + .collect(); + (keep_indices, toss_indices) + } + + fn create_remapped_indices( + &self, + most_similar_indices: &Array1, + ) -> HashMap { + let mut remapped_indices = self.indices.clone(); + for (toss_word, remapped_idx) in self.words() + [self.words_len() - most_similar_indices.len()..] + .iter() + .zip(most_similar_indices) + { + remapped_indices.insert(toss_word.to_owned(), *remapped_idx); + } + remapped_indices + } +} + #[cfg(test)] mod tests { + use std::collections::HashMap; use std::io::{Cursor, Read, Seek, SeekFrom}; + use ndarray::arr1; + use super::{BucketSubwordVocab, FastTextSubwordVocab, SubwordVocab}; use crate::chunks::io::{ReadChunk, WriteChunk}; - use crate::chunks::vocab::{read_chunk_size, ExplicitSubwordVocab}; + use crate::chunks::vocab::{read_chunk_size, ExplicitSubwordVocab, VocabPruneIndices}; use crate::compat::fasttext::FastTextIndexer; use crate::subword::{BucketIndexer, ExplicitIndexer, FinalfusionHashIndexer}; @@ -576,4 +656,28 @@ mod tests { let vocab = SubwordVocab::read_chunk(&mut cursor).unwrap(); assert_eq!(vocab, check_vocab); } + + #[test] + fn test_part_indices() { + let vocab = test_subword_vocab(); + let (test_keep_indices, test_toss_indices) = vocab.part_indices(2); + assert_eq!(vec![0, 1], test_keep_indices); + assert_eq!(vec![2, 3], test_toss_indices); + } + + #[test] + fn test_create_remapped_indices() { + let vocab = test_subword_vocab(); + let test_remapped_indices = vocab.create_remapped_indices(&arr1(&[1, 0])); + let remapped_indices: HashMap = [ + ("this".to_owned(), 0), + ("is".to_owned(), 1), + ("a".to_owned(), 1), + ("test".to_owned(), 0), + ] + .iter() + .cloned() + .collect(); + assert_eq!(remapped_indices, test_remapped_indices); + } } diff --git a/src/embeddings.rs b/src/embeddings.rs index 7e436ed..3af404b 100644 --- a/src/embeddings.rs +++ b/src/embeddings.rs @@ -13,14 +13,14 @@ use reductive::pq::TrainPQ; use crate::chunks::io::{ChunkIdentifier, Header, MmapChunk, ReadChunk, WriteChunk}; use crate::chunks::metadata::Metadata; -use crate::chunks::norms::NdNorms; +use crate::chunks::norms::{NdNorms, PruneNorms}; use crate::chunks::storage::{ MmapArray, MmapQuantizedArray, NdArray, Quantize as QuantizeStorage, QuantizedArray, Storage, - StorageView, StorageViewWrap, StorageWrap, + StoragePrune, StorageView, StorageViewWrap, StorageWrap, }; use crate::chunks::vocab::{ - BucketSubwordVocab, ExplicitSubwordVocab, FastTextSubwordVocab, SimpleVocab, Vocab, VocabWrap, - WordIndex, + BucketSubwordVocab, ExplicitSubwordVocab, FastTextSubwordVocab, SimpleVocab, Vocab, VocabPrune, + VocabPruneIndices, VocabWrap, WordIndex, }; use crate::io::{ErrorKind, MmapEmbeddings, ReadEmbeddings, Result, WriteEmbeddings}; use crate::util::l2_normalize; @@ -566,6 +566,49 @@ impl<'a> Iterator for IterWithNorms<'a> { } } +/// Prune embeddings. +pub trait Prune { + /// Prune embeddings. + /// + /// This method prunes the embeddings. Given the size of vocab that + /// need to be kept, a word with embedding that need to be discarded + /// will be remapped to its closest embedding among those remaining. + fn prune(&self, n_keep: usize, batch_size: usize) -> Embeddings; +} + +impl Prune for Embeddings +where + V: VocabPrune + VocabPruneIndices, + S: StoragePrune, +{ + fn prune(&self, n_keep: usize, batch_size: usize) -> Embeddings { + assert!( + n_keep > 0 && n_keep <= self.vocab().words_len(), + "Number of vectors to be kept should at least be 1 and at most be {}.", + self.vocab().words_len() + ); + let (keep_indices, toss_indices) = self.vocab().part_indices(n_keep); + let most_similar_indices = + self.storage() + .most_similar(&keep_indices, &toss_indices, batch_size); + let new_indices = self.vocab().create_remapped_indices(&most_similar_indices); + Embeddings { + metadata: self.metadata().cloned(), + storage: self.storage().prune_storage(&toss_indices), + vocab: self.vocab().prune_vocab(new_indices), + norms: if self.norms.is_some() { + Some( + self.norms() + .unwrap() + .prune_norms(&toss_indices, &most_similar_indices), + ) + } else { + None + }, + } + } +} + #[cfg(test)] mod tests { use std::fs::File;