@@ -2,9 +2,11 @@ use std::fs::File;
22use std:: io:: { BufReader , Read , Seek , SeekFrom , Write } ;
33
44use byteorder:: { LittleEndian , ReadBytesExt } ;
5- use ndarray:: { ArrayView2 , CowArray , Ix1 } ;
5+ use ndarray:: { Array1 , ArrayView2 , CowArray , Ix1 } ;
66
7- use super :: { MmapArray , MmapQuantizedArray , NdArray , QuantizedArray , Storage , StorageView } ;
7+ use super :: {
8+ MmapArray , MmapQuantizedArray , NdArray , QuantizedArray , Storage , StoragePrune , StorageView ,
9+ } ;
810use crate :: chunks:: io:: { ChunkIdentifier , MmapChunk , ReadChunk , WriteChunk } ;
911use crate :: io:: { Error , ErrorKind , Result } ;
1012
@@ -297,3 +299,114 @@ impl MmapChunk for StorageViewWrap {
297299 }
298300 }
299301}
302+
303+ pub enum StoragePruneWrap {
304+ NdArray ( NdArray ) ,
305+ QuantizedArray ( QuantizedArray ) ,
306+ }
307+
308+ impl Storage for StoragePruneWrap {
309+ fn embedding ( & self , idx : usize ) -> CowArray < f32 , Ix1 > {
310+ match self {
311+ StoragePruneWrap :: QuantizedArray ( inner) => inner. embedding ( idx) ,
312+ StoragePruneWrap :: NdArray ( inner) => inner. embedding ( idx) ,
313+ }
314+ }
315+
316+ fn shape ( & self ) -> ( usize , usize ) {
317+ match self {
318+ StoragePruneWrap :: QuantizedArray ( inner) => inner. shape ( ) ,
319+ StoragePruneWrap :: NdArray ( inner) => inner. shape ( ) ,
320+ }
321+ }
322+ }
323+
324+ impl From < NdArray > for StoragePruneWrap {
325+ fn from ( s : NdArray ) -> Self {
326+ StoragePruneWrap :: NdArray ( s)
327+ }
328+ }
329+
330+ impl From < QuantizedArray > for StoragePruneWrap {
331+ fn from ( s : QuantizedArray ) -> Self {
332+ StoragePruneWrap :: QuantizedArray ( s)
333+ }
334+ }
335+
336+ impl ReadChunk for StoragePruneWrap {
337+ fn read_chunk < R > ( read : & mut R ) -> Result < Self >
338+ where
339+ R : Read + Seek ,
340+ {
341+ let chunk_start_pos = read
342+ . seek ( SeekFrom :: Current ( 0 ) )
343+ . map_err ( |e| ErrorKind :: io_error ( "Cannot get storage chunk start position" , e) ) ?;
344+
345+ let chunk_id = read
346+ . read_u32 :: < LittleEndian > ( )
347+ . map_err ( |e| ErrorKind :: io_error ( "Cannot read storage chunk identifier" , e) ) ?;
348+ let chunk_id = ChunkIdentifier :: try_from ( chunk_id)
349+ . ok_or_else ( || ErrorKind :: Format ( format ! ( "Unknown chunk identifier: {}" , chunk_id) ) )
350+ . map_err ( Error :: from) ?;
351+
352+ read. seek ( SeekFrom :: Start ( chunk_start_pos) )
353+ . map_err ( |e| ErrorKind :: io_error ( "Cannot seek to storage chunk start position" , e) ) ?;
354+
355+ match chunk_id {
356+ ChunkIdentifier :: NdArray => NdArray :: read_chunk ( read) . map ( StoragePruneWrap :: NdArray ) ,
357+ ChunkIdentifier :: QuantizedArray => {
358+ QuantizedArray :: read_chunk ( read) . map ( StoragePruneWrap :: QuantizedArray )
359+ }
360+ _ => Err ( ErrorKind :: Format ( format ! (
361+ "Invalid chunk identifier, expected one of: {} or {}, got: {}" ,
362+ ChunkIdentifier :: NdArray ,
363+ ChunkIdentifier :: QuantizedArray ,
364+ chunk_id
365+ ) )
366+ . into ( ) ) ,
367+ }
368+ }
369+ }
370+
371+ impl WriteChunk for StoragePruneWrap {
372+ fn chunk_identifier ( & self ) -> ChunkIdentifier {
373+ match self {
374+ StoragePruneWrap :: QuantizedArray ( inner) => inner. chunk_identifier ( ) ,
375+ StoragePruneWrap :: NdArray ( inner) => inner. chunk_identifier ( ) ,
376+ }
377+ }
378+
379+ fn write_chunk < W > ( & self , write : & mut W ) -> Result < ( ) >
380+ where
381+ W : Write + Seek ,
382+ {
383+ match self {
384+ StoragePruneWrap :: QuantizedArray ( inner) => inner. write_chunk ( write) ,
385+ StoragePruneWrap :: NdArray ( inner) => inner. write_chunk ( write) ,
386+ }
387+ }
388+ }
389+
390+ impl StoragePrune for StoragePruneWrap {
391+ fn simple_prune_storage ( & self , toss_indices : & [ usize ] ) -> StorageWrap {
392+ match self {
393+ StoragePruneWrap :: NdArray ( inner) => inner. simple_prune_storage ( toss_indices) ,
394+ StoragePruneWrap :: QuantizedArray ( inner) => inner. simple_prune_storage ( toss_indices) ,
395+ }
396+ }
397+ fn most_similar (
398+ & self ,
399+ keep_indices : & [ usize ] ,
400+ toss_indices : & [ usize ] ,
401+ batch_size : usize ,
402+ ) -> Array1 < usize > {
403+ match self {
404+ StoragePruneWrap :: NdArray ( inner) => {
405+ inner. most_similar ( keep_indices, toss_indices, batch_size)
406+ }
407+ StoragePruneWrap :: QuantizedArray ( inner) => {
408+ inner. most_similar ( keep_indices, toss_indices, batch_size)
409+ }
410+ }
411+ }
412+ }
0 commit comments