11import warnings
2- from typing import Any , Dict , List , Optional
2+ from typing import Any , Dict , List , Optional , Type
33
44import torch
55from torch import Tensor
@@ -58,10 +58,10 @@ def hook(model: torch.nn.Module, inputs: Any, outputs: Any) -> None:
5858
5959def get_embeddings_hetero (
6060 model : torch .nn .Module ,
61- supported_models : Optional [List [torch .nn .Module ]] = None ,
61+ supported_models : Optional [List [Type [ torch .nn .Module ] ]] = None ,
6262 * args : Any ,
6363 ** kwargs : Any ,
64- ) -> Dict [str , List [Tensor ]]:
64+ ) -> Dict [NodeType , List [Tensor ]]:
6565 """Returns the output embeddings of all
6666 :class:`~torch_geometric.nn.conv.MessagePassing` layers in a heterogeneous
6767 :obj:`model`, organized by edge type.
@@ -73,8 +73,9 @@ def get_embeddings_hetero(
7373
7474 Args:
7575 model (torch.nn.Module): The heterogeneous GNN model.
76- supported_models (Optional[List[torch.nn.Module]]): A list of supported
77- heterogenous models. (default: :obj:`None`)
76+ supported_models (List[Type[torch.nn.Module]], optional): A list of
77+ supported model classes. If not provided, defaults to
78+ [HGTConv, HANConv, HeteroConv].
7879 *args: Arguments passed to the model.
7980 **kwargs (optional): Additional keyword arguments passed to the model.
8081
0 commit comments