-
Notifications
You must be signed in to change notification settings - Fork 12
Add prune support #120
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add prune support #120
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,14 +1,17 @@ | ||
use std::collections::HashSet; | ||
use std::fs::File; | ||
use std::io::{BufReader, Read, Seek, SeekFrom, Write}; | ||
use std::iter::FromIterator; | ||
use std::mem::size_of; | ||
|
||
#[cfg(target_endian = "big")] | ||
use byteorder::ByteOrder; | ||
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; | ||
use memmap::{Mmap, MmapOptions}; | ||
use ndarray::{Array2, ArrayView2, ArrayViewMut2, CowArray, Dimension, Ix1, Ix2}; | ||
use ndarray::{s, Array1, Array2, ArrayView2, ArrayViewMut2, Axis, CowArray, Dimension, Ix1, Ix2}; | ||
use ordered_float::OrderedFloat; | ||
|
||
use super::{Storage, StorageView, StorageViewMut}; | ||
use super::{Storage, StoragePrune, StorageView, StorageViewMut, StorageWrap}; | ||
use crate::chunks::io::{ChunkIdentifier, MmapChunk, ReadChunk, TypeId, WriteChunk}; | ||
use crate::io::{Error, ErrorKind, Result}; | ||
use crate::util::padding; | ||
|
@@ -274,15 +277,60 @@ impl WriteChunk for NdArray { | |
} | ||
} | ||
|
||
impl StoragePrune for NdArray { | ||
fn prune_storage(&self, toss_indices: &[usize]) -> StorageWrap { | ||
let toss_indices: HashSet<usize> = HashSet::from_iter(toss_indices.iter().cloned()); | ||
let mut keep_indices_all = Vec::new(); | ||
for idx in 0..self.shape().1 - 1 { | ||
if !toss_indices.contains(&idx) { | ||
keep_indices_all.push(idx); | ||
} | ||
} | ||
NdArray::new(self.inner.select(Axis(0), &keep_indices_all).to_owned()).into() | ||
} | ||
|
||
fn most_similar( | ||
&self, | ||
keep_indices: &[usize], | ||
toss_indices: &[usize], | ||
batch_size: usize, | ||
) -> Array1<usize> { | ||
let toss = toss_indices.len(); | ||
let mut remap_indices = Array1::zeros(toss); | ||
let keep_embeds = self.inner.select(Axis(0), keep_indices); | ||
let keep_embeds_t = &keep_embeds.t(); | ||
let toss_embeds = self.inner.select(Axis(0), toss_indices); | ||
for n in (0..toss).step_by(batch_size) { | ||
let mut offset = n + batch_size; | ||
if offset > toss { | ||
offset = toss; | ||
} | ||
#[allow(clippy::deref_addrof)] | ||
let batch = toss_embeds.slice(s![n..offset, ..]); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You can add this above the line to silence clippy There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Btw, this is another case where clippy complained on my laptop and not on CI |
||
let similarity_scores = batch.dot(keep_embeds_t); | ||
for (i, row) in similarity_scores.axis_iter(Axis(0)).enumerate() { | ||
let dist = row | ||
.iter() | ||
.enumerate() | ||
.max_by_key(|(_, &v)| OrderedFloat(v)) | ||
.unwrap() | ||
.0; | ||
remap_indices[n + i] = dist; | ||
} | ||
} | ||
remap_indices | ||
} | ||
} | ||
|
||
#[cfg(test)] | ||
mod tests { | ||
use std::io::{Cursor, Read, Seek, SeekFrom}; | ||
|
||
use byteorder::{LittleEndian, ReadBytesExt}; | ||
use ndarray::Array2; | ||
use ndarray::{arr1, arr2, Array2, Axis}; | ||
|
||
use crate::chunks::io::{ReadChunk, WriteChunk}; | ||
use crate::chunks::storage::{NdArray, StorageView}; | ||
use crate::chunks::storage::{NdArray, Storage, StoragePrune, StorageView}; | ||
|
||
const N_ROWS: usize = 100; | ||
const N_COLS: usize = 100; | ||
|
@@ -326,4 +374,36 @@ mod tests { | |
let arr = NdArray::read_chunk(&mut cursor).unwrap(); | ||
assert_eq!(arr.view(), check_arr.view()); | ||
} | ||
|
||
fn test_ndarray_for_pruning() -> NdArray { | ||
let test_storage: Array2<f32> = arr2(&[ | ||
[1.0, 0.0, 0.0, 0.0, 0.0], | ||
[0.0, 1.0, 0.0, 0.0, 0.0], | ||
[0.6, 0.8, 0.0, 0.0, 0.0], | ||
[0.8, 0.6, 0.0, 0.0, 0.0], | ||
]); | ||
NdArray::new(test_storage) | ||
} | ||
|
||
#[test] | ||
fn test_most_similar() { | ||
let storage = test_ndarray_for_pruning(); | ||
let keep_indices = &[0, 1]; | ||
let toss_indices = &[2, 3]; | ||
let test_remap_indices = storage.most_similar(keep_indices, toss_indices, 1); | ||
assert_eq!(arr1(&[1, 0]), test_remap_indices); | ||
} | ||
|
||
#[test] | ||
fn test_prune_storage() { | ||
let pruned_storage: Array2<f32> = | ||
arr2(&[[1.0, 0.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0, 0.0]]); | ||
let storage = test_ndarray_for_pruning(); | ||
let toss_indices = &[2, 3]; | ||
let test_pruned_storage = storage.prune_storage(toss_indices); | ||
assert_eq!(pruned_storage.dim(), test_pruned_storage.shape()); | ||
for (idx, row) in pruned_storage.axis_iter(Axis(0)).enumerate() { | ||
assert_eq!(row, test_pruned_storage.embedding(idx)); | ||
} | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
//! Embedding matrix representations. | ||
|
||
use ndarray::{ArrayView2, ArrayViewMut2, CowArray, Ix1}; | ||
use ndarray::{Array1, ArrayView2, ArrayViewMut2, CowArray, Ix1}; | ||
|
||
mod array; | ||
pub use self::array::{MmapArray, NdArray}; | ||
|
@@ -33,3 +33,17 @@ pub(crate) trait StorageViewMut: Storage { | |
/// Get a view of the embedding matrix. | ||
fn view_mut(&mut self) -> ArrayViewMut2<f32>; | ||
} | ||
|
||
/// Storage that can be pruned. | ||
pub trait StoragePrune: Storage { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add rustdoc. |
||
/// Prune a storage. Discard the vectors which need to be pruned off based on their indices. | ||
fn prune_storage(&self, toss_indices: &[usize]) -> StorageWrap; | ||
|
||
/// Find a nearest vector for each vector that need to be tossed. | ||
fn most_similar( | ||
&self, | ||
keep_indices: &[usize], | ||
toss_indices: &[usize], | ||
batch_size: usize, | ||
) -> Array1<usize>; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add rustdoc for the trait and for the
prune_norms
method,