Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added a `Database` interface and `SQLite` implementation ([#8028](https://github.com/pyg-team/pytorch_geometric/pull/8028), [#8044](https://github.com/pyg-team/pytorch_geometric/pull/8044), [#8046](https://github.com/pyg-team/pytorch_geometric/pull/8046), [#8051](https://github.com/pyg-team/pytorch_geometric/pull/8051))
- Added a `Database` interface and `SQLiteDatabase`/`RocksDatabase` implementations ([#8028](https://github.com/pyg-team/pytorch_geometric/pull/8028), [#8044](https://github.com/pyg-team/pytorch_geometric/pull/8044), [#8046](https://github.com/pyg-team/pytorch_geometric/pull/8046), [#8051](https://github.com/pyg-team/pytorch_geometric/pull/8051), [#8052](https://github.com/pyg-team/pytorch_geometric/pull/8052))
- Added support for weighted/biased sampling in `NeighborLoader`/`LinkNeighborLoader` ([#8038](https://github.com/pyg-team/pytorch_geometric/pull/8038))
- Added the `MixHopConv` layer and an corresponding example ([#8025](https://github.com/pyg-team/pytorch_geometric/pull/8025))
- Added the option to pass keyword arguments to the underlying normalization layers within `BasicGNN` and `MLP` ([#8024](https://github.com/pyg-team/pytorch_geometric/pull/8024), [#8033](https://github.com/pyg-team/pytorch_geometric/pull/8033))
Expand Down
47 changes: 39 additions & 8 deletions test/data/test_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pytest
import torch

from torch_geometric.data.database import SQLiteDatabase
from torch_geometric.data.database import RocksDatabase, SQLiteDatabase
from torch_geometric.profile import benchmark
from torch_geometric.testing import withPackage

Expand Down Expand Up @@ -35,6 +35,32 @@ def test_sqlite_database(tmp_path, batch_size):
db.close()


@withPackage('rocksdict')
@pytest.mark.parametrize('batch_size', [None, 1])
def test_rocks_database(tmp_path, batch_size):
path = osp.join(tmp_path, 'rocks.db')
db = RocksDatabase(path)
assert str(db) == 'RocksDatabase()'
with pytest.raises(NotImplementedError):
len(db)

data = torch.randn(5)
db.insert(0, data)
assert torch.equal(db.get(0), data)

indices = torch.tensor([1, 2])
data_list = torch.randn(2, 5)
db.multi_insert(indices, data_list, batch_size=batch_size)

out_list = db.multi_get(indices, batch_size=batch_size)
assert isinstance(out_list, list)
assert len(out_list) == 2
assert torch.equal(out_list[0], data_list[0])
assert torch.equal(out_list[1], data_list[1])

db.close()


@withPackage('sqlite3')
def test_database_syntactic_sugar(tmp_path):
path = osp.join(tmp_path, 'sqlite.db')
Expand Down Expand Up @@ -72,14 +98,19 @@ def test_database_syntactic_sugar(tmp_path):
args = parser.parse_args()

data = torch.randn(args.numel, 128)

tmp_dir = tempfile.TemporaryDirectory()

path = osp.join(tmp_dir.name, 'sqlite.db')
db = SQLiteDatabase(path, name='test_table')
sqlite_db = SQLiteDatabase(path, name='test_table')
t = time.perf_counter()
sqlite_db.multi_insert(range(args.numel), data, batch_size=100, log=True)
print(f'Initialized SQLiteDB in {time.perf_counter() - t:.2f} seconds')

path = osp.join(tmp_dir.name, 'rocks.db')
rocks_db = RocksDatabase(path)
t = time.perf_counter()
db.multi_insert(range(args.numel), data, batch_size=100, log=True)
print(f'Initialized DB in {time.perf_counter() - t:.2f} seconds')
rocks_db.multi_insert(range(args.numel), data, batch_size=100, log=True)
print(f'Initialized RocksDB in {time.perf_counter() - t:.2f} seconds')

def in_memory_get(data):
index = torch.randint(0, args.numel, (128, ))
Expand All @@ -90,9 +121,9 @@ def db_get(db):
return db[index]

benchmark(
funcs=[in_memory_get, db_get],
func_names=['In-Memory', 'SQLite'],
args=[(data, ), (db, )],
funcs=[in_memory_get, db_get, db_get],
func_names=['In-Memory', 'SQLite', 'RocksDB'],
args=[(data, ), (sqlite_db, ), (rocks_db, )],
num_steps=50,
num_warmups=5,
)
Expand Down
75 changes: 64 additions & 11 deletions torch_geometric/data/database.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import io
import pickle
from abc import ABC, abstractmethod
from typing import Any, Iterable, List, Optional, Union
from uuid import uuid4

import torch
from torch import Tensor
from tqdm import tqdm

Expand Down Expand Up @@ -90,14 +89,12 @@ def serialize(data: Any) -> bytes:
if isinstance(data, Tensor):
data = data.clone()

buffer = io.BytesIO()
torch.save(data, buffer)
return buffer.getvalue()
return pickle.dumps(data)

@staticmethod
def deserialize(data: bytes) -> Any:
r"""Deserializes bytes into the original data."""
return torch.load(io.BytesIO(data))
return pickle.loads(data)

def slice_to_range(self, indices: slice) -> range:
start = 0 if indices.start is None else indices.start
Expand Down Expand Up @@ -132,7 +129,10 @@ def __setitem__(
self.multi_insert(key, value)

def __repr__(self) -> str:
return f'{self.__class__.__name__}({len(self)})'
try:
return f'{self.__class__.__name__}({len(self)})'
except NotImplementedError:
return f'{self.__class__.__name__}()'


class SQLiteDatabase(Database):
Expand Down Expand Up @@ -161,10 +161,11 @@ def connect(self):
self._cursor = self._connection.cursor()

def close(self):
self._connection.commit()
self._connection.close()
self._connection = None
self._cursor = None
if self._connection is not None:
self._connection.commit()
self._connection.close()
self._connection = None
self._cursor = None

@property
def cursor(self) -> Any:
Expand Down Expand Up @@ -245,3 +246,55 @@ def __len__(self) -> int:
query = f'SELECT COUNT(*) FROM {self.name}'
self.cursor.execute(query)
return self.cursor.fetchone()[0]


class RocksDatabase(Database):
def __init__(self, path: str):
super().__init__()

import rocksdict

self.path = path

self._db: Optional[rocksdict.Rdict] = None

self.connect()

def connect(self):
import rocksdict
self._db = rocksdict.Rdict(
self.path,
options=rocksdict.Options(raw_mode=True),
)

def close(self):
if self._db is not None:
self._db.close()
self._db = None

@property
def db(self) -> Any:
if self._db is None:
raise RuntimeError("No open database connection")
return self._db

@staticmethod
def to_key(index: int) -> bytes:
return index.to_bytes(8, byteorder='big', signed=True)

def insert(self, index: int, data: Any):
# Ensure that data is not a view of a larger tensor:
if isinstance(data, Tensor):
data = data.clone()

self.db[self.to_key(index)] = self.serialize(data)

def get(self, index: int) -> Any:
return self.deserialize(self.db[self.to_key(index)])

def _multi_get(self, indices: Union[Iterable[int], Tensor]) -> List[Any]:
if isinstance(indices, Tensor):
indices = indices.tolist()
indices = [self.to_key(index) for index in indices]
data_list = self.db[indices]
return [self.deserialize(data) for data in data_list]