Skip to content

Commit 7cba07f

Browse files
Kh4Lpuririshi98pre-commit-ci[bot]
authored
Fix circular import in GPSE (#10190)
Co-authored-by: Rishi Puri <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent a21f251 commit 7cba07f

File tree

2 files changed

+9
-6
lines changed

2 files changed

+9
-6
lines changed

torch_geometric/nn/models/gpse.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import torch
1010
import torch.nn as nn
1111
import torch.nn.functional as F
12+
from torch.nn import Module
1213
from tqdm import trange
1314

1415
import torch_geometric.transforms as T
@@ -715,8 +716,9 @@ def forward(self, x, pos_enc):
715716

716717

717718
@torch.no_grad()
718-
def gpse_process(model: GPSE, data: Data, rand_type: str, use_vn: bool = True,
719-
bernoulli_thresh: float = 0.5, neighbor_loader: bool = False,
719+
def gpse_process(model: Module, data: Data, rand_type: str,
720+
use_vn: bool = True, bernoulli_thresh: float = 0.5,
721+
neighbor_loader: bool = False,
720722
num_neighbors: List[int] = [30, 20, 10], fillval: int = 5,
721723
layers_mp: int = None, **kwargs) -> torch.Tensor:
722724
r"""Processes the data using the :class:`GPSE` model to generate and append
@@ -731,7 +733,7 @@ def gpse_process(model: GPSE, data: Data, rand_type: str, use_vn: bool = True,
731733
:obj:`precompute_GPSE` on your whole dataset is advised instead.
732734
733735
Args:
734-
model (GPSE): The :class:`GPSE` model.
736+
model (Module): The :class:`GPSE` model.
735737
data (torch_geometric.data.Data): A :class:`~torch_geometric.data.Data`
736738
object.
737739
rand_type (str, optional): Type of random features to use. Options are

torch_geometric/transforms/add_gpse.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1+
from torch.nn import Module
2+
13
from torch_geometric.data import Data
24
from torch_geometric.data.datapipes import functional_transform
3-
from torch_geometric.nn.models.gpse import GPSE
45
from torch_geometric.transforms import BaseTransform, VirtualNode
56

67

@@ -13,15 +14,15 @@ class AddGPSE(BaseTransform):
1314
the actual encodings.
1415
1516
Args:
16-
model (GPSE): The pre-trained GPSE model.
17+
model (Module): The pre-trained GPSE model.
1718
use_vn (bool, optional): Whether to use virtual nodes.
1819
(default: :obj:`True`)
1920
rand_type (str, optional): Type of random features to use. Options are
2021
:obj:`NormalSE`, :obj:`UniformSE`, :obj:`BernoulliSE`.
2122
(default: :obj:`NormalSE`)
2223
2324
"""
24-
def __init__(self, model: GPSE, use_vn: bool = True,
25+
def __init__(self, model: Module, use_vn: bool = True,
2526
rand_type: str = 'NormalSE'):
2627
self.model = model
2728
self.use_vn = use_vn

0 commit comments

Comments
 (0)