Skip to content

Commit cbab312

Browse files
author
NianhengWu
committed
Add prune support
1 parent ee917db commit cbab312

File tree

9 files changed

+438
-13
lines changed

9 files changed

+438
-13
lines changed

src/chunks/norms.rs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,20 @@ impl WriteChunk for NdNorms {
127127
}
128128
}
129129

130+
pub trait PruneNorms {
131+
fn prune_norms(&self, toss_indices: &[usize], most_similar_indices: &Array1<usize>) -> NdNorms;
132+
}
133+
134+
impl PruneNorms for NdNorms {
135+
fn prune_norms(&self, toss_indices: &[usize], most_similar_indices: &Array1<usize>) -> NdNorms {
136+
let mut pruned_norms = self.inner.clone();
137+
for (toss_idx, remapped_idx) in toss_indices.iter().zip(most_similar_indices) {
138+
pruned_norms[*toss_idx] = pruned_norms[*remapped_idx];
139+
}
140+
NdNorms::new(pruned_norms)
141+
}
142+
}
143+
130144
#[cfg(test)]
131145
mod tests {
132146
use std::io::{Cursor, Read, Seek, SeekFrom};

src/chunks/storage/array.rs

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,17 @@
1+
use std::collections::HashSet;
12
use std::fs::File;
23
use std::io::{BufReader, Read, Seek, SeekFrom, Write};
4+
use std::iter::FromIterator;
35
use std::mem::size_of;
46

57
#[cfg(target_endian = "big")]
68
use byteorder::ByteOrder;
79
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
810
use memmap::{Mmap, MmapOptions};
9-
use ndarray::{Array2, ArrayView2, ArrayViewMut2, CowArray, Dimension, Ix1, Ix2};
11+
use ndarray::{s, Array1, Array2, ArrayView2, ArrayViewMut2, Axis, CowArray, Dimension, Ix1, Ix2};
12+
use ordered_float::OrderedFloat;
1013

11-
use super::{Storage, StorageView, StorageViewMut};
14+
use super::{Storage, StoragePrune, StorageView, StorageViewMut, StorageWrap};
1215
use crate::chunks::io::{ChunkIdentifier, MmapChunk, ReadChunk, TypeId, WriteChunk};
1316
use crate::io::{Error, ErrorKind, Result};
1417
use crate::util::padding;
@@ -274,6 +277,50 @@ impl WriteChunk for NdArray {
274277
}
275278
}
276279

280+
impl StoragePrune for NdArray {
281+
fn simple_prune_storage(&self, toss_indices: &[usize]) -> StorageWrap {
282+
let toss_indices: HashSet<usize> = HashSet::from_iter(toss_indices.iter().cloned());
283+
let mut keep_indices_all = Vec::new();
284+
for idx in 0..self.shape().1 {
285+
if !toss_indices.contains(&idx) {
286+
keep_indices_all.push(idx);
287+
}
288+
}
289+
NdArray::new(self.inner.select(Axis(0), &keep_indices_all).to_owned()).into()
290+
}
291+
292+
fn most_similar(
293+
&self,
294+
keep_indices: &[usize],
295+
toss_indices: &[usize],
296+
batch_size: usize,
297+
) -> Array1<usize> {
298+
let toss = toss_indices.len();
299+
let mut remap_indices = Array1::zeros(toss);
300+
let keep_embeds = self.inner.select(Axis(0), keep_indices);
301+
let keep_embeds_t = &keep_embeds.t();
302+
let toss_embeds = self.inner.select(Axis(0), toss_indices);
303+
for n in (0..toss).step_by(batch_size) {
304+
let mut offset = n + batch_size;
305+
if offset > toss {
306+
offset = toss;
307+
}
308+
let batch = toss_embeds.slice(s![n..offset, ..]);
309+
let similarity_scores = batch.dot(keep_embeds_t);
310+
for (i, row) in similarity_scores.axis_iter(Axis(0)).enumerate() {
311+
let dist = row
312+
.iter()
313+
.enumerate()
314+
.max_by_key(|(_, &v)| OrderedFloat(v))
315+
.unwrap()
316+
.0;
317+
remap_indices[n + i] = dist;
318+
}
319+
}
320+
remap_indices
321+
}
322+
}
323+
277324
#[cfg(test)]
278325
mod tests {
279326
use std::io::{Cursor, Read, Seek, SeekFrom};

src/chunks/storage/mod.rs

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
//! Embedding matrix representations.
22
3-
use ndarray::{ArrayView2, ArrayViewMut2, CowArray, Ix1};
3+
use ndarray::{Array1, ArrayView2, ArrayViewMut2, CowArray, Ix1};
44

55
mod array;
66
pub use self::array::{MmapArray, NdArray};
@@ -33,3 +33,13 @@ pub(crate) trait StorageViewMut: Storage {
3333
/// Get a view of the embedding matrix.
3434
fn view_mut(&mut self) -> ArrayViewMut2<f32>;
3535
}
36+
37+
pub trait StoragePrune: Storage {
38+
fn simple_prune_storage(&self, toss_indices: &[usize]) -> StorageWrap;
39+
fn most_similar(
40+
&self,
41+
keep_indices: &[usize],
42+
toss_indices: &[usize],
43+
batch_size: usize,
44+
) -> Array1<usize>;
45+
}

src/chunks/storage/quantized.rs

Lines changed: 74 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,20 @@
1+
use std::collections::HashSet;
12
use std::fs::File;
23
use std::io::{BufReader, Read, Seek, SeekFrom, Write};
4+
use std::iter::FromIterator;
35
use std::mem::size_of;
46

57
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
68
use memmap::{Mmap, MmapOptions};
79
use ndarray::{
8-
Array, Array1, Array2, ArrayView1, ArrayView2, CowArray, Dimension, IntoDimension, Ix1,
10+
Array, Array1, Array2, ArrayView1, ArrayView2, Axis, CowArray, Dimension, IntoDimension, Ix1,
911
};
12+
use ordered_float::OrderedFloat;
1013
use rand::{RngCore, SeedableRng};
1114
use rand_xorshift::XorShiftRng;
1215
use reductive::pq::{QuantizeVector, ReconstructVector, TrainPQ, PQ};
1316

14-
use super::{Storage, StorageView};
17+
use super::{Storage, StoragePrune, StorageView, StorageWrap};
1518
use crate::chunks::io::{ChunkIdentifier, MmapChunk, ReadChunk, TypeId, WriteChunk};
1619
use crate::io::{Error, ErrorKind, Result};
1720
use crate::util::padding;
@@ -561,6 +564,75 @@ impl WriteChunk for MmapQuantizedArray {
561564
}
562565
}
563566

567+
impl StoragePrune for QuantizedArray {
568+
fn simple_prune_storage(&self, toss_indices: &[usize]) -> StorageWrap {
569+
let mut keep_indices_all = Vec::new();
570+
let toss_indices: HashSet<usize> = HashSet::from_iter(toss_indices.iter().cloned());
571+
for idx in 0..self.quantized_embeddings.shape()[0] {
572+
if !toss_indices.contains(&idx) {
573+
keep_indices_all.push(idx);
574+
}
575+
}
576+
let norms = if self.norms.is_some() {
577+
Some(
578+
self.norms
579+
.as_ref()
580+
.unwrap()
581+
.select(Axis(0), &keep_indices_all),
582+
)
583+
} else {
584+
None
585+
};
586+
let new_storage = QuantizedArray {
587+
quantizer: self.quantizer.clone(),
588+
quantized_embeddings: self
589+
.quantized_embeddings
590+
.select(Axis(0), &keep_indices_all)
591+
.to_owned(),
592+
norms,
593+
};
594+
new_storage.into()
595+
}
596+
597+
fn most_similar(
598+
&self,
599+
keep_indices: &[usize],
600+
toss_indices: &[usize],
601+
_batch_size: usize,
602+
) -> Array1<usize> {
603+
let dists: Vec<Array2<f32>> = self
604+
.quantizer
605+
.subquantizers()
606+
.axis_iter(Axis(0))
607+
.map(|quantizer| quantizer.dot(&quantizer.t()))
608+
.collect();
609+
let keep_quantized_embeddings = self.quantized_embeddings.select(Axis(0), keep_indices);
610+
let toss_quantized_embeddings = self.quantized_embeddings.select(Axis(0), toss_indices);
611+
let mut remap_indices = Array1::zeros((toss_quantized_embeddings.shape()[0],));
612+
for (i, toss_row) in toss_quantized_embeddings.axis_iter(Axis(0)).enumerate() {
613+
let mut row_dist = vec![0f32; keep_quantized_embeddings.shape()[0]];
614+
for (n, keep_row) in keep_quantized_embeddings.axis_iter(Axis(0)).enumerate() {
615+
row_dist[n] = toss_row
616+
.iter()
617+
.zip(keep_row.iter())
618+
.enumerate()
619+
.map(|(id, (&toss_id, &keep_id))| {
620+
dists[id][(toss_id as usize, keep_id as usize)]
621+
})
622+
.sum();
623+
}
624+
625+
remap_indices[i] = row_dist
626+
.iter()
627+
.enumerate()
628+
.max_by_key(|(_, &v)| OrderedFloat(v))
629+
.unwrap()
630+
.0;
631+
}
632+
remap_indices
633+
}
634+
}
635+
564636
#[cfg(test)]
565637
mod tests {
566638
use std::fs::File;

src/chunks/storage/wrappers.rs

Lines changed: 115 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,11 @@ use std::fs::File;
22
use std::io::{BufReader, Read, Seek, SeekFrom, Write};
33

44
use byteorder::{LittleEndian, ReadBytesExt};
5-
use ndarray::{ArrayView2, CowArray, Ix1};
5+
use ndarray::{Array1, ArrayView2, CowArray, Ix1};
66

7-
use super::{MmapArray, MmapQuantizedArray, NdArray, QuantizedArray, Storage, StorageView};
7+
use super::{
8+
MmapArray, MmapQuantizedArray, NdArray, QuantizedArray, Storage, StoragePrune, StorageView,
9+
};
810
use crate::chunks::io::{ChunkIdentifier, MmapChunk, ReadChunk, WriteChunk};
911
use crate::io::{Error, ErrorKind, Result};
1012

@@ -297,3 +299,114 @@ impl MmapChunk for StorageViewWrap {
297299
}
298300
}
299301
}
302+
303+
pub enum StoragePruneWrap {
304+
NdArray(NdArray),
305+
QuantizedArray(QuantizedArray),
306+
}
307+
308+
impl Storage for StoragePruneWrap {
309+
fn embedding(&self, idx: usize) -> CowArray<f32, Ix1> {
310+
match self {
311+
StoragePruneWrap::QuantizedArray(inner) => inner.embedding(idx),
312+
StoragePruneWrap::NdArray(inner) => inner.embedding(idx),
313+
}
314+
}
315+
316+
fn shape(&self) -> (usize, usize) {
317+
match self {
318+
StoragePruneWrap::QuantizedArray(inner) => inner.shape(),
319+
StoragePruneWrap::NdArray(inner) => inner.shape(),
320+
}
321+
}
322+
}
323+
324+
impl From<NdArray> for StoragePruneWrap {
325+
fn from(s: NdArray) -> Self {
326+
StoragePruneWrap::NdArray(s)
327+
}
328+
}
329+
330+
impl From<QuantizedArray> for StoragePruneWrap {
331+
fn from(s: QuantizedArray) -> Self {
332+
StoragePruneWrap::QuantizedArray(s)
333+
}
334+
}
335+
336+
impl ReadChunk for StoragePruneWrap {
337+
fn read_chunk<R>(read: &mut R) -> Result<Self>
338+
where
339+
R: Read + Seek,
340+
{
341+
let chunk_start_pos = read
342+
.seek(SeekFrom::Current(0))
343+
.map_err(|e| ErrorKind::io_error("Cannot get storage chunk start position", e))?;
344+
345+
let chunk_id = read
346+
.read_u32::<LittleEndian>()
347+
.map_err(|e| ErrorKind::io_error("Cannot read storage chunk identifier", e))?;
348+
let chunk_id = ChunkIdentifier::try_from(chunk_id)
349+
.ok_or_else(|| ErrorKind::Format(format!("Unknown chunk identifier: {}", chunk_id)))
350+
.map_err(Error::from)?;
351+
352+
read.seek(SeekFrom::Start(chunk_start_pos))
353+
.map_err(|e| ErrorKind::io_error("Cannot seek to storage chunk start position", e))?;
354+
355+
match chunk_id {
356+
ChunkIdentifier::NdArray => NdArray::read_chunk(read).map(StoragePruneWrap::NdArray),
357+
ChunkIdentifier::QuantizedArray => {
358+
QuantizedArray::read_chunk(read).map(StoragePruneWrap::QuantizedArray)
359+
}
360+
_ => Err(ErrorKind::Format(format!(
361+
"Invalid chunk identifier, expected one of: {} or {}, got: {}",
362+
ChunkIdentifier::NdArray,
363+
ChunkIdentifier::QuantizedArray,
364+
chunk_id
365+
))
366+
.into()),
367+
}
368+
}
369+
}
370+
371+
impl WriteChunk for StoragePruneWrap {
372+
fn chunk_identifier(&self) -> ChunkIdentifier {
373+
match self {
374+
StoragePruneWrap::QuantizedArray(inner) => inner.chunk_identifier(),
375+
StoragePruneWrap::NdArray(inner) => inner.chunk_identifier(),
376+
}
377+
}
378+
379+
fn write_chunk<W>(&self, write: &mut W) -> Result<()>
380+
where
381+
W: Write + Seek,
382+
{
383+
match self {
384+
StoragePruneWrap::QuantizedArray(inner) => inner.write_chunk(write),
385+
StoragePruneWrap::NdArray(inner) => inner.write_chunk(write),
386+
}
387+
}
388+
}
389+
390+
impl StoragePrune for StoragePruneWrap {
391+
fn simple_prune_storage(&self, toss_indices: &[usize]) -> StorageWrap {
392+
match self {
393+
StoragePruneWrap::NdArray(inner) => inner.simple_prune_storage(toss_indices),
394+
StoragePruneWrap::QuantizedArray(inner) => inner.simple_prune_storage(toss_indices),
395+
}
396+
}
397+
fn most_similar(
398+
&self,
399+
keep_indices: &[usize],
400+
toss_indices: &[usize],
401+
batch_size: usize,
402+
) -> Array1<usize> {
403+
match self {
404+
StoragePruneWrap::NdArray(inner) => {
405+
inner.most_similar(keep_indices, toss_indices, batch_size)
406+
}
407+
StoragePruneWrap::QuantizedArray(inner) => {
408+
inner.most_similar(keep_indices, toss_indices, batch_size)
409+
}
410+
}
411+
}
412+
}

src/chunks/vocab/mod.rs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ use std::collections::HashMap;
44
use std::io::{Read, Seek, Write};
55

66
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
7+
use ndarray::Array1;
78

89
use crate::io::{Error, ErrorKind, Result};
910

@@ -65,6 +66,18 @@ impl WordIndex {
6566
}
6667
}
6768

69+
pub trait VocabPrune: Vocab {
70+
fn prune_vocab(&self, remapped_indices: HashMap<String, usize>) -> VocabWrap;
71+
}
72+
73+
pub trait VocabPruneIndices: Vocab {
74+
fn part_indices(&self, n_keep: usize) -> (Vec<usize>, Vec<usize>);
75+
fn create_remapped_indices(
76+
&self,
77+
most_similar_indices: &Array1<usize>,
78+
) -> HashMap<String, usize>;
79+
}
80+
6881
pub(crate) fn create_indices(words: &[String]) -> HashMap<String, usize> {
6982
let mut indices = HashMap::new();
7083

0 commit comments

Comments
 (0)