Skip to content

Commit cf9d018

Browse files
author
XinweiHe
committed
fix
1 parent 4edf6b0 commit cf9d018

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

torch_geometric/utils/embedding.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import warnings
2-
from typing import Any, Dict, List, Optional
2+
from typing import Any, Dict, List, Optional, Type
33

44
import torch
55
from torch import Tensor
@@ -58,10 +58,10 @@ def hook(model: torch.nn.Module, inputs: Any, outputs: Any) -> None:
5858

5959
def get_embeddings_hetero(
6060
model: torch.nn.Module,
61-
supported_models: Optional[List[torch.nn.Module]] = None,
61+
supported_models: Optional[List[Type[torch.nn.Module]]] = None,
6262
*args: Any,
6363
**kwargs: Any,
64-
) -> Dict[str, List[Tensor]]:
64+
) -> Dict[NodeType, List[Tensor]]:
6565
"""Returns the output embeddings of all
6666
:class:`~torch_geometric.nn.conv.MessagePassing` layers in a heterogeneous
6767
:obj:`model`, organized by edge type.
@@ -73,8 +73,9 @@ def get_embeddings_hetero(
7373
7474
Args:
7575
model (torch.nn.Module): The heterogeneous GNN model.
76-
supported_models (Optional[List[torch.nn.Module]]): A list of supported
77-
heterogenous models. (default: :obj:`None`)
76+
supported_models (List[Type[torch.nn.Module]], optional): A list of
77+
supported model classes. If not provided, defaults to
78+
[HGTConv, HANConv, HeteroConv].
7879
*args: Arguments passed to the model.
7980
**kwargs (optional): Additional keyword arguments passed to the model.
8081

0 commit comments

Comments
 (0)