Skip to content

Commit 79f97f6

Browse files
rusty1sJakubPietrakIntel
authored andcommitted
RocksDatabase implementation (#8052)
1 parent 2612cc8 commit 79f97f6

File tree

3 files changed

+104
-20
lines changed

3 files changed

+104
-20
lines changed

CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
77

88
### Added
99

10-
- Added a `Database` interface and `SQLite` implementation ([#8028](https://github.com/pyg-team/pytorch_geometric/pull/8028), [#8044](https://github.com/pyg-team/pytorch_geometric/pull/8044), [#8046](https://github.com/pyg-team/pytorch_geometric/pull/8046), [#8051](https://github.com/pyg-team/pytorch_geometric/pull/8051))
10+
- Added a `Database` interface and `SQLiteDatabase`/`RocksDatabase` implementations ([#8028](https://github.com/pyg-team/pytorch_geometric/pull/8028), [#8044](https://github.com/pyg-team/pytorch_geometric/pull/8044), [#8046](https://github.com/pyg-team/pytorch_geometric/pull/8046), [#8051](https://github.com/pyg-team/pytorch_geometric/pull/8051), [#8052](https://github.com/pyg-team/pytorch_geometric/pull/8052))
1111
- Added support for weighted/biased sampling in `NeighborLoader`/`LinkNeighborLoader` ([#8038](https://github.com/pyg-team/pytorch_geometric/pull/8038))
1212
- Added the `MixHopConv` layer and an corresponding example ([#8025](https://github.com/pyg-team/pytorch_geometric/pull/8025))
1313
- Added the option to pass keyword arguments to the underlying normalization layers within `BasicGNN` and `MLP` ([#8024](https://github.com/pyg-team/pytorch_geometric/pull/8024), [#8033](https://github.com/pyg-team/pytorch_geometric/pull/8033))

test/data/test_database.py

Lines changed: 39 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import pytest
44
import torch
55

6-
from torch_geometric.data.database import SQLiteDatabase
6+
from torch_geometric.data.database import RocksDatabase, SQLiteDatabase
77
from torch_geometric.profile import benchmark
88
from torch_geometric.testing import withPackage
99

@@ -35,6 +35,32 @@ def test_sqlite_database(tmp_path, batch_size):
3535
db.close()
3636

3737

38+
@withPackage('rocksdict')
39+
@pytest.mark.parametrize('batch_size', [None, 1])
40+
def test_rocks_database(tmp_path, batch_size):
41+
path = osp.join(tmp_path, 'rocks.db')
42+
db = RocksDatabase(path)
43+
assert str(db) == 'RocksDatabase()'
44+
with pytest.raises(NotImplementedError):
45+
len(db)
46+
47+
data = torch.randn(5)
48+
db.insert(0, data)
49+
assert torch.equal(db.get(0), data)
50+
51+
indices = torch.tensor([1, 2])
52+
data_list = torch.randn(2, 5)
53+
db.multi_insert(indices, data_list, batch_size=batch_size)
54+
55+
out_list = db.multi_get(indices, batch_size=batch_size)
56+
assert isinstance(out_list, list)
57+
assert len(out_list) == 2
58+
assert torch.equal(out_list[0], data_list[0])
59+
assert torch.equal(out_list[1], data_list[1])
60+
61+
db.close()
62+
63+
3864
@withPackage('sqlite3')
3965
def test_database_syntactic_sugar(tmp_path):
4066
path = osp.join(tmp_path, 'sqlite.db')
@@ -72,14 +98,19 @@ def test_database_syntactic_sugar(tmp_path):
7298
args = parser.parse_args()
7399

74100
data = torch.randn(args.numel, 128)
75-
76101
tmp_dir = tempfile.TemporaryDirectory()
102+
77103
path = osp.join(tmp_dir.name, 'sqlite.db')
78-
db = SQLiteDatabase(path, name='test_table')
104+
sqlite_db = SQLiteDatabase(path, name='test_table')
105+
t = time.perf_counter()
106+
sqlite_db.multi_insert(range(args.numel), data, batch_size=100, log=True)
107+
print(f'Initialized SQLiteDB in {time.perf_counter() - t:.2f} seconds')
79108

109+
path = osp.join(tmp_dir.name, 'rocks.db')
110+
rocks_db = RocksDatabase(path)
80111
t = time.perf_counter()
81-
db.multi_insert(range(args.numel), data, batch_size=100, log=True)
82-
print(f'Initialized DB in {time.perf_counter() - t:.2f} seconds')
112+
rocks_db.multi_insert(range(args.numel), data, batch_size=100, log=True)
113+
print(f'Initialized RocksDB in {time.perf_counter() - t:.2f} seconds')
83114

84115
def in_memory_get(data):
85116
index = torch.randint(0, args.numel, (128, ))
@@ -90,9 +121,9 @@ def db_get(db):
90121
return db[index]
91122

92123
benchmark(
93-
funcs=[in_memory_get, db_get],
94-
func_names=['In-Memory', 'SQLite'],
95-
args=[(data, ), (db, )],
124+
funcs=[in_memory_get, db_get, db_get],
125+
func_names=['In-Memory', 'SQLite', 'RocksDB'],
126+
args=[(data, ), (sqlite_db, ), (rocks_db, )],
96127
num_steps=50,
97128
num_warmups=5,
98129
)

torch_geometric/data/database.py

Lines changed: 64 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
1-
import io
1+
import pickle
22
from abc import ABC, abstractmethod
33
from typing import Any, Iterable, List, Optional, Union
44
from uuid import uuid4
55

6-
import torch
76
from torch import Tensor
87
from tqdm import tqdm
98

@@ -90,14 +89,12 @@ def serialize(data: Any) -> bytes:
9089
if isinstance(data, Tensor):
9190
data = data.clone()
9291

93-
buffer = io.BytesIO()
94-
torch.save(data, buffer)
95-
return buffer.getvalue()
92+
return pickle.dumps(data)
9693

9794
@staticmethod
9895
def deserialize(data: bytes) -> Any:
9996
r"""Deserializes bytes into the original data."""
100-
return torch.load(io.BytesIO(data))
97+
return pickle.loads(data)
10198

10299
def slice_to_range(self, indices: slice) -> range:
103100
start = 0 if indices.start is None else indices.start
@@ -132,7 +129,10 @@ def __setitem__(
132129
self.multi_insert(key, value)
133130

134131
def __repr__(self) -> str:
135-
return f'{self.__class__.__name__}({len(self)})'
132+
try:
133+
return f'{self.__class__.__name__}({len(self)})'
134+
except NotImplementedError:
135+
return f'{self.__class__.__name__}()'
136136

137137

138138
class SQLiteDatabase(Database):
@@ -161,10 +161,11 @@ def connect(self):
161161
self._cursor = self._connection.cursor()
162162

163163
def close(self):
164-
self._connection.commit()
165-
self._connection.close()
166-
self._connection = None
167-
self._cursor = None
164+
if self._connection is not None:
165+
self._connection.commit()
166+
self._connection.close()
167+
self._connection = None
168+
self._cursor = None
168169

169170
@property
170171
def cursor(self) -> Any:
@@ -245,3 +246,55 @@ def __len__(self) -> int:
245246
query = f'SELECT COUNT(*) FROM {self.name}'
246247
self.cursor.execute(query)
247248
return self.cursor.fetchone()[0]
249+
250+
251+
class RocksDatabase(Database):
252+
def __init__(self, path: str):
253+
super().__init__()
254+
255+
import rocksdict
256+
257+
self.path = path
258+
259+
self._db: Optional[rocksdict.Rdict] = None
260+
261+
self.connect()
262+
263+
def connect(self):
264+
import rocksdict
265+
self._db = rocksdict.Rdict(
266+
self.path,
267+
options=rocksdict.Options(raw_mode=True),
268+
)
269+
270+
def close(self):
271+
if self._db is not None:
272+
self._db.close()
273+
self._db = None
274+
275+
@property
276+
def db(self) -> Any:
277+
if self._db is None:
278+
raise RuntimeError("No open database connection")
279+
return self._db
280+
281+
@staticmethod
282+
def to_key(index: int) -> bytes:
283+
return index.to_bytes(8, byteorder='big', signed=True)
284+
285+
def insert(self, index: int, data: Any):
286+
# Ensure that data is not a view of a larger tensor:
287+
if isinstance(data, Tensor):
288+
data = data.clone()
289+
290+
self.db[self.to_key(index)] = self.serialize(data)
291+
292+
def get(self, index: int) -> Any:
293+
return self.deserialize(self.db[self.to_key(index)])
294+
295+
def _multi_get(self, indices: Union[Iterable[int], Tensor]) -> List[Any]:
296+
if isinstance(indices, Tensor):
297+
indices = indices.tolist()
298+
indices = [self.to_key(index) for index in indices]
299+
data_list = self.db[indices]
300+
return [self.deserialize(data) for data in data_list]

0 commit comments

Comments
 (0)