-
Notifications
You must be signed in to change notification settings - Fork 3.9k
Added visualize_graph to HeteroExplanation.
#10207
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
for more information, see https://pre-commit.ci
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #10207 +/- ##
==========================================
- Coverage 86.11% 85.41% -0.70%
==========================================
Files 496 496
Lines 33655 33984 +329
==========================================
+ Hits 28981 29029 +48
- Misses 4674 4955 +281 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
HeteroExplanation HeteroExplanation
|
This looks good. Can we update this so that the Node type doesn't appear over each node, this makes it look cluttered, and we can identify node types by color anyway. |
Fixed default for when no labels are provided to be empty node. Here's an example of what this looks like with and without provided node labels: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR adds a new visualization helper to render heterogeneous explanation graphs, with node sizes/opacities and edge widths/opacities controlled by user-defined ranges. Key changes include:
- Introducing visualize_hetero_graph and its internal helper _visualize_hetero_graph_via_networkx in graph.py.
- Extending the ExplanationMixin with visualize_explanation_graph to support heterogeneous graphs.
- Updating the package init.py and adding tests in test/explain/test_hetero_explanation.py.
Reviewed Changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated no comments.
| File | Description |
|---|---|
| torch_geometric/visualization/graph.py | Added new visualization functions for heterogeneous graphs with networkx backend. |
| torch_geometric/visualization/init.py | Updated exports to include visualize_hetero_graph. |
| torch_geometric/explain/explanation.py | Added visualize_explanation_graph method in ExplanationMixin. |
| test/explain/test_hetero_explanation.py | Added tests for the new heterogeneous explanation visualization functionality. |
Comments suppressed due to low confidence (1)
torch_geometric/visualization/graph.py:344
- [nitpick] Consider renaming 'src_mapping' and 'dst_mapping' to 'src_index_mapping' and 'dst_index_mapping' respectively for improved clarity on what the mappings represent.
src_mapping = {
HeteroExplanation visualize_explanation_graph to HeteroExplanation.
wsad1
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@BlazStojanovic this looks good.
Lets fix the lint issues and merge.
visualize_explanation_graph to HeteroExplanation.visualize_graph to HeteroExplanation.
Hi all! I added a utility to visualize the `HeteroExplanation` at the
subgraph level. Visualization works in the following way:
- Node importance is rendered as node size and node opacity
- Edge importance is rendered as edge width and edge opacity
The user has the ability to specify the ranges of node size, edge width,
and both opacities. The conversion between `HeteroData` and networkx
graph is a bit tricky, would appreciate if reviewers are particularly
strict with that part of the code.
Here's an example of `HeteroExplanation` visualized in this way:
```
import os.path as osp
import tempfile
import pytest
import torch
from torch_geometric.data import HeteroData
from torch_geometric.explain import HeteroExplanation
import matplotlib.pyplot as plt
def generate_random_explanation(
num_papers: int = 21,
num_authors: int = 8,
num_institutions: int = 4,
num_edges_pa: int = 16, # paper-author edges
num_edges_ai: int = 8, # author-institution edges
num_edges_pp: int = 12, # paper-paper edges
seed: int = 42,
) -> HeteroExplanation:
r"""Generates a random heterogeneous explanation for testing.
Args:
num_papers (int): Number of paper nodes.
num_authors (int): Number of author nodes.
num_institutions (int): Number of institution nodes.
num_edges_pa (int): Number of paper-author edges.
num_edges_ai (int): Number of author-institution edges.
num_edges_pp (int): Number of paper-paper edges.
seed (int): Random seed for reproducibility.
"""
torch.manual_seed(seed)
# Create a heterogeneous graph
data = HeteroData()
# Add paper nodes
data['paper'].x = torch.randn(num_papers, 16)
data['paper'].num_nodes = num_papers
# Add author nodes
data['author'].x = torch.randn(num_authors, 8)
data['author'].num_nodes = num_authors
# Add institution nodes
data['institution'].x = torch.randn(num_institutions, 4)
data['institution'].num_nodes = num_institutions
# Add edges between papers and authors
edge_index_pa = torch.randint(0, num_papers, (2, num_edges_pa))
edge_index_pa[1] = torch.randint(0, num_authors, (num_edges_pa,))
data['paper', 'written_by', 'author'].edge_index = edge_index_pa
# Add edges between authors and institutions
edge_index_ai = torch.randint(0, num_authors, (2, num_edges_ai))
edge_index_ai[1] = torch.randint(0, num_institutions, (num_edges_ai,))
data['author', 'affiliated_with', 'institution'].edge_index = edge_index_ai
# Add edges between papers (citations)
edge_index_pp = torch.randint(0, num_papers, (2, num_edges_pp))
# Remove self-citations by ensuring source and target nodes are different
# Generate source and target nodes separately to avoid self-citations
edge_index_pp[1] = torch.randint(0, num_papers-1, (num_edges_pp,)) # Generate targets
edge_index_pp[1] += (edge_index_pp[1] >= edge_index_pp[0]).long() # Shift up if >= source
mask = torch.ones(num_edges_pp, dtype=torch.bool) # Keep all edges
data['paper', 'cites', 'paper'].edge_index = edge_index_pp[:, mask]
# Create explanation masks
explanation = HeteroExplanation()
explanation['paper'].x = data['paper'].x
explanation['author'].x = data['author'].x
explanation['institution'].x = data['institution'].x
# Copy edge indices
explanation['paper', 'written_by', 'author'].edge_index = data['paper', 'written_by', 'author'].edge_index
explanation['author', 'affiliated_with', 'institution'].edge_index = data['author', 'affiliated_with', 'institution'].edge_index
explanation['paper', 'cites', 'paper'].edge_index = data['paper', 'cites', 'paper'].edge_index
# Add random node and edge masks
explanation['paper'].node_mask = torch.rand(num_papers, 1)
explanation['author'].node_mask = torch.rand(num_authors, 1)
explanation['institution'].node_mask = torch.rand(num_institutions, 1)
explanation['paper', 'written_by', 'author'].edge_mask = torch.rand(num_edges_pa)
explanation['author', 'affiliated_with', 'institution'].edge_mask = torch.rand(num_edges_ai)
explanation['paper', 'cites', 'paper'].edge_mask = torch.rand(num_edges_pp)
return explanation
if __name__ == '__main__':
# For interactive testing
explanation = generate_random_explanation()
# Create a single figure with spring layout
fig, ax = plt.subplots(figsize=(12, 8))
# Generate node names that match the number of nodes
num_papers = explanation['paper'].num_nodes
num_authors = explanation['author'].num_nodes
num_institutions = explanation['institution'].num_nodes
explanation.visualize_explanation_graph(
node_labels={
'paper': [f'Paper {i}' for i in range(num_papers)],
'author': [f'Author {i}' for i in range(num_authors)],
'institution': [f'Inst {i}' for i in range(num_institutions)]
}
)
# Add title to the figure
ax.set_title('Example Visualization with Spring Layout')
# Show the figure
plt.show()
```
<img width="1212" alt="Screenshot 2025-04-17 at 16 41 42"
src="https://github.com/user-attachments/assets/fe9ead88-7022-4978-9f33-b0f54e55610e"
/>
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Jinu Sunil <[email protected]>
Hi all! I added a utility to visualize the
HeteroExplanationat the subgraph level. Visualization works in the following way:The user has the ability to specify the ranges of node size, edge width, and both opacities. The conversion between
HeteroDataand networkx graph is a bit tricky, would appreciate if reviewers are particularly strict with that part of the code.Here's an example of
HeteroExplanationvisualized in this way: