Skip to content

Conversation

@Linnore
Copy link
Contributor

@Linnore Linnore commented Jul 17, 2024

As discussed here #7555 (reply in thread), the GAT paper mentioned using skip-connection to reach the reported metrics on PPI dataset. This PR adds the option of residual to enable skip-connection.

Specifically, we consider the following for adding residual:

  1. Define the actual_out_channels = heads * out_channels if concat else out_channels
  2. If in_channels == actual_out_channels, then residual is just x before the convolution. Otherwise, residual is a linear projection of the input x onto the dimension of actual_out_channels.

For bipartite case where the input is (x_src, x_dst), the above consideration will be applied for x_dst.

@Linnore
Copy link
Contributor Author

Linnore commented Jul 17, 2024

An example gat_ppi.py is added under examples/gat_ppi.py. Running the example with argument --residual and --no-residual can show that the implementation of residual is effective.

Note that for the case in_channels == actual_out_channels described above, there is no fully connected layer added as in examples/ppi.py to get the residual; instead, an identity mapping is used. Though having slower convergence speed, this is a reasonable design and is the same design as DGL's GATConv just for reference.

I am not sure whether the residual is correctly implemented for the bipartite input though. It's probably problematic when the bipartite input is (x_src, x_dst=None). In GATv2Conv, such input should be asserted and should not be allowed, but GATConv seems not having the same checking.

Feel free to comment!

@Linnore
Copy link
Contributor Author

Linnore commented Jul 21, 2024

The Linting check fails due to torch_geometric/config_mixin.py:85: error: "Type[DataclassInstance]" has no attribute "_target_" [attr-defined]. It is not related to this PR.

Copy link
Member

@zechengz zechengz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general LGTM. Maybe @rusty1s can take a look

@Linnore Linnore changed the title Added the residual option for GATConv and GATv2Conv. Add the residual option for GATConv and GATv2Conv. Jul 24, 2024
@Linnore Linnore changed the title Add the residual option for GATConv and GATv2Conv. Added the residual option for GATConv and GATv2Conv. Jul 24, 2024
@rusty1s rusty1s merged commit 2ab9971 into pyg-team:master Jul 29, 2024
Copy link
Contributor Author

@Linnore Linnore left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question about adding linear transformation when in_channels == total_out_channels

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants