Skip to content

Add prune support #120

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 38 additions & 2 deletions src/chunks/norms.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,14 +127,32 @@ impl WriteChunk for NdNorms {
}
}

/// Prune the embedding norms.
pub trait PruneNorms {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add rustdoc for the trait and for the prune_norms method,

/// Prune the embedding norms. Remap the norms of the words whose original vectors need to be
/// tossed to their nearest remaining vectors' norms.
fn prune_norms(&self, toss_indices: &[usize], most_similar_indices: &Array1<usize>) -> NdNorms;
}

impl PruneNorms for NdNorms {
fn prune_norms(&self, toss_indices: &[usize], most_similar_indices: &Array1<usize>) -> NdNorms {
let mut pruned_norms = self.inner.clone();
for (toss_idx, remapped_idx) in toss_indices.iter().zip(most_similar_indices) {
pruned_norms[*toss_idx] = pruned_norms[*remapped_idx];
}
NdNorms::new(pruned_norms)
}
}

#[cfg(test)]
mod tests {
use std::io::{Cursor, Read, Seek, SeekFrom};
use std::ops::Deref;

use byteorder::{LittleEndian, ReadBytesExt};
use ndarray::Array1;
use ndarray::{arr1, Array1};

use super::NdNorms;
use super::{NdNorms, PruneNorms};
use crate::chunks::io::{ReadChunk, WriteChunk};

const LEN: usize = 100;
Expand Down Expand Up @@ -174,4 +192,22 @@ mod tests {
let arr = NdNorms::read_chunk(&mut cursor).unwrap();
assert_eq!(arr.view(), check_arr.view());
}

#[test]
fn test_prune_norms() {
let original_norms = test_ndnorms();
let toss_indices = &[1, 5, 7];
let most_similar_indices = arr1(&[2, 6, 8]);
let test_ndnorms = original_norms.prune_norms(toss_indices, &most_similar_indices);
for (toss_idx, remap_idx) in toss_indices.iter().zip(most_similar_indices.iter()) {
assert_eq!(
test_ndnorms.deref()[*toss_idx],
test_ndnorms.deref()[*remap_idx]
);
assert_eq!(
original_norms.deref()[*remap_idx],
test_ndnorms.deref()[*toss_idx]
);
}
}
}
88 changes: 84 additions & 4 deletions src/chunks/storage/array.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
use std::collections::HashSet;
use std::fs::File;
use std::io::{BufReader, Read, Seek, SeekFrom, Write};
use std::iter::FromIterator;
use std::mem::size_of;

#[cfg(target_endian = "big")]
use byteorder::ByteOrder;
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
use memmap::{Mmap, MmapOptions};
use ndarray::{Array2, ArrayView2, ArrayViewMut2, CowArray, Dimension, Ix1, Ix2};
use ndarray::{s, Array1, Array2, ArrayView2, ArrayViewMut2, Axis, CowArray, Dimension, Ix1, Ix2};
use ordered_float::OrderedFloat;

use super::{Storage, StorageView, StorageViewMut};
use super::{Storage, StoragePrune, StorageView, StorageViewMut, StorageWrap};
use crate::chunks::io::{ChunkIdentifier, MmapChunk, ReadChunk, TypeId, WriteChunk};
use crate::io::{Error, ErrorKind, Result};
use crate::util::padding;
Expand Down Expand Up @@ -274,15 +277,60 @@ impl WriteChunk for NdArray {
}
}

impl StoragePrune for NdArray {
fn prune_storage(&self, toss_indices: &[usize]) -> StorageWrap {
let toss_indices: HashSet<usize> = HashSet::from_iter(toss_indices.iter().cloned());
let mut keep_indices_all = Vec::new();
for idx in 0..self.shape().1 - 1 {
if !toss_indices.contains(&idx) {
keep_indices_all.push(idx);
}
}
NdArray::new(self.inner.select(Axis(0), &keep_indices_all).to_owned()).into()
}

fn most_similar(
&self,
keep_indices: &[usize],
toss_indices: &[usize],
batch_size: usize,
) -> Array1<usize> {
let toss = toss_indices.len();
let mut remap_indices = Array1::zeros(toss);
let keep_embeds = self.inner.select(Axis(0), keep_indices);
let keep_embeds_t = &keep_embeds.t();
let toss_embeds = self.inner.select(Axis(0), toss_indices);
for n in (0..toss).step_by(batch_size) {
let mut offset = n + batch_size;
if offset > toss {
offset = toss;
}
#[allow(clippy::deref_addrof)]
let batch = toss_embeds.slice(s![n..offset, ..]);
Copy link
Author

@RealNicolasBourbaki RealNicolasBourbaki Nov 11, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

clippy complains at here, but this seems to be a bug

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can add this above the line to silence clippy #[allow(clippy::deref_addrof)], I think we have used that at other parts as well, tho I never looked deeper into why that happens.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Btw, this is another case where clippy complained on my laptop and not on CI

let similarity_scores = batch.dot(keep_embeds_t);
for (i, row) in similarity_scores.axis_iter(Axis(0)).enumerate() {
let dist = row
.iter()
.enumerate()
.max_by_key(|(_, &v)| OrderedFloat(v))
.unwrap()
.0;
remap_indices[n + i] = dist;
}
}
remap_indices
}
}

#[cfg(test)]
mod tests {
use std::io::{Cursor, Read, Seek, SeekFrom};

use byteorder::{LittleEndian, ReadBytesExt};
use ndarray::Array2;
use ndarray::{arr1, arr2, Array2, Axis};

use crate::chunks::io::{ReadChunk, WriteChunk};
use crate::chunks::storage::{NdArray, StorageView};
use crate::chunks::storage::{NdArray, Storage, StoragePrune, StorageView};

const N_ROWS: usize = 100;
const N_COLS: usize = 100;
Expand Down Expand Up @@ -326,4 +374,36 @@ mod tests {
let arr = NdArray::read_chunk(&mut cursor).unwrap();
assert_eq!(arr.view(), check_arr.view());
}

fn test_ndarray_for_pruning() -> NdArray {
let test_storage: Array2<f32> = arr2(&[
[1.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 1.0, 0.0, 0.0, 0.0],
[0.6, 0.8, 0.0, 0.0, 0.0],
[0.8, 0.6, 0.0, 0.0, 0.0],
]);
NdArray::new(test_storage)
}

#[test]
fn test_most_similar() {
let storage = test_ndarray_for_pruning();
let keep_indices = &[0, 1];
let toss_indices = &[2, 3];
let test_remap_indices = storage.most_similar(keep_indices, toss_indices, 1);
assert_eq!(arr1(&[1, 0]), test_remap_indices);
}

#[test]
fn test_prune_storage() {
let pruned_storage: Array2<f32> =
arr2(&[[1.0, 0.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0, 0.0]]);
let storage = test_ndarray_for_pruning();
let toss_indices = &[2, 3];
let test_pruned_storage = storage.prune_storage(toss_indices);
assert_eq!(pruned_storage.dim(), test_pruned_storage.shape());
for (idx, row) in pruned_storage.axis_iter(Axis(0)).enumerate() {
assert_eq!(row, test_pruned_storage.embedding(idx));
}
}
}
16 changes: 15 additions & 1 deletion src/chunks/storage/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//! Embedding matrix representations.

use ndarray::{ArrayView2, ArrayViewMut2, CowArray, Ix1};
use ndarray::{Array1, ArrayView2, ArrayViewMut2, CowArray, Ix1};

mod array;
pub use self::array::{MmapArray, NdArray};
Expand Down Expand Up @@ -33,3 +33,17 @@ pub(crate) trait StorageViewMut: Storage {
/// Get a view of the embedding matrix.
fn view_mut(&mut self) -> ArrayViewMut2<f32>;
}

/// Storage that can be pruned.
pub trait StoragePrune: Storage {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add rustdoc.

/// Prune a storage. Discard the vectors which need to be pruned off based on their indices.
fn prune_storage(&self, toss_indices: &[usize]) -> StorageWrap;

/// Find a nearest vector for each vector that need to be tossed.
fn most_similar(
&self,
keep_indices: &[usize],
toss_indices: &[usize],
batch_size: usize,
) -> Array1<usize>;
}
123 changes: 119 additions & 4 deletions src/chunks/storage/quantized.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
use std::collections::HashSet;
use std::fs::File;
use std::io::{BufReader, Read, Seek, SeekFrom, Write};
use std::iter::FromIterator;
use std::mem::size_of;

use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
use memmap::{Mmap, MmapOptions};
use ndarray::{
Array, Array1, Array2, ArrayView1, ArrayView2, CowArray, Dimension, IntoDimension, Ix1,
Array, Array1, Array2, ArrayView1, ArrayView2, Axis, CowArray, Dimension, IntoDimension, Ix1,
};
use ordered_float::OrderedFloat;
use rand::{RngCore, SeedableRng};
use rand_xorshift::XorShiftRng;
use reductive::pq::{QuantizeVector, ReconstructVector, TrainPQ, PQ};

use super::{Storage, StorageView};
use super::{Storage, StoragePrune, StorageView, StorageWrap};
use crate::chunks::io::{ChunkIdentifier, MmapChunk, ReadChunk, TypeId, WriteChunk};
use crate::io::{Error, ErrorKind, Result};
use crate::util::padding;
Expand Down Expand Up @@ -561,17 +564,88 @@ impl WriteChunk for MmapQuantizedArray {
}
}

impl StoragePrune for QuantizedArray {
fn prune_storage(&self, toss_indices: &[usize]) -> StorageWrap {
let mut keep_indices_all = Vec::new();
let toss_indices: HashSet<usize> = HashSet::from_iter(toss_indices.iter().cloned());
for idx in 0..self.quantized_embeddings.shape()[0] - 1 {
if !toss_indices.contains(&idx) {
keep_indices_all.push(idx);
}
}
let norms = if self.norms.is_some() {
Some(
self.norms
.as_ref()
.unwrap()
.select(Axis(0), &keep_indices_all),
)
} else {
None
};
let new_storage = QuantizedArray {
quantizer: self.quantizer.clone(),
quantized_embeddings: self
.quantized_embeddings
.select(Axis(0), &keep_indices_all)
.to_owned(),
norms,
};
new_storage.into()
}

fn most_similar(
&self,
keep_indices: &[usize],
toss_indices: &[usize],
_batch_size: usize,
) -> Array1<usize> {
let dists: Vec<Array2<f32>> = self
.quantizer
.subquantizers()
.axis_iter(Axis(0))
.map(|quantizer| quantizer.dot(&quantizer.t()))
.collect();
let keep_quantized_embeddings = self.quantized_embeddings.select(Axis(0), keep_indices);
let toss_quantized_embeddings = self.quantized_embeddings.select(Axis(0), toss_indices);
let mut remap_indices = Array1::zeros((toss_quantized_embeddings.shape()[0],));
for (i, toss_row) in toss_quantized_embeddings.axis_iter(Axis(0)).enumerate() {
let mut row_dist = vec![0f32; keep_quantized_embeddings.shape()[0]];
for (n, keep_row) in keep_quantized_embeddings.axis_iter(Axis(0)).enumerate() {
row_dist[n] = toss_row
.iter()
.zip(keep_row.iter())
.enumerate()
.map(|(id, (&toss_id, &keep_id))| {
dists[id][(toss_id as usize, keep_id as usize)]
})
.sum();
}

remap_indices[i] = row_dist
.iter()
.enumerate()
.max_by_key(|(_, &v)| OrderedFloat(v))
.unwrap()
.0;
}
remap_indices
}
}

#[cfg(test)]
mod tests {
use std::fs::File;
use std::io::{BufReader, Cursor, Read, Seek, SeekFrom};

use byteorder::{LittleEndian, ReadBytesExt};
use ndarray::Array2;
use ndarray::{arr1, arr2, Array2};
use reductive::pq::PQ;

use crate::chunks::io::{MmapChunk, ReadChunk, WriteChunk};
use crate::chunks::storage::{MmapQuantizedArray, NdArray, Quantize, QuantizedArray, Storage};
use crate::chunks::storage::{
MmapQuantizedArray, NdArray, Quantize, QuantizedArray, Storage, StoragePrune,
};

const N_ROWS: usize = 100;
const N_COLS: usize = 100;
Expand Down Expand Up @@ -676,4 +750,45 @@ mod tests {
// Check
storage_eq(&arr, &check_arr);
}

fn test_ndarray_for_pruning() -> NdArray {
let test_storage: Array2<f32> = arr2(&[
[1.0, 0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 1.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 1.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.8, 0.6, 0.0, 0.0],
[0.8, 0.6, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.8, 0.6, 0.0, 0.0, 0.0],
]);
NdArray::new(test_storage)
}

#[test]
fn test_most_similar() {
let ndarray = test_ndarray_for_pruning();
let quantized_storage = ndarray.quantize::<PQ<f32>>(3, 2, 5, 1, true);
let keep_indices = &[0, 1, 2];
let toss_indices = &[3, 4, 5];
let test_remap_indices = quantized_storage.most_similar(keep_indices, toss_indices, 1);
assert_eq!(arr1(&[2, 0, 1]), test_remap_indices);
}

#[test]
fn test_prune_storage() {
let ndarray = test_ndarray_for_pruning();
let quantized_storage = ndarray.quantize::<PQ<f32>>(3, 2, 5, 1, true);
let keep_indices = &[0, 1, 2];
let toss_indices = &[3, 4, 5];
let test_pruned_storage = quantized_storage.prune_storage(toss_indices);
assert_eq!(
(ndarray.shape().0 - toss_indices.len(), ndarray.shape().1),
test_pruned_storage.shape()
);
for idx in keep_indices.iter() {
assert_eq!(
quantized_storage.embedding(*idx),
test_pruned_storage.embedding(*idx)
);
}
}
}
Loading