Skip to content

Conversation

@BlazStojanovic
Copy link
Contributor

Hi all! I added a utility to visualize the HeteroExplanation at the subgraph level. Visualization works in the following way:

  • Node importance is rendered as node size and node opacity
  • Edge importance is rendered as edge width and edge opacity

The user has the ability to specify the ranges of node size, edge width, and both opacities. The conversion between HeteroData and networkx graph is a bit tricky, would appreciate if reviewers are particularly strict with that part of the code.

Here's an example of HeteroExplanation visualized in this way:

import os.path as osp
import tempfile

import pytest
import torch
from torch_geometric.data import HeteroData
from torch_geometric.explain import HeteroExplanation

import matplotlib.pyplot as plt


def generate_random_explanation(
    num_papers: int = 21,
    num_authors: int = 8,
    num_institutions: int = 4,
    num_edges_pa: int = 16,  # paper-author edges
    num_edges_ai: int = 8,  # author-institution edges
    num_edges_pp: int = 12,  # paper-paper edges
    seed: int = 42,
) -> HeteroExplanation:
    r"""Generates a random heterogeneous explanation for testing.

    Args:
        num_papers (int): Number of paper nodes.
        num_authors (int): Number of author nodes.
        num_institutions (int): Number of institution nodes.
        num_edges_pa (int): Number of paper-author edges.
        num_edges_ai (int): Number of author-institution edges.
        num_edges_pp (int): Number of paper-paper edges.
        seed (int): Random seed for reproducibility.
    """
    torch.manual_seed(seed)

    # Create a heterogeneous graph
    data = HeteroData()

    # Add paper nodes
    data['paper'].x = torch.randn(num_papers, 16)
    data['paper'].num_nodes = num_papers

    # Add author nodes
    data['author'].x = torch.randn(num_authors, 8)
    data['author'].num_nodes = num_authors

    # Add institution nodes
    data['institution'].x = torch.randn(num_institutions, 4)
    data['institution'].num_nodes = num_institutions

    # Add edges between papers and authors
    edge_index_pa = torch.randint(0, num_papers, (2, num_edges_pa))
    edge_index_pa[1] = torch.randint(0, num_authors, (num_edges_pa,))
    data['paper', 'written_by', 'author'].edge_index = edge_index_pa

    # Add edges between authors and institutions
    edge_index_ai = torch.randint(0, num_authors, (2, num_edges_ai))
    edge_index_ai[1] = torch.randint(0, num_institutions, (num_edges_ai,))
    data['author', 'affiliated_with', 'institution'].edge_index = edge_index_ai

    # Add edges between papers (citations)
    edge_index_pp = torch.randint(0, num_papers, (2, num_edges_pp))
    # Remove self-citations by ensuring source and target nodes are different
    # Generate source and target nodes separately to avoid self-citations
    edge_index_pp[1] = torch.randint(0, num_papers-1, (num_edges_pp,))  # Generate targets
    edge_index_pp[1] += (edge_index_pp[1] >= edge_index_pp[0]).long()  # Shift up if >= source
    mask = torch.ones(num_edges_pp, dtype=torch.bool)  # Keep all edges
    data['paper', 'cites', 'paper'].edge_index = edge_index_pp[:, mask]

    # Create explanation masks
    explanation = HeteroExplanation()
    explanation['paper'].x = data['paper'].x
    explanation['author'].x = data['author'].x
    explanation['institution'].x = data['institution'].x
    
    # Copy edge indices
    explanation['paper', 'written_by', 'author'].edge_index = data['paper', 'written_by', 'author'].edge_index
    explanation['author', 'affiliated_with', 'institution'].edge_index = data['author', 'affiliated_with', 'institution'].edge_index
    explanation['paper', 'cites', 'paper'].edge_index = data['paper', 'cites', 'paper'].edge_index

    # Add random node and edge masks
    explanation['paper'].node_mask = torch.rand(num_papers, 1)
    explanation['author'].node_mask = torch.rand(num_authors, 1)
    explanation['institution'].node_mask = torch.rand(num_institutions, 1)
    
    explanation['paper', 'written_by', 'author'].edge_mask = torch.rand(num_edges_pa)
    explanation['author', 'affiliated_with', 'institution'].edge_mask = torch.rand(num_edges_ai)
    explanation['paper', 'cites', 'paper'].edge_mask = torch.rand(num_edges_pp)

    return explanation


if __name__ == '__main__':
    # For interactive testing
    explanation = generate_random_explanation()
    
    # Create a single figure with spring layout
    fig, ax = plt.subplots(figsize=(12, 8))
    
    # Generate node names that match the number of nodes
    num_papers = explanation['paper'].num_nodes
    num_authors = explanation['author'].num_nodes
    num_institutions = explanation['institution'].num_nodes
    
    explanation.visualize_explanation_graph(
        node_labels={
            'paper': [f'Paper {i}' for i in range(num_papers)],
            'author': [f'Author {i}' for i in range(num_authors)],
            'institution': [f'Inst {i}' for i in range(num_institutions)]
        }
    )
    
    # Add title to the figure
    ax.set_title('Example Visualization with Spring Layout')
    
    # Show the figure
    plt.show()
Screenshot 2025-04-17 at 16 41 42

@codecov
Copy link

codecov bot commented Apr 17, 2025

Codecov Report

❌ Patch coverage is 96.42857% with 5 lines in your changes missing coverage. Please review.
✅ Project coverage is 85.41%. Comparing base (c211214) to head (d48bec0).
⚠️ Report is 73 commits behind head on master.

Files with missing lines Patch % Lines
torch_geometric/visualization/graph.py 97.41% 3 Missing ⚠️
torch_geometric/explain/explanation.py 91.30% 2 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master   #10207      +/-   ##
==========================================
- Coverage   86.11%   85.41%   -0.70%     
==========================================
  Files         496      496              
  Lines       33655    33984     +329     
==========================================
+ Hits        28981    29029      +48     
- Misses       4674     4955     +281     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@BlazStojanovic BlazStojanovic changed the title Adding ability to visualise HeteroExplanation Adding helper to visualise HeteroExplanation Apr 18, 2025
@wsad1
Copy link
Member

wsad1 commented Apr 24, 2025

This looks good. Can we update this so that the Node type doesn't appear over each node, this makes it look cluttered, and we can identify node types by color anyway.

@BlazStojanovic
Copy link
Contributor Author

This looks good. Can we update this so that the Node type doesn't appear over each node, this makes it look cluttered, and we can identify node types by color anyway.

Fixed default for when no labels are provided to be empty node. Here's an example of what this looks like with and without provided node labels:
fig_no_labels.pdf
fig_w_labels.pdf

@wsad1 wsad1 requested a review from Copilot April 25, 2025 21:40
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR adds a new visualization helper to render heterogeneous explanation graphs, with node sizes/opacities and edge widths/opacities controlled by user-defined ranges. Key changes include:

  • Introducing visualize_hetero_graph and its internal helper _visualize_hetero_graph_via_networkx in graph.py.
  • Extending the ExplanationMixin with visualize_explanation_graph to support heterogeneous graphs.
  • Updating the package init.py and adding tests in test/explain/test_hetero_explanation.py.

Reviewed Changes

Copilot reviewed 4 out of 4 changed files in this pull request and generated no comments.

File Description
torch_geometric/visualization/graph.py Added new visualization functions for heterogeneous graphs with networkx backend.
torch_geometric/visualization/init.py Updated exports to include visualize_hetero_graph.
torch_geometric/explain/explanation.py Added visualize_explanation_graph method in ExplanationMixin.
test/explain/test_hetero_explanation.py Added tests for the new heterogeneous explanation visualization functionality.
Comments suppressed due to low confidence (1)

torch_geometric/visualization/graph.py:344

  • [nitpick] Consider renaming 'src_mapping' and 'dst_mapping' to 'src_index_mapping' and 'dst_index_mapping' respectively for improved clarity on what the mappings represent.
src_mapping = {

@wsad1 wsad1 changed the title Adding helper to visualise HeteroExplanation Added visualize_explanation_graph to HeteroExplanation. Apr 25, 2025
Copy link
Member

@wsad1 wsad1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@BlazStojanovic this looks good.
Lets fix the lint issues and merge.

@wsad1 wsad1 changed the title Added visualize_explanation_graph to HeteroExplanation. Added visualize_graph to HeteroExplanation. Apr 28, 2025
@wsad1 wsad1 enabled auto-merge (squash) April 28, 2025 19:43
@wsad1 wsad1 merged commit 32829eb into pyg-team:master Apr 28, 2025
17 checks passed
chrisn-pik pushed a commit to chrisn-pik/pytorch_geometric that referenced this pull request Jun 30, 2025
Hi all! I added a utility to visualize the `HeteroExplanation` at the
subgraph level. Visualization works in the following way:
- Node importance is rendered as node size and node opacity
- Edge importance is rendered as edge width and edge opacity

The user has the ability to specify the ranges of node size, edge width,
and both opacities. The conversion between `HeteroData` and networkx
graph is a bit tricky, would appreciate if reviewers are particularly
strict with that part of the code.

Here's an example of `HeteroExplanation` visualized in this way:
```
import os.path as osp
import tempfile

import pytest
import torch
from torch_geometric.data import HeteroData
from torch_geometric.explain import HeteroExplanation

import matplotlib.pyplot as plt


def generate_random_explanation(
    num_papers: int = 21,
    num_authors: int = 8,
    num_institutions: int = 4,
    num_edges_pa: int = 16,  # paper-author edges
    num_edges_ai: int = 8,  # author-institution edges
    num_edges_pp: int = 12,  # paper-paper edges
    seed: int = 42,
) -> HeteroExplanation:
    r"""Generates a random heterogeneous explanation for testing.

    Args:
        num_papers (int): Number of paper nodes.
        num_authors (int): Number of author nodes.
        num_institutions (int): Number of institution nodes.
        num_edges_pa (int): Number of paper-author edges.
        num_edges_ai (int): Number of author-institution edges.
        num_edges_pp (int): Number of paper-paper edges.
        seed (int): Random seed for reproducibility.
    """
    torch.manual_seed(seed)

    # Create a heterogeneous graph
    data = HeteroData()

    # Add paper nodes
    data['paper'].x = torch.randn(num_papers, 16)
    data['paper'].num_nodes = num_papers

    # Add author nodes
    data['author'].x = torch.randn(num_authors, 8)
    data['author'].num_nodes = num_authors

    # Add institution nodes
    data['institution'].x = torch.randn(num_institutions, 4)
    data['institution'].num_nodes = num_institutions

    # Add edges between papers and authors
    edge_index_pa = torch.randint(0, num_papers, (2, num_edges_pa))
    edge_index_pa[1] = torch.randint(0, num_authors, (num_edges_pa,))
    data['paper', 'written_by', 'author'].edge_index = edge_index_pa

    # Add edges between authors and institutions
    edge_index_ai = torch.randint(0, num_authors, (2, num_edges_ai))
    edge_index_ai[1] = torch.randint(0, num_institutions, (num_edges_ai,))
    data['author', 'affiliated_with', 'institution'].edge_index = edge_index_ai

    # Add edges between papers (citations)
    edge_index_pp = torch.randint(0, num_papers, (2, num_edges_pp))
    # Remove self-citations by ensuring source and target nodes are different
    # Generate source and target nodes separately to avoid self-citations
    edge_index_pp[1] = torch.randint(0, num_papers-1, (num_edges_pp,))  # Generate targets
    edge_index_pp[1] += (edge_index_pp[1] >= edge_index_pp[0]).long()  # Shift up if >= source
    mask = torch.ones(num_edges_pp, dtype=torch.bool)  # Keep all edges
    data['paper', 'cites', 'paper'].edge_index = edge_index_pp[:, mask]

    # Create explanation masks
    explanation = HeteroExplanation()
    explanation['paper'].x = data['paper'].x
    explanation['author'].x = data['author'].x
    explanation['institution'].x = data['institution'].x
    
    # Copy edge indices
    explanation['paper', 'written_by', 'author'].edge_index = data['paper', 'written_by', 'author'].edge_index
    explanation['author', 'affiliated_with', 'institution'].edge_index = data['author', 'affiliated_with', 'institution'].edge_index
    explanation['paper', 'cites', 'paper'].edge_index = data['paper', 'cites', 'paper'].edge_index

    # Add random node and edge masks
    explanation['paper'].node_mask = torch.rand(num_papers, 1)
    explanation['author'].node_mask = torch.rand(num_authors, 1)
    explanation['institution'].node_mask = torch.rand(num_institutions, 1)
    
    explanation['paper', 'written_by', 'author'].edge_mask = torch.rand(num_edges_pa)
    explanation['author', 'affiliated_with', 'institution'].edge_mask = torch.rand(num_edges_ai)
    explanation['paper', 'cites', 'paper'].edge_mask = torch.rand(num_edges_pp)

    return explanation


if __name__ == '__main__':
    # For interactive testing
    explanation = generate_random_explanation()
    
    # Create a single figure with spring layout
    fig, ax = plt.subplots(figsize=(12, 8))
    
    # Generate node names that match the number of nodes
    num_papers = explanation['paper'].num_nodes
    num_authors = explanation['author'].num_nodes
    num_institutions = explanation['institution'].num_nodes
    
    explanation.visualize_explanation_graph(
        node_labels={
            'paper': [f'Paper {i}' for i in range(num_papers)],
            'author': [f'Author {i}' for i in range(num_authors)],
            'institution': [f'Inst {i}' for i in range(num_institutions)]
        }
    )
    
    # Add title to the figure
    ax.set_title('Example Visualization with Spring Layout')
    
    # Show the figure
    plt.show()
```

<img width="1212" alt="Screenshot 2025-04-17 at 16 41 42"
src="https://github.com/user-attachments/assets/fe9ead88-7022-4978-9f33-b0f54e55610e"
/>

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Jinu Sunil <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants