Skip to content

Commit 78b96ad

Browse files
committed
update
1 parent f7b0b2f commit 78b96ad

File tree

1 file changed

+16
-10
lines changed

1 file changed

+16
-10
lines changed

torch_geometric/data/database.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
1-
import io
1+
import pickle
22
from abc import ABC, abstractmethod
33
from typing import Any, Iterable, List, Optional, Union
44
from uuid import uuid4
55

6-
import torch
76
from torch import Tensor
87
from tqdm import tqdm
98

@@ -90,14 +89,12 @@ def serialize(data: Any) -> bytes:
9089
if isinstance(data, Tensor):
9190
data = data.clone()
9291

93-
buffer = io.BytesIO()
94-
torch.save(data, buffer)
95-
return buffer.getvalue()
92+
return pickle.dumps(data)
9693

9794
@staticmethod
9895
def deserialize(data: bytes) -> Any:
9996
r"""Deserializes bytes into the original data."""
100-
return torch.load(io.BytesIO(data))
97+
return pickle.loads(data)
10198

10299
def slice_to_range(self, indices: slice) -> range:
103100
start = 0 if indices.start is None else indices.start
@@ -265,7 +262,10 @@ def __init__(self, path: str):
265262

266263
def connect(self):
267264
import rocksdict
268-
self._db = rocksdict.Rdict(self.path)
265+
self._db = rocksdict.Rdict(
266+
self.path,
267+
options=rocksdict.Options(raw_mode=True),
268+
)
269269

270270
def close(self):
271271
if self._db is not None:
@@ -278,17 +278,23 @@ def db(self) -> Any:
278278
raise RuntimeError("No open database connection")
279279
return self._db
280280

281+
@staticmethod
282+
def to_key(index: int) -> bytes:
283+
return index.to_bytes(8, byteorder='big', signed=True)
284+
281285
def insert(self, index: int, data: Any):
282286
# Ensure that data is not a view of a larger tensor:
283287
if isinstance(data, Tensor):
284288
data = data.clone()
285289

286-
self.db[index] = data
290+
self.db[self.to_key(index)] = self.serialize(data)
287291

288292
def get(self, index: int) -> Any:
289-
return self.db[self.to_key(index)]
293+
return self.deserialize(self.db[self.to_key(index)])
290294

291295
def _multi_get(self, indices: Union[Iterable[int], Tensor]) -> List[Any]:
292296
if isinstance(indices, Tensor):
293297
indices = indices.tolist()
294-
return self.db[indices]
298+
indices = [self.to_key(index) for index in indices]
299+
data_list = self.db[indices]
300+
return [self.deserialize(data) for data in data_list]

0 commit comments

Comments
 (0)