1- import io
1+ import pickle
22from abc import ABC , abstractmethod
33from typing import Any , Iterable , List , Optional , Union
44from uuid import uuid4
55
6- import torch
76from torch import Tensor
87from tqdm import tqdm
98
@@ -90,14 +89,12 @@ def serialize(data: Any) -> bytes:
9089 if isinstance (data , Tensor ):
9190 data = data .clone ()
9291
93- buffer = io .BytesIO ()
94- torch .save (data , buffer )
95- return buffer .getvalue ()
92+ return pickle .dumps (data )
9693
9794 @staticmethod
9895 def deserialize (data : bytes ) -> Any :
9996 r"""Deserializes bytes into the original data."""
100- return torch . load ( io . BytesIO ( data ) )
97+ return pickle . loads ( data )
10198
10299 def slice_to_range (self , indices : slice ) -> range :
103100 start = 0 if indices .start is None else indices .start
@@ -265,7 +262,10 @@ def __init__(self, path: str):
265262
266263 def connect (self ):
267264 import rocksdict
268- self ._db = rocksdict .Rdict (self .path )
265+ self ._db = rocksdict .Rdict (
266+ self .path ,
267+ options = rocksdict .Options (raw_mode = True ),
268+ )
269269
270270 def close (self ):
271271 if self ._db is not None :
@@ -278,17 +278,23 @@ def db(self) -> Any:
278278 raise RuntimeError ("No open database connection" )
279279 return self ._db
280280
281+ @staticmethod
282+ def to_key (index : int ) -> bytes :
283+ return index .to_bytes (8 , byteorder = 'big' , signed = True )
284+
281285 def insert (self , index : int , data : Any ):
282286 # Ensure that data is not a view of a larger tensor:
283287 if isinstance (data , Tensor ):
284288 data = data .clone ()
285289
286- self .db [index ] = data
290+ self .db [self . to_key ( index ) ] = self . serialize ( data )
287291
288292 def get (self , index : int ) -> Any :
289- return self .db [self .to_key (index )]
293+ return self .deserialize ( self . db [self .to_key (index )])
290294
291295 def _multi_get (self , indices : Union [Iterable [int ], Tensor ]) -> List [Any ]:
292296 if isinstance (indices , Tensor ):
293297 indices = indices .tolist ()
294- return self .db [indices ]
298+ indices = [self .to_key (index ) for index in indices ]
299+ data_list = self .db [indices ]
300+ return [self .deserialize (data ) for data in data_list ]
0 commit comments